Spaces:
GIZ
/
Running on CPU Upgrade

leavoigt commited on
Commit
773f59c
·
1 Parent(s): 6e25856

add generator

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. model_params.cfg +1 -9
  3. utils/generator.py +231 -22
app.py CHANGED
@@ -199,7 +199,7 @@ def retrieve_paragraphs(query):
199
  """Connect to retriever and retrieve paragraphs"""
200
  try:
201
  # Call the API with the uploaded file
202
- client = Client("https://giz-chatfed-retriever.hf.space/")
203
  result = client.predict(
204
  query=query,
205
  reports_filter="",
 
199
  """Connect to retriever and retrieve paragraphs"""
200
  try:
201
  # Call the API with the uploaded file
202
+ client = Client("https://giz-eudr-retriever.hf.space/")
203
  result = client.predict(
204
  query=query,
205
  reports_filter="",
model_params.cfg CHANGED
@@ -1,12 +1,3 @@
1
- [retriever]
2
- MODEL = BAAI/bge-m3
3
- NORMALIZE = 1
4
- TOP_K = 20
5
-
6
- [ranker]
7
- MODEL = BAAI/bge-reranker-v2-m3
8
- TOP_K = 5
9
-
10
  [generator]
11
  PROVIDER = huggingface
12
  MODEL = meta-llama/Meta-Llama-3-8B-Instruct
@@ -22,5 +13,6 @@ NVIDIA_MODEL = meta-llama/Llama-3.1-8B-Instruct
22
  NVIDIA_ENDPOINT = https://huggingface.co/api/integrations/dgx/v1
23
  MAX_TOKENS = 768
24
  INF_PROVIDER = nebius
 
25
  [app]
26
  dropdown_default = Annual Consolidated OAG 2024
 
 
 
 
 
 
 
 
 
 
1
  [generator]
2
  PROVIDER = huggingface
3
  MODEL = meta-llama/Meta-Llama-3-8B-Instruct
 
13
  NVIDIA_ENDPOINT = https://huggingface.co/api/integrations/dgx/v1
14
  MAX_TOKENS = 768
15
  INF_PROVIDER = nebius
16
+
17
  [app]
18
  dropdown_default = Annual Consolidated OAG 2024
utils/generator.py CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import configparser
2
 
3
 
@@ -15,42 +30,236 @@ def getconfig(configfile_path: str):
15
  except:
16
  logging.warning("config file not found")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  config = getconfig("model_params.cfg")
 
20
  PROVIDER = config.get("generator", "PROVIDER")
21
  MODEL = config.get("generator", "MODEL")
22
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
23
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def generate_response(chunks, user_query, model):
 
 
 
27
  """
28
- Generator function to produce a response text to the initial user query
29
- using the retrieved document chunks and a model.
30
 
31
  Args:
32
- chunks (list of dict): Retrieved chunks, each with 'answer' and 'score'.
33
- user_query (str): The initial query text from the user.
34
- model (callable or object): Language model or interface to generate text (stub).
 
 
 
 
 
 
35
 
36
- Yields:
37
- str: Generated text response to the query.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Aggregate text from relevant chunks for context
41
- context_text = "\n\n".join(chunk['answer'] for chunk in relevant_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Step 3: Compose prompt for model (simulate or call model)
44
- prompt = (
45
- f"You are an assistant responding to the query:\n\"{user_query}\"\n\n"
46
- f"Use the following information to answer:\n{context_text}\n\n"
47
- f"Provide a clear, concise, and informative answer."
 
 
 
 
 
 
 
 
 
48
  )
 
 
 
 
 
 
 
49
 
50
- # Step 4: Use model to generate the response text
51
- # This is a placeholder; replace with actual model call, e.g.:
52
- # response = model.generate_text(prompt)
53
- # For demo, just yield the prompt as a stub.
54
-
55
- # Yield the final response once
56
- yield f"Simulated response based on retrieved info:\n\n{prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ import json
4
+ import ast
5
+ from typing import List, Dict, Any, Union
6
+ from dotenv import load_dotenv
7
+
8
+ # LangChain imports
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_anthropic import ChatAnthropic
11
+ from langchain_cohere import ChatCohere
12
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+
15
+ import os
16
  import configparser
17
 
18
 
 
30
  except:
31
  logging.warning("config file not found")
32
 
33
+ # ---------------------------------------------------------------------
34
+ # Provider-agnostic authentication and configuration
35
+ # ---------------------------------------------------------------------
36
+
37
+ def get_auth(provider: str) -> dict:
38
+ """Get authentication configuration for different providers"""
39
+ auth_configs = {
40
+ "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
41
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
42
+ "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
43
+ "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
44
+ }
45
+
46
+ if provider not in auth_configs:
47
+ raise ValueError(f"Unsupported provider: {provider}")
48
+
49
+ auth_config = auth_configs[provider]
50
+ api_key = auth_config.get("api_key")
51
+
52
+ if not api_key:
53
+ raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
54
+
55
+ return auth_config
56
+
57
+ # ---------------------------------------------------------------------
58
+ # Model / client initialization (non exaustive list of providers)
59
+ # ---------------------------------------------------------------------
60
 
61
  config = getconfig("model_params.cfg")
62
+
63
  PROVIDER = config.get("generator", "PROVIDER")
64
  MODEL = config.get("generator", "MODEL")
65
  MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
66
  TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
67
 
68
+ # Set up authentication for the selected provider
69
+ auth_config = get_auth(PROVIDER)
70
+
71
+ def get_chat_model():
72
+ """Initialize the appropriate LangChain chat model based on provider"""
73
+ common_params = {
74
+ "temperature": TEMPERATURE,
75
+ "max_tokens": MAX_TOKENS,
76
+ }
77
+ logging.info(f"provider is {PROVIDER}")
78
+
79
+ if PROVIDER == "openai":
80
+ return ChatOpenAI(
81
+ model=MODEL,
82
+ openai_api_key=auth_config["api_key"],
83
+ **common_params
84
+ )
85
+ elif PROVIDER == "anthropic":
86
+ return ChatAnthropic(
87
+ model=MODEL,
88
+ anthropic_api_key=auth_config["api_key"],
89
+ **common_params
90
+ )
91
+ elif PROVIDER == "cohere":
92
+ return ChatCohere(
93
+ model=MODEL,
94
+ cohere_api_key=auth_config["api_key"],
95
+ **common_params
96
+ )
97
+ elif PROVIDER == "huggingface":
98
+ # Initialize HuggingFaceEndpoint with explicit parameters
99
+ llm = HuggingFaceEndpoint(
100
+ repo_id=MODEL,
101
+ huggingfacehub_api_token=auth_config["api_key"],
102
+ task="text-generation",
103
+ temperature=TEMPERATURE,
104
+ max_new_tokens=MAX_TOKENS
105
+ )
106
+ return ChatHuggingFace(llm=llm)
107
+ else:
108
+ raise ValueError(f"Unsupported provider: {PROVIDER}")
109
+
110
+ # Initialize provider-agnostic chat model
111
+ chat_model = get_chat_model()
112
 
113
+ # ---------------------------------------------------------------------
114
+ # Context processing - may need further refinement (i.e. to manage other data sources)
115
+ # ---------------------------------------------------------------------
116
+ def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
117
  """
118
+ Extract only relevant fields from retrieval results.
 
119
 
120
  Args:
121
+ retrieval_results: List of JSON objects from retriever
122
+
123
+ Returns:
124
+ List of processed objects with only relevant fields
125
+ """
126
+
127
+ retrieval_results = ast.literal_eval(retrieval_results)
128
+
129
+ processed_results = []
130
 
131
+ for result in retrieval_results:
132
+ # Extract the answer content
133
+ answer = result.get('answer', '')
134
+
135
+ # Extract document identification from metadata
136
+ metadata = result.get('answer_metadata', {})
137
+ doc_info = {
138
+ 'answer': answer,
139
+ 'filename': metadata.get('filename', 'Unknown'),
140
+ 'page': metadata.get('page', 'Unknown'),
141
+ 'year': metadata.get('year', 'Unknown'),
142
+ 'source': metadata.get('source', 'Unknown'),
143
+ 'document_id': metadata.get('_id', 'Unknown')
144
+ }
145
+
146
+ processed_results.append(doc_info)
147
+
148
+ return processed_results
149
+
150
+ def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str:
151
  """
152
+ Format processed retrieval results into a context string for the LLM.
153
+
154
+ Args:
155
+ processed_results: List of processed objects with relevant fields
156
+
157
+ Returns:
158
+ Formatted context string
159
+ """
160
+ if not processed_results:
161
+ return ""
162
+
163
+ context_parts = []
164
+
165
+ for i, result in enumerate(processed_results, 1):
166
+ doc_reference = f"[Document {i}: {result['filename']}"
167
+ if result['page'] != 'Unknown':
168
+ doc_reference += f", Page {result['page']}"
169
+ if result['year'] != 'Unknown':
170
+ doc_reference += f", Year {result['year']}"
171
+ doc_reference += "]"
172
+
173
+ context_part = f"{doc_reference}\n{result['answer']}\n"
174
+ context_parts.append(context_part)
175
+
176
+ return "\n".join(context_parts)
177
 
178
+ # ---------------------------------------------------------------------
179
+ # Core generation function for both Gradio UI and MCP
180
+ # ---------------------------------------------------------------------
181
+ async def _call_llm(messages: list) -> str:
182
+ """
183
+ Provider-agnostic LLM call using LangChain.
184
+
185
+ Args:
186
+ messages: List of LangChain message objects
187
+
188
+ Returns:
189
+ Generated response content as string
190
+ """
191
+ try:
192
+ # Use async invoke for better performance
193
+ response = await chat_model.ainvoke(messages)
194
+ return response.content.strip()
195
+ except Exception as e:
196
+ logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
197
+ raise
198
 
199
+ def build_messages(question: str, context: str) -> list:
200
+ """
201
+ Build messages in LangChain format.
202
+
203
+ Args:
204
+ question: The user's question
205
+ context: The relevant context for answering
206
+
207
+ Returns:
208
+ List of LangChain message objects
209
+ """
210
+ system_content = (
211
+ "You are an expert assistant. Answer the USER question using only the "
212
+ "CONTEXT provided. If the context is insufficient say 'I don't know.'"
213
  )
214
+
215
+ user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
216
+
217
+ return [
218
+ SystemMessage(content=system_content),
219
+ HumanMessage(content=user_content)
220
+ ]
221
 
222
+
223
+ async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str:
224
+ """
225
+ Generate an answer to a query using provided context through RAG.
226
+
227
+ This function takes a user query and relevant context, then uses a language model
228
+ to generate a comprehensive answer based on the provided information.
229
+
230
+ Args:
231
+ query (str): User query
232
+ context (list): List of retrieval result objects (dictionaries)
233
+ Returns:
234
+ str: The generated answer based on the query and context
235
+ """
236
+ if not query.strip():
237
+ return "Error: Query cannot be empty"
238
+
239
+ # Handle both string context (for Gradio UI) and list context (from retriever)
240
+ if isinstance(context, list):
241
+ if not context:
242
+ return "Error: No retrieval results provided"
243
+
244
+ # Process the retrieval results
245
+ processed_results = extract_relevant_fields(context)
246
+ formatted_context = format_context_from_results(processed_results)
247
+
248
+ if not formatted_context.strip():
249
+ return "Error: No valid content found in retrieval results"
250
+
251
+ elif isinstance(context, str):
252
+ if not context.strip():
253
+ return "Error: Context cannot be empty"
254
+ formatted_context = context
255
+
256
+ else:
257
+ return "Error: Context must be either a string or list of retrieval results"
258
+
259
+ try:
260
+ messages = build_messages(query, formatted_context)
261
+ answer = await _call_llm(messages)
262
+ return answer
263
+ except Exception as e:
264
+ logging.exception("Generation failed")
265
+ return f"Error: {str(e)}"