Spaces:
Sleeping
Sleeping
| import os | |
| import getpass | |
| import spacy | |
| import pandas as pd | |
| from typing import Optional | |
| from langchain.docstore.document import Document | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel | |
| import subprocess | |
| from langchain.llms.base import LLM | |
| # Mistral Client Setup | |
| from mistralai import Mistral | |
| from pydantic_ai import Agent # Import Pydantic AI's Agent | |
| # Initialize Mistral API client | |
| mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set | |
| client = Mistral(api_key=mistral_api_key) | |
| # Initialize Pydantic AI Agent (for text validation) | |
| pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str) | |
| # Load spaCy model for NER and download it if not already installed | |
| def install_spacy_model(): | |
| try: | |
| spacy.load("en_core_web_sm") | |
| print("spaCy model 'en_core_web_sm' is already installed.") | |
| except OSError: | |
| print("Downloading spaCy model 'en_core_web_sm'...") | |
| subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True) | |
| print("spaCy model 'en_core_web_sm' downloaded successfully.") | |
| install_spacy_model() | |
| nlp = spacy.load("en_core_web_sm") | |
| # Function to extract the main topic from the query using spaCy NER | |
| def extract_main_topic(query: str) -> str: | |
| """ | |
| Extracts the main topic from the user's query using spaCy's NER. | |
| Returns the first named entity or noun found in the query. | |
| """ | |
| doc = nlp(query) | |
| # Try to extract the main topic as a named entity (person, product, etc.) | |
| main_topic = None | |
| for ent in doc.ents: | |
| # Filter for specific entity types (you can adjust this based on your needs) | |
| if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed | |
| main_topic = ent.text | |
| break | |
| # If no named entity found, fallback to extracting the first noun or proper noun | |
| if not main_topic: | |
| for token in doc: | |
| if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun | |
| main_topic = token.text | |
| break | |
| # Return the extracted topic or a fallback value if no topic is found | |
| return main_topic if main_topic else "this topic" | |
| # Function to moderate text using Mistral moderation API | |
| def moderate_text(query: str) -> str: | |
| """ | |
| Classifies the query as harmful or not using Mistral Moderation via Mistral API. | |
| Returns "OutOfScope" if harmful, otherwise returns the original query. | |
| """ | |
| try: | |
| pydantic_agent.run_sync(query) # Validate input | |
| except Exception as e: | |
| print(f"Error validating text: {e}") | |
| return "Invalid text format." | |
| response = client.classifiers.moderate_chat( | |
| model="mistral-moderation-latest", | |
| inputs=[{"role": "user", "content": query}] | |
| ) | |
| categories = response['results'][0]['categories'] | |
| if categories.get("violence_and_threats", False) or \ | |
| categories.get("hate_and_discrimination", False) or \ | |
| categories.get("dangerous_and_criminal_content", False) or \ | |
| categories.get("selfharm", False): | |
| return "OutOfScope" | |
| return query | |
| # Build or load vectorstore function | |
| def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: | |
| if os.path.exists(store_dir): | |
| print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...") | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
| vectorstore = FAISS.load_local(store_dir, embeddings) | |
| return vectorstore | |
| else: | |
| print(f"DEBUG: Building new store from CSV: {csv_path}") | |
| df = pd.read_csv(csv_path) | |
| df = df.loc[:, ~df.columns.str.contains('^Unnamed')] | |
| df.columns = df.columns.str.strip() | |
| if "Answer" in df.columns: | |
| df.rename(columns={"Answer": "Answers"}, inplace=True) | |
| if "Question" not in df.columns and "Question " in df.columns: | |
| df.rename(columns={"Question ": "Question"}, inplace=True) | |
| if "Question" not in df.columns or "Answers" not in df.columns: | |
| raise ValueError("CSV must have 'Question' and 'Answers' columns.") | |
| docs = [] | |
| for _, row in df.iterrows(): | |
| q = str(row["Question"]) | |
| ans = str(row["Answers"]) | |
| doc = Document(page_content=ans, metadata={"question": q}) | |
| docs.append(doc) | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
| vectorstore = FAISS.from_documents(docs, embedding=embeddings) | |
| vectorstore.save_local(store_dir) | |
| return vectorstore | |
| # Build RAG chain for Gemini (no changes) | |
| def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA: | |
| class GeminiLangChainLLM(LLM): | |
| def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str: | |
| messages = [{"role": "user", "content": prompt}] | |
| return llm_model(messages, stop_sequences=stop) | |
| def _llm_type(self) -> str: | |
| return "custom_gemini" | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| gemini_as_llm = GeminiLangChainLLM() | |
| rag_chain = RetrievalQA.from_chain_type( | |
| llm=gemini_as_llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True | |
| ) | |
| return rag_chain | |
| # Initialize all the separate chains | |
| from classification_chain import get_classification_chain | |
| from refusal_chain import get_refusal_chain | |
| from tailor_chain import get_tailor_chain | |
| from cleaner_chain import get_cleaner_chain | |
| classification_chain = get_classification_chain() # Ensure this function is imported correctly | |
| refusal_chain = get_refusal_chain() # Ensure this function is imported correctly | |
| tailor_chain = get_tailor_chain() # Ensure this function is imported correctly | |
| cleaner_chain = get_cleaner_chain() # Ensure this function is imported correctly | |
| # Build our vectorstores + RAG chains | |
| wellness_csv = "AIChatbot.csv" | |
| brand_csv = "BrandAI.csv" | |
| wellness_store_dir = "faiss_wellness_store" | |
| brand_store_dir = "faiss_brand_store" | |
| wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir) | |
| brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir) | |
| gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY")) | |
| wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore) | |
| brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore) | |
| # Tools / Agents for web search | |
| search_tool = DuckDuckGoSearchTool() | |
| web_agent = CodeAgent(tools=[search_tool], model=gemini_llm) | |
| managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.") | |
| manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent]) | |
| def do_web_search(query: str) -> str: | |
| print("DEBUG: Attempting web search for more info...") | |
| search_query = f"Give me relevant info: {query}" | |
| response = manager_agent.run(search_query) | |
| return response | |
| # # Orchestrator: run_with_chain | |
| # def run_with_chain(query: str) -> str: | |
| # print("DEBUG: Starting run_with_chain...") | |
| # # Moderate the query for harmful content | |
| # moderated_query = moderate_text(query) | |
| # if moderated_query == "OutOfScope": | |
| # return "Sorry, this query contains harmful or inappropriate content." | |
| # # Classify the query | |
| # class_result = classification_chain.invoke({"query": moderated_query}) | |
| # classification = class_result.get("text", "").strip() | |
| # print("DEBUG: Classification =>", classification) | |
| # if classification == "OutOfScope": | |
| # refusal_text = refusal_chain.run({"topic": "this topic"}) | |
| # final_refusal = tailor_chain.run({"response": refusal_text}) | |
| # return final_refusal.strip() | |
| # if classification == "Wellness": | |
| # rag_result = wellness_rag_chain({"query": moderated_query}) | |
| # csv_answer = rag_result["result"].strip() | |
| # if not csv_answer: | |
| # web_answer = do_web_search(moderated_query) | |
| # else: | |
| # lower_ans = csv_answer.lower() | |
| # if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]): | |
| # web_answer = do_web_search(moderated_query) | |
| # else: | |
| # web_answer = "" | |
| # final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer) | |
| # final_answer = tailor_chain.run({"response": final_merged}) | |
| # return final_answer.strip() | |
| # if classification == "Brand": | |
| # rag_result = brand_rag_chain({"query": moderated_query}) | |
| # csv_answer = rag_result["result"].strip() | |
| # final_merged = cleaner_chain.merge(kb=csv_answer, web="") | |
| # final_answer = tailor_chain.run({"response": final_merged}) | |
| # return final_answer.strip() | |
| # refusal_text = refusal_chain.run({"topic": "this topic"}) | |
| # final_refusal = tailor_chain.run({"response": refusal_text}) | |
| # return final_refusal.strip() | |
| def run_with_chain(query: str) -> str: | |
| print("DEBUG: Starting run_with_chain...") | |
| # Moderate the query for harmful content | |
| moderated_query = moderate_text(query) | |
| if moderated_query == "OutOfScope": | |
| return "Sorry, this query contains harmful or inappropriate content." | |
| # Classify the query manually, ensuring box breathing is recognized | |
| classification = classify_query(moderated_query) | |
| print("DEBUG: Classification =>", classification) | |
| if classification == "OutOfScope": | |
| refusal_text = refusal_chain.run({"topic": "this topic"}) | |
| final_refusal = tailor_chain.run({"response": refusal_text}) | |
| return final_refusal.strip() | |
| if classification == "Wellness": | |
| rag_result = wellness_rag_chain({"query": moderated_query}) | |
| csv_answer = rag_result["result"].strip() | |
| if not csv_answer: | |
| web_answer = do_web_search(moderated_query) | |
| else: | |
| lower_ans = csv_answer.lower() | |
| if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]): | |
| web_answer = do_web_search(moderated_query) | |
| else: | |
| web_answer = "" | |
| final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer) | |
| final_answer = tailor_chain.run({"response": final_merged}) | |
| return final_answer.strip() | |
| if classification == "Brand": | |
| rag_result = brand_rag_chain({"query": moderated_query}) | |
| csv_answer = rag_result["result"].strip() | |
| final_merged = cleaner_chain.merge(kb=csv_answer, web="") | |
| final_answer = tailor_chain.run({"response": final_merged}) | |
| return final_answer.strip() | |
| refusal_text = refusal_chain.run({"topic": "this topic"}) | |
| final_refusal = tailor_chain.run({"response": refusal_text}) | |
| return final_refusal.strip() | |