Phoenix21 commited on
Commit
7997061
·
verified ·
1 Parent(s): 6705f79

CREATED PIPELINE RUNNABLE

Browse files
Files changed (1) hide show
  1. pipeline.py +39 -24
pipeline.py CHANGED
@@ -1,9 +1,15 @@
1
  # pipeline.py
 
2
  import os
3
  import getpass
4
  import pandas as pd
5
  from typing import Optional, Dict, Any
6
 
 
 
 
 
 
7
  from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.vectorstores import FAISS
@@ -12,7 +18,7 @@ from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  import litellm
14
 
15
- # For classification/refusal/tailor/cleaner logic
16
  from classification_chain import get_classification_chain
17
  from refusal_chain import get_refusal_chain
18
  from tailor_chain import get_tailor_chain
@@ -21,7 +27,7 @@ from cleaner_chain import get_cleaner_chain
21
  from langchain.llms.base import LLM
22
 
23
  ###############################################################################
24
- # 1) Environment Setup
25
  ###############################################################################
26
  if not os.environ.get("GEMINI_API_KEY"):
27
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
@@ -29,11 +35,11 @@ if not os.environ.get("GROQ_API_KEY"):
29
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
30
 
31
  ###############################################################################
32
- # 2) VectorStore Building/Loading
33
  ###############################################################################
34
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
35
  if os.path.exists(store_dir):
36
- print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
37
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
38
  vectorstore = FAISS.load_local(store_dir, embeddings)
39
  return vectorstore
@@ -64,7 +70,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
64
  return vectorstore
65
 
66
  ###############################################################################
67
- # 3) Build RAG chain for Gemini
68
  ###############################################################################
69
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
70
  class GeminiLangChainLLM(LLM):
@@ -87,7 +93,7 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
87
  return rag_chain
88
 
89
  ###############################################################################
90
- # 4) Init Sub-Chains
91
  ###############################################################################
92
  classification_chain = get_classification_chain()
93
  refusal_chain = get_refusal_chain()
@@ -95,15 +101,15 @@ tailor_chain = get_tailor_chain()
95
  cleaner_chain = get_cleaner_chain()
96
 
97
  ###############################################################################
98
- # 5) Build VectorStores & RAG
99
  ###############################################################################
 
 
100
  wellness_csv = "AIChatbot.csv"
101
  brand_csv = "BrandAI.csv"
102
  wellness_store_dir = "faiss_wellness_store"
103
  brand_store_dir = "faiss_brand_store"
104
 
105
- gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
106
-
107
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
108
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
109
 
@@ -122,35 +128,27 @@ def do_web_search(query: str) -> str:
122
  return response
123
 
124
  ###############################################################################
125
- # 6) Orchestrator: run_with_chain_context
126
  ###############################################################################
127
  def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
128
  """
129
- This function is called by the RunnableWithMessageHistory in my_memory_logic.py
130
- inputs: { "input": <user_query>, "chat_history": <list of messages> }
131
- Returns: { "answer": <final response> }
132
  """
133
-
134
- user_query = inputs["input"] # The user's new question
135
- # You can optionally use inputs.get("chat_history") if needed
136
  chat_history = inputs.get("chat_history", [])
137
 
138
- print("DEBUG: Starting run_with_chain_context...")
139
- print(f"User query: {user_query}")
140
  # 1) Classification
141
  class_result = classification_chain.invoke({"query": user_query})
142
  classification = class_result.get("text", "").strip()
143
- print("DEBUG: Classification =>", classification)
144
 
145
- # 2) If OutOfScope => refusal => tailor => return
146
  if classification == "OutOfScope":
147
  refusal_text = refusal_chain.run({})
148
  final_refusal = tailor_chain.run({"response": refusal_text})
149
  return {"answer": final_refusal.strip()}
150
 
151
- # 3) If Wellness => wellness RAG => if insufficient => web => unify => tailor
152
  if classification == "Wellness":
153
- # pass chat_history if your chain can use it
154
  rag_result = wellness_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
155
  csv_answer = rag_result["result"].strip()
156
  if not csv_answer:
@@ -161,11 +159,11 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
161
  web_answer = do_web_search(user_query)
162
  else:
163
  web_answer = ""
 
164
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
165
  final_answer = tailor_chain.run({"response": final_merged}).strip()
166
  return {"answer": final_answer}
167
 
168
- # 4) If Brand => brand RAG => tailor => return
169
  if classification == "Brand":
170
  rag_result = brand_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
171
  csv_answer = rag_result["result"].strip()
@@ -173,7 +171,24 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
173
  final_answer = tailor_chain.run({"response": final_merged}).strip()
174
  return {"answer": final_answer}
175
 
176
- # 5) fallback => refusal
177
  refusal_text = refusal_chain.run({})
178
  final_refusal = tailor_chain.run({"response": refusal_text}).strip()
179
  return {"answer": final_refusal}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # pipeline.py
2
+
3
  import os
4
  import getpass
5
  import pandas as pd
6
  from typing import Optional, Dict, Any
7
 
8
+ # (Optional) from langchain.schema import RunnableConfig
9
+ # If you have the latest "langchain_core", use from langchain_core.runnables.base import Runnable
10
+ # or from langchain.runnables.base import Runnable (depending on your version)
11
+ from langchain.runnables.base import Runnable
12
+
13
  from langchain.docstore.document import Document
14
  from langchain.embeddings import HuggingFaceEmbeddings
15
  from langchain.vectorstores import FAISS
 
18
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
19
  import litellm
20
 
21
+ # Classification/Refusal/Tailor/Cleaner
22
  from classification_chain import get_classification_chain
23
  from refusal_chain import get_refusal_chain
24
  from tailor_chain import get_tailor_chain
 
27
  from langchain.llms.base import LLM
28
 
29
  ###############################################################################
30
+ # 1) Environment keys
31
  ###############################################################################
32
  if not os.environ.get("GEMINI_API_KEY"):
33
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
 
35
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
36
 
37
  ###############################################################################
38
+ # 2) Build or load VectorStore
39
  ###############################################################################
40
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
41
  if os.path.exists(store_dir):
42
+ print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading from disk.")
43
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
44
  vectorstore = FAISS.load_local(store_dir, embeddings)
45
  return vectorstore
 
70
  return vectorstore
71
 
72
  ###############################################################################
73
+ # 3) Build RAG chain
74
  ###############################################################################
75
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
76
  class GeminiLangChainLLM(LLM):
 
93
  return rag_chain
94
 
95
  ###############################################################################
96
+ # 4) Initialize sub-chains
97
  ###############################################################################
98
  classification_chain = get_classification_chain()
99
  refusal_chain = get_refusal_chain()
 
101
  cleaner_chain = get_cleaner_chain()
102
 
103
  ###############################################################################
104
+ # 5) Build vectorstores & RAG
105
  ###############################################################################
106
+ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
107
+
108
  wellness_csv = "AIChatbot.csv"
109
  brand_csv = "BrandAI.csv"
110
  wellness_store_dir = "faiss_wellness_store"
111
  brand_store_dir = "faiss_brand_store"
112
 
 
 
113
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
114
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
115
 
 
128
  return response
129
 
130
  ###############################################################################
131
+ # 6) Orchestrator function: returns a dict => {"answer": "..."}
132
  ###############################################################################
133
  def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
134
  """
135
+ Called by the Runnable.
136
+ inputs: { "input": <user_query>, "chat_history": <list of messages> (optional) }
137
+ Output: { "answer": <final string> }
138
  """
139
+ user_query = inputs["input"]
 
 
140
  chat_history = inputs.get("chat_history", [])
141
 
 
 
142
  # 1) Classification
143
  class_result = classification_chain.invoke({"query": user_query})
144
  classification = class_result.get("text", "").strip()
 
145
 
 
146
  if classification == "OutOfScope":
147
  refusal_text = refusal_chain.run({})
148
  final_refusal = tailor_chain.run({"response": refusal_text})
149
  return {"answer": final_refusal.strip()}
150
 
 
151
  if classification == "Wellness":
 
152
  rag_result = wellness_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
153
  csv_answer = rag_result["result"].strip()
154
  if not csv_answer:
 
159
  web_answer = do_web_search(user_query)
160
  else:
161
  web_answer = ""
162
+
163
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
164
  final_answer = tailor_chain.run({"response": final_merged}).strip()
165
  return {"answer": final_answer}
166
 
 
167
  if classification == "Brand":
168
  rag_result = brand_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
169
  csv_answer = rag_result["result"].strip()
 
171
  final_answer = tailor_chain.run({"response": final_merged}).strip()
172
  return {"answer": final_answer}
173
 
174
+ # fallback
175
  refusal_text = refusal_chain.run({})
176
  final_refusal = tailor_chain.run({"response": refusal_text}).strip()
177
  return {"answer": final_refusal}
178
+
179
+
180
+ ###############################################################################
181
+ # 7) Build a "Runnable" wrapper so .with_listeners() works
182
+ ###############################################################################
183
+ from langchain.runnables.base import Runnable
184
+
185
+ class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
186
+ """
187
+ Wraps run_with_chain_context(...) in a Runnable
188
+ so that RunnableWithMessageHistory can attach listeners.
189
+ """
190
+ def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
191
+ return run_with_chain_context(input)
192
+
193
+ # Export an instance of PipelineRunnable for use in my_memory_logic.py
194
+ pipeline_runnable = PipelineRunnable()