thechaiexperiment commited on
Commit
95a7e5a
·
1 Parent(s): 37952cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -52
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  import pickle
3
  import numpy as np
4
- from flask import Flask, request, jsonify
5
- from flask_cors import CORS
 
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForSeq2SeqLM,
@@ -17,13 +18,26 @@ import nltk
17
  import torch
18
  import pandas as pd
19
 
20
- app = Flask(__name__)
21
- CORS(app)
 
 
 
 
 
 
 
 
 
22
 
23
  # Global variables for models and data
24
  models = {}
25
  data = {}
26
 
 
 
 
 
27
  def init_nltk():
28
  """Initialize NLTK resources"""
29
  try:
@@ -78,7 +92,6 @@ def load_embeddings():
78
  print(f"Error: {embeddings_path} not found")
79
  return False
80
 
81
- # Custom unpickler to handle potential compatibility issues
82
  class CustomUnpickler(pickle.Unpickler):
83
  def find_class(self, module, name):
84
  if module == "__main__":
@@ -193,11 +206,9 @@ def generate_answer(query, context, max_length=860, temperature=0.2):
193
 
194
  response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
195
 
196
- # Clean up the response
197
  if "Answer:" in response:
198
  response = response.split("Answer:")[-1].strip()
199
 
200
- # Remove incomplete sentences at the end
201
  sentences = nltk.sent_tokenize(response)
202
  if sentences:
203
  return " ".join(sentences)
@@ -247,8 +258,8 @@ def rerank_documents(query, doc_texts):
247
  print(f"Error reranking documents: {e}")
248
  return np.zeros(len(doc_texts))
249
 
250
- @app.route('/health', methods=['GET'])
251
- def health_check():
252
  """Health check endpoint"""
253
  status = {
254
  'status': 'healthy',
@@ -256,95 +267,79 @@ def health_check():
256
  'embeddings_loaded': bool(data.get('embeddings')),
257
  'documents_loaded': not data.get('df', pd.DataFrame()).empty
258
  }
259
- return jsonify(status)
260
 
261
- @app.route('/api/query', methods=['POST'])
262
- def process_query():
263
  """Main query processing endpoint"""
264
  try:
265
- if not request.is_json:
266
- return jsonify({'error': 'Request must be JSON', 'success': False}), 400
267
-
268
- data = request.json
269
- if not data or 'query' not in data:
270
- return jsonify({'error': 'No query provided', 'success': False}), 400
271
 
272
- query_text = data['query']
273
- language_code = data.get('language_code', 0)
274
-
275
- # Basic response if no models or data are loaded
276
  if not models or not data.get('embeddings'):
277
- return jsonify({
278
- 'answer': 'The system is currently initializing. Please try again in a few minutes.',
279
- 'success': False
280
- }), 503
281
 
282
- # Process query with available models and data
283
  try:
284
- # Handle Arabic queries
285
  if language_code == 0:
286
  query_text = translate_text(query_text, 'ar_to_en')
287
 
288
- # Get query embedding and find relevant documents
289
  query_embedding = models['embedding'].encode([query_text])
290
  relevant_docs = query_embeddings(query_embedding)
291
 
292
  if not relevant_docs:
293
- return jsonify({
294
  'answer': 'No relevant information found. Please try a different query.',
295
  'success': True
296
- })
297
 
298
- # Retrieve and process documents
299
  doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
300
  doc_texts = [text for text in doc_texts if text.strip()]
301
 
302
  if not doc_texts:
303
- return jsonify({
304
  'answer': 'Unable to retrieve relevant documents. Please try again.',
305
  'success': True
306
- })
307
 
308
- # Rerank documents
309
  rerank_scores = rerank_documents(query_text, doc_texts)
310
  ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
311
 
312
- # Combine top documents
313
  context = " ".join(ranked_texts[:3])
314
-
315
- # Generate answer
316
  answer = generate_answer(query_text, context)
317
 
318
- # Translate answer back to Arabic if needed
319
  if language_code == 0:
320
  answer = translate_text(answer, 'en_to_ar')
321
 
322
- return jsonify({
323
  'answer': answer,
324
  'success': True
325
- })
326
 
327
  except Exception as e:
328
  print(f"Error processing query: {e}")
329
- return jsonify({
330
- 'error': 'An error occurred while processing your query',
331
- 'success': False
332
- }), 500
333
 
334
  except Exception as e:
335
  print(f"Error in process_query: {e}")
336
- return jsonify({
337
- 'error': str(e),
338
- 'success': False
339
- }), 500
340
 
341
- # Initialize everything when the app starts
342
  print("Initializing application...")
343
  init_success = init_nltk() and load_models() and load_data()
344
 
345
  if not init_success:
346
  print("Warning: Application initialized with partial functionality")
347
 
 
348
  if __name__ == "__main__":
349
- app.run(host='0.0.0.0', port=7860)
350
-
 
1
  import os
2
  import pickle
3
  import numpy as np
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
  from transformers import (
8
  AutoTokenizer,
9
  AutoModelForSeq2SeqLM,
 
18
  import torch
19
  import pandas as pd
20
 
21
+ # Initialize FastAPI app
22
+ app = FastAPI()
23
+
24
+ # Add CORS middleware
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
 
33
  # Global variables for models and data
34
  models = {}
35
  data = {}
36
 
37
+ class QueryRequest(BaseModel):
38
+ query: str
39
+ language_code: int = 0
40
+
41
  def init_nltk():
42
  """Initialize NLTK resources"""
43
  try:
 
92
  print(f"Error: {embeddings_path} not found")
93
  return False
94
 
 
95
  class CustomUnpickler(pickle.Unpickler):
96
  def find_class(self, module, name):
97
  if module == "__main__":
 
206
 
207
  response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
208
 
 
209
  if "Answer:" in response:
210
  response = response.split("Answer:")[-1].strip()
211
 
 
212
  sentences = nltk.sent_tokenize(response)
213
  if sentences:
214
  return " ".join(sentences)
 
258
  print(f"Error reranking documents: {e}")
259
  return np.zeros(len(doc_texts))
260
 
261
+ @app.get("/health")
262
+ async def health_check():
263
  """Health check endpoint"""
264
  status = {
265
  'status': 'healthy',
 
267
  'embeddings_loaded': bool(data.get('embeddings')),
268
  'documents_loaded': not data.get('df', pd.DataFrame()).empty
269
  }
270
+ return status
271
 
272
+ @app.post("/api/query")
273
+ async def process_query(request: QueryRequest):
274
  """Main query processing endpoint"""
275
  try:
276
+ query_text = request.query
277
+ language_code = request.language_code
 
 
 
 
278
 
 
 
 
 
279
  if not models or not data.get('embeddings'):
280
+ raise HTTPException(
281
+ status_code=503,
282
+ detail="The system is currently initializing. Please try again in a few minutes."
283
+ )
284
 
 
285
  try:
 
286
  if language_code == 0:
287
  query_text = translate_text(query_text, 'ar_to_en')
288
 
 
289
  query_embedding = models['embedding'].encode([query_text])
290
  relevant_docs = query_embeddings(query_embedding)
291
 
292
  if not relevant_docs:
293
+ return {
294
  'answer': 'No relevant information found. Please try a different query.',
295
  'success': True
296
+ }
297
 
 
298
  doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
299
  doc_texts = [text for text in doc_texts if text.strip()]
300
 
301
  if not doc_texts:
302
+ return {
303
  'answer': 'Unable to retrieve relevant documents. Please try again.',
304
  'success': True
305
+ }
306
 
 
307
  rerank_scores = rerank_documents(query_text, doc_texts)
308
  ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
309
 
 
310
  context = " ".join(ranked_texts[:3])
 
 
311
  answer = generate_answer(query_text, context)
312
 
 
313
  if language_code == 0:
314
  answer = translate_text(answer, 'en_to_ar')
315
 
316
+ return {
317
  'answer': answer,
318
  'success': True
319
+ }
320
 
321
  except Exception as e:
322
  print(f"Error processing query: {e}")
323
+ raise HTTPException(
324
+ status_code=500,
325
+ detail="An error occurred while processing your query"
326
+ )
327
 
328
  except Exception as e:
329
  print(f"Error in process_query: {e}")
330
+ raise HTTPException(
331
+ status_code=500,
332
+ detail=str(e)
333
+ )
334
 
335
+ # Initialize application
336
  print("Initializing application...")
337
  init_success = init_nltk() and load_models() and load_data()
338
 
339
  if not init_success:
340
  print("Warning: Application initialized with partial functionality")
341
 
342
+ # For running locally
343
  if __name__ == "__main__":
344
+ import uvicorn
345
+ uvicorn.run(app, host="0.0.0.0", port=7860)