Spaces:
Build error
Build error
| import os | |
| import logging | |
| import re | |
| import time | |
| import gc | |
| from datetime import datetime | |
| from typing import Optional, List, Dict, Any | |
| from collections import OrderedDict | |
| import pandas as pd | |
| from pydantic import BaseModel, Field, ValidationError, validator | |
| # NLTK for input validation | |
| import nltk | |
| from nltk.corpus import words | |
| try: | |
| english_words = set(words.words()) | |
| except LookupError: | |
| nltk.download('words') | |
| english_words = set(words.words()) | |
| # LangChain / Groq / LLM imports | |
| from langchain_groq import ChatGroq | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA, LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.docstore.document import Document | |
| # Custom chain imports | |
| from classification_chain import get_classification_chain | |
| from refusal_chain import get_refusal_chain | |
| from tailor_chain import get_tailor_chain | |
| from cleaner_chain import get_cleaner_chain | |
| from tailor_chain_wellnessBrand import get_tailor_chain_wellnessBrand | |
| # Mistral moderation | |
| from mistralai import Mistral | |
| # Google Gemini LLM | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| # Web search | |
| # from smolagents import DuckDuckGoSearchTool, ManagedAgent, HfApiModel, CodeAgent | |
| # from openinference.instrumentation.smolagents import SmolagentsInstrumentor | |
| # from phoenix.otel import register | |
| # register() | |
| # SmolagentsInstrumentor().instrument(skip_dep_check=True) | |
| from smolagents import ( | |
| CodeAgent, | |
| DuckDuckGoSearchTool, | |
| HfApiModel, | |
| ToolCallingAgent, | |
| VisitWebpageTool, | |
| ) | |
| # Import new prompts | |
| from prompts import ( | |
| selfharm_prompt, frustration_prompt, ethical_conflict_prompt, | |
| classification_prompt, refusal_prompt, tailor_prompt, cleaner_prompt | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------------------------------- | |
| # Basic Models | |
| # ------------------------------------------------------- | |
| class QueryInput(BaseModel): | |
| query: str = Field(..., min_length=1) | |
| def check_query_is_string(cls, v): | |
| if not isinstance(v, str): | |
| raise ValueError("Query must be a valid string") | |
| if not v.strip(): | |
| raise ValueError("Query cannot be empty or whitespace") | |
| return v.strip() | |
| class ProcessingMetrics(BaseModel): | |
| total_requests: int = 0 | |
| cache_hits: int = 0 | |
| errors: int = 0 | |
| average_response_time: float = 0.0 | |
| last_reset: Optional[datetime] = None | |
| def update_metrics(self, processing_time: float, is_cache_hit: bool = False): | |
| self.total_requests += 1 | |
| if is_cache_hit: | |
| self.cache_hits += 1 | |
| self.average_response_time = ( | |
| (self.average_response_time * (self.total_requests - 1) + processing_time) | |
| / self.total_requests | |
| ) | |
| # ------------------------------------------------------- | |
| # Mistral Moderation | |
| # ------------------------------------------------------- | |
| class ModerationResult(BaseModel): | |
| is_safe: bool | |
| categories: Dict[str, bool] | |
| original_text: str | |
| mistral_api_key = os.environ.get("MISTRAL_API_KEY") | |
| client = Mistral(api_key=mistral_api_key) | |
| def moderate_text(query: str) -> ModerationResult: | |
| """ | |
| Uses Mistral's moderation to detect unsafe content. | |
| """ | |
| try: | |
| query_input = QueryInput(query=query) | |
| response = client.classifiers.moderate_chat( | |
| model="mistral-moderation-latest", | |
| inputs=[{"role": "user", "content": query_input.query}] | |
| ) | |
| is_safe = True | |
| categories = {} | |
| if hasattr(response, 'results') and response.results: | |
| cats = response.results[0].categories | |
| categories = { | |
| "violence": cats.get("violence_and_threats", False), | |
| "hate": cats.get("hate_and_discrimination", False), | |
| "dangerous": cats.get("dangerous_and_criminal_content", False), | |
| "selfharm": cats.get("selfharm", False) | |
| } | |
| is_safe = not any(categories.values()) | |
| return ModerationResult( | |
| is_safe=is_safe, | |
| categories=categories, | |
| original_text=query_input.query | |
| ) | |
| except ValidationError as ve: | |
| raise ValueError(f"Moderation input validation failed: {ve}") | |
| except Exception as e: | |
| raise RuntimeError(f"Moderation failed: {e}") | |
| def compute_moderation_severity(mresult: ModerationResult) -> float: | |
| severity = 0.0 | |
| for flag in mresult.categories.values(): | |
| if flag: | |
| severity += 0.3 | |
| return min(severity, 1.0) | |
| # ------------------------------------------------------- | |
| # Models | |
| # ------------------------------------------------------- | |
| GROQ_MODELS = { | |
| "default": "llama3-70b-8192", | |
| "classification": "mixtral-8x7b-32768", | |
| "moderation": "mistral-moderation-latest", | |
| "combination": "llama-3.3-70b-versatile" | |
| } | |
| MAX_RETRIES = 3 | |
| RATE_LIMIT_REQUESTS = 60 | |
| CACHE_SIZE_LIMIT = 1000 | |
| # Google Gemini (primary) | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| gemini_llm = ChatGoogleGenerativeAI( | |
| model="gemini-1.0-pro", | |
| temperature=0.5, | |
| max_retries=2, | |
| google_api_key=GEMINI_API_KEY | |
| ) | |
| # Fallback | |
| fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", "YOUR_GROQ_API_KEY") | |
| groq_fallback_llm = ChatGroq( | |
| model=GROQ_MODELS["default"], | |
| temperature=0.7, | |
| groq_api_key=fallback_groq_api_key, | |
| max_tokens=2048 | |
| ) | |
| # ------------------------------------------------------- | |
| # Rate-limit & Cache | |
| # ------------------------------------------------------- | |
| def handle_rate_limiting(state: "PipelineState") -> bool: | |
| current_time = time.time() | |
| one_min_ago = current_time - 60 | |
| state.request_timestamps = [t for t in state.request_timestamps if t > one_min_ago] | |
| if len(state.request_timestamps) >= RATE_LIMIT_REQUESTS: | |
| return False | |
| state.request_timestamps.append(current_time) | |
| return True | |
| def manage_cache(state: "PipelineState", query: str, response: str = None) -> Optional[str]: | |
| cache_key = query.strip().lower() | |
| if response is None: | |
| return state.cache.get(cache_key) | |
| if cache_key in state.cache: | |
| state.cache.move_to_end(cache_key) | |
| state.cache[cache_key] = response | |
| if len(state.cache) > CACHE_SIZE_LIMIT: | |
| state.cache.popitem(last=False) | |
| return None | |
| def create_error_response(error_type: str, details: str = "") -> str: | |
| templates = { | |
| "validation": "I couldn't process your query: {details}", | |
| "processing": "I encountered an error while processing: {details}", | |
| "rate_limit": "Too many requests. Please try again soon.", | |
| "general": "Apologies, but something went wrong." | |
| } | |
| return templates.get(error_type, templates["general"]).format(details=details) | |
| # ------------------------------------------------------- | |
| # Web Search | |
| # ------------------------------------------------------- | |
| web_search_cache: Dict[str, str] = {} | |
| def store_websearch_result(query: str, result: str): | |
| web_search_cache[query.strip().lower()] = result | |
| def retrieve_websearch_result(query: str) -> Optional[str]: | |
| return web_search_cache.get(query.strip().lower()) | |
| def do_web_search(query: str) -> str: | |
| try: | |
| cached = retrieve_websearch_result(query) | |
| if cached: | |
| logger.info("Using cached web search result.") | |
| return cached | |
| logger.info("Performing a new web search for: '%s'", query) | |
| # model = HfApiModel() | |
| # search_tool = DuckDuckGoSearchTool() | |
| # web_agent = CodeAgent(tools=[search_tool], model=model) | |
| # managed_web_agent = ManagedAgent( | |
| # agent=web_agent, | |
| # name="web_search", | |
| # description="Runs a web search. Provide your query." | |
| # ) | |
| search_agent = ToolCallingAgent( | |
| tools=[DuckDuckGoSearchTool(), VisitWebpageTool()], | |
| model=HfApiModel(), | |
| name="search_agent", | |
| description="This is an agent that can do web search.", | |
| ) | |
| manager_agent = CodeAgent( | |
| tools=[], | |
| model=model, | |
| managed_agents=[managed_web_agent] | |
| ) | |
| new_search_result = manager_agent.run(f"Search for information about: {query}") | |
| store_websearch_result(query, new_search_result) | |
| return str(new_search_result).strip() | |
| except Exception as e: | |
| logger.error(f"Web search failed: {e}") | |
| return "" | |
| def is_greeting(query: str) -> bool: | |
| """ | |
| Returns True if the query is a greeting. This check is designed to be | |
| lenient enough to catch common greetings even with minor spelling mistakes | |
| or punctuation. | |
| """ | |
| # Define a set of common greeting words (you can add variants or use fuzzy matching if needed) | |
| greetings = {"hello", "hi", "hey", "hii", "hola", "greetings"} | |
| # Remove punctuation and extra whitespace, and lower the case. | |
| cleaned = re.sub(r'[^\w\s]', '', query).strip().lower() | |
| # Split the cleaned text into words. | |
| words_in_query = set(cleaned.split()) | |
| # Return True if any of the greeting words are in the query. | |
| return not words_in_query.isdisjoint(greetings) | |
| # ------------------------------------------------------- | |
| # Vector Stores & RAG | |
| # ------------------------------------------------------- | |
| def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: | |
| if os.path.exists(store_dir): | |
| logger.info(f"Loading existing FAISS store from {store_dir}") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
| ) | |
| return FAISS.load_local(store_dir, embeddings) | |
| else: | |
| logger.info(f"Building new FAISS store from {csv_path}") | |
| df = pd.read_csv(csv_path) | |
| df = df.loc[:, ~df.columns.str.contains('^Unnamed')] | |
| df.columns = df.columns.str.strip() | |
| if "Answer" in df.columns: | |
| df.rename(columns={"Answer": "Answers"}, inplace=True) | |
| if "Question " in df.columns and "Question" not in df.columns: | |
| df.rename(columns={"Question ": "Question"}, inplace=True) | |
| if "Question" not in df.columns or "Answers" not in df.columns: | |
| raise ValueError("CSV must have 'Question' and 'Answers' columns.") | |
| docs = [] | |
| for _, row in df.iterrows(): | |
| question_text = str(row["Question"]).strip() | |
| ans = str(row["Answers"]).strip() | |
| doc = Document(page_content=ans, metadata={"question": question_text}) | |
| docs.append(doc) | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
| ) | |
| vectorstore = FAISS.from_documents(docs, embedding=embeddings) | |
| vectorstore.save_local(store_dir) | |
| return vectorstore | |
| #rag chain is for wellness | |
| def build_rag_chain(vectorstore: FAISS, llm) -> RetrievalQA: | |
| prompt = PromptTemplate( | |
| template=""" | |
| [INST] You are a helpful AI specialized in Wellness & Well-being topics. | |
| Please use the following context to provide a detailed, helpful answer. | |
| If the context doesn't fully address the question, acknowledge this and provide the best possible information. | |
| Context: {context} | |
| Question: {question} | |
| Guidelines for responses: | |
| 1. Start with a clear introduction establishing the wellness topic | |
| 2. Present information using numbered lists for actionable steps | |
| 3. Include evidence-based examples and practical applications | |
| 4. Provide specific, implementable suggestions | |
| 5. End with clear takeaways or next steps | |
| Additional considerations: | |
| - All recommendations should be grounded in current wellness research | |
| - Focus on sustainable, long-term lifestyle modifications | |
| - Acknowledge individual differences in wellness journeys | |
| - Emphasize holistic approaches to health and well-being | |
| - Include relevant studies or research when applicable | |
| [/INST] | |
| """, | |
| input_variables=["context", "question"] | |
| ) | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs={ | |
| "prompt": prompt, | |
| "verbose": False, | |
| "document_variable_name": "context" | |
| } | |
| ) | |
| return chain | |
| #rag chain to is for brand | |
| def build_rag_chain2(vectorstore: FAISS, llm) -> RetrievalQA: | |
| prompt = PromptTemplate( | |
| template=""" | |
| [INST] You are the Brand Strategy Specialist for Daily Wellness AI. | |
| Please provide detailed, strategic guidance specific to Daily Wellness AI's brand development and market positioning. | |
| If additional context is needed, acknowledge this while maintaining focus on our company's objectives. | |
| Context: {context} | |
| Question: {question} | |
| Guidelines for Daily Wellness AI specific responses: | |
| 1. Begin with addressing specific Daily Wellness AI brand challenges or opportunities | |
| 2. Align recommendations with our core mission of democratizing personalized wellness | |
| 3. Include competitive analysis within the AI wellness space | |
| 4. Provide actionable steps that reflect our technological capabilities | |
| 5. Conclude with KPIs aligned with our growth objectives | |
| Brand Pillars to Address: | |
| - AI-Driven Personalization | |
| - Scientific Credibility | |
| - User-Centric Design | |
| - Innovation Leadership | |
| - Community Building | |
| [/INST] | |
| """, | |
| input_variables=["context", "question"] | |
| ) | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs={ | |
| "prompt": prompt, | |
| "verbose": False, | |
| "document_variable_name": "context" | |
| } | |
| ) | |
| return chain | |
| # ------------------------------------------------------- | |
| # PipelineState | |
| # ------------------------------------------------------- | |
| class PipelineState: | |
| _instance = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(PipelineState, cls).__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._initialize() | |
| def _initialize(self): | |
| try: | |
| self.metrics = ProcessingMetrics() | |
| self.error_count = 0 | |
| self.request_timestamps = [] | |
| self.cache = OrderedDict() | |
| self._setup_chains() | |
| self._initialized = True | |
| self.metrics.last_reset = datetime.now() | |
| logger.info("Pipeline state initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize pipeline: {e}") | |
| raise RuntimeError("Pipeline initialization failed.") from e | |
| def _setup_chains(self): | |
| # Existing custom chains | |
| self.tailor_chainWellnessBrand = get_tailor_chain_wellnessBrand() | |
| self.classification_chain = get_classification_chain() | |
| self.refusal_chain = get_refusal_chain() | |
| self.tailor_chain = get_tailor_chain() | |
| self.cleaner_chain = get_cleaner_chain() | |
| # Specialized chain for self-harm | |
| from prompts import selfharm_prompt | |
| self.self_harm_chain = LLMChain(llm=gemini_llm, prompt=selfharm_prompt, verbose=False) | |
| # NEW: chain for frustration/harsh queries | |
| from prompts import frustration_prompt | |
| self.frustration_chain = LLMChain(llm=gemini_llm, prompt=frustration_prompt, verbose=False) | |
| # NEW: chain for ethical conflict queries | |
| from prompts import ethical_conflict_prompt | |
| self.ethical_conflict_chain = LLMChain(llm=gemini_llm, prompt=ethical_conflict_prompt, verbose=False) | |
| # Build brand & wellness vectorstores | |
| brand_csv = "BrandAI.csv" | |
| brand_store = "faiss_brand_store" | |
| wellness_csv = "AIChatbot.csv" | |
| wellness_store = "faiss_wellness_store" | |
| brand_vs = build_or_load_vectorstore(brand_csv, brand_store) | |
| wellness_vs = build_or_load_vectorstore(wellness_csv, wellness_store) | |
| # Default LLM & fallback | |
| self.gemini_llm = gemini_llm | |
| self.groq_fallback_llm = groq_fallback_llm | |
| self.brand_rag_chain = build_rag_chain2(brand_vs, self.gemini_llm) | |
| self.wellness_rag_chain = build_rag_chain(wellness_vs, self.gemini_llm) | |
| self.brand_rag_chain_fallback = build_rag_chain2(brand_vs, self.groq_fallback_llm) | |
| self.wellness_rag_chain_fallback = build_rag_chain(wellness_vs, self.groq_fallback_llm) | |
| def handle_error(self, error: Exception) -> bool: | |
| self.error_count += 1 | |
| self.metrics.errors += 1 | |
| if self.error_count >= MAX_RETRIES: | |
| logger.warning("Max error reached, resetting pipeline.") | |
| self.reset() | |
| return False | |
| return True | |
| def reset(self): | |
| try: | |
| logger.info("Resetting pipeline state.") | |
| old_metrics = self.metrics | |
| self._initialized = False | |
| self.__init__() | |
| self.metrics = old_metrics | |
| self.metrics.last_reset = datetime.now() | |
| self.error_count = 0 | |
| gc.collect() | |
| logger.info("Pipeline state reset done.") | |
| except Exception as e: | |
| logger.error(f"Reset pipeline failed: {e}") | |
| raise RuntimeError("Failed to reset pipeline.") | |
| def get_metrics(self) -> Dict[str, Any]: | |
| uptime = (datetime.now() - self.metrics.last_reset).total_seconds() / 3600 | |
| return { | |
| "total_requests": self.metrics.total_requests, | |
| "cache_hits": self.metrics.cache_hits, | |
| "error_rate": self.metrics.errors / max(self.metrics.total_requests, 1), | |
| "average_response_time": self.metrics.average_response_time, | |
| "uptime_hours": uptime | |
| } | |
| def update_metrics(self, start_time: float, is_cache_hit: bool = False): | |
| duration = time.time() - start_time | |
| self.metrics.update_metrics(duration, is_cache_hit) | |
| pipeline_state = PipelineState() | |
| # ------------------------------------------------------- | |
| # Helper checks: detect aggression or ethical conflict | |
| # ------------------------------------------------------- | |
| def is_aggressive_or_harsh(query: str) -> bool: | |
| """ | |
| Very naive check: If user is insulting AI, complaining about worthless answers, etc. | |
| You can refine with better logic or a small LLM classifier. | |
| """ | |
| triggers = ["useless", "worthless", "you cannot do anything", "so bad at answering"] | |
| for t in triggers: | |
| if t in query.lower(): | |
| return True | |
| return False | |
| def is_ethical_conflict(query: str) -> bool: | |
| """ | |
| Check if user is asking about lying, revenge, or other moral dilemmas. | |
| You can expand or refine as needed. | |
| """ | |
| ethics_keywords = ["should i lie", "should i cheat", "revenge", "get back at", "hurt them back"] | |
| q_lower = query.lower() | |
| return any(k in q_lower for k in ethics_keywords) | |
| # ------------------------------------------------------- | |
| # Main Pipeline | |
| # ------------------------------------------------------- | |
| def run_with_chain(query: str) -> str: | |
| """ | |
| Overall flow: | |
| 1) Validate & rate-limit | |
| 2) Mistral moderation => | |
| - If self-harm => self_harm_chain | |
| - If hate => refusal | |
| - If violence/dangerous => we STILL produce a guided response (ethics) unless it's extreme | |
| 3) If not refused, check if query is aggression/ethical => route to chain | |
| 4) Otherwise classify => brand/wellness/out-of-scope => RAG => tailor | |
| """ | |
| start_time = time.time() | |
| try: | |
| # 1) Validate | |
| if not query or query.strip() == "": | |
| return create_error_response("validation", "Empty query.") | |
| if len(query.strip()) < 2: | |
| return create_error_response("validation", "Too short.") | |
| words_in_text = re.findall(r'\b\w+\b', query.lower()) | |
| if not any(w in english_words for w in words_in_text): | |
| return create_error_response("validation", "Unclear words.") | |
| if len(query) > 500: | |
| return create_error_response("validation", "Too long (>500).") | |
| if not handle_rate_limiting(pipeline_state): | |
| return create_error_response("rate_limit") | |
| # New: Check if the query is a greeting | |
| if is_greeting(query): | |
| greeting_response = "Hello there!! Welcome to DailyWellness, How may I assist you today?" | |
| manage_cache(pipeline_state, query, greeting_response) | |
| pipeline_state.update_metrics(start_time) | |
| return greeting_response | |
| if not handle_rate_limiting(pipeline_state): | |
| return create_error_response("rate_limit") | |
| # Cache check | |
| cached = manage_cache(pipeline_state, query) | |
| if cached: | |
| pipeline_state.update_metrics(start_time, is_cache_hit=True) | |
| return cached | |
| # 2) Mistral moderation | |
| try: | |
| mod_res = moderate_text(query) | |
| severity = compute_moderation_severity(mod_res) | |
| # If self-harm => supportive | |
| if mod_res.categories.get("selfharm", False): | |
| logger.info("Self-harm flagged => providing supportive chain response.") | |
| selfharm_resp = pipeline_state.self_harm_chain.run({"query": query}) | |
| final_tailored = pipeline_state.tailor_chain.run({"response": selfharm_resp}).strip() | |
| manage_cache(pipeline_state, query, final_tailored) | |
| pipeline_state.update_metrics(start_time) | |
| return final_tailored | |
| # If hate => refuse | |
| if mod_res.categories.get("hate", False): | |
| logger.info("Hate content => refusal.") | |
| refusal_resp = pipeline_state.refusal_chain.run({"topic": "moderation_flagged"}) | |
| manage_cache(pipeline_state, query, refusal_resp) | |
| pipeline_state.update_metrics(start_time) | |
| return refusal_resp | |
| # If "dangerous" or "violence" is flagged, we might still want to | |
| # provide a "non-violent advice" approach (like revenge queries). | |
| # So we won't automatically refuse. We'll rely on the | |
| # is_ethical_conflict() check below. | |
| except Exception as e: | |
| logger.error(f"Moderation error: {e}") | |
| severity = 0.0 | |
| # 3) Check for aggression or ethical conflict | |
| if is_aggressive_or_harsh(query): | |
| logger.info("Detected harsh/aggressive language => frustration_chain.") | |
| frustration_resp = pipeline_state.frustration_chain.run({"query": query}) | |
| final_tailored = pipeline_state.tailor_chain.run({"response": frustration_resp}).strip() | |
| manage_cache(pipeline_state, query, final_tailored) | |
| pipeline_state.update_metrics(start_time) | |
| return final_tailored | |
| if is_ethical_conflict(query): | |
| logger.info("Detected ethical dilemma => ethical_conflict_chain.") | |
| ethical_resp = pipeline_state.ethical_conflict_chain.run({"query": query}) | |
| final_tailored = pipeline_state.tailor_chain.run({"response": ethical_resp}).strip() | |
| manage_cache(pipeline_state, query, final_tailored) | |
| pipeline_state.update_metrics(start_time) | |
| return final_tailored | |
| # 4) Standard path: classification => brand/wellness/out-of-scope | |
| try: | |
| class_out = pipeline_state.classification_chain.run({"query": query}) | |
| classification = class_out.strip().lower() | |
| except Exception as e: | |
| logger.error(f"Classification error: {e}") | |
| if not pipeline_state.handle_error(e): | |
| return create_error_response("processing", "Classification error.") | |
| return create_error_response("processing") | |
| if classification in ["outofscope", "out_of_scope"]: | |
| try: | |
| # Politely refuse if truly out-of-scope | |
| refusal_text = pipeline_state.refusal_chain.run({"topic": query}) | |
| tailored_refusal = pipeline_state.tailor_chain.run({"response": refusal_text}).strip() | |
| manage_cache(pipeline_state, query, tailored_refusal) | |
| pipeline_state.update_metrics(start_time) | |
| return tailored_refusal | |
| except Exception as e: | |
| logger.error(f"Refusal chain error: {e}") | |
| if not pipeline_state.handle_error(e): | |
| return create_error_response("processing", "Refusal error.") | |
| return create_error_response("processing") | |
| # brand vs wellness | |
| if classification == "brand": | |
| rag_chain_main = pipeline_state.brand_rag_chain | |
| rag_chain_fallback = pipeline_state.brand_rag_chain_fallback | |
| else: | |
| rag_chain_main = pipeline_state.wellness_rag_chain | |
| rag_chain_fallback = pipeline_state.wellness_rag_chain_fallback | |
| # RAG with fallback | |
| try: | |
| try: | |
| rag_output = rag_chain_main({"query": query}) | |
| except Exception as e_main: | |
| if "resource exhausted" in str(e_main).lower(): | |
| logger.warning("Gemini resource exhausted. Falling back to Groq.") | |
| rag_output = rag_chain_fallback({"query": query}) | |
| else: | |
| raise | |
| if isinstance(rag_output, dict) and "result" in rag_output: | |
| csv_ans = rag_output["result"].strip() | |
| else: | |
| csv_ans = str(rag_output).strip() | |
| # If not enough => web | |
| if "not enough context" in csv_ans.lower() or len(csv_ans) < 40: | |
| logger.info("Insufficient RAG => web search.") | |
| web_info = do_web_search(query) | |
| if web_info: | |
| csv_ans += f"\n\nAdditional info:\n{web_info}" | |
| except Exception as e: | |
| logger.error(f"RAG error: {e}") | |
| if not pipeline_state.handle_error(e): | |
| return create_error_response("processing", "RAG error.") | |
| return create_error_response("processing") | |
| # Tailor final | |
| try: | |
| final_tailored = pipeline_state.tailor_chainWellnessBrand.run({"response": csv_ans}).strip() | |
| if severity > 0.5: | |
| final_tailored += "\n\n(Please note: This may involve sensitive content.)" | |
| manage_cache(pipeline_state, query, final_tailored) | |
| pipeline_state.update_metrics(start_time) | |
| return final_tailored | |
| except Exception as e: | |
| logger.error(f"Tailor chain error: {e}") | |
| if not pipeline_state.handle_error(e): | |
| return create_error_response("processing", "Tailoring error.") | |
| return create_error_response("processing") | |
| except Exception as e: | |
| logger.error(f"Critical error in run_with_chain: {e}") | |
| pipeline_state.metrics.errors += 1 | |
| return create_error_response("general") | |
| # ------------------------------------------------------- | |
| # Health & Utility | |
| # ------------------------------------------------------- | |
| # def reset_pipeline(): | |
| # try: | |
| # pipeline_state.reset() | |
| # return {"status": "success", "message": "Pipeline reset successful"} | |
| # except Exception as e: | |
| # logger.error(f"Reset pipeline error: {e}") | |
| # return {"status": "error", "message": str(e)} | |
| # def get_pipeline_health() -> Dict[str, Any]: | |
| # try: | |
| # stats = pipeline_state.get_metrics() | |
| # healthy = stats["error_rate"] < 0.1 | |
| # return { | |
| # **stats, | |
| # "is_healthy": healthy, | |
| # "status": "healthy" if healthy else "degraded" | |
| # } | |
| # except Exception as e: | |
| # logger.error(f"Health check error: {e}") | |
| # return {"is_healthy": False, "status": "error", "error": str(e)} | |
| # def health_check() -> Dict[str, Any]: | |
| # try: | |
| # _ = run_with_chain("Test query for pipeline health check.") | |
| # return { | |
| # "status": "ok", | |
| # "timestamp": datetime.now().isoformat(), | |
| # "metrics": get_pipeline_health() | |
| # } | |
| # except Exception as e: | |
| # return { | |
| # "status": "error", | |
| # "timestamp": datetime.now().isoformat(), | |
| # "error": str(e) | |
| # } | |
| logger.info("Pipeline initialization complete!") | |