thechaiexperiment commited on
Commit
2a5cca5
·
1 Parent(s): 9b4d106

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -11
app.py CHANGED
@@ -19,10 +19,12 @@ from transformers import (
19
  import pandas as pd
20
  import time
21
 
 
 
 
22
  class CustomUnpickler(pickle.Unpickler):
23
  def persistent_load(self, pid):
24
  try:
25
- # Handle string encoding issues by decoding and re-encoding as ASCII
26
  if isinstance(pid, bytes):
27
  pid = pid.decode('utf-8', errors='ignore')
28
  pid = str(pid).encode('ascii', errors='ignore').decode('ascii')
@@ -39,11 +41,9 @@ def safe_load_embeddings():
39
  unpickler = CustomUnpickler(file)
40
  embeddings_data = unpickler.load()
41
 
42
- # Verify the data structure
43
  if not isinstance(embeddings_data, dict):
44
  raise ValueError("Loaded data is not a dictionary")
45
 
46
- # Verify the embeddings format
47
  first_key = next(iter(embeddings_data))
48
  if not isinstance(embeddings_data[first_key], (np.ndarray, list)):
49
  raise ValueError("Embeddings are not in the expected format")
@@ -54,6 +54,7 @@ def safe_load_embeddings():
54
  print(f"Error loading embeddings: {str(e)}")
55
  return None
56
 
 
57
  class GlobalModels:
58
  embedding_model = None
59
  cross_encoder = None
@@ -71,8 +72,25 @@ class GlobalModels:
71
  bio_tokenizer = None
72
  bio_model = None
73
 
 
74
  global_models = GlobalModels()
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @app.on_event("startup")
77
  async def load_models():
78
  """Initialize all models and data on startup"""
@@ -86,12 +104,36 @@ async def load_models():
86
  raise HTTPException(status_code=500, detail="Failed to load embeddings data")
87
  global_models.embeddings_data = embeddings_data
88
 
89
- # Continue loading other models only if embeddings loaded successfully
90
  global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
91
  global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
92
 
93
- # Load remaining models...
94
- # (rest of your model loading code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  print("All models loaded successfully")
97
 
@@ -99,11 +141,6 @@ async def load_models():
99
  print(f"Error during startup: {str(e)}")
100
  raise HTTPException(status_code=500, detail=f"Failed to initialize application: {str(e)}")
101
 
102
- # Rest of your FastAPI application code remains the same...
103
-
104
- @app.get("/")
105
- async def root():
106
- return {"message": "Server is running"}
107
 
108
  # Models and data structures to store loaded models
109
  class GlobalModels:
@@ -356,6 +393,10 @@ async def get_answer(input_data: QueryInput):
356
  except Exception as e:
357
  raise HTTPException(status_code=500, detail=str(e))
358
 
 
 
 
 
359
  if __name__ == "__main__":
360
  import uvicorn
361
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
19
  import pandas as pd
20
  import time
21
 
22
+ # Initialize FastAPI app first
23
+ app = FastAPI()
24
+
25
  class CustomUnpickler(pickle.Unpickler):
26
  def persistent_load(self, pid):
27
  try:
 
28
  if isinstance(pid, bytes):
29
  pid = pid.decode('utf-8', errors='ignore')
30
  pid = str(pid).encode('ascii', errors='ignore').decode('ascii')
 
41
  unpickler = CustomUnpickler(file)
42
  embeddings_data = unpickler.load()
43
 
 
44
  if not isinstance(embeddings_data, dict):
45
  raise ValueError("Loaded data is not a dictionary")
46
 
 
47
  first_key = next(iter(embeddings_data))
48
  if not isinstance(embeddings_data[first_key], (np.ndarray, list)):
49
  raise ValueError("Embeddings are not in the expected format")
 
54
  print(f"Error loading embeddings: {str(e)}")
55
  return None
56
 
57
+ # Models and data structures
58
  class GlobalModels:
59
  embedding_model = None
60
  cross_encoder = None
 
72
  bio_tokenizer = None
73
  bio_model = None
74
 
75
+ # Initialize global models
76
  global_models = GlobalModels()
77
 
78
+ # Download NLTK data
79
+ nltk.download('punkt')
80
+
81
+ # Pydantic models for request validation
82
+ class QueryInput(BaseModel):
83
+ query_text: str
84
+ language_code: int # 0 for Arabic, 1 for English
85
+ query_type: str # "profile" or "question"
86
+ previous_qa: Optional[List[Dict[str, str]]] = None
87
+
88
+ class DocumentResponse(BaseModel):
89
+ title: str
90
+ url: str
91
+ text: str
92
+ score: float
93
+
94
  @app.on_event("startup")
95
  async def load_models():
96
  """Initialize all models and data on startup"""
 
104
  raise HTTPException(status_code=500, detail="Failed to load embeddings data")
105
  global_models.embeddings_data = embeddings_data
106
 
107
+ # Load remaining models
108
  global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
109
  global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
110
 
111
+ # Load BART models
112
+ global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
113
+ global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
114
+
115
+ # Load Orca model
116
+ model_name = "M4-ai/Orca-2.0-Tau-1.8B"
117
+ global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
118
+ global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
119
+
120
+ # Load translation models
121
+ global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
122
+ global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
123
+ global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
124
+ global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
125
+
126
+ # Load Medical NER models
127
+ global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
128
+ global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
129
+
130
+ # Load URL mapping data
131
+ try:
132
+ df = pd.read_excel('finalcleaned_excel_file.xlsx')
133
+ global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
134
+ except Exception as e:
135
+ print(f"Error loading URL mapping data: {e}")
136
+ raise HTTPException(status_code=500, detail="Failed to load URL mapping data.")
137
 
138
  print("All models loaded successfully")
139
 
 
141
  print(f"Error during startup: {str(e)}")
142
  raise HTTPException(status_code=500, detail=f"Failed to initialize application: {str(e)}")
143
 
 
 
 
 
 
144
 
145
  # Models and data structures to store loaded models
146
  class GlobalModels:
 
393
  except Exception as e:
394
  raise HTTPException(status_code=500, detail=str(e))
395
 
396
+ @app.get("/")
397
+ async def root():
398
+ return {"message": "Server is running"}
399
+
400
  if __name__ == "__main__":
401
  import uvicorn
402
  uvicorn.run(app, host="0.0.0.0", port=7860)