Spaces:
GIZ
/
Running on CPU Upgrade

File size: 8,397 Bytes
773f59c
 
 
 
 
 
 
 
 
 
 
 
cc44d48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773f59c
 
 
 
 
 
 
d7ceee3
773f59c
d7ceee3
 
773f59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc44d48
 
773f59c
cc44d48
 
 
 
d7ceee3
 
cc44d48
773f59c
 
 
 
 
 
 
 
 
 
941faf0
cc3c7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ceee3
cc3c7b8
773f59c
 
 
 
 
d7ceee3
 
773f59c
 
 
 
 
 
 
 
 
cc44d48
 
773f59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa5fd5f
 
 
773f59c
 
 
cc44d48
773f59c
 
 
 
 
 
 
 
 
 
 
 
b52ae58
 
 
 
 
 
 
 
 
 
 
 
 
 
fa5fd5f
b52ae58
 
 
 
 
 
 
 
cc44d48
773f59c
 
 
 
 
 
 
cc44d48
773f59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c464ae6
 
5cfa298
c464ae6
 
773f59c
 
 
 
 
 
 
 
 
 
 
 
d7ceee3
773f59c
d7ceee3
773f59c
 
d7ceee3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import logging
import asyncio
import json
import ast
from typing import List, Dict, Any, Union
from dotenv import load_dotenv

# LangChain imports
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage

import os
import configparser


def getconfig(configfile_path: str):
    """
    Read the config file
    Params
    ----------------
    configfile_path: file path of .cfg file
    """
    config = configparser.ConfigParser()
    try:
        config.read_file(open(configfile_path))
        return config
    except:
        logging.warning("config file not found")

# ---------------------------------------------------------------------
# Provider-agnostic authentication and configuration
# ---------------------------------------------------------------------

def get_auth(provider: str) -> dict:
    """Get authentication configuration for different providers"""
    auth_configs = {
        "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
        "huggingface": {"api_key": os.getenv("HF_TOKEN")},
        "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
        "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
    }
    
    if provider not in auth_configs:
        raise ValueError(f"Unsupported provider: {provider}")
    
    auth_config = auth_configs[provider]
    api_key = auth_config.get("api_key")
    
    if not api_key:
        raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
    
    return auth_config

# ---------------------------------------------------------------------
# Model / client initialization (non exaustive list of providers)
# ---------------------------------------------------------------------

config = getconfig("model_params.cfg")

PROVIDER = config.get("generator", "PROVIDER")
MODEL = config.get("generator", "MODEL")
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
ORGANIZATION = config.get("generator", "ORGANIZATION")

# Set up authentication for the selected provider
auth_config = get_auth(PROVIDER)

def get_chat_model():
    """Initialize the appropriate LangChain chat model based on provider"""
    common_params = {
        "temperature": TEMPERATURE,
        "max_tokens": MAX_TOKENS,
    }
    
    #### Currently the option to fetach any other Generator type are disabled #####3
    # if PROVIDER == "openai":
    #     return ChatOpenAI(
    #         model=MODEL,
    #         openai_api_key=auth_config["api_key"],
    #         **common_params
    #     )
    # elif PROVIDER == "anthropic":
    #     return ChatAnthropic(
    #         model=MODEL,
    #         anthropic_api_key=auth_config["api_key"],
    #         **common_params
    #     )
    # elif PROVIDER == "cohere":
    #     return ChatCohere(
    #         model=MODEL,
    #         cohere_api_key=auth_config["api_key"],
    #         **common_params
    #     )
    if PROVIDER == "huggingface":
        # Initialize HuggingFaceEndpoint with explicit parameters
        llm = HuggingFaceEndpoint(
            repo_id=MODEL,
            huggingfacehub_api_token=auth_config["api_key"],
            task="text-generation",
            provider=INFERENCE_PROVIDER,     
            server_kwargs={"bill_to": ORGANIZATION},
            temperature=TEMPERATURE,
            max_new_tokens=MAX_TOKENS
        )
        return ChatHuggingFace(llm=llm)
    else:
        raise ValueError(f"Unsupported provider: {PROVIDER}")

# Initialize provider-agnostic chat model
chat_model = get_chat_model()


# ---------------------------------------------------------------------
# Core generation function for both Gradio UI and MCP
# ---------------------------------------------------------------------
async def _call_llm(messages: list) -> str:
    """
    Provider-agnostic LLM call using LangChain.
    
    Args:
        messages: List of LangChain message objects
        
    Returns:
        Generated response content as string
    """
    try:
        # Use async invoke for better performance
        response = await chat_model.ainvoke(messages)
        print(response)
        return response.content
        #return response.content.strip()
    except Exception as e:
        logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
        raise

def build_messages(question: str, context: str) -> list:
    """
    Build messages in LangChain format.
    
    Args:
        question: The user's question
        context: The relevant context for answering
        
    Returns:
        List of LangChain message objects
    """
    system_content = (
        """
        You are an expert assistant. Your task is to generate accurate, helpful responses using only the 
        information contained in the "CONTEXT" provided.

        Instructions:
        - Answer based only on provided context: Use only the information present in the retrieved_paragraphs below. Do not use any external knowledge or make assumptions beyond what is explicitly stated.
        - Language matching: Respond in the same language as the user's query.
        - Handle missing information: If the retrieved paragraphs do not contain sufficient information to answer the query, respond with "I don't know" or equivalent in the query language. If information is incomplete, state what you know and acknowledge limitations.
        - Be accurate and specific: When information is available, provide clear, specific answers. Include relevant details, useful facts, and numbers from the context.
        - Stay focused: Answer only what is asked. Do not provide additional information not requested.
        - Structure your response effectively:
                * Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
                * Use bullet points and lists when it makes sense to improve readability.
                * You do not need to use every passage. Only use the ones that help answer the question.
        - Format your response properly: Use markdown formatting (bullet points, numbered lists, headers) to make your response clear and easy to read. Example: <br> for linebreaks 
        
        Input Format:
        - Query: {query}
        - Retrieved Paragraphs: {retrieved_paragraphs}

        Generate your response based on these guidelines.

        """
    )
    
    user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
    
    return [
        SystemMessage(content=system_content),
        HumanMessage(content=user_content)
    ]

    
async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str:
    """
    Generate an answer to a query using provided context through RAG.
    
    This function takes a user query and relevant context, then uses a language model
    to generate a comprehensive answer based on the provided information.
    
    Args:
        query (str): User query
        context (list): List of retrieval result objects (dictionaries)
    Returns:
        str: The generated answer based on the query and context
    """
    if not query.strip():
        return "Error: Query cannot be empty"
    
    # Handle both string context (for Gradio UI) and list context (from retriever)
    if isinstance(context, list):
        if not context:
            return "Error: No retrieval results provided"
        
        # Process the retrieval results
        # processed_results = extract_relevant_fields(context)
        formatted_context = context
   
        # if not formatted_context.strip():
        #     return "Error: No valid content found in retrieval results"
    
    elif isinstance(context, str):
        if not context.strip():
            return "Error: Context cannot be empty"
        formatted_context = context
    
    else:
        return "Error: Context must be either a string or list of retrieval results"
    
    try:
        messages = build_messages(query, formatted_context)
        answer = await _call_llm(messages)

        return answer
        
    except Exception as e:
        logging.exception("Generation failed")
        return f"Error: {str(e)}"