Spaces:
Sleeping
Sleeping
| import os | |
| import getpass | |
| import spacy | |
| import pandas as pd | |
| from typing import Optional | |
| import subprocess | |
| from langchain.llms.base import LLM | |
| from langchain.docstore.document import Document | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel | |
| from pydantic_ai import Agent # Import Pydantic AI's Agent | |
| from mistralai import Mistral | |
| import asyncio # Needed for managing async tasks | |
| # Initialize Mistral API client | |
| mistral_api_key = os.environ.get("MISTRAL_API_KEY") | |
| client = Mistral(api_key=mistral_api_key) | |
| # Initialize Pydantic AI Agent (for text validation) | |
| pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str) | |
| # Load spaCy model for NER and download it if not already installed | |
| def install_spacy_model(): | |
| try: | |
| spacy.load("en_core_web_sm") | |
| print("spaCy model 'en_core_web_sm' is already installed.") | |
| except OSError: | |
| print("Downloading spaCy model 'en_core_web_sm'...") | |
| subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True) | |
| print("spaCy model 'en_core_web_sm' downloaded successfully.") | |
| install_spacy_model() | |
| nlp = spacy.load("en_core_web_sm") | |
| # Function to extract the main topic from the query using spaCy NER | |
| def extract_main_topic(query: str) -> str: | |
| doc = nlp(query) | |
| main_topic = None | |
| for ent in doc.ents: | |
| if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: | |
| main_topic = ent.text | |
| break | |
| if not main_topic: | |
| for token in doc: | |
| if token.pos_ in ["NOUN", "PROPN"]: | |
| main_topic = token.text | |
| break | |
| return main_topic if main_topic else "this topic" | |
| # Function to classify query based on wellness topics | |
| def classify_query(query: str) -> str: | |
| wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"] | |
| if any(keyword in query.lower() for keyword in wellness_keywords): | |
| return "Wellness" | |
| class_result = classification_chain.invoke({"query": query}) | |
| classification = class_result.get("text", "").strip() | |
| return classification if classification != "OutOfScope" else "OutOfScope" | |
| # Function to moderate text using Mistral moderation API (async version) | |
| async def moderate_text(query: str) -> str: | |
| try: | |
| await pydantic_agent.run(query) # Use async run for Pydantic validation | |
| except Exception as e: | |
| print(f"Error validating text: {e}") | |
| return "Invalid text format." | |
| response = await client.classifiers.moderate_chat( | |
| model="mistral-moderation-latest", | |
| inputs=[{"role": "user", "content": query}] | |
| ) | |
| categories = response['results'][0]['categories'] | |
| if categories.get("violence_and_threats", False) or \ | |
| categories.get("hate_and_discrimination", False) or \ | |
| categories.get("dangerous_and_criminal_content", False) or \ | |
| categories.get("selfharm", False): | |
| return "OutOfScope" | |
| return query | |
| # Use the event loop to run the async functions properly | |
| async def run_async_pipeline(query: str) -> str: | |
| # Moderate the query for harmful content (async) | |
| moderated_query = await moderate_text(query) | |
| if moderated_query == "OutOfScope": | |
| return "Sorry, this query contains harmful or inappropriate content." | |
| # Classify the query manually | |
| classification = classify_query(moderated_query) | |
| if classification == "OutOfScope": | |
| refusal_text = refusal_chain.run({"topic": "this topic"}) | |
| final_refusal = tailor_chain.run({"response": refusal_text}) | |
| return final_refusal.strip() | |
| if classification == "Wellness": | |
| rag_result = wellness_rag_chain({"query": moderated_query}) | |
| csv_answer = rag_result["result"].strip() | |
| web_answer = "" # Empty if we found an answer from the knowledge base | |
| if not csv_answer: | |
| web_answer = await do_web_search(moderated_query) | |
| final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer) | |
| final_answer = tailor_chain.run({"response": final_merged}) | |
| return final_answer.strip() | |
| if classification == "Brand": | |
| rag_result = brand_rag_chain({"query": moderated_query}) | |
| csv_answer = rag_result["result"].strip() | |
| final_merged = cleaner_chain.merge(kb=csv_answer, web="") | |
| final_answer = tailor_chain.run({"response": final_merged}) | |
| return final_answer.strip() | |
| refusal_text = refusal_chain.run({"topic": "this topic"}) | |
| final_refusal = tailor_chain.run({"response": refusal_text}) | |
| return final_refusal.strip() | |
| # Run the pipeline with the event loop | |
| def run_with_chain(query: str) -> str: | |
| return asyncio.run(run_async_pipeline(query)) | |