Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import re | |
import time | |
import gc | |
from datetime import datetime | |
from typing import Optional, List, Dict, Any | |
from collections import OrderedDict | |
import pandas as pd | |
from pydantic import BaseModel, Field, ValidationError, validator | |
# NLTK for input validation | |
import nltk | |
from nltk.corpus import words | |
try: | |
english_words = set(words.words()) | |
except LookupError: | |
nltk.download('words') | |
english_words = set(words.words()) | |
# LangChain / Groq / LLM imports | |
from langchain_groq import ChatGroq | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA, LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain.docstore.document import Document | |
# Custom chain imports | |
from classification_chain import get_classification_chain | |
from refusal_chain import get_refusal_chain | |
from tailor_chain import get_tailor_chain | |
from cleaner_chain import get_cleaner_chain | |
from tailor_chain_wellnessBrand import get_tailor_chain_wellnessBrand | |
# Mistral moderation | |
from mistralai import Mistral | |
# Google Gemini LLM | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
# Web search | |
# from smolagents import DuckDuckGoSearchTool, ManagedAgent, HfApiModel, CodeAgent | |
# from openinference.instrumentation.smolagents import SmolagentsInstrumentor | |
# from phoenix.otel import register | |
# register() | |
# SmolagentsInstrumentor().instrument(skip_dep_check=True) | |
from smolagents import ( | |
CodeAgent, | |
DuckDuckGoSearchTool, | |
HfApiModel, | |
ToolCallingAgent, | |
VisitWebpageTool, | |
) | |
# Import new prompts | |
from prompts import ( | |
selfharm_prompt, frustration_prompt, ethical_conflict_prompt, | |
classification_prompt, refusal_prompt, tailor_prompt, cleaner_prompt | |
) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ------------------------------------------------------- | |
# Basic Models | |
# ------------------------------------------------------- | |
class QueryInput(BaseModel): | |
query: str = Field(..., min_length=1) | |
def check_query_is_string(cls, v): | |
if not isinstance(v, str): | |
raise ValueError("Query must be a valid string") | |
if not v.strip(): | |
raise ValueError("Query cannot be empty or whitespace") | |
return v.strip() | |
class ProcessingMetrics(BaseModel): | |
total_requests: int = 0 | |
cache_hits: int = 0 | |
errors: int = 0 | |
average_response_time: float = 0.0 | |
last_reset: Optional[datetime] = None | |
def update_metrics(self, processing_time: float, is_cache_hit: bool = False): | |
self.total_requests += 1 | |
if is_cache_hit: | |
self.cache_hits += 1 | |
self.average_response_time = ( | |
(self.average_response_time * (self.total_requests - 1) + processing_time) | |
/ self.total_requests | |
) | |
# ------------------------------------------------------- | |
# Mistral Moderation | |
# ------------------------------------------------------- | |
class ModerationResult(BaseModel): | |
is_safe: bool | |
categories: Dict[str, bool] | |
original_text: str | |
mistral_api_key = os.environ.get("MISTRAL_API_KEY") | |
client = Mistral(api_key=mistral_api_key) | |
def moderate_text(query: str) -> ModerationResult: | |
""" | |
Uses Mistral's moderation to detect unsafe content. | |
""" | |
try: | |
query_input = QueryInput(query=query) | |
response = client.classifiers.moderate_chat( | |
model="mistral-moderation-latest", | |
inputs=[{"role": "user", "content": query_input.query}] | |
) | |
is_safe = True | |
categories = {} | |
if hasattr(response, 'results') and response.results: | |
cats = response.results[0].categories | |
categories = { | |
"violence": cats.get("violence_and_threats", False), | |
"hate": cats.get("hate_and_discrimination", False), | |
"dangerous": cats.get("dangerous_and_criminal_content", False), | |
"selfharm": cats.get("selfharm", False) | |
} | |
is_safe = not any(categories.values()) | |
return ModerationResult( | |
is_safe=is_safe, | |
categories=categories, | |
original_text=query_input.query | |
) | |
except ValidationError as ve: | |
raise ValueError(f"Moderation input validation failed: {ve}") | |
except Exception as e: | |
raise RuntimeError(f"Moderation failed: {e}") | |
def compute_moderation_severity(mresult: ModerationResult) -> float: | |
severity = 0.0 | |
for flag in mresult.categories.values(): | |
if flag: | |
severity += 0.3 | |
return min(severity, 1.0) | |
# ------------------------------------------------------- | |
# Models | |
# ------------------------------------------------------- | |
GROQ_MODELS = { | |
"default": "llama3-70b-8192", | |
"classification": "mixtral-8x7b-32768", | |
"moderation": "mistral-moderation-latest", | |
"combination": "llama-3.3-70b-versatile" | |
} | |
MAX_RETRIES = 3 | |
RATE_LIMIT_REQUESTS = 60 | |
CACHE_SIZE_LIMIT = 1000 | |
# Google Gemini (primary) | |
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
gemini_llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
temperature=0.5, | |
max_retries=2, | |
google_api_key=GEMINI_API_KEY | |
) | |
# Fallback | |
fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", "YOUR_GROQ_API_KEY") | |
groq_fallback_llm = ChatGroq( | |
model=GROQ_MODELS["default"], | |
temperature=0.7, | |
groq_api_key=fallback_groq_api_key, | |
max_tokens=2048 | |
) | |
# ------------------------------------------------------- | |
# Rate-limit & Cache | |
# ------------------------------------------------------- | |
def handle_rate_limiting(state: "PipelineState") -> bool: | |
current_time = time.time() | |
one_min_ago = current_time - 60 | |
state.request_timestamps = [t for t in state.request_timestamps if t > one_min_ago] | |
if len(state.request_timestamps) >= RATE_LIMIT_REQUESTS: | |
return False | |
state.request_timestamps.append(current_time) | |
return True | |
def manage_cache(state: "PipelineState", query: str, response: str = None) -> Optional[str]: | |
cache_key = query.strip().lower() | |
if response is None: | |
return state.cache.get(cache_key) | |
if cache_key in state.cache: | |
state.cache.move_to_end(cache_key) | |
state.cache[cache_key] = response | |
if len(state.cache) > CACHE_SIZE_LIMIT: | |
state.cache.popitem(last=False) | |
return None | |
def create_error_response(error_type: str, details: str = "") -> str: | |
templates = { | |
"validation": "I couldn't process your query: {details}", | |
"processing": "I encountered an error while processing: {details}", | |
"rate_limit": "Too many requests. Please try again soon.", | |
"general": "Apologies, but something went wrong." | |
} | |
return templates.get(error_type, templates["general"]).format(details=details) | |
# ------------------------------------------------------- | |
# Web Search | |
# ------------------------------------------------------- | |
web_search_cache: Dict[str, str] = {} | |
def store_websearch_result(query: str, result: str): | |
web_search_cache[query.strip().lower()] = result | |
def retrieve_websearch_result(query: str) -> Optional[str]: | |
return web_search_cache.get(query.strip().lower()) | |
def do_web_search(query: str) -> str: | |
try: | |
cached = retrieve_websearch_result(query) | |
if cached: | |
logger.info("Using cached web search result.") | |
return cached | |
logger.info("Performing a new web search for: '%s'", query) | |
# model = HfApiModel() | |
# search_tool = DuckDuckGoSearchTool() | |
# web_agent = CodeAgent(tools=[search_tool], model=model) | |
# managed_web_agent = ManagedAgent( | |
# agent=web_agent, | |
# name="web_search", | |
# description="Runs a web search. Provide your query." | |
# ) | |
search_agent = ToolCallingAgent( | |
tools=[DuckDuckGoSearchTool(), VisitWebpageTool()], | |
model=HfApiModel(), | |
name="search_agent", | |
description="This is an agent that can do web search.", | |
) | |
manager_agent = CodeAgent( | |
tools=[], | |
model=model, | |
managed_agents=[managed_web_agent] | |
) | |
new_search_result = manager_agent.run(f"Search for information about: {query}") | |
store_websearch_result(query, new_search_result) | |
return str(new_search_result).strip() | |
except Exception as e: | |
logger.error(f"Web search failed: {e}") | |
return "" | |
def is_greeting(query: str) -> bool: | |
""" | |
Returns True if the query is a greeting. This check is designed to be | |
lenient enough to catch common greetings even with minor spelling mistakes | |
or punctuation. | |
""" | |
# Define a set of common greeting words (you can add variants or use fuzzy matching if needed) | |
greetings = {"hello", "hi", "hey", "hii", "hola", "greetings"} | |
# Remove punctuation and extra whitespace, and lower the case. | |
cleaned = re.sub(r'[^\w\s]', '', query).strip().lower() | |
# Split the cleaned text into words. | |
words_in_query = set(cleaned.split()) | |
# Return True if any of the greeting words are in the query. | |
return not words_in_query.isdisjoint(greetings) | |
# ------------------------------------------------------- | |
# Vector Stores & RAG | |
# ------------------------------------------------------- | |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: | |
if os.path.exists(store_dir): | |
logger.info(f"Loading existing FAISS store from {store_dir}") | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
) | |
return FAISS.load_local(store_dir, embeddings) | |
else: | |
logger.info(f"Building new FAISS store from {csv_path}") | |
df = pd.read_csv(csv_path) | |
df = df.loc[:, ~df.columns.str.contains('^Unnamed')] | |
df.columns = df.columns.str.strip() | |
if "Answer" in df.columns: | |
df.rename(columns={"Answer": "Answers"}, inplace=True) | |
if "Question " in df.columns and "Question" not in df.columns: | |
df.rename(columns={"Question ": "Question"}, inplace=True) | |
if "Question" not in df.columns or "Answers" not in df.columns: | |
raise ValueError("CSV must have 'Question' and 'Answers' columns.") | |
docs = [] | |
for _, row in df.iterrows(): | |
question_text = str(row["Question"]).strip() | |
ans = str(row["Answers"]).strip() | |
doc = Document(page_content=ans, metadata={"question": question_text}) | |
docs.append(doc) | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
) | |
vectorstore = FAISS.from_documents(docs, embedding=embeddings) | |
vectorstore.save_local(store_dir) | |
return vectorstore | |
#rag chain is for wellness | |
def build_rag_chain(vectorstore: FAISS, llm) -> RetrievalQA: | |
prompt = PromptTemplate( | |
template=""" | |
[INST] You are a helpful AI specialized in Wellness & Well-being topics. | |
Please use the following context to provide a detailed, helpful answer. | |
If the context doesn't fully address the question, acknowledge this and provide the best possible information. | |
Context: {context} | |
Question: {question} | |
Guidelines for responses: | |
1. Start with a clear introduction establishing the wellness topic | |
2. Present information using numbered lists for actionable steps | |
3. Include evidence-based examples and practical applications | |
4. Provide specific, implementable suggestions | |
5. End with clear takeaways or next steps | |
Additional considerations: | |
- All recommendations should be grounded in current wellness research | |
- Focus on sustainable, long-term lifestyle modifications | |
- Acknowledge individual differences in wellness journeys | |
- Emphasize holistic approaches to health and well-being | |
- Include relevant studies or research when applicable | |
[/INST] | |
""", | |
input_variables=["context", "question"] | |
) | |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs={ | |
"prompt": prompt, | |
"verbose": False, | |
"document_variable_name": "context" | |
} | |
) | |
return chain | |
#rag chain to is for brand | |
def build_rag_chain2(vectorstore: FAISS, llm) -> RetrievalQA: | |
prompt = PromptTemplate( | |
template=""" | |
[INST] You are the Brand Strategy Specialist for Daily Wellness AI. | |
Please provide detailed, strategic guidance specific to Daily Wellness AI's brand development and market positioning. | |
If additional context is needed, acknowledge this while maintaining focus on our company's objectives. | |
Context: {context} | |
Question: {question} | |
Guidelines for Daily Wellness AI specific responses: | |
1. Begin with addressing specific Daily Wellness AI brand challenges or opportunities | |
2. Align recommendations with our core mission of democratizing personalized wellness | |
3. Include competitive analysis within the AI wellness space | |
4. Provide actionable steps that reflect our technological capabilities | |
5. Conclude with KPIs aligned with our growth objectives | |
Brand Pillars to Address: | |
- AI-Driven Personalization | |
- Scientific Credibility | |
- User-Centric Design | |
- Innovation Leadership | |
- Community Building | |
[/INST] | |
""", | |
input_variables=["context", "question"] | |
) | |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs={ | |
"prompt": prompt, | |
"verbose": False, | |
"document_variable_name": "context" | |
} | |
) | |
return chain | |
# ------------------------------------------------------- | |
# PipelineState | |
# ------------------------------------------------------- | |
class PipelineState: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(PipelineState, cls).__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if self._initialized: | |
return | |
self._initialize() | |
def _initialize(self): | |
try: | |
self.metrics = ProcessingMetrics() | |
self.error_count = 0 | |
self.request_timestamps = [] | |
self.cache = OrderedDict() | |
self._setup_chains() | |
self._initialized = True | |
self.metrics.last_reset = datetime.now() | |
logger.info("Pipeline state initialized successfully.") | |
except Exception as e: | |
logger.error(f"Failed to initialize pipeline: {e}") | |
raise RuntimeError("Pipeline initialization failed.") from e | |
def _setup_chains(self): | |
# Existing custom chains | |
self.tailor_chainWellnessBrand = get_tailor_chain_wellnessBrand() | |
self.classification_chain = get_classification_chain() | |
self.refusal_chain = get_refusal_chain() | |
self.tailor_chain = get_tailor_chain() | |
self.cleaner_chain = get_cleaner_chain() | |
# Specialized chain for self-harm | |
from prompts import selfharm_prompt | |
self.self_harm_chain = LLMChain(llm=gemini_llm, prompt=selfharm_prompt, verbose=False) | |
# NEW: chain for frustration/harsh queries | |
from prompts import frustration_prompt | |
self.frustration_chain = LLMChain(llm=gemini_llm, prompt=frustration_prompt, verbose=False) | |
# NEW: chain for ethical conflict queries | |
from prompts import ethical_conflict_prompt | |
self.ethical_conflict_chain = LLMChain(llm=gemini_llm, prompt=ethical_conflict_prompt, verbose=False) | |
# Build brand & wellness vectorstores | |
brand_csv = "BrandAI.csv" | |
brand_store = "faiss_brand_store" | |
wellness_csv = "AIChatbot.csv" | |
wellness_store = "faiss_wellness_store" | |
brand_vs = build_or_load_vectorstore(brand_csv, brand_store) | |
wellness_vs = build_or_load_vectorstore(wellness_csv, wellness_store) | |
# Default LLM & fallback | |
self.gemini_llm = gemini_llm | |
self.groq_fallback_llm = groq_fallback_llm | |
self.brand_rag_chain = build_rag_chain2(brand_vs, self.gemini_llm) | |
self.wellness_rag_chain = build_rag_chain(wellness_vs, self.gemini_llm) | |
self.brand_rag_chain_fallback = build_rag_chain2(brand_vs, self.groq_fallback_llm) | |
self.wellness_rag_chain_fallback = build_rag_chain(wellness_vs, self.groq_fallback_llm) | |
def handle_error(self, error: Exception) -> bool: | |
self.error_count += 1 | |
self.metrics.errors += 1 | |
if self.error_count >= MAX_RETRIES: | |
logger.warning("Max error reached, resetting pipeline.") | |
self.reset() | |
return False | |
return True | |
def reset(self): | |
try: | |
logger.info("Resetting pipeline state.") | |
old_metrics = self.metrics | |
self._initialized = False | |
self.__init__() | |
self.metrics = old_metrics | |
self.metrics.last_reset = datetime.now() | |
self.error_count = 0 | |
gc.collect() | |
logger.info("Pipeline state reset done.") | |
except Exception as e: | |
logger.error(f"Reset pipeline failed: {e}") | |
raise RuntimeError("Failed to reset pipeline.") | |
def get_metrics(self) -> Dict[str, Any]: | |
uptime = (datetime.now() - self.metrics.last_reset).total_seconds() / 3600 | |
return { | |
"total_requests": self.metrics.total_requests, | |
"cache_hits": self.metrics.cache_hits, | |
"error_rate": self.metrics.errors / max(self.metrics.total_requests, 1), | |
"average_response_time": self.metrics.average_response_time, | |
"uptime_hours": uptime | |
} | |
def update_metrics(self, start_time: float, is_cache_hit: bool = False): | |
duration = time.time() - start_time | |
self.metrics.update_metrics(duration, is_cache_hit) | |
pipeline_state = PipelineState() | |
# ------------------------------------------------------- | |
# Helper checks: detect aggression or ethical conflict | |
# ------------------------------------------------------- | |
def is_aggressive_or_harsh(query: str) -> bool: | |
""" | |
Very naive check: If user is insulting AI, complaining about worthless answers, etc. | |
You can refine with better logic or a small LLM classifier. | |
""" | |
triggers = ["useless", "worthless", "you cannot do anything", "so bad at answering"] | |
for t in triggers: | |
if t in query.lower(): | |
return True | |
return False | |
def is_ethical_conflict(query: str) -> bool: | |
""" | |
Check if user is asking about lying, revenge, or other moral dilemmas. | |
You can expand or refine as needed. | |
""" | |
ethics_keywords = ["should i lie", "should i cheat", "revenge", "get back at", "hurt them back"] | |
q_lower = query.lower() | |
return any(k in q_lower for k in ethics_keywords) | |
# ------------------------------------------------------- | |
# Main Pipeline | |
# ------------------------------------------------------- | |
def run_with_chain(query: str) -> str: | |
""" | |
Overall flow: | |
1) Validate & rate-limit | |
2) Mistral moderation => | |
- If self-harm => self_harm_chain | |
- If hate => refusal | |
- If violence/dangerous => we STILL produce a guided response (ethics) unless it's extreme | |
3) If not refused, check if query is aggression/ethical => route to chain | |
4) Otherwise classify => brand/wellness/out-of-scope => RAG => tailor | |
""" | |
start_time = time.time() | |
try: | |
# 1) Validate | |
if not query or query.strip() == "": | |
return create_error_response("validation", "Empty query.") | |
if len(query.strip()) < 2: | |
return create_error_response("validation", "Too short.") | |
words_in_text = re.findall(r'\b\w+\b', query.lower()) | |
if not any(w in english_words for w in words_in_text): | |
return create_error_response("validation", "Unclear words.") | |
if len(query) > 500: | |
return create_error_response("validation", "Too long (>500).") | |
if not handle_rate_limiting(pipeline_state): | |
return create_error_response("rate_limit") | |
# New: Check if the query is a greeting | |
if is_greeting(query): | |
greeting_response = "Hello there!! Welcome to DailyWellness, How may I assist you today?" | |
manage_cache(pipeline_state, query, greeting_response) | |
pipeline_state.update_metrics(start_time) | |
return greeting_response | |
if not handle_rate_limiting(pipeline_state): | |
return create_error_response("rate_limit") | |
# Cache check | |
cached = manage_cache(pipeline_state, query) | |
if cached: | |
pipeline_state.update_metrics(start_time, is_cache_hit=True) | |
return cached | |
# 2) Mistral moderation | |
try: | |
mod_res = moderate_text(query) | |
severity = compute_moderation_severity(mod_res) | |
# If self-harm => supportive | |
if mod_res.categories.get("selfharm", False): | |
logger.info("Self-harm flagged => providing supportive chain response.") | |
selfharm_resp = pipeline_state.self_harm_chain.run({"query": query}) | |
final_tailored = pipeline_state.tailor_chain.run({"response": selfharm_resp}).strip() | |
manage_cache(pipeline_state, query, final_tailored) | |
pipeline_state.update_metrics(start_time) | |
return final_tailored | |
# If hate => refuse | |
if mod_res.categories.get("hate", False): | |
logger.info("Hate content => refusal.") | |
refusal_resp = pipeline_state.refusal_chain.run({"topic": "moderation_flagged"}) | |
manage_cache(pipeline_state, query, refusal_resp) | |
pipeline_state.update_metrics(start_time) | |
return refusal_resp | |
# If "dangerous" or "violence" is flagged, we might still want to | |
# provide a "non-violent advice" approach (like revenge queries). | |
# So we won't automatically refuse. We'll rely on the | |
# is_ethical_conflict() check below. | |
except Exception as e: | |
logger.error(f"Moderation error: {e}") | |
severity = 0.0 | |
# 3) Check for aggression or ethical conflict | |
if is_aggressive_or_harsh(query): | |
logger.info("Detected harsh/aggressive language => frustration_chain.") | |
frustration_resp = pipeline_state.frustration_chain.run({"query": query}) | |
final_tailored = pipeline_state.tailor_chain.run({"response": frustration_resp}).strip() | |
manage_cache(pipeline_state, query, final_tailored) | |
pipeline_state.update_metrics(start_time) | |
return final_tailored | |
if is_ethical_conflict(query): | |
logger.info("Detected ethical dilemma => ethical_conflict_chain.") | |
ethical_resp = pipeline_state.ethical_conflict_chain.run({"query": query}) | |
final_tailored = pipeline_state.tailor_chain.run({"response": ethical_resp}).strip() | |
manage_cache(pipeline_state, query, final_tailored) | |
pipeline_state.update_metrics(start_time) | |
return final_tailored | |
# 4) Standard path: classification => brand/wellness/out-of-scope | |
try: | |
class_out = pipeline_state.classification_chain.run({"query": query}) | |
classification = class_out.strip().lower() | |
except Exception as e: | |
logger.error(f"Classification error: {e}") | |
if not pipeline_state.handle_error(e): | |
return create_error_response("processing", "Classification error.") | |
return create_error_response("processing") | |
if classification in ["outofscope", "out_of_scope"]: | |
try: | |
# Politely refuse if truly out-of-scope | |
refusal_text = pipeline_state.refusal_chain.run({"topic": query}) | |
tailored_refusal = pipeline_state.tailor_chain.run({"response": refusal_text}).strip() | |
manage_cache(pipeline_state, query, tailored_refusal) | |
pipeline_state.update_metrics(start_time) | |
return tailored_refusal | |
except Exception as e: | |
logger.error(f"Refusal chain error: {e}") | |
if not pipeline_state.handle_error(e): | |
return create_error_response("processing", "Refusal error.") | |
return create_error_response("processing") | |
# brand vs wellness | |
if classification == "brand": | |
rag_chain_main = pipeline_state.brand_rag_chain | |
rag_chain_fallback = pipeline_state.brand_rag_chain_fallback | |
else: | |
rag_chain_main = pipeline_state.wellness_rag_chain | |
rag_chain_fallback = pipeline_state.wellness_rag_chain_fallback | |
# RAG with fallback | |
try: | |
try: | |
rag_output = rag_chain_main({"query": query}) | |
except Exception as e_main: | |
if "resource exhausted" in str(e_main).lower(): | |
logger.warning("Gemini resource exhausted. Falling back to Groq.") | |
rag_output = rag_chain_fallback({"query": query}) | |
else: | |
raise | |
if isinstance(rag_output, dict) and "result" in rag_output: | |
csv_ans = rag_output["result"].strip() | |
else: | |
csv_ans = str(rag_output).strip() | |
# If not enough => web | |
if "not enough context" in csv_ans.lower() or len(csv_ans) < 40: | |
logger.info("Insufficient RAG => web search.") | |
web_info = do_web_search(query) | |
if web_info: | |
csv_ans += f"\n\nAdditional info:\n{web_info}" | |
except Exception as e: | |
logger.error(f"RAG error: {e}") | |
if not pipeline_state.handle_error(e): | |
return create_error_response("processing", "RAG error.") | |
return create_error_response("processing") | |
# Tailor final | |
try: | |
final_tailored = pipeline_state.tailor_chainWellnessBrand.run({"response": csv_ans}).strip() | |
if severity > 0.5: | |
final_tailored += "\n\n(Please note: This may involve sensitive content.)" | |
manage_cache(pipeline_state, query, final_tailored) | |
pipeline_state.update_metrics(start_time) | |
return final_tailored | |
except Exception as e: | |
logger.error(f"Tailor chain error: {e}") | |
if not pipeline_state.handle_error(e): | |
return create_error_response("processing", "Tailoring error.") | |
return create_error_response("processing") | |
except Exception as e: | |
logger.error(f"Critical error in run_with_chain: {e}") | |
pipeline_state.metrics.errors += 1 | |
return create_error_response("general") | |
# ------------------------------------------------------- | |
# Health & Utility | |
# ------------------------------------------------------- | |
# def reset_pipeline(): | |
# try: | |
# pipeline_state.reset() | |
# return {"status": "success", "message": "Pipeline reset successful"} | |
# except Exception as e: | |
# logger.error(f"Reset pipeline error: {e}") | |
# return {"status": "error", "message": str(e)} | |
# def get_pipeline_health() -> Dict[str, Any]: | |
# try: | |
# stats = pipeline_state.get_metrics() | |
# healthy = stats["error_rate"] < 0.1 | |
# return { | |
# **stats, | |
# "is_healthy": healthy, | |
# "status": "healthy" if healthy else "degraded" | |
# } | |
# except Exception as e: | |
# logger.error(f"Health check error: {e}") | |
# return {"is_healthy": False, "status": "error", "error": str(e)} | |
# def health_check() -> Dict[str, Any]: | |
# try: | |
# _ = run_with_chain("Test query for pipeline health check.") | |
# return { | |
# "status": "ok", | |
# "timestamp": datetime.now().isoformat(), | |
# "metrics": get_pipeline_health() | |
# } | |
# except Exception as e: | |
# return { | |
# "status": "error", | |
# "timestamp": datetime.now().isoformat(), | |
# "error": str(e) | |
# } | |
logger.info("Pipeline initialization complete!") | |