Phoenix21 commited on
Commit
81ce286
·
1 Parent(s): d65c6b1

Removed useless code and cleaned the pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +168 -361
pipeline.py CHANGED
@@ -10,7 +10,6 @@ from collections import OrderedDict
10
  import pandas as pd
11
  from pydantic import BaseModel, Field, ValidationError, validator
12
 
13
- # NLTK for input validation
14
  import nltk
15
  from nltk.corpus import words
16
  try:
@@ -19,7 +18,6 @@ except LookupError:
19
  nltk.download('words')
20
  english_words = set(words.words())
21
 
22
- # LangChain / Groq / LLM imports
23
  from langchain_groq import ChatGroq
24
  from langchain_community.embeddings import HuggingFaceEmbeddings
25
  from langchain_community.vectorstores import FAISS
@@ -28,35 +26,15 @@ from langchain.prompts import PromptTemplate
28
  from langchain.docstore.document import Document
29
  from langchain_core.caches import BaseCache
30
  from langchain_core.callbacks import Callbacks
31
- # from langchain_core.callbacks import CallbackManager
32
- # from langchain.callbacks.base import BaseCallbacks # Updated import
33
- # from langchain.callbacks.manager import CallbackManager
34
- # from langchain.callbacks import StdOutCallbackHandler
35
 
36
- # Custom chain imports
37
- # from groq_client import GroqClient
38
  from chain.classification_chain import get_classification_chain
39
  from chain.refusal_chain import get_refusal_chain
40
  from chain.tailor_chain import get_tailor_chain
41
  from chain.cleaner_chain import get_cleaner_chain
42
  from chain.tailor_chain_wellnessBrand import get_tailor_chain_wellnessBrand
43
 
44
- # Mistral moderation
45
  from mistralai import Mistral
46
 
47
- # Google Gemini LLM
48
- # from langchain_google_genai import ChatGoogleGenerativeAI
49
-
50
- # Web search
51
- # from smolagents import DuckDuckGoSearchTool, ManagedAgent, HfApiModel, CodeAgent
52
- # from openinference.instrumentation.smolagents import SmolagentsInstrumentor
53
- # from phoenix.otel import register
54
-
55
-
56
- # register()
57
- # SmolagentsInstrumentor().instrument(skip_dep_check=True)
58
-
59
-
60
  from smolagents import (
61
  CodeAgent,
62
  DuckDuckGoSearchTool,
@@ -65,9 +43,7 @@ from smolagents import (
65
  VisitWebpageTool,
66
  )
67
 
68
- # Import new prompts
69
- from chain.prompts import selfharm_prompt, frustration_prompt, ethical_conflict_prompt,classification_prompt, refusal_prompt, tailor_prompt, cleaner_prompt
70
-
71
 
72
  logging.basicConfig(level=logging.INFO)
73
  logger = logging.getLogger(__name__)
@@ -75,19 +51,12 @@ logger = logging.getLogger(__name__)
75
  from langchain_core.tracers import LangChainTracer
76
  from langsmith import Client
77
 
 
 
 
 
78
 
79
- os.environ["LANGCHAIN_TRACING_V2"]="true"
80
- os.environ["LANGSMITH_ENDPOINT"]="https://api.smith.langchain.com"
81
- # langsmith_client = Client()
82
- os.environ["LANGCHAIN_API_KEY"]=os.getenv("LANGCHAIN_API_KEY")
83
- os.environ["LANGCHAIN_PROJECT"]=os.getenv("LANGCHAIN_PROJECT")
84
- # tracer = LangChainTracer(project_name=os.environ.get("LANGCHAIN_PROJECT", "healthy_ai_expert"))
85
-
86
-
87
-
88
- # -------------------------------------------------------
89
  # Basic Models
90
- # -------------------------------------------------------
91
  class QueryInput(BaseModel):
92
  query: str = Field(..., min_length=1)
93
 
@@ -115,9 +84,7 @@ class ProcessingMetrics(BaseModel):
115
  / self.total_requests
116
  )
117
 
118
- # -------------------------------------------------------
119
  # Mistral Moderation
120
- # -------------------------------------------------------
121
  class ModerationResult(BaseModel):
122
  is_safe: bool
123
  categories: Dict[str, bool]
@@ -127,9 +94,7 @@ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
127
  client = Mistral(api_key=mistral_api_key)
128
 
129
  def moderate_text(query: str) -> ModerationResult:
130
- """
131
- Uses Mistral's moderation to detect unsafe content.
132
- """
133
  try:
134
  query_input = QueryInput(query=query)
135
  response = client.classifiers.moderate_chat(
@@ -161,53 +126,27 @@ def moderate_text(query: str) -> ModerationResult:
161
  raise RuntimeError(f"Moderation failed: {e}")
162
 
163
  def compute_moderation_severity(mresult: ModerationResult) -> float:
 
164
  severity = 0.0
165
  for flag in mresult.categories.values():
166
  if flag:
167
  severity += 0.3
168
  return min(severity, 1.0)
169
 
170
- # -------------------------------------------------------
171
  # Models
172
- # -------------------------------------------------------
173
  GROQ_MODELS = {
174
- "default": "llama3-70b-8192",
175
  "classification": "qwen-qwq-32b",
176
- "moderation": "mistral-moderation-latest",
177
- "combination": "llama-3.3-70b-versatile"
178
  }
179
 
180
  MAX_RETRIES = 3
181
  RATE_LIMIT_REQUESTS = 60
182
  CACHE_SIZE_LIMIT = 1000
183
 
184
- # Google Gemini (primary)
185
- # GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
186
- # gemini_llm = ChatGoogleGenerativeAI(
187
- # model="gemini-2.0-flash",
188
- # temperature=0.5,
189
- # max_tokens=None,
190
- # timeout=None,
191
- # max_retries=2,
192
- # )
193
-
194
- # # Fallback
195
- # fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", "GROQ_API_KEY")
196
-
197
- # # Attempt to initialize ChatGroq without a cache
198
- # try:
199
- # groq_fallback_llm = ChatGroq(
200
- # model=GROQ_MODELS["default"],
201
- # temperature=0.7,
202
- # # groq_api_key=fallback_groq_api_key,
203
- # max_tokens=2048
204
- # )
205
- # except Exception as e:
206
- # logger.error(f"Failed to initialize ChatGroq: {e}")
207
- # raise RuntimeError("ChatGroq initialization failed.") from e
208
- # Define a simple no-op cache class
209
  class NoCache(BaseCache):
210
- """Simple no-op cache implementation."""
211
  def __init__(self):
212
  pass
213
 
@@ -220,28 +159,27 @@ class NoCache(BaseCache):
220
  def clear(self):
221
  pass
222
 
223
- # Rebuild the ChatGroq model after defining NoCache
224
  ChatGroq.model_rebuild()
225
- # Initialize ChatGroq with cache
226
  try:
227
  fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY"))
228
  if not fallback_groq_api_key:
229
  logger.warning("No Groq API key found for fallback LLM")
230
  groq_fallback_llm = ChatGroq(
231
- model=GROQ_MODELS["default"], # Replace with your actual model name if different
232
  temperature=0.7,
233
  groq_api_key=fallback_groq_api_key,
234
  max_tokens=2048,
235
- cache=NoCache(), # Set cache explicitly
236
- callbacks=[] # Explicitly set callbacks to an empty list
237
  )
238
  except Exception as e:
239
  logger.error(f"Failed to initialize fallback Groq LLM: {e}")
240
  raise RuntimeError("ChatGroq initialization failed.") from e
241
- # -------------------------------------------------------
242
  # Rate-limit & Cache
243
- # -------------------------------------------------------
244
  def handle_rate_limiting(state: "PipelineState") -> bool:
 
245
  current_time = time.time()
246
  one_min_ago = current_time - 60
247
  state.request_timestamps = [t for t in state.request_timestamps if t > one_min_ago]
@@ -251,6 +189,7 @@ def handle_rate_limiting(state: "PipelineState") -> bool:
251
  return True
252
 
253
  def manage_cache(state: "PipelineState", query: str, response: str = None) -> Optional[str]:
 
254
  cache_key = query.strip().lower()
255
  if response is None:
256
  return state.cache.get(cache_key)
@@ -262,17 +201,16 @@ def manage_cache(state: "PipelineState", query: str, response: str = None) -> Op
262
  return None
263
 
264
  def create_error_response(error_type: str, details: str = "") -> str:
 
265
  templates = {
266
  "validation": "I couldn't process your query: {details}",
267
  "processing": "I encountered an error while processing: {details}",
268
  "rate_limit": "Too many requests. Please try again soon.",
269
- "general": "Apologies, but something went wrong."
270
  }
271
  return templates.get(error_type, templates["general"]).format(details=details)
272
 
273
- # -------------------------------------------------------
274
  # Web Search
275
- # -------------------------------------------------------
276
  web_search_cache: Dict[str, str] = {}
277
 
278
  def store_websearch_result(query: str, result: str):
@@ -282,6 +220,7 @@ def retrieve_websearch_result(query: str) -> Optional[str]:
282
  return web_search_cache.get(query.strip().lower())
283
 
284
  def do_web_search(query: str) -> str:
 
285
  try:
286
  cached = retrieve_websearch_result(query)
287
  if cached:
@@ -289,26 +228,17 @@ def do_web_search(query: str) -> str:
289
  return cached
290
 
291
  logger.info("Performing a new web search for: '%s'", query)
292
- # model = HfApiModel()
293
- # search_tool = DuckDuckGoSearchTool()
294
- # web_agent = CodeAgent(tools=[search_tool], model=model)
295
-
296
- # managed_web_agent = ManagedAgent(
297
- # agent=web_agent,
298
- # name="web_search",
299
- # description="Runs a web search. Provide your query."
300
- # )
301
  search_agent = ToolCallingAgent(
302
- tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
303
- model=HfApiModel(),
304
- name="search_agent",
305
- description="This is an agent that can do web search.",
306
  )
307
 
308
  manager_agent = CodeAgent(
309
  tools=[],
310
- model=model,
311
- managed_agents=[managed_web_agent]
312
  )
313
 
314
  new_search_result = manager_agent.run(f"Search for information about: {query}")
@@ -319,34 +249,21 @@ def do_web_search(query: str) -> str:
319
  return ""
320
 
321
  def is_greeting(query: str) -> bool:
322
- """
323
- Returns True if the query is a greeting. This check is designed to be
324
- lenient enough to catch common greetings even with minor spelling mistakes
325
- or punctuation.
326
- """
327
- # Define a set of common greeting words (you can add variants or use fuzzy matching if needed)
328
  greetings = {"hello", "hi", "hey", "hii", "hola", "greetings"}
329
-
330
- # Remove punctuation and extra whitespace, and lower the case.
331
  cleaned = re.sub(r'[^\w\s]', '', query).strip().lower()
332
-
333
- # Split the cleaned text into words.
334
  words_in_query = set(cleaned.split())
335
-
336
- # Return True if any of the greeting words are in the query.
337
  return not words_in_query.isdisjoint(greetings)
338
 
339
-
340
- # -------------------------------------------------------
341
  # Vector Stores & RAG
342
- # -------------------------------------------------------
343
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
 
344
  if os.path.exists(store_dir):
345
  logger.info(f"Loading existing FAISS store from {store_dir}")
346
  embeddings = HuggingFaceEmbeddings(
347
  model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1"
348
  )
349
- return FAISS.load_local(store_dir, embeddings,allow_dangerous_deserialization=True)# It will allow to deserialize and use the faiss store created first time locally
350
  else:
351
  logger.info(f"Building new FAISS store from {csv_path}")
352
  df = pd.read_csv(csv_path)
@@ -373,8 +290,9 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
373
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
374
  vectorstore.save_local(store_dir)
375
  return vectorstore
376
- #rag chain is for wellness
377
  def build_rag_chain(vectorstore: FAISS, llm) -> RetrievalQA:
 
378
  prompt = PromptTemplate(
379
  template="""
380
  [INST] You are an AI wellness assistant speaking directly to a user who has asked: "{question}"
@@ -408,8 +326,9 @@ def build_rag_chain(vectorstore: FAISS, llm) -> RetrievalQA:
408
  }
409
  )
410
  return chain
411
- #rag chain to is for brand
412
  def build_rag_chain2(vectorstore: FAISS, llm) -> RetrievalQA:
 
413
  prompt = PromptTemplate(
414
  template="""
415
  [INST] You are the brand strategy advisor for Healthy AI Expert. A team member has asked: "{question}"
@@ -425,7 +344,6 @@ def build_rag_chain2(vectorstore: FAISS, llm) -> RetrievalQA:
425
 
426
  Remember our key brand pillars: AI-driven personalization, scientific credibility, user-centric design, and innovation leadership.
427
  [/INST]
428
-
429
  """,
430
  input_variables=["context", "question"]
431
  )
@@ -444,9 +362,7 @@ def build_rag_chain2(vectorstore: FAISS, llm) -> RetrievalQA:
444
  )
445
  return chain
446
 
447
- # -------------------------------------------------------
448
  # PipelineState
449
- # -------------------------------------------------------
450
  class PipelineState:
451
  _instance = None
452
 
@@ -462,6 +378,7 @@ class PipelineState:
462
  self._initialize()
463
 
464
  def _initialize(self):
 
465
  try:
466
  self.metrics = ProcessingMetrics()
467
  self.error_count = 0
@@ -478,52 +395,31 @@ class PipelineState:
478
  raise RuntimeError("Pipeline initialization failed.") from e
479
 
480
  def _setup_chains(self):
481
- # Existing custom chains
482
  self.tailor_chainWellnessBrand = get_tailor_chain_wellnessBrand()
483
  self.classification_chain = get_classification_chain()
484
- self.refusal_chain = get_refusal_chain()
485
- self.tailor_chain = get_tailor_chain()
486
- self.cleaner_chain = get_cleaner_chain()
487
 
488
- # Specialized chain for self-harm
489
- from chain.prompts import selfharm_prompt
490
- # self.self_harm_chain = LLMChain(llm=gemini_llm, prompt=selfharm_prompt, verbose=False)
491
-
492
  self.self_harm_chain = LLMChain(llm=groq_fallback_llm, prompt=selfharm_prompt, verbose=False)
493
-
494
-
495
- # NEW: chain for frustration/harsh queries
496
- from chain.prompts import frustration_prompt
497
- # self.frustration_chain = LLMChain(llm=gemini_llm, prompt=frustration_prompt, verbose=False)
498
  self.frustration_chain = LLMChain(llm=groq_fallback_llm, prompt=frustration_prompt, verbose=False)
499
-
500
-
501
- # NEW: chain for ethical conflict queries
502
- from chain.prompts import ethical_conflict_prompt
503
- # self.ethical_conflict_chain = LLMChain(llm=gemini_llm, prompt=ethical_conflict_prompt, verbose=False)
504
  self.ethical_conflict_chain = LLMChain(llm=groq_fallback_llm, prompt=ethical_conflict_prompt, verbose=False)
505
 
506
- # Build brand & wellness vectorstores
507
- brand_csv = "dataset/BrandAI.csv"
508
- brand_store = "faiss_brand_store"
509
  wellness_csv = "dataset/AIChatbot.csv"
510
  wellness_store = "faiss_wellness_store"
511
 
512
- brand_vs = build_or_load_vectorstore(brand_csv, brand_store)
513
  wellness_vs = build_or_load_vectorstore(wellness_csv, wellness_store)
514
 
515
- # Default LLM & fallback
516
- # self.gemini_llm = gemini_llm
517
  self.groq_fallback_llm = groq_fallback_llm
518
-
519
- # self.brand_rag_chain = build_rag_chain2(brand_vs, self.gemini_llm)
520
- # self.wellness_rag_chain = build_rag_chain(wellness_vs, self.gemini_llm)
521
- self.brand_rag_chain = build_rag_chain2(brand_vs, self.groq_fallback_llm)
522
  self.wellness_rag_chain = build_rag_chain(wellness_vs, self.groq_fallback_llm)
523
- # self.brand_rag_chain_fallback = build_rag_chain2(brand_vs, self.groq_fallback_llm)
524
- # self.wellness_rag_chain_fallback = build_rag_chain(wellness_vs, self.groq_fallback_llm)
525
 
526
  def handle_error(self, error: Exception) -> bool:
 
527
  self.error_count += 1
528
  self.metrics.errors += 1
529
  if self.error_count >= MAX_RETRIES:
@@ -533,6 +429,7 @@ class PipelineState:
533
  return True
534
 
535
  def reset(self):
 
536
  try:
537
  logger.info("Resetting pipeline state.")
538
  old_metrics = self.metrics
@@ -548,6 +445,7 @@ class PipelineState:
548
  raise RuntimeError("Failed to reset pipeline.")
549
 
550
  def get_metrics(self) -> Dict[str, Any]:
 
551
  uptime = (datetime.now() - self.metrics.last_reset).total_seconds() / 3600
552
  return {
553
  "total_requests": self.metrics.total_requests,
@@ -558,20 +456,15 @@ class PipelineState:
558
  }
559
 
560
  def update_metrics(self, start_time: float, is_cache_hit: bool = False):
 
561
  duration = time.time() - start_time
562
  self.metrics.update_metrics(duration, is_cache_hit)
563
 
564
  pipeline_state = PipelineState()
565
 
566
- # -------------------------------------------------------
567
- # Helper checks: detect aggression or ethical conflict
568
- # -------------------------------------------------------
569
-
570
  def is_aggressive_or_harsh(query: str) -> bool:
571
- """
572
- Very naive check: If user is insulting AI, complaining about worthless answers, etc.
573
- You can refine with better logic or a small LLM classifier.
574
- """
575
  triggers = ["useless", "worthless", "you cannot do anything", "so bad at answering"]
576
  for t in triggers:
577
  if t in query.lower():
@@ -579,226 +472,140 @@ def is_aggressive_or_harsh(query: str) -> bool:
579
  return False
580
 
581
  def is_ethical_conflict(query: str) -> bool:
582
- """
583
- Check if user is asking about lying, revenge, or other moral dilemmas.
584
- You can expand or refine as needed.
585
- """
586
  ethics_keywords = ["should i lie", "should i cheat", "revenge", "get back at", "hurt them back"]
587
  q_lower = query.lower()
588
  return any(k in q_lower for k in ethics_keywords)
589
 
590
-
591
- # -------------------------------------------------------
592
  # Main Pipeline
593
- # -------------------------------------------------------
594
  def run_with_chain(query: str) -> str:
595
- """
596
- Overall flow:
597
- 1) Validate & rate-limit
598
- 2) Mistral moderation =>
599
- - If self-harm => self_harm_chain
600
- - If hate => refusal
601
- - If violence/dangerous => we STILL produce a guided response (ethics) unless it's extreme
602
- 3) If not refused, check if query is aggression/ethical => route to chain
603
- 4) Otherwise classify => brand/wellness/out-of-scope => RAG => tailor
604
- """
605
- # with tracer.new_trace(name="wellness_pipeline_run") as run:
606
- start_time = time.time()
607
- try:
608
- # 1) Validate
609
- if not query or query.strip() == "":
610
- return create_error_response("validation", "Empty query.")
611
- if len(query.strip()) < 2:
612
- return create_error_response("validation", "Too short.")
613
- words_in_text = re.findall(r'\b\w+\b', query.lower())
614
- if not any(w in english_words for w in words_in_text):
615
- return create_error_response("validation", "Unclear words.")
616
- if len(query) > 500:
617
- return create_error_response("validation", "Too long (>500).")
618
- if not handle_rate_limiting(pipeline_state):
619
- return create_error_response("rate_limit")
620
- # New: Check if the query is a greeting
621
- if is_greeting(query):
622
- greeting_response = "Hello there!! Welcome to Healthy AI Expert, How may I assist you today?"
623
- manage_cache(pipeline_state, query, greeting_response)
624
- pipeline_state.update_metrics(start_time)
625
- return greeting_response
626
 
627
- if not handle_rate_limiting(pipeline_state):
628
- return create_error_response("rate_limit")
629
-
630
- # Cache check
631
- cached = manage_cache(pipeline_state, query)
632
- if cached:
633
- pipeline_state.update_metrics(start_time, is_cache_hit=True)
634
- return cached
635
-
636
- # 2) Mistral moderation
637
- try:
638
- mod_res = moderate_text(query)
639
- severity = compute_moderation_severity(mod_res)
640
-
641
- # If self-harm => supportive
642
- if mod_res.categories.get("selfharm", False):
643
- logger.info("Self-harm flagged => providing supportive chain response.")
644
- selfharm_resp = pipeline_state.self_harm_chain.run({"query": query})
645
- final_tailored = pipeline_state.tailor_chain.run({"response": selfharm_resp}).strip()
646
- manage_cache(pipeline_state, query, final_tailored)
647
- pipeline_state.update_metrics(start_time)
648
- return final_tailored
649
-
650
- # If hate => refuse
651
- if mod_res.categories.get("hate", False):
652
- logger.info("Hate content => refusal.")
653
- refusal_resp = pipeline_state.refusal_chain.run({"topic": "moderation_flagged"})
654
- manage_cache(pipeline_state, query, refusal_resp)
655
- pipeline_state.update_metrics(start_time)
656
- return refusal_resp
657
-
658
- # If "dangerous" or "violence" is flagged, we might still want to
659
- # provide a "non-violent advice" approach (like revenge queries).
660
- # So we won't automatically refuse. We'll rely on the
661
- # is_ethical_conflict() check below.
662
-
663
- except Exception as e:
664
- logger.error(f"Moderation error: {e}")
665
- severity = 0.0
666
-
667
- # 3) Check for aggression or ethical conflict
668
- if is_aggressive_or_harsh(query):
669
- logger.info("Detected harsh/aggressive language => frustration_chain.")
670
- frustration_resp = pipeline_state.frustration_chain.run({"query": query})
671
- final_tailored = pipeline_state.tailor_chain.run({"response": frustration_resp}).strip()
672
  manage_cache(pipeline_state, query, final_tailored)
673
  pipeline_state.update_metrics(start_time)
674
  return final_tailored
675
-
676
- if is_ethical_conflict(query):
677
- logger.info("Detected ethical dilemma => ethical_conflict_chain.")
678
- ethical_resp = pipeline_state.ethical_conflict_chain.run({"query": query})
679
- final_tailored = pipeline_state.tailor_chain.run({"response": ethical_resp}).strip()
680
- manage_cache(pipeline_state, query, final_tailored)
681
  pipeline_state.update_metrics(start_time)
682
- return final_tailored
683
-
684
- # 4) Standard path: classification => brand/wellness/out-of-scope
685
- try:
686
- class_out = pipeline_state.classification_chain.run({"query": query})
687
- classification = class_out.strip().lower()
688
- except Exception as e:
689
- logger.error(f"Classification error: {e}")
690
- if not pipeline_state.handle_error(e):
691
- return create_error_response("processing", "Classification error.")
692
- return create_error_response("processing")
693
-
694
- if classification in ["outofscope", "out_of_scope"]:
695
- try:
696
- # Politely refuse if truly out-of-scope
697
- refusal_text = pipeline_state.refusal_chain.run({"topic": query})
698
- tailored_refusal = pipeline_state.tailor_chain.run({"response": refusal_text}).strip()
699
- manage_cache(pipeline_state, query, tailored_refusal)
700
- pipeline_state.update_metrics(start_time)
701
- return tailored_refusal
702
- except Exception as e:
703
- logger.error(f"Refusal chain error: {e}")
704
- if not pipeline_state.handle_error(e):
705
- return create_error_response("processing", "Refusal error.")
706
- return create_error_response("processing")
707
-
708
- # brand vs wellness
709
- if classification == "brand":
710
- rag_chain_main = pipeline_state.brand_rag_chain
711
- # rag_chain_fallback = pipeline_state.brand_rag_chain_fallback
712
- else:
713
- rag_chain_main = pipeline_state.wellness_rag_chain
714
- # rag_chain_fallback = pipeline_state.wellness_rag_chain_fallback
715
-
716
- # RAG with fallback
717
- try:
718
- try:
719
- rag_output = rag_chain_main({"query": query})
720
- except Exception as e_main:
721
- if "resource exhausted" in str(e_main).lower():
722
- logger.warning("Gemini resource exhausted. Falling back to Groq.")
723
- # rag_output = rag_chain_fallback({"query": query})
724
- else:
725
- raise
726
-
727
- if isinstance(rag_output, dict) and "result" in rag_output:
728
- csv_ans = rag_output["result"].strip()
729
- else:
730
- csv_ans = str(rag_output).strip()
731
-
732
- # If not enough => web
733
- if "not enough context" in csv_ans.lower() or len(csv_ans) < 40:
734
- logger.info("Insufficient RAG => web search.")
735
- web_info = do_web_search(query)
736
- if web_info:
737
- csv_ans += f"\n\nAdditional info:\n{web_info}"
738
- except Exception as e:
739
- logger.error(f"RAG error: {e}")
740
- if not pipeline_state.handle_error(e):
741
- return create_error_response("processing", "RAG error.")
742
- return create_error_response("processing")
743
-
744
- # Tailor final
745
  try:
746
- final_tailored = pipeline_state.tailor_chainWellnessBrand.run({"response": csv_ans}).strip()
747
- if severity > 0.5:
748
- final_tailored += "\n\n(Please note: This may involve sensitive content.)"
749
-
750
- manage_cache(pipeline_state, query, final_tailored)
751
  pipeline_state.update_metrics(start_time)
752
- return final_tailored
753
  except Exception as e:
754
- logger.error(f"Tailor chain error: {e}")
755
  if not pipeline_state.handle_error(e):
756
- return create_error_response("processing", "Tailoring error.")
757
  return create_error_response("processing")
758
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
  except Exception as e:
760
- logger.error(f"Critical error in run_with_chain: {e}")
761
- pipeline_state.metrics.errors += 1
762
- return create_error_response("general")
763
-
764
- # -------------------------------------------------------
765
- # Health & Utility
766
- # -------------------------------------------------------
767
- # def reset_pipeline():
768
- # try:
769
- # pipeline_state.reset()
770
- # return {"status": "success", "message": "Pipeline reset successful"}
771
- # except Exception as e:
772
- # logger.error(f"Reset pipeline error: {e}")
773
- # return {"status": "error", "message": str(e)}
774
-
775
- # def get_pipeline_health() -> Dict[str, Any]:
776
- # try:
777
- # stats = pipeline_state.get_metrics()
778
- # healthy = stats["error_rate"] < 0.1
779
- # return {
780
- # **stats,
781
- # "is_healthy": healthy,
782
- # "status": "healthy" if healthy else "degraded"
783
- # }
784
- # except Exception as e:
785
- # logger.error(f"Health check error: {e}")
786
- # return {"is_healthy": False, "status": "error", "error": str(e)}
787
-
788
- # def health_check() -> Dict[str, Any]:
789
- # try:
790
- # _ = run_with_chain("Test query for pipeline health check.")
791
- # return {
792
- # "status": "ok",
793
- # "timestamp": datetime.now().isoformat(),
794
- # "metrics": get_pipeline_health()
795
- # }
796
- # except Exception as e:
797
- # return {
798
- # "status": "error",
799
- # "timestamp": datetime.now().isoformat(),
800
- # "error": str(e)
801
- # }
802
-
803
- logger.info("Pipeline initialization complete!")
804
 
 
 
10
  import pandas as pd
11
  from pydantic import BaseModel, Field, ValidationError, validator
12
 
 
13
  import nltk
14
  from nltk.corpus import words
15
  try:
 
18
  nltk.download('words')
19
  english_words = set(words.words())
20
 
 
21
  from langchain_groq import ChatGroq
22
  from langchain_community.embeddings import HuggingFaceEmbeddings
23
  from langchain_community.vectorstores import FAISS
 
26
  from langchain.docstore.document import Document
27
  from langchain_core.caches import BaseCache
28
  from langchain_core.callbacks import Callbacks
 
 
 
 
29
 
 
 
30
  from chain.classification_chain import get_classification_chain
31
  from chain.refusal_chain import get_refusal_chain
32
  from chain.tailor_chain import get_tailor_chain
33
  from chain.cleaner_chain import get_cleaner_chain
34
  from chain.tailor_chain_wellnessBrand import get_tailor_chain_wellnessBrand
35
 
 
36
  from mistralai import Mistral
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  from smolagents import (
39
  CodeAgent,
40
  DuckDuckGoSearchTool,
 
43
  VisitWebpageTool,
44
  )
45
 
46
+ from chain.prompts import selfharm_prompt, frustration_prompt, ethical_conflict_prompt, classification_prompt, refusal_prompt, tailor_prompt, cleaner_prompt
 
 
47
 
48
  logging.basicConfig(level=logging.INFO)
49
  logger = logging.getLogger(__name__)
 
51
  from langchain_core.tracers import LangChainTracer
52
  from langsmith import Client
53
 
54
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
55
+ os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
56
+ os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
57
+ os.environ["LANGCHAIN_PROJECT"] = os.getenv("LANGCHAIN_PROJECT")
58
 
 
 
 
 
 
 
 
 
 
 
59
  # Basic Models
 
60
  class QueryInput(BaseModel):
61
  query: str = Field(..., min_length=1)
62
 
 
84
  / self.total_requests
85
  )
86
 
 
87
  # Mistral Moderation
 
88
  class ModerationResult(BaseModel):
89
  is_safe: bool
90
  categories: Dict[str, bool]
 
94
  client = Mistral(api_key=mistral_api_key)
95
 
96
  def moderate_text(query: str) -> ModerationResult:
97
+ """Moderates text using Mistral to detect unsafe content."""
 
 
98
  try:
99
  query_input = QueryInput(query=query)
100
  response = client.classifiers.moderate_chat(
 
126
  raise RuntimeError(f"Moderation failed: {e}")
127
 
128
  def compute_moderation_severity(mresult: ModerationResult) -> float:
129
+ """Computes severity score based on moderation flags."""
130
  severity = 0.0
131
  for flag in mresult.categories.values():
132
  if flag:
133
  severity += 0.3
134
  return min(severity, 1.0)
135
 
 
136
  # Models
 
137
  GROQ_MODELS = {
138
+ "default": "llama3-70b-8192",
139
  "classification": "qwen-qwq-32b",
140
+ "moderation": "mistral-moderation-latest",
141
+ "combination": "llama-3.3-70b-versatile"
142
  }
143
 
144
  MAX_RETRIES = 3
145
  RATE_LIMIT_REQUESTS = 60
146
  CACHE_SIZE_LIMIT = 1000
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  class NoCache(BaseCache):
149
+ """No-op cache implementation for ChatGroq."""
150
  def __init__(self):
151
  pass
152
 
 
159
  def clear(self):
160
  pass
161
 
 
162
  ChatGroq.model_rebuild()
163
+
164
  try:
165
  fallback_groq_api_key = os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY"))
166
  if not fallback_groq_api_key:
167
  logger.warning("No Groq API key found for fallback LLM")
168
  groq_fallback_llm = ChatGroq(
169
+ model=GROQ_MODELS["default"],
170
  temperature=0.7,
171
  groq_api_key=fallback_groq_api_key,
172
  max_tokens=2048,
173
+ cache=NoCache(),
174
+ callbacks=[]
175
  )
176
  except Exception as e:
177
  logger.error(f"Failed to initialize fallback Groq LLM: {e}")
178
  raise RuntimeError("ChatGroq initialization failed.") from e
179
+
180
  # Rate-limit & Cache
 
181
  def handle_rate_limiting(state: "PipelineState") -> bool:
182
+ """Enforces rate limiting based on request timestamps."""
183
  current_time = time.time()
184
  one_min_ago = current_time - 60
185
  state.request_timestamps = [t for t in state.request_timestamps if t > one_min_ago]
 
189
  return True
190
 
191
  def manage_cache(state: "PipelineState", query: str, response: str = None) -> Optional[str]:
192
+ """Manages cache for query responses."""
193
  cache_key = query.strip().lower()
194
  if response is None:
195
  return state.cache.get(cache_key)
 
201
  return None
202
 
203
  def create_error_response(error_type: str, details: str = "") -> str:
204
+ """Generates standardized error messages."""
205
  templates = {
206
  "validation": "I couldn't process your query: {details}",
207
  "processing": "I encountered an error while processing: {details}",
208
  "rate_limit": "Too many requests. Please try again soon.",
209
+ "general": "Apologies, but something went wrong."
210
  }
211
  return templates.get(error_type, templates["general"]).format(details=details)
212
 
 
213
  # Web Search
 
214
  web_search_cache: Dict[str, str] = {}
215
 
216
  def store_websearch_result(query: str, result: str):
 
220
  return web_search_cache.get(query.strip().lower())
221
 
222
  def do_web_search(query: str) -> str:
223
+ """Performs web search if no cached result exists."""
224
  try:
225
  cached = retrieve_websearch_result(query)
226
  if cached:
 
228
  return cached
229
 
230
  logger.info("Performing a new web search for: '%s'", query)
 
 
 
 
 
 
 
 
 
231
  search_agent = ToolCallingAgent(
232
+ tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
233
+ model=HfApiModel(),
234
+ name="search_agent",
235
+ description="This is an agent that can do web search.",
236
  )
237
 
238
  manager_agent = CodeAgent(
239
  tools=[],
240
+ model=HfApiModel(),
241
+ managed_agents=[search_agent]
242
  )
243
 
244
  new_search_result = manager_agent.run(f"Search for information about: {query}")
 
249
  return ""
250
 
251
  def is_greeting(query: str) -> bool:
252
+ """Detects if the query is a greeting."""
 
 
 
 
 
253
  greetings = {"hello", "hi", "hey", "hii", "hola", "greetings"}
 
 
254
  cleaned = re.sub(r'[^\w\s]', '', query).strip().lower()
 
 
255
  words_in_query = set(cleaned.split())
 
 
256
  return not words_in_query.isdisjoint(greetings)
257
 
 
 
258
  # Vector Stores & RAG
 
259
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
260
+ """Builds or loads FAISS vector store from CSV data."""
261
  if os.path.exists(store_dir):
262
  logger.info(f"Loading existing FAISS store from {store_dir}")
263
  embeddings = HuggingFaceEmbeddings(
264
  model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1"
265
  )
266
+ return FAISS.load_local(store_dir, embeddings, allow_dangerous_deserialization=True)
267
  else:
268
  logger.info(f"Building new FAISS store from {csv_path}")
269
  df = pd.read_csv(csv_path)
 
290
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
291
  vectorstore.save_local(store_dir)
292
  return vectorstore
293
+
294
  def build_rag_chain(vectorstore: FAISS, llm) -> RetrievalQA:
295
+ """Builds RAG chain for wellness queries."""
296
  prompt = PromptTemplate(
297
  template="""
298
  [INST] You are an AI wellness assistant speaking directly to a user who has asked: "{question}"
 
326
  }
327
  )
328
  return chain
329
+
330
  def build_rag_chain2(vectorstore: FAISS, llm) -> RetrievalQA:
331
+ """Builds RAG chain for brand strategy queries."""
332
  prompt = PromptTemplate(
333
  template="""
334
  [INST] You are the brand strategy advisor for Healthy AI Expert. A team member has asked: "{question}"
 
344
 
345
  Remember our key brand pillars: AI-driven personalization, scientific credibility, user-centric design, and innovation leadership.
346
  [/INST]
 
347
  """,
348
  input_variables=["context", "question"]
349
  )
 
362
  )
363
  return chain
364
 
 
365
  # PipelineState
 
366
  class PipelineState:
367
  _instance = None
368
 
 
378
  self._initialize()
379
 
380
  def _initialize(self):
381
+ """Initializes pipeline state and chains."""
382
  try:
383
  self.metrics = ProcessingMetrics()
384
  self.error_count = 0
 
395
  raise RuntimeError("Pipeline initialization failed.") from e
396
 
397
  def _setup_chains(self):
398
+ """Sets up all processing chains and vector stores."""
399
  self.tailor_chainWellnessBrand = get_tailor_chain_wellnessBrand()
400
  self.classification_chain = get_classification_chain()
401
+ self.refusal_chain = get_refusal_chain()
402
+ self.tailor_chain = get_tailor_chain()
403
+ self.cleaner_chain = get_cleaner_chain()
404
 
 
 
 
 
405
  self.self_harm_chain = LLMChain(llm=groq_fallback_llm, prompt=selfharm_prompt, verbose=False)
 
 
 
 
 
406
  self.frustration_chain = LLMChain(llm=groq_fallback_llm, prompt=frustration_prompt, verbose=False)
 
 
 
 
 
407
  self.ethical_conflict_chain = LLMChain(llm=groq_fallback_llm, prompt=ethical_conflict_prompt, verbose=False)
408
 
409
+ brand_csv = "dataset/BrandAI.csv"
410
+ brand_store = "faiss_brand_store"
 
411
  wellness_csv = "dataset/AIChatbot.csv"
412
  wellness_store = "faiss_wellness_store"
413
 
414
+ brand_vs = build_or_load_vectorstore(brand_csv, brand_store)
415
  wellness_vs = build_or_load_vectorstore(wellness_csv, wellness_store)
416
 
 
 
417
  self.groq_fallback_llm = groq_fallback_llm
418
+ self.brand_rag_chain = build_rag_chain2(brand_vs, self.groq_fallback_llm)
 
 
 
419
  self.wellness_rag_chain = build_rag_chain(wellness_vs, self.groq_fallback_llm)
 
 
420
 
421
  def handle_error(self, error: Exception) -> bool:
422
+ """Handles errors and triggers reset if needed."""
423
  self.error_count += 1
424
  self.metrics.errors += 1
425
  if self.error_count >= MAX_RETRIES:
 
429
  return True
430
 
431
  def reset(self):
432
+ """Resets pipeline state while preserving metrics."""
433
  try:
434
  logger.info("Resetting pipeline state.")
435
  old_metrics = self.metrics
 
445
  raise RuntimeError("Failed to reset pipeline.")
446
 
447
  def get_metrics(self) -> Dict[str, Any]:
448
+ """Returns pipeline performance metrics."""
449
  uptime = (datetime.now() - self.metrics.last_reset).total_seconds() / 3600
450
  return {
451
  "total_requests": self.metrics.total_requests,
 
456
  }
457
 
458
  def update_metrics(self, start_time: float, is_cache_hit: bool = False):
459
+ """Updates processing metrics."""
460
  duration = time.time() - start_time
461
  self.metrics.update_metrics(duration, is_cache_hit)
462
 
463
  pipeline_state = PipelineState()
464
 
465
+ # Helper Checks
 
 
 
466
  def is_aggressive_or_harsh(query: str) -> bool:
467
+ """Detects aggressive or harsh language in query."""
 
 
 
468
  triggers = ["useless", "worthless", "you cannot do anything", "so bad at answering"]
469
  for t in triggers:
470
  if t in query.lower():
 
472
  return False
473
 
474
  def is_ethical_conflict(query: str) -> bool:
475
+ """Detects ethical dilemmas in query."""
 
 
 
476
  ethics_keywords = ["should i lie", "should i cheat", "revenge", "get back at", "hurt them back"]
477
  q_lower = query.lower()
478
  return any(k in q_lower for k in ethics_keywords)
479
 
 
 
480
  # Main Pipeline
 
481
  def run_with_chain(query: str) -> str:
482
+ """Processes query through validation, moderation, and chains."""
483
+ start_time = time.time()
484
+ try:
485
+ if not query or query.strip() == "":
486
+ return create_error_response("validation", "Empty query.")
487
+ if len(query.strip()) < 2:
488
+ return create_error_response("validation", "Too short.")
489
+ words_in_text = re.findall(r'\b\w+\b', query.lower())
490
+ if not any(w in english_words for w in words_in_text):
491
+ return create_error_response("validation", "Unclear words.")
492
+ if len(query) > 500:
493
+ return create_error_response("validation", "Too long (>500).")
494
+ if not handle_rate_limiting(pipeline_state):
495
+ return create_error_response("rate_limit")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
+ if is_greeting(query):
498
+ greeting_response = "Hello there!! Welcome to Healthy AI Expert, How may I assist you today?"
499
+ manage_cache(pipeline_state, query, greeting_response)
500
+ pipeline_state.update_metrics(start_time)
501
+ return greeting_response
502
+
503
+ cached = manage_cache(pipeline_state, query)
504
+ if cached:
505
+ pipeline_state.update_metrics(start_time, is_cache_hit=True)
506
+ return cached
507
+
508
+ try:
509
+ mod_res = moderate_text(query)
510
+ severity = compute_moderation_severity(mod_res)
511
+
512
+ if mod_res.categories.get("selfharm", False):
513
+ logger.info("Self-harm flagged => providing supportive chain response.")
514
+ selfharm_resp = pipeline_state.self_harm_chain.run({"query": query})
515
+ final_tailored = pipeline_state.tailor_chain.run({"response": selfharm_resp}).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  manage_cache(pipeline_state, query, final_tailored)
517
  pipeline_state.update_metrics(start_time)
518
  return final_tailored
519
+
520
+ if mod_res.categories.get("hate", False):
521
+ logger.info("Hate content => refusal.")
522
+ refusal_resp = pipeline_state.refusal_chain.run({"topic": "moderation_flagged"})
523
+ manage_cache(pipeline_state, query, refusal_resp)
 
524
  pipeline_state.update_metrics(start_time)
525
+ return refusal_resp
526
+
527
+ except Exception as e:
528
+ logger.error(f"Moderation error: {e}")
529
+ severity = 0.0
530
+
531
+ if is_aggressive_or_harsh(query):
532
+ logger.info("Detected harsh/aggressive language => frustration_chain.")
533
+ frustration_resp = pipeline_state.frustration_chain.run({"query": query})
534
+ final_tailored = pipeline_state.tailor_chain.run({"response": frustration_resp}).strip()
535
+ manage_cache(pipeline_state, query, final_tailored)
536
+ pipeline_state.update_metrics(start_time)
537
+ return final_tailored
538
+
539
+ if is_ethical_conflict(query):
540
+ logger.info("Detected ethical dilemma => ethical_conflict_chain.")
541
+ ethical_resp = pipeline_state.ethical_conflict_chain.run({"query": query})
542
+ final_tailored = pipeline_state.tailor_chain.run({"response": ethical_resp}).strip()
543
+ manage_cache(pipeline_state, query, final_tailored)
544
+ pipeline_state.update_metrics(start_time)
545
+ return final_tailored
546
+
547
+ try:
548
+ class_out = pipeline_state.classification_chain.run({"query": query})
549
+ classification = class_out.strip().lower()
550
+ except Exception as e:
551
+ logger.error(f"Classification error: {e}")
552
+ if not pipeline_state.handle_error(e):
553
+ return create_error_response("processing", "Classification error.")
554
+ return create_error_response("processing")
555
+
556
+ if classification in ["outofscope", "out_of_scope"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  try:
558
+ refusal_text = pipeline_state.refusal_chain.run({"topic": query})
559
+ tailored_refusal = pipeline_state.tailor_chain.run({"response": refusal_text}).strip()
560
+ manage_cache(pipeline_state, query, tailored_refusal)
 
 
561
  pipeline_state.update_metrics(start_time)
562
+ return tailored_refusal
563
  except Exception as e:
564
+ logger.error(f"Refusal chain error: {e}")
565
  if not pipeline_state.handle_error(e):
566
+ return create_error_response("processing", "Refusal error.")
567
  return create_error_response("processing")
568
+
569
+ if classification == "brand":
570
+ rag_chain_main = pipeline_state.brand_rag_chain
571
+ else:
572
+ rag_chain_main = pipeline_state.wellness_rag_chain
573
+
574
+ try:
575
+ rag_output = rag_chain_main({"query": query})
576
+ if isinstance(rag_output, dict) and "result" in rag_output:
577
+ csv_ans = rag_output["result"].strip()
578
+ else:
579
+ csv_ans = str(rag_output).strip()
580
+
581
+ if "not enough context" in csv_ans.lower() or len(csv_ans) < 40:
582
+ logger.info("Insufficient RAG => web search.")
583
+ web_info = do_web_search(query)
584
+ if web_info:
585
+ csv_ans += f"\n\nAdditional info:\n{web_info}"
586
+ except Exception as e:
587
+ logger.error(f"RAG error: {e}")
588
+ if not pipeline_state.handle_error(e):
589
+ return create_error_response("processing", "RAG error.")
590
+ return create_error_response("processing")
591
+
592
+ try:
593
+ final_tailored = pipeline_state.tailor_chainWellnessBrand.run({"response": csv_ans}).strip()
594
+ if severity > 0.5:
595
+ final_tailored += "\n\n(Please note: This may involve sensitive content.)"
596
+
597
+ manage_cache(pipeline_state, query, final_tailored)
598
+ pipeline_state.update_metrics(start_time)
599
+ return final_tailored
600
  except Exception as e:
601
+ logger.error(f"Tailor chain error: {e}")
602
+ if not pipeline_state.handle_error(e):
603
+ return create_error_response("processing", "Tailoring error.")
604
+ return create_error_response("processing")
605
+
606
+ except Exception as e:
607
+ logger.error(f"Critical error in run_with_chain: {e}")
608
+ pipeline_state.metrics.errors += 1
609
+ return create_error_response("general")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
+ logger.info("Pipeline initialization complete!")