Update pipeline.py
Browse files- pipeline.py +9 -17
pipeline.py
CHANGED
|
@@ -5,8 +5,14 @@ import getpass
|
|
| 5 |
import pandas as pd
|
| 6 |
from typing import Optional, Dict, Any
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from langchain.docstore.document import Document
|
| 12 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
@@ -16,7 +22,6 @@ from langchain.chains import RetrievalQA
|
|
| 16 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 17 |
import litellm
|
| 18 |
|
| 19 |
-
# Classification/Refusal/Tailor/Cleaner
|
| 20 |
from classification_chain import get_classification_chain
|
| 21 |
from refusal_chain import get_refusal_chain
|
| 22 |
from tailor_chain import get_tailor_chain
|
|
@@ -129,15 +134,10 @@ def do_web_search(query: str) -> str:
|
|
| 129 |
# 6) Orchestrator function: returns a dict => {"answer": "..."}
|
| 130 |
###############################################################################
|
| 131 |
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
| 132 |
-
"""
|
| 133 |
-
Called by the Runnable.
|
| 134 |
-
inputs: { "input": <user_query>, "chat_history": <list of messages> (optional) }
|
| 135 |
-
Output: { "answer": <final string> }
|
| 136 |
-
"""
|
| 137 |
user_query = inputs["input"]
|
| 138 |
chat_history = inputs.get("chat_history", [])
|
| 139 |
|
| 140 |
-
#
|
| 141 |
class_result = classification_chain.invoke({"query": user_query})
|
| 142 |
classification = class_result.get("text", "").strip()
|
| 143 |
|
|
@@ -157,7 +157,6 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
| 157 |
web_answer = do_web_search(user_query)
|
| 158 |
else:
|
| 159 |
web_answer = ""
|
| 160 |
-
|
| 161 |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
|
| 162 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
| 163 |
return {"answer": final_answer}
|
|
@@ -169,7 +168,6 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
| 169 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
| 170 |
return {"answer": final_answer}
|
| 171 |
|
| 172 |
-
# fallback
|
| 173 |
refusal_text = refusal_chain.run({})
|
| 174 |
final_refusal = tailor_chain.run({"response": refusal_text}).strip()
|
| 175 |
return {"answer": final_refusal}
|
|
@@ -177,14 +175,8 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
| 177 |
###############################################################################
|
| 178 |
# 7) Build a "Runnable" wrapper so .with_listeners() works
|
| 179 |
###############################################################################
|
| 180 |
-
|
| 181 |
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
|
| 182 |
-
"""
|
| 183 |
-
Wraps run_with_chain_context(...) in a Runnable
|
| 184 |
-
so that RunnableWithMessageHistory can attach listeners.
|
| 185 |
-
"""
|
| 186 |
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
|
| 187 |
return run_with_chain_context(input)
|
| 188 |
|
| 189 |
-
# Export an instance of PipelineRunnable for use in my_memory_logic.py
|
| 190 |
pipeline_runnable = PipelineRunnable()
|
|
|
|
| 5 |
import pandas as pd
|
| 6 |
from typing import Optional, Dict, Any
|
| 7 |
|
| 8 |
+
# Conditional import for Runnable from available locations
|
| 9 |
+
try:
|
| 10 |
+
from langchain_core.runnables.base import Runnable
|
| 11 |
+
except ImportError:
|
| 12 |
+
try:
|
| 13 |
+
from langchain.runnables.base import Runnable
|
| 14 |
+
except ImportError:
|
| 15 |
+
raise ImportError("Cannot find Runnable class. Please upgrade LangChain or check your installation.")
|
| 16 |
|
| 17 |
from langchain.docstore.document import Document
|
| 18 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
|
| 22 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 23 |
import litellm
|
| 24 |
|
|
|
|
| 25 |
from classification_chain import get_classification_chain
|
| 26 |
from refusal_chain import get_refusal_chain
|
| 27 |
from tailor_chain import get_tailor_chain
|
|
|
|
| 134 |
# 6) Orchestrator function: returns a dict => {"answer": "..."}
|
| 135 |
###############################################################################
|
| 136 |
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
user_query = inputs["input"]
|
| 138 |
chat_history = inputs.get("chat_history", [])
|
| 139 |
|
| 140 |
+
# Classification step
|
| 141 |
class_result = classification_chain.invoke({"query": user_query})
|
| 142 |
classification = class_result.get("text", "").strip()
|
| 143 |
|
|
|
|
| 157 |
web_answer = do_web_search(user_query)
|
| 158 |
else:
|
| 159 |
web_answer = ""
|
|
|
|
| 160 |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
|
| 161 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
| 162 |
return {"answer": final_answer}
|
|
|
|
| 168 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
| 169 |
return {"answer": final_answer}
|
| 170 |
|
|
|
|
| 171 |
refusal_text = refusal_chain.run({})
|
| 172 |
final_refusal = tailor_chain.run({"response": refusal_text}).strip()
|
| 173 |
return {"answer": final_refusal}
|
|
|
|
| 175 |
###############################################################################
|
| 176 |
# 7) Build a "Runnable" wrapper so .with_listeners() works
|
| 177 |
###############################################################################
|
|
|
|
| 178 |
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
|
| 180 |
return run_with_chain_context(input)
|
| 181 |
|
|
|
|
| 182 |
pipeline_runnable = PipelineRunnable()
|