import logging import asyncio import json import ast from typing import List, Dict, Any, Union from dotenv import load_dotenv # LangChain imports from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_cohere import ChatCohere from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_core.messages import SystemMessage, HumanMessage import os import configparser def getconfig(configfile_path: str): """ Read the config file Params ---------------- configfile_path: file path of .cfg file """ config = configparser.ConfigParser() try: config.read_file(open(configfile_path)) return config except: logging.warning("config file not found") # --------------------------------------------------------------------- # Provider-agnostic authentication and configuration # --------------------------------------------------------------------- def get_auth(provider: str) -> dict: """Get authentication configuration for different providers""" auth_configs = { "openai": {"api_key": os.getenv("OPENAI_API_KEY")}, "huggingface": {"api_key": os.getenv("HF_TOKEN")}, "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")}, "cohere": {"api_key": os.getenv("COHERE_API_KEY")}, } if provider not in auth_configs: raise ValueError(f"Unsupported provider: {provider}") auth_config = auth_configs[provider] api_key = auth_config.get("api_key") if not api_key: raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.") return auth_config # --------------------------------------------------------------------- # Model / client initialization (non exaustive list of providers) # --------------------------------------------------------------------- config = getconfig("model_params.cfg") PROVIDER = config.get("generator", "PROVIDER") MODEL = config.get("generator", "MODEL") MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) TEMPERATURE = float(config.get("generator", "TEMPERATURE")) # Set up authentication for the selected provider auth_config = get_auth(PROVIDER) def get_chat_model(): """Initialize the appropriate LangChain chat model based on provider""" common_params = { "temperature": TEMPERATURE, "max_tokens": MAX_TOKENS, } logging.info(f"provider is {PROVIDER}") if PROVIDER == "openai": return ChatOpenAI( model=MODEL, openai_api_key=auth_config["api_key"], **common_params ) elif PROVIDER == "anthropic": return ChatAnthropic( model=MODEL, anthropic_api_key=auth_config["api_key"], **common_params ) elif PROVIDER == "cohere": return ChatCohere( model=MODEL, cohere_api_key=auth_config["api_key"], **common_params ) elif PROVIDER == "huggingface": # Initialize HuggingFaceEndpoint with explicit parameters llm = HuggingFaceEndpoint( repo_id=MODEL, huggingfacehub_api_token=auth_config["api_key"], task="text-generation", temperature=TEMPERATURE, max_new_tokens=MAX_TOKENS ) return ChatHuggingFace(llm=llm) else: raise ValueError(f"Unsupported provider: {PROVIDER}") # Initialize provider-agnostic chat model chat_model = get_chat_model() # --------------------------------------------------------------------- # Context processing - may need further refinement (i.e. to manage other data sources) # --------------------------------------------------------------------- def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Extract only relevant fields from retrieval results. Args: retrieval_results: List of JSON objects from retriever Returns: List of processed objects with only relevant fields """ retrieval_results = ast.literal_eval(retrieval_results) processed_results = [] for result in retrieval_results: # Extract the answer content answer = result.get('answer', '') # Extract document identification from metadata metadata = result.get('answer_metadata', {}) doc_info = { 'answer': answer, 'filename': metadata.get('filename', 'Unknown'), 'page': metadata.get('page', 'Unknown'), 'year': metadata.get('year', 'Unknown'), 'source': metadata.get('source', 'Unknown'), 'document_id': metadata.get('_id', 'Unknown') } processed_results.append(doc_info) return processed_results def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str: """ Format processed retrieval results into a context string for the LLM. Args: processed_results: List of processed objects with relevant fields Returns: Formatted context string """ if not processed_results: return "" context_parts = [] for i, result in enumerate(processed_results, 1): doc_reference = f"[Document {i}: {result['filename']}" if result['page'] != 'Unknown': doc_reference += f", Page {result['page']}" if result['year'] != 'Unknown': doc_reference += f", Year {result['year']}" doc_reference += "]" context_part = f"{doc_reference}\n{result['answer']}\n" context_parts.append(context_part) return "\n".join(context_parts) # --------------------------------------------------------------------- # Core generation function for both Gradio UI and MCP # --------------------------------------------------------------------- async def _call_llm(messages: list) -> str: """ Provider-agnostic LLM call using LangChain. Args: messages: List of LangChain message objects Returns: Generated response content as string """ try: # Use async invoke for better performance response = await chat_model.ainvoke(messages) return response.content.strip() except Exception as e: logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") raise def build_messages(question: str, context: str) -> list: """ Build messages in LangChain format. Args: question: The user's question context: The relevant context for answering Returns: List of LangChain message objects """ system_content = ( "You are an expert assistant. Answer the USER question using only the " "CONTEXT provided. If the context is insufficient say 'I don't know.'" ) user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}" return [ SystemMessage(content=system_content), HumanMessage(content=user_content) ] async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str: """ Generate an answer to a query using provided context through RAG. This function takes a user query and relevant context, then uses a language model to generate a comprehensive answer based on the provided information. Args: query (str): User query context (list): List of retrieval result objects (dictionaries) Returns: str: The generated answer based on the query and context """ if not query.strip(): return "Error: Query cannot be empty" # Handle both string context (for Gradio UI) and list context (from retriever) if isinstance(context, list): if not context: return "Error: No retrieval results provided" # Process the retrieval results processed_results = extract_relevant_fields(context) formatted_context = format_context_from_results(processed_results) if not formatted_context.strip(): return "Error: No valid content found in retrieval results" elif isinstance(context, str): if not context.strip(): return "Error: Context cannot be empty" formatted_context = context else: return "Error: Context must be either a string or list of retrieval results" try: messages = build_messages(query, formatted_context) answer = await _call_llm(messages) return answer except Exception as e: logging.exception("Generation failed") return f"Error: {str(e)}"