Spaces:
Sleeping
Sleeping
Commit
·
95a7e5a
1
Parent(s):
37952cf
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import os
|
2 |
import pickle
|
3 |
import numpy as np
|
4 |
-
from
|
5 |
-
from
|
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForSeq2SeqLM,
|
@@ -17,13 +18,26 @@ import nltk
|
|
17 |
import torch
|
18 |
import pandas as pd
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Global variables for models and data
|
24 |
models = {}
|
25 |
data = {}
|
26 |
|
|
|
|
|
|
|
|
|
27 |
def init_nltk():
|
28 |
"""Initialize NLTK resources"""
|
29 |
try:
|
@@ -78,7 +92,6 @@ def load_embeddings():
|
|
78 |
print(f"Error: {embeddings_path} not found")
|
79 |
return False
|
80 |
|
81 |
-
# Custom unpickler to handle potential compatibility issues
|
82 |
class CustomUnpickler(pickle.Unpickler):
|
83 |
def find_class(self, module, name):
|
84 |
if module == "__main__":
|
@@ -193,11 +206,9 @@ def generate_answer(query, context, max_length=860, temperature=0.2):
|
|
193 |
|
194 |
response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
|
195 |
|
196 |
-
# Clean up the response
|
197 |
if "Answer:" in response:
|
198 |
response = response.split("Answer:")[-1].strip()
|
199 |
|
200 |
-
# Remove incomplete sentences at the end
|
201 |
sentences = nltk.sent_tokenize(response)
|
202 |
if sentences:
|
203 |
return " ".join(sentences)
|
@@ -247,8 +258,8 @@ def rerank_documents(query, doc_texts):
|
|
247 |
print(f"Error reranking documents: {e}")
|
248 |
return np.zeros(len(doc_texts))
|
249 |
|
250 |
-
@app.
|
251 |
-
def health_check():
|
252 |
"""Health check endpoint"""
|
253 |
status = {
|
254 |
'status': 'healthy',
|
@@ -256,95 +267,79 @@ def health_check():
|
|
256 |
'embeddings_loaded': bool(data.get('embeddings')),
|
257 |
'documents_loaded': not data.get('df', pd.DataFrame()).empty
|
258 |
}
|
259 |
-
return
|
260 |
|
261 |
-
@app.
|
262 |
-
def process_query():
|
263 |
"""Main query processing endpoint"""
|
264 |
try:
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
data = request.json
|
269 |
-
if not data or 'query' not in data:
|
270 |
-
return jsonify({'error': 'No query provided', 'success': False}), 400
|
271 |
|
272 |
-
query_text = data['query']
|
273 |
-
language_code = data.get('language_code', 0)
|
274 |
-
|
275 |
-
# Basic response if no models or data are loaded
|
276 |
if not models or not data.get('embeddings'):
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
|
282 |
-
# Process query with available models and data
|
283 |
try:
|
284 |
-
# Handle Arabic queries
|
285 |
if language_code == 0:
|
286 |
query_text = translate_text(query_text, 'ar_to_en')
|
287 |
|
288 |
-
# Get query embedding and find relevant documents
|
289 |
query_embedding = models['embedding'].encode([query_text])
|
290 |
relevant_docs = query_embeddings(query_embedding)
|
291 |
|
292 |
if not relevant_docs:
|
293 |
-
return
|
294 |
'answer': 'No relevant information found. Please try a different query.',
|
295 |
'success': True
|
296 |
-
}
|
297 |
|
298 |
-
# Retrieve and process documents
|
299 |
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
300 |
doc_texts = [text for text in doc_texts if text.strip()]
|
301 |
|
302 |
if not doc_texts:
|
303 |
-
return
|
304 |
'answer': 'Unable to retrieve relevant documents. Please try again.',
|
305 |
'success': True
|
306 |
-
}
|
307 |
|
308 |
-
# Rerank documents
|
309 |
rerank_scores = rerank_documents(query_text, doc_texts)
|
310 |
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
|
311 |
|
312 |
-
# Combine top documents
|
313 |
context = " ".join(ranked_texts[:3])
|
314 |
-
|
315 |
-
# Generate answer
|
316 |
answer = generate_answer(query_text, context)
|
317 |
|
318 |
-
# Translate answer back to Arabic if needed
|
319 |
if language_code == 0:
|
320 |
answer = translate_text(answer, 'en_to_ar')
|
321 |
|
322 |
-
return
|
323 |
'answer': answer,
|
324 |
'success': True
|
325 |
-
}
|
326 |
|
327 |
except Exception as e:
|
328 |
print(f"Error processing query: {e}")
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
|
334 |
except Exception as e:
|
335 |
print(f"Error in process_query: {e}")
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
|
341 |
-
# Initialize
|
342 |
print("Initializing application...")
|
343 |
init_success = init_nltk() and load_models() and load_data()
|
344 |
|
345 |
if not init_success:
|
346 |
print("Warning: Application initialized with partial functionality")
|
347 |
|
|
|
348 |
if __name__ == "__main__":
|
349 |
-
|
350 |
-
|
|
|
1 |
import os
|
2 |
import pickle
|
3 |
import numpy as np
|
4 |
+
from fastapi import FastAPI, HTTPException
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from pydantic import BaseModel
|
7 |
from transformers import (
|
8 |
AutoTokenizer,
|
9 |
AutoModelForSeq2SeqLM,
|
|
|
18 |
import torch
|
19 |
import pandas as pd
|
20 |
|
21 |
+
# Initialize FastAPI app
|
22 |
+
app = FastAPI()
|
23 |
+
|
24 |
+
# Add CORS middleware
|
25 |
+
app.add_middleware(
|
26 |
+
CORSMiddleware,
|
27 |
+
allow_origins=["*"],
|
28 |
+
allow_credentials=True,
|
29 |
+
allow_methods=["*"],
|
30 |
+
allow_headers=["*"],
|
31 |
+
)
|
32 |
|
33 |
# Global variables for models and data
|
34 |
models = {}
|
35 |
data = {}
|
36 |
|
37 |
+
class QueryRequest(BaseModel):
|
38 |
+
query: str
|
39 |
+
language_code: int = 0
|
40 |
+
|
41 |
def init_nltk():
|
42 |
"""Initialize NLTK resources"""
|
43 |
try:
|
|
|
92 |
print(f"Error: {embeddings_path} not found")
|
93 |
return False
|
94 |
|
|
|
95 |
class CustomUnpickler(pickle.Unpickler):
|
96 |
def find_class(self, module, name):
|
97 |
if module == "__main__":
|
|
|
206 |
|
207 |
response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
|
208 |
|
|
|
209 |
if "Answer:" in response:
|
210 |
response = response.split("Answer:")[-1].strip()
|
211 |
|
|
|
212 |
sentences = nltk.sent_tokenize(response)
|
213 |
if sentences:
|
214 |
return " ".join(sentences)
|
|
|
258 |
print(f"Error reranking documents: {e}")
|
259 |
return np.zeros(len(doc_texts))
|
260 |
|
261 |
+
@app.get("/health")
|
262 |
+
async def health_check():
|
263 |
"""Health check endpoint"""
|
264 |
status = {
|
265 |
'status': 'healthy',
|
|
|
267 |
'embeddings_loaded': bool(data.get('embeddings')),
|
268 |
'documents_loaded': not data.get('df', pd.DataFrame()).empty
|
269 |
}
|
270 |
+
return status
|
271 |
|
272 |
+
@app.post("/api/query")
|
273 |
+
async def process_query(request: QueryRequest):
|
274 |
"""Main query processing endpoint"""
|
275 |
try:
|
276 |
+
query_text = request.query
|
277 |
+
language_code = request.language_code
|
|
|
|
|
|
|
|
|
278 |
|
|
|
|
|
|
|
|
|
279 |
if not models or not data.get('embeddings'):
|
280 |
+
raise HTTPException(
|
281 |
+
status_code=503,
|
282 |
+
detail="The system is currently initializing. Please try again in a few minutes."
|
283 |
+
)
|
284 |
|
|
|
285 |
try:
|
|
|
286 |
if language_code == 0:
|
287 |
query_text = translate_text(query_text, 'ar_to_en')
|
288 |
|
|
|
289 |
query_embedding = models['embedding'].encode([query_text])
|
290 |
relevant_docs = query_embeddings(query_embedding)
|
291 |
|
292 |
if not relevant_docs:
|
293 |
+
return {
|
294 |
'answer': 'No relevant information found. Please try a different query.',
|
295 |
'success': True
|
296 |
+
}
|
297 |
|
|
|
298 |
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
299 |
doc_texts = [text for text in doc_texts if text.strip()]
|
300 |
|
301 |
if not doc_texts:
|
302 |
+
return {
|
303 |
'answer': 'Unable to retrieve relevant documents. Please try again.',
|
304 |
'success': True
|
305 |
+
}
|
306 |
|
|
|
307 |
rerank_scores = rerank_documents(query_text, doc_texts)
|
308 |
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
|
309 |
|
|
|
310 |
context = " ".join(ranked_texts[:3])
|
|
|
|
|
311 |
answer = generate_answer(query_text, context)
|
312 |
|
|
|
313 |
if language_code == 0:
|
314 |
answer = translate_text(answer, 'en_to_ar')
|
315 |
|
316 |
+
return {
|
317 |
'answer': answer,
|
318 |
'success': True
|
319 |
+
}
|
320 |
|
321 |
except Exception as e:
|
322 |
print(f"Error processing query: {e}")
|
323 |
+
raise HTTPException(
|
324 |
+
status_code=500,
|
325 |
+
detail="An error occurred while processing your query"
|
326 |
+
)
|
327 |
|
328 |
except Exception as e:
|
329 |
print(f"Error in process_query: {e}")
|
330 |
+
raise HTTPException(
|
331 |
+
status_code=500,
|
332 |
+
detail=str(e)
|
333 |
+
)
|
334 |
|
335 |
+
# Initialize application
|
336 |
print("Initializing application...")
|
337 |
init_success = init_nltk() and load_models() and load_data()
|
338 |
|
339 |
if not init_success:
|
340 |
print("Warning: Application initialized with partial functionality")
|
341 |
|
342 |
+
# For running locally
|
343 |
if __name__ == "__main__":
|
344 |
+
import uvicorn
|
345 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|