thechaiexperiment commited on
Commit
35e1586
·
1 Parent(s): 59344e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -0
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional, Dict
4
+ import pickle
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
8
+ from bs4 import BeautifulSoup
9
+ import os
10
+ import nltk
11
+ import torch
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ BartForConditionalGeneration,
15
+ AutoModelForCausalLM,
16
+ AutoModelForSeq2SeqLM
17
+ )
18
+ import pandas as pd
19
+ import time
20
+
21
+ app = FastAPI()
22
+
23
+ # Models and data structures to store loaded models
24
+ class GlobalModels:
25
+ embedding_model = None
26
+ cross_encoder = None
27
+ semantic_model = None
28
+ tokenizer = None
29
+ model = None
30
+ tokenizer_f = None
31
+ model_f = None
32
+ ar_to_en_tokenizer = None
33
+ ar_to_en_model = None
34
+ en_to_ar_tokenizer = None
35
+ en_to_ar_model = None
36
+ embeddings_data = None
37
+ file_name_to_url = None
38
+ bio_tokenizer = None
39
+ bio_model = None
40
+
41
+ global_models = GlobalModels()
42
+
43
+ # Download NLTK data
44
+ nltk.download('punkt')
45
+
46
+ # Pydantic models for request validation
47
+ class QueryInput(BaseModel):
48
+ query_text: str
49
+ language_code: int # 0 for Arabic, 1 for English
50
+ query_type: str # "profile" or "question"
51
+ previous_qa: Optional[List[Dict[str, str]]] = None
52
+
53
+ class DocumentResponse(BaseModel):
54
+ title: str
55
+ url: str
56
+ text: str
57
+ score: float
58
+
59
+ @app.on_event("startup")
60
+ async def load_models():
61
+ """Initialize all models and data on startup"""
62
+ try:
63
+ # Load embedding models
64
+ global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
65
+ global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
66
+ global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
67
+
68
+ # Load BART models
69
+ global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
70
+ global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
71
+
72
+ # Load Orca model
73
+ model_name = "M4-ai/Orca-2.0-Tau-1.8B"
74
+ global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
75
+ global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
76
+
77
+ # Load translation models
78
+ global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
79
+ global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
80
+ global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
81
+ global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
82
+
83
+ # Load Medical NER models
84
+ global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
85
+ global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
86
+
87
+ # Load embeddings data
88
+ with open('embeddings.pkl', 'rb') as file:
89
+ global_models.embeddings_data = pickle.load(file)
90
+
91
+ # Load URL mapping data
92
+ df = pd.read_excel('finalcleaned_excel_file.xlsx')
93
+ global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
94
+
95
+ except Exception as e:
96
+ print(f"Error loading models: {e}")
97
+ raise
98
+
99
+ def translate_ar_to_en(text):
100
+ try:
101
+ inputs = global_models.ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
102
+ translated_ids = global_models.ar_to_en_model.generate(
103
+ inputs.input_ids,
104
+ max_length=512,
105
+ num_beams=4,
106
+ early_stopping=True
107
+ )
108
+ translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
109
+ return translated_text
110
+ except Exception as e:
111
+ print(f"Error during Arabic to English translation: {e}")
112
+ return None
113
+
114
+ def translate_en_to_ar(text):
115
+ try:
116
+ inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
117
+ translated_ids = global_models.en_to_ar_model.generate(
118
+ inputs.input_ids,
119
+ max_length=512,
120
+ num_beams=4,
121
+ early_stopping=True
122
+ )
123
+ translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
124
+ return translated_text
125
+ except Exception as e:
126
+ print(f"Error during English to Arabic translation: {e}")
127
+ return None
128
+
129
+ def process_query(query_text, language_code):
130
+ if language_code == 0:
131
+ return translate_ar_to_en(query_text)
132
+ return query_text
133
+
134
+ def embed_query_text(query_text):
135
+ return global_models.embedding_model.encode([query_text])
136
+
137
+ def query_embeddings(query_embedding, n_results=5):
138
+ doc_ids = list(global_models.embeddings_data.keys())
139
+ doc_embeddings = np.array(list(global_models.embeddings_data.values()))
140
+ similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
141
+ top_indices = similarities.argsort()[-n_results:][::-1]
142
+ return [(doc_ids[i], similarities[i]) for i in top_indices]
143
+
144
+ def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'):
145
+ texts = []
146
+ for doc_id in doc_ids:
147
+ file_path = os.path.join(folder_path, doc_id)
148
+ try:
149
+ with open(file_path, 'r', encoding='utf-8') as file:
150
+ soup = BeautifulSoup(file, 'html.parser')
151
+ text = soup.get_text(separator=' ', strip=True)
152
+ texts.append(text)
153
+ except FileNotFoundError:
154
+ texts.append("")
155
+ return texts
156
+
157
+ def extract_entities(text):
158
+ inputs = global_models.bio_tokenizer(text, return_tensors="pt")
159
+ outputs = global_models.bio_model(**inputs)
160
+ predictions = torch.argmax(outputs.logits, dim=2)
161
+ tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
162
+ return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0]
163
+
164
+ def create_prompt(question, passage):
165
+ return f"""
166
+ As a medical expert, you are required to answer the following question based only on the provided passage.
167
+ Do not include any information not present in the passage. Your response should directly reflect the content
168
+ of the passage. Maintain accuracy and relevance to the provided information.
169
+
170
+ Passage: {passage}
171
+
172
+ Question: {question}
173
+
174
+ Answer:
175
+ """
176
+
177
+ def generate_answer(prompt, max_length=860, temperature=0.2):
178
+ inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True)
179
+
180
+ start_time = time.time()
181
+ output_ids = global_models.model_f.generate(
182
+ inputs.input_ids,
183
+ max_length=max_length,
184
+ num_return_sequences=1,
185
+ temperature=temperature,
186
+ pad_token_id=global_models.tokenizer_f.eos_token_id
187
+ )
188
+ duration = time.time() - start_time
189
+
190
+ answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
191
+ return answer, duration
192
+
193
+ def clean_answer(answer):
194
+ answer_part = answer.split("Answer:")[-1].strip()
195
+ if not answer_part.endswith('.'):
196
+ last_period_index = answer_part.rfind('.')
197
+ if last_period_index != -1:
198
+ answer_part = answer_part[:last_period_index + 1].strip()
199
+ return answer_part
200
+
201
+ @app.post("/retrieve_documents")
202
+ async def retrieve_documents(input_data: QueryInput):
203
+ try:
204
+ # Process query
205
+ processed_query = process_query(input_data.query_text, input_data.language_code)
206
+ query_embedding = embed_query_text(processed_query)
207
+ results = query_embeddings(query_embedding)
208
+
209
+ # Get document texts and rerank
210
+ document_ids = [doc_id for doc_id, _ in results]
211
+ document_texts = retrieve_document_texts(document_ids)
212
+ scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
213
+
214
+ # Prepare response
215
+ documents = []
216
+ for score, doc_id, text in zip(scores, document_ids, document_texts):
217
+ url = global_models.file_name_to_url.get(doc_id, "")
218
+ documents.append({
219
+ "title": doc_id,
220
+ "url": url,
221
+ "text": text if input_data.language_code == 1 else translate_en_to_ar(text),
222
+ "score": float(score)
223
+ })
224
+
225
+ return {"status": "success", "documents": documents}
226
+
227
+ except Exception as e:
228
+ raise HTTPException(status_code=500, detail=str(e))
229
+
230
+ @app.post("/get_answer")
231
+ async def get_answer(input_data: QueryInput):
232
+ try:
233
+ # Process query
234
+ processed_query = process_query(input_data.query_text, input_data.language_code)
235
+
236
+ # Get relevant documents
237
+ query_embedding = embed_query_text(processed_query)
238
+ results = query_embeddings(query_embedding)
239
+ document_ids = [doc_id for doc_id, _ in results]
240
+ document_texts = retrieve_document_texts(document_ids)
241
+
242
+ # Extract entities and create context
243
+ entities = extract_entities(processed_query)
244
+ context = " ".join(document_texts)
245
+ enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}"
246
+
247
+ # Generate answer
248
+ prompt = create_prompt(processed_query, enhanced_context)
249
+ answer, duration = generate_answer(prompt)
250
+ final_answer = clean_answer(answer)
251
+
252
+ # Translate if needed
253
+ if input_data.language_code == 0:
254
+ final_answer = translate_en_to_ar(final_answer)
255
+
256
+ return {
257
+ "status": "success",
258
+ "answer": final_answer,
259
+ "processing_time": duration
260
+ }
261
+
262
+ except Exception as e:
263
+ raise HTTPException(status_code=500, detail=str(e))
264
+
265
+ if __name__ == "__main__":
266
+ import uvicorn
267
+ uvicorn.run(app, host="0.0.0.0", port=7860)