Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
import asyncio | |
import json | |
import ast | |
from typing import List, Dict, Any, Union | |
from dotenv import load_dotenv | |
# LangChain imports | |
from langchain_openai import ChatOpenAI | |
from langchain_anthropic import ChatAnthropic | |
from langchain_cohere import ChatCohere | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_core.messages import SystemMessage, HumanMessage | |
import os | |
import configparser | |
def getconfig(configfile_path: str): | |
""" | |
Read the config file | |
Params | |
---------------- | |
configfile_path: file path of .cfg file | |
""" | |
config = configparser.ConfigParser() | |
try: | |
config.read_file(open(configfile_path)) | |
return config | |
except: | |
logging.warning("config file not found") | |
# --------------------------------------------------------------------- | |
# Provider-agnostic authentication and configuration | |
# --------------------------------------------------------------------- | |
def get_auth(provider: str) -> dict: | |
"""Get authentication configuration for different providers""" | |
auth_configs = { | |
"openai": {"api_key": os.getenv("OPENAI_API_KEY")}, | |
"huggingface": {"api_key": os.getenv("HF_TOKEN")}, | |
"anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")}, | |
"cohere": {"api_key": os.getenv("COHERE_API_KEY")}, | |
} | |
if provider not in auth_configs: | |
raise ValueError(f"Unsupported provider: {provider}") | |
auth_config = auth_configs[provider] | |
api_key = auth_config.get("api_key") | |
if not api_key: | |
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.") | |
return auth_config | |
# --------------------------------------------------------------------- | |
# Model / client initialization (non exaustive list of providers) | |
# --------------------------------------------------------------------- | |
config = getconfig("model_params.cfg") | |
PROVIDER = config.get("generator", "PROVIDER") | |
MODEL = config.get("generator", "MODEL") | |
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) | |
TEMPERATURE = float(config.get("generator", "TEMPERATURE")) | |
# Set up authentication for the selected provider | |
auth_config = get_auth(PROVIDER) | |
def get_chat_model(): | |
"""Initialize the appropriate LangChain chat model based on provider""" | |
common_params = { | |
"temperature": TEMPERATURE, | |
"max_tokens": MAX_TOKENS, | |
} | |
logging.info(f"provider is {PROVIDER}") | |
if PROVIDER == "openai": | |
return ChatOpenAI( | |
model=MODEL, | |
openai_api_key=auth_config["api_key"], | |
**common_params | |
) | |
elif PROVIDER == "anthropic": | |
return ChatAnthropic( | |
model=MODEL, | |
anthropic_api_key=auth_config["api_key"], | |
**common_params | |
) | |
elif PROVIDER == "cohere": | |
return ChatCohere( | |
model=MODEL, | |
cohere_api_key=auth_config["api_key"], | |
**common_params | |
) | |
elif PROVIDER == "huggingface": | |
# Initialize HuggingFaceEndpoint with explicit parameters | |
llm = HuggingFaceEndpoint( | |
repo_id=MODEL, | |
huggingfacehub_api_token=auth_config["api_key"], | |
task="text-generation", | |
temperature=TEMPERATURE, | |
max_new_tokens=MAX_TOKENS | |
) | |
return ChatHuggingFace(llm=llm) | |
else: | |
raise ValueError(f"Unsupported provider: {PROVIDER}") | |
# Initialize provider-agnostic chat model | |
chat_model = get_chat_model() | |
# --------------------------------------------------------------------- | |
# Context processing - may need further refinement (i.e. to manage other data sources) | |
# --------------------------------------------------------------------- | |
def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
""" | |
Extract only relevant fields from retrieval results. | |
Args: | |
retrieval_results: List of JSON objects from retriever | |
Returns: | |
List of processed objects with only relevant fields | |
""" | |
retrieval_results = ast.literal_eval(retrieval_results) | |
processed_results = [] | |
for result in retrieval_results: | |
# Extract the answer content | |
answer = result.get('answer', '') | |
# Extract document identification from metadata | |
metadata = result.get('answer_metadata', {}) | |
doc_info = { | |
'answer': answer, | |
'filename': metadata.get('filename', 'Unknown'), | |
'page': metadata.get('page', 'Unknown'), | |
'year': metadata.get('year', 'Unknown'), | |
'source': metadata.get('source', 'Unknown'), | |
'document_id': metadata.get('_id', 'Unknown') | |
} | |
processed_results.append(doc_info) | |
return processed_results | |
def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str: | |
""" | |
Format processed retrieval results into a context string for the LLM. | |
Args: | |
processed_results: List of processed objects with relevant fields | |
Returns: | |
Formatted context string | |
""" | |
if not processed_results: | |
return "" | |
context_parts = [] | |
for i, result in enumerate(processed_results, 1): | |
doc_reference = f"[Document {i}: {result['filename']}" | |
if result['page'] != 'Unknown': | |
doc_reference += f", Page {result['page']}" | |
if result['year'] != 'Unknown': | |
doc_reference += f", Year {result['year']}" | |
doc_reference += "]" | |
context_part = f"{doc_reference}\n{result['answer']}\n" | |
context_parts.append(context_part) | |
return "\n".join(context_parts) | |
# --------------------------------------------------------------------- | |
# Core generation function for both Gradio UI and MCP | |
# --------------------------------------------------------------------- | |
async def _call_llm(messages: list) -> str: | |
""" | |
Provider-agnostic LLM call using LangChain. | |
Args: | |
messages: List of LangChain message objects | |
Returns: | |
Generated response content as string | |
""" | |
try: | |
# Use async invoke for better performance | |
response = await chat_model.ainvoke(messages) | |
return response.content.strip() | |
except Exception as e: | |
logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") | |
raise | |
def build_messages(question: str, context: str) -> list: | |
""" | |
Build messages in LangChain format. | |
Args: | |
question: The user's question | |
context: The relevant context for answering | |
Returns: | |
List of LangChain message objects | |
""" | |
system_content = ( | |
"You are an expert assistant. Answer the USER question using only the " | |
"CONTEXT provided. If the context is insufficient say 'I don't know.'" | |
) | |
user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}" | |
return [ | |
SystemMessage(content=system_content), | |
HumanMessage(content=user_content) | |
] | |
async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str: | |
""" | |
Generate an answer to a query using provided context through RAG. | |
This function takes a user query and relevant context, then uses a language model | |
to generate a comprehensive answer based on the provided information. | |
Args: | |
query (str): User query | |
context (list): List of retrieval result objects (dictionaries) | |
Returns: | |
str: The generated answer based on the query and context | |
""" | |
if not query.strip(): | |
return "Error: Query cannot be empty" | |
# Handle both string context (for Gradio UI) and list context (from retriever) | |
if isinstance(context, list): | |
if not context: | |
return "Error: No retrieval results provided" | |
# Process the retrieval results | |
processed_results = extract_relevant_fields(context) | |
formatted_context = format_context_from_results(processed_results) | |
if not formatted_context.strip(): | |
return "Error: No valid content found in retrieval results" | |
elif isinstance(context, str): | |
if not context.strip(): | |
return "Error: Context cannot be empty" | |
formatted_context = context | |
else: | |
return "Error: Context must be either a string or list of retrieval results" | |
try: | |
messages = build_messages(query, formatted_context) | |
answer = await _call_llm(messages) | |
return answer | |
except Exception as e: | |
logging.exception("Generation failed") | |
return f"Error: {str(e)}" | |