File size: 7,611 Bytes
eee7a65
35e1586
 
eee7a65
 
35e1586
 
eee7a65
31bad44
35e1586
31bad44
35e1586
eee7a65
 
 
 
 
35e1586
7b16750
eee7a65
 
554b5f1
31bad44
 
 
554b5f1
31bad44
 
eee7a65
31bad44
 
 
 
 
3b1c99a
31bad44
 
 
 
 
 
 
 
eee7a65
 
31bad44
 
 
 
 
 
 
 
 
 
eee7a65
 
31bad44
 
eee7a65
31bad44
eee7a65
3b1c99a
31bad44
eee7a65
 
 
31bad44
35e1586
31bad44
eee7a65
 
31bad44
 
eee7a65
31bad44
 
eee7a65
 
 
 
 
 
35e1586
31bad44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eee7a65
 
31bad44
eee7a65
35e1586
eee7a65
 
31bad44
eee7a65
 
 
 
35e1586
eee7a65
 
35e1586
31bad44
eee7a65
31bad44
35e1586
31bad44
 
 
 
 
 
56b03dc
31bad44
 
 
 
 
 
 
 
 
56b03dc
eee7a65
31bad44
eee7a65
31bad44
eee7a65
31bad44
35e1586
eee7a65
 
 
 
35e1586
eee7a65
 
 
 
 
35e1586
31bad44
 
 
eee7a65
31bad44
 
 
eee7a65
31bad44
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import pickle
import numpy as np
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    AutoModelForTokenClassification,
    AutoModelForCausalLM,
    pipeline
)
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from bs4 import BeautifulSoup
import nltk
import torch
import pandas as pd

app = Flask(__name__)
CORS(app)

# Global variables for models and data
models = {}
data = {}

def init_nltk():
    """Initialize NLTK resources"""
    try:
        nltk.download('punkt', quiet=True)
        return True
    except Exception as e:
        print(f"Error initializing NLTK: {e}")
        return False

def load_models():
    """Initialize all required models"""
    try:
        print("Loading models...")
        
        # Embedding models
        models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
        
        # Translation models
        models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
        models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
        models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
        models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
        
        # NER model
        models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
        models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
        models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
        
        # LLM model
        model_name = "M4-ai/Orca-2.0-Tau-1.8B"
        models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
        models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
        
        print("Models loaded successfully")
        return True
    except Exception as e:
        print(f"Error loading models: {e}")
        return False

def load_data():
    """Load embeddings and document data"""
    try:
        print("Loading data...")
        
        # Load embeddings
        with open('embeddings.pkl', 'rb') as f:
            data['embeddings'] = pickle.load(f)
        
        # Load document links
        data['df'] = pd.read_excel('finalcleaned_excel_file.xlsx')
        
        print("Data loaded successfully")
        return True
    except Exception as e:
        print(f"Error loading data: {e}")
        return False

def translate_text(text, source_to_target='ar_to_en'):
    """Translate text between Arabic and English"""
    try:
        if source_to_target == 'ar_to_en':
            tokenizer = models['ar_to_en_tokenizer']
            model = models['ar_to_en_model']
        else:
            tokenizer = models['en_to_ar_tokenizer']
            model = models['en_to_ar_model']
            
        inputs = tokenizer(text, return_tensors="pt", truncation=True)
        outputs = model.generate(**inputs)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        print(f"Translation error: {e}")
        return text

def query_embeddings(query_embedding, n_results=5):
    """Find relevant documents using embedding similarity"""
    doc_ids = list(data['embeddings'].keys())
    doc_embeddings = np.array(list(data['embeddings'].values()))
    similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
    top_indices = similarities.argsort()[-n_results:][::-1]
    return [(doc_ids[i], similarities[i]) for i in top_indices]

def retrieve_document_text(doc_id):
    """Retrieve document text from HTML file"""
    try:
        with open(f"downloaded_articles/{doc_id}", 'r', encoding='utf-8') as file:
            soup = BeautifulSoup(file, 'html.parser')
            return soup.get_text(separator=' ', strip=True)
    except Exception as e:
        print(f"Error retrieving document {doc_id}: {e}")
        return ""

def extract_entities(text):
    """Extract medical entities from text"""
    try:
        results = models['ner_pipeline'](text)
        return list({result['word'] for result in results if result['entity'].startswith("B-")})
    except Exception as e:
        print(f"Error extracting entities: {e}")
        return []

def generate_answer(query, context, max_length=860, temperature=0.2):
    """Generate answer using LLM"""
    try:
        prompt = f"""
        As a medical expert, answer the following question based only on the provided context:
        
        Context: {context}
        Question: {query}
        
        Answer:"""
        
        inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
        outputs = models['llm_model'].generate(
            inputs.input_ids,
            max_length=max_length,
            num_return_sequences=1,
            temperature=temperature,
            pad_token_id=models['llm_tokenizer'].eos_token_id
        )
        
        answer = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
        return answer.split("Answer:")[-1].strip()
    except Exception as e:
        print(f"Error generating answer: {e}")
        return "Sorry, I couldn't generate an answer at this time."

@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({'status': 'healthy'})

@app.route('/api/query', methods=['POST'])
def process_query():
    """Main query processing endpoint"""
    try:
        data = request.json
        if not data or 'query' not in data:
            return jsonify({'error': 'No query provided', 'success': False}), 400

        query_text = data['query']
        language_code = data.get('language_code', 0)

        # Translate if Arabic
        if language_code == 0:
            query_text = translate_text(query_text, 'ar_to_en')

        # Get query embedding and find relevant documents
        query_embedding = models['embedding'].encode([query_text])
        relevant_docs = query_embeddings(query_embedding)
        
        # Retrieve and process documents
        doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
        
        # Extract entities and generate context
        query_entities = extract_entities(query_text)
        contexts = []
        for text in doc_texts:
            doc_entities = extract_entities(text)
            if set(query_entities) & set(doc_entities):
                contexts.append(text)
        
        context = " ".join(contexts[:3])  # Use top 3 most relevant contexts
        
        # Generate answer
        answer = generate_answer(query_text, context)
        
        # Translate back if needed
        if language_code == 0:
            answer = translate_text(answer, 'en_to_ar')

        return jsonify({
            'answer': answer,
            'success': True
        })

    except Exception as e:
        return jsonify({
            'error': str(e),
            'success': False
        }), 500

# Initialize everything when the app starts
print("Initializing application...")
init_success = init_nltk() and load_models() and load_data()

if not init_success:
    print("Failed to initialize application")
    exit(1)

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)