Spaces:
Sleeping
Sleeping
refactor: use classes
Browse files
app.py
CHANGED
@@ -8,6 +8,10 @@ import arxiv
|
|
8 |
import requests
|
9 |
import xml.etree.ElementTree as ET
|
10 |
import re
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
@@ -16,392 +20,330 @@ logging.basicConfig(level=logging.INFO)
|
|
16 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
17 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
18 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
19 |
-
|
20 |
-
|
21 |
-
#
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
# Remove any remaining weird characters
|
49 |
-
text = ''.join(char for char in text if ord(char) < 128)
|
50 |
-
|
51 |
-
return text.strip()
|
52 |
-
|
53 |
-
def format_paper(title, abstract):
|
54 |
-
"""Format paper information consistently"""
|
55 |
-
title = clean_text(title)
|
56 |
-
abstract = clean_text(abstract)
|
57 |
-
|
58 |
-
if len(abstract) > 1000:
|
59 |
-
abstract = abstract[:997] + "..."
|
60 |
-
|
61 |
-
return f"""Title: {title}
|
62 |
-
|
63 |
-
Abstract: {abstract}
|
64 |
-
|
65 |
-
---"""
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
max_results=max_results,
|
78 |
-
sort_by=arxiv.SortCriterion.Relevance
|
79 |
-
)
|
80 |
-
|
81 |
-
papers = []
|
82 |
-
for result in client.results(search):
|
83 |
-
# Only include papers that mention autism in title or abstract
|
84 |
-
if ('autism' in result.title.lower() or
|
85 |
-
'asd' in result.title.lower() or
|
86 |
-
'autism' in result.summary.lower() or
|
87 |
-
'asd' in result.summary.lower()):
|
88 |
-
papers.append({
|
89 |
-
'title': result.title,
|
90 |
-
'abstract': result.summary,
|
91 |
-
'url': result.pdf_url,
|
92 |
-
'published': result.published.strftime("%Y-%m-%d"),
|
93 |
-
'relevance_score': 1 if 'autism' in result.title.lower() else 0.5
|
94 |
-
})
|
95 |
-
|
96 |
-
return papers
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
# Always include autism in the search term
|
103 |
-
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
|
104 |
-
|
105 |
-
# Search for papers
|
106 |
-
search_url = f"{base_url}/esearch.fcgi"
|
107 |
-
search_params = {
|
108 |
-
'db': 'pubmed',
|
109 |
-
'term': search_term,
|
110 |
-
'retmax': max_results,
|
111 |
-
'sort': 'relevance',
|
112 |
-
'retmode': 'xml'
|
113 |
-
}
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
|
121 |
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
title = article.find('.//ArticleTitle')
|
138 |
abstract = article.find('.//Abstract/AbstractText')
|
139 |
year = article.find('.//PubDate/Year')
|
140 |
-
pmid = article.find('.//PMID')
|
141 |
|
142 |
if title is not None and abstract is not None:
|
143 |
title_text = title.text.lower()
|
144 |
abstract_text = abstract.text.lower()
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
'
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
def
|
163 |
-
|
164 |
-
|
165 |
-
pubmed_papers = fetch_pubmed_papers(query)
|
166 |
-
|
167 |
-
# Combine and format papers
|
168 |
-
all_papers = []
|
169 |
-
for paper in arxiv_papers + pubmed_papers:
|
170 |
-
if paper['abstract'] and len(paper['abstract'].strip()) > 0:
|
171 |
-
# Clean and format the paper content
|
172 |
-
clean_title = clean_text(paper['title'])
|
173 |
-
clean_abstract = clean_text(paper['abstract'])
|
174 |
-
|
175 |
-
# Check if the paper is actually about autism
|
176 |
-
if ('autism' in clean_title.lower() or
|
177 |
-
'asd' in clean_title.lower() or
|
178 |
-
'autism' in clean_abstract.lower() or
|
179 |
-
'asd' in clean_abstract.lower()):
|
180 |
-
|
181 |
-
formatted_text = format_paper(clean_title, clean_abstract)
|
182 |
-
|
183 |
-
all_papers.append({
|
184 |
-
'title': clean_title,
|
185 |
-
'text': formatted_text,
|
186 |
-
'url': paper['url'],
|
187 |
-
'published': paper['published'],
|
188 |
-
'relevance_score': paper.get('relevance_score', 0.5)
|
189 |
-
})
|
190 |
-
|
191 |
-
# Sort papers by relevance score and convert to DataFrame
|
192 |
-
all_papers.sort(key=lambda x: x['relevance_score'], reverse=True)
|
193 |
-
df = pd.DataFrame(all_papers)
|
194 |
-
|
195 |
-
if df.empty:
|
196 |
-
st.warning("No autism-related papers found. Please try a different search term.")
|
197 |
-
return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score'])
|
198 |
-
|
199 |
-
return df
|
200 |
-
|
201 |
-
def generate_answer(question, context, max_length=512):
|
202 |
-
"""Generate a comprehensive answer using the local model"""
|
203 |
-
model, tokenizer = load_local_model()
|
204 |
-
|
205 |
-
if model is None or tokenizer is None:
|
206 |
-
return "Error: Could not load the model. Please try again later."
|
207 |
-
|
208 |
-
# Clean and format the context
|
209 |
-
clean_context = clean_text(context)
|
210 |
-
clean_question = clean_text(question)
|
211 |
-
|
212 |
-
# Format the input for T5 (it expects a specific format)
|
213 |
-
input_text = f"""Context
|
214 |
-
Input Question: {clean_question}
|
215 |
-
Source Materials: {clean_context}
|
216 |
-
Primary Objective
|
217 |
-
Generate a comprehensive yet accessible summary of autism research that bridges the gap between academic knowledge and public understanding. The response should be evidence-based while remaining engaging and practical for general readers.
|
218 |
-
Content Structure
|
219 |
-
1. Opening Overview
|
220 |
-
|
221 |
-
Begin with a concise, jargon-free definition of autism
|
222 |
-
Frame the topic within everyday experiences
|
223 |
-
Establish relevance to the reader's understanding
|
224 |
-
|
225 |
-
2. Key Concepts Breakdown
|
226 |
-
|
227 |
-
Transform complex research findings into digestible information
|
228 |
-
Structure information in a logical progression
|
229 |
-
Connect each point to real-world scenarios
|
230 |
-
|
231 |
-
3. Research Integration
|
232 |
-
Present research findings using this framework:
|
233 |
-
|
234 |
-
Main finding: [Clear statement of what was discovered]
|
235 |
-
Real-world meaning: [Practical implications]
|
236 |
-
Context: [How this fits into broader understanding]
|
237 |
-
|
238 |
-
4. Examples and Applications
|
239 |
-
Include:
|
240 |
-
|
241 |
-
Concrete, relatable scenarios
|
242 |
-
Day-to-day situations
|
243 |
-
Practical implications for families and individuals
|
244 |
-
|
245 |
-
Writing Guidelines
|
246 |
-
Language Requirements
|
247 |
-
|
248 |
-
Target reading level: 8th grade
|
249 |
-
Sentence length: Maximum 20 words
|
250 |
-
Paragraph length: 2-4 sentences
|
251 |
-
Technical terms: Must include plain language explanation in parentheses
|
252 |
-
|
253 |
-
Tone and Style
|
254 |
-
|
255 |
-
Empathetic and respectful
|
256 |
-
Solution-focused approach
|
257 |
-
Balanced perspective
|
258 |
-
Inclusive language
|
259 |
-
|
260 |
-
Formatting Specifications
|
261 |
-
|
262 |
-
Use headers for major sections
|
263 |
-
Include white space between concepts
|
264 |
-
Implement bullet points for lists
|
265 |
-
Bold key terms with immediate explanations
|
266 |
-
|
267 |
-
Research Citation Format
|
268 |
-
When referencing studies, follow this pattern:
|
269 |
-
"Research from [Institution] shows [finding in simple terms]. This means [practical interpretation]."
|
270 |
-
Quality Checks
|
271 |
-
Before finalizing, ensure the summary:
|
272 |
-
|
273 |
-
Answers the original question directly
|
274 |
-
Maintains scientific accuracy while being accessible
|
275 |
-
Provides actionable insights
|
276 |
-
Respects neurodiversity perspectives
|
277 |
-
Balances depth with clarity
|
278 |
-
|
279 |
-
Response Framework
|
280 |
-
|
281 |
-
Introduction (2-3 sentences)
|
282 |
-
|
283 |
-
Core definition
|
284 |
-
Relevance statement
|
285 |
-
|
286 |
-
|
287 |
-
Main Body (3-4 key points)
|
288 |
-
|
289 |
-
Evidence-based insights
|
290 |
-
Practical examples
|
291 |
-
Real-world applications
|
292 |
-
|
293 |
-
|
294 |
-
Conclusion (2-3 sentences)
|
295 |
-
|
296 |
-
Summary of key takeaways
|
297 |
-
Actionable next steps or implications
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
Engagement Elements
|
302 |
-
|
303 |
-
Include thought-provoking questions
|
304 |
-
Provide relatable scenarios
|
305 |
-
Connect to common experiences
|
306 |
-
Offer practical applications
|
307 |
-
|
308 |
-
Modified Output Analysis
|
309 |
-
The response should be evaluated against these criteria:
|
310 |
-
|
311 |
-
Clarity: Is the information immediately understandable?
|
312 |
-
Accuracy: Does it reflect the research correctly?
|
313 |
-
Relevance: Does it address the specific question?
|
314 |
-
Practicality: Are the insights actionable?
|
315 |
-
Engagement: Does it maintain reader interest?
|
316 |
-
|
317 |
-
Special Considerations
|
318 |
-
|
319 |
-
Acknowledge spectrum nature of autism
|
320 |
-
Respect diverse perspectives
|
321 |
-
Focus on strengths and challenges
|
322 |
-
Avoid deficit-based language
|
323 |
-
Include support-oriented information
|
324 |
-
|
325 |
-
Remember to adapt the depth and complexity based on the specific question while maintaining accessibility and scientific accuracy."""
|
326 |
-
|
327 |
-
try:
|
328 |
-
# T5 expects a specific format for the input
|
329 |
-
inputs = tokenizer(input_text,
|
330 |
-
return_tensors="pt",
|
331 |
-
max_length=1024,
|
332 |
-
truncation=True,
|
333 |
-
padding=True)
|
334 |
-
|
335 |
-
with torch.inference_mode():
|
336 |
-
outputs = model.generate(
|
337 |
-
**inputs,
|
338 |
-
max_length=max_length,
|
339 |
-
min_length=200,
|
340 |
-
num_beams=3, # Reduzindo para mais variedade
|
341 |
-
length_penalty=1.2, # Melhor equilíbrio entre concisão e detalhes
|
342 |
-
temperature=0.8, # Aumentando um pouco para mais fluidez
|
343 |
-
repetition_penalty=1.2,
|
344 |
-
early_stopping=True,
|
345 |
-
no_repeat_ngram_size=2, # Mantendo variação no texto
|
346 |
-
do_sample=True,
|
347 |
-
top_k=30, # Reduzindo para respostas mais coerentes
|
348 |
-
top_p=0.9 # Equilibrando diversidade e precisão
|
349 |
-
)
|
350 |
-
|
351 |
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
|
|
|
|
364 |
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
-
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
-
return
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
st.
|
380 |
-
|
381 |
-
|
382 |
-
""
|
383 |
-
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
-
if
|
387 |
-
|
388 |
-
# Search for papers
|
389 |
-
df = search_research_papers(query)
|
390 |
-
|
391 |
-
st.write("Searching for data in PubMed and arXiv...")
|
392 |
-
st.write(f"Found {len(df)} relevant papers!")
|
393 |
-
|
394 |
-
# Get relevant context
|
395 |
-
context = "\n".join([
|
396 |
-
f"{text[:1000]}" for text in df['text'].head(3)
|
397 |
-
])
|
398 |
-
|
399 |
-
# Generate answer
|
400 |
-
st.write("Generating answer...")
|
401 |
-
answer = generate_answer(query, context)
|
402 |
-
# Display paper sources
|
403 |
-
with st.expander("View source papers"):
|
404 |
-
for _, paper in df.iterrows():
|
405 |
-
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
|
406 |
-
st.success("Answer found!")
|
407 |
-
st.markdown(answer)
|
|
|
8 |
import requests
|
9 |
import xml.etree.ElementTree as ET
|
10 |
import re
|
11 |
+
from functools import lru_cache
|
12 |
+
from typing import List, Dict, Optional
|
13 |
+
from dataclasses import dataclass
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
|
16 |
# Configure logging
|
17 |
logging.basicConfig(level=logging.INFO)
|
|
|
20 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
21 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
22 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
23 |
+
MODEL_PATH = "google/mt5-base"
|
24 |
+
|
25 |
+
# Constants for better maintainability
|
26 |
+
MAX_ABSTRACT_LENGTH = 1000
|
27 |
+
MAX_PAPERS = 5
|
28 |
+
CACHE_SIZE = 128
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class Paper:
|
32 |
+
title: str
|
33 |
+
abstract: str
|
34 |
+
url: str
|
35 |
+
published: str
|
36 |
+
relevance_score: float
|
37 |
+
|
38 |
+
class TextProcessor:
|
39 |
+
@staticmethod
|
40 |
+
def clean_text(text: str) -> str:
|
41 |
+
"""Clean and normalize text content with improved handling"""
|
42 |
+
if not text:
|
43 |
+
return ""
|
44 |
+
|
45 |
+
# Improved text cleaning
|
46 |
+
text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text)
|
47 |
+
text = re.sub(r'\s+', ' ', text)
|
48 |
+
text = text.encode('ascii', 'ignore').decode('ascii') # Better character handling
|
49 |
+
|
50 |
+
return text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
@staticmethod
|
53 |
+
def format_paper(title: str, abstract: str) -> str:
|
54 |
+
"""Format paper information with improved structure"""
|
55 |
+
title = TextProcessor.clean_text(title)
|
56 |
+
abstract = TextProcessor.clean_text(abstract)
|
57 |
+
|
58 |
+
if len(abstract) > MAX_ABSTRACT_LENGTH:
|
59 |
+
abstract = abstract[:MAX_ABSTRACT_LENGTH-3] + "..."
|
60 |
+
|
61 |
+
return f"""Title: {title}\nAbstract: {abstract}\n---"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
class ResearchFetcher:
|
64 |
+
def __init__(self):
|
65 |
+
self.session = requests.Session() # Reuse connection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
@lru_cache(maxsize=CACHE_SIZE)
|
68 |
+
def fetch_arxiv_papers(self, query: str) -> List[Paper]:
|
69 |
+
"""Fetch papers from arXiv with improved filtering"""
|
70 |
+
client = arxiv.Client()
|
71 |
+
search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio"
|
|
|
72 |
|
73 |
+
search = arxiv.Search(
|
74 |
+
query=search_query,
|
75 |
+
max_results=MAX_PAPERS,
|
76 |
+
sort_by=arxiv.SortCriterion.Relevance
|
77 |
+
)
|
78 |
|
79 |
+
papers = []
|
80 |
+
for result in client.results(search):
|
81 |
+
title_lower = result.title.lower()
|
82 |
+
summary_lower = result.summary.lower()
|
83 |
+
|
84 |
+
if any(term in title_lower or term in summary_lower
|
85 |
+
for term in ['autism', 'asd']):
|
86 |
+
papers.append(Paper(
|
87 |
+
title=result.title,
|
88 |
+
abstract=result.summary,
|
89 |
+
url=result.pdf_url,
|
90 |
+
published=result.published.strftime("%Y-%m-%d"),
|
91 |
+
relevance_score=1.0 if 'autism' in title_lower else 0.5
|
92 |
+
))
|
93 |
|
94 |
+
return papers
|
95 |
+
|
96 |
+
@lru_cache(maxsize=CACHE_SIZE)
|
97 |
+
def fetch_pubmed_papers(self, query: str) -> List[Paper]:
|
98 |
+
"""Fetch papers from PubMed with improved error handling"""
|
99 |
+
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
100 |
+
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
|
101 |
|
102 |
+
try:
|
103 |
+
# Fetch IDs efficiently
|
104 |
+
response = self.session.get(
|
105 |
+
f"{base_url}/esearch.fcgi",
|
106 |
+
params={
|
107 |
+
'db': 'pubmed',
|
108 |
+
'term': search_term,
|
109 |
+
'retmax': MAX_PAPERS,
|
110 |
+
'sort': 'relevance',
|
111 |
+
'retmode': 'xml'
|
112 |
+
},
|
113 |
+
timeout=10
|
114 |
+
)
|
115 |
+
response.raise_for_status()
|
116 |
+
|
117 |
+
root = ET.fromstring(response.content)
|
118 |
+
id_list = root.findall('.//Id')
|
119 |
+
|
120 |
+
if not id_list:
|
121 |
+
return []
|
122 |
+
|
123 |
+
# Fetch details in parallel
|
124 |
+
with ThreadPoolExecutor(max_workers=3) as executor:
|
125 |
+
paper_futures = [
|
126 |
+
executor.submit(self._fetch_paper_details, base_url, id_elem.text)
|
127 |
+
for id_elem in id_list
|
128 |
+
]
|
129 |
+
|
130 |
+
return [paper for future in paper_futures
|
131 |
+
for paper in [future.result()] if paper is not None]
|
132 |
+
|
133 |
+
except Exception as e:
|
134 |
+
logging.error(f"Error fetching PubMed papers: {str(e)}")
|
135 |
+
return []
|
136 |
+
|
137 |
+
def _fetch_paper_details(self, base_url: str, paper_id: str) -> Optional[Paper]:
|
138 |
+
"""Fetch individual paper details with timeout"""
|
139 |
+
try:
|
140 |
+
response = self.session.get(
|
141 |
+
f"{base_url}/efetch.fcgi",
|
142 |
+
params={
|
143 |
+
'db': 'pubmed',
|
144 |
+
'id': paper_id,
|
145 |
+
'retmode': 'xml'
|
146 |
+
},
|
147 |
+
timeout=5
|
148 |
+
)
|
149 |
+
response.raise_for_status()
|
150 |
+
|
151 |
+
article = ET.fromstring(response.content).find('.//PubmedArticle')
|
152 |
+
if article is None:
|
153 |
+
return None
|
154 |
+
|
155 |
title = article.find('.//ArticleTitle')
|
156 |
abstract = article.find('.//Abstract/AbstractText')
|
157 |
year = article.find('.//PubDate/Year')
|
|
|
158 |
|
159 |
if title is not None and abstract is not None:
|
160 |
title_text = title.text.lower()
|
161 |
abstract_text = abstract.text.lower()
|
162 |
|
163 |
+
if any(term in title_text or term in abstract_text
|
164 |
+
for term in ['autism', 'asd']):
|
165 |
+
return Paper(
|
166 |
+
title=title.text,
|
167 |
+
abstract=abstract.text,
|
168 |
+
url=f"https://pubmed.ncbi.nlm.nih.gov/{paper_id}/",
|
169 |
+
published=year.text if year is not None else 'Unknown',
|
170 |
+
relevance_score=1.0 if any(term in title_text
|
171 |
+
for term in ['autism', 'asd']) else 0.5
|
172 |
+
)
|
173 |
|
174 |
+
except Exception as e:
|
175 |
+
logging.error(f"Error fetching paper {paper_id}: {str(e)}")
|
176 |
+
return None
|
177 |
+
|
178 |
+
class ModelHandler:
|
179 |
+
def __init__(self):
|
180 |
+
self.model = None
|
181 |
+
self.tokenizer = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
+
@st.cache_resource
|
184 |
+
def load_model(self):
|
185 |
+
"""Load model with improved error handling and resource management"""
|
186 |
+
if self.model is None:
|
187 |
+
try:
|
188 |
+
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
189 |
+
self.model = T5ForConditionalGeneration.from_pretrained(
|
190 |
+
MODEL_PATH,
|
191 |
+
device_map={"": "cpu"},
|
192 |
+
torch_dtype=torch.float32,
|
193 |
+
low_cpu_mem_usage=True
|
194 |
+
)
|
195 |
+
return True
|
196 |
+
except Exception as e:
|
197 |
+
logging.error(f"Error loading model: {str(e)}")
|
198 |
+
return False
|
199 |
+
return True
|
200 |
+
|
201 |
+
def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
|
202 |
+
"""Generate answer with improved prompt engineering and parameters"""
|
203 |
+
if not self.load_model():
|
204 |
+
return "Error: Model loading failed. Please try again later."
|
205 |
|
206 |
+
try:
|
207 |
+
# Improved prompt template
|
208 |
+
input_text = self._create_enhanced_prompt(question, context)
|
209 |
+
|
210 |
+
inputs = self.tokenizer(
|
211 |
+
input_text,
|
212 |
+
return_tensors="pt",
|
213 |
+
max_length=1024,
|
214 |
+
truncation=True,
|
215 |
+
padding=True
|
216 |
+
)
|
217 |
|
218 |
+
with torch.inference_mode():
|
219 |
+
outputs = self.model.generate(
|
220 |
+
**inputs,
|
221 |
+
max_length=max_length,
|
222 |
+
min_length=200,
|
223 |
+
num_beams=4,
|
224 |
+
length_penalty=1.5,
|
225 |
+
temperature=0.7,
|
226 |
+
repetition_penalty=1.3,
|
227 |
+
early_stopping=True,
|
228 |
+
no_repeat_ngram_size=3,
|
229 |
+
do_sample=True,
|
230 |
+
top_k=40,
|
231 |
+
top_p=0.95
|
232 |
+
)
|
233 |
+
|
234 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
235 |
+
response = TextProcessor.clean_text(response)
|
236 |
+
|
237 |
+
if len(response.strip()) < 100:
|
238 |
+
return self._get_fallback_response()
|
239 |
+
|
240 |
+
return self._format_response(response)
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
logging.error(f"Error generating response: {str(e)}")
|
244 |
+
return "Error: Could not generate response. Please try again."
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def _create_enhanced_prompt(question: str, context: str) -> str:
|
248 |
+
"""Create an enhanced prompt for better response quality"""
|
249 |
+
return f"""Context: {context}
|
250 |
+
|
251 |
+
Question: {question}
|
252 |
+
|
253 |
+
Instructions:
|
254 |
+
1. Provide a clear, evidence-based answer
|
255 |
+
2. Include specific findings from the research
|
256 |
+
3. Explain practical implications
|
257 |
+
4. Use accessible language
|
258 |
+
5. Address the question directly
|
259 |
+
6. Include relevant examples
|
260 |
+
|
261 |
+
Response should be:
|
262 |
+
- Accurate and scientific
|
263 |
+
- Easy to understand
|
264 |
+
- Practical and applicable
|
265 |
+
- Respectful of neurodiversity
|
266 |
+
- Supported by the provided research
|
267 |
+
|
268 |
+
Generate a comprehensive response:"""
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def _get_fallback_response() -> str:
|
272 |
+
"""Provide a structured fallback response"""
|
273 |
+
return """Based on the available research, I cannot provide a specific answer to your question. However, I can suggest:
|
274 |
+
|
275 |
+
1. Try rephrasing your question to focus on specific aspects of autism
|
276 |
+
2. Consider asking about:
|
277 |
+
- Specific behaviors or characteristics
|
278 |
+
- Intervention strategies
|
279 |
+
- Research findings
|
280 |
+
- Support approaches
|
281 |
+
|
282 |
+
This will help me provide more accurate, research-based information."""
|
283 |
+
|
284 |
+
@staticmethod
|
285 |
+
def _format_response(response: str) -> str:
|
286 |
+
"""Format the response for better readability"""
|
287 |
+
# Add section headers
|
288 |
+
sections = response.split('\n\n')
|
289 |
+
formatted_sections = []
|
290 |
|
291 |
+
for i, section in enumerate(sections):
|
292 |
+
if i == 0:
|
293 |
+
formatted_sections.append(f"### Overview\n{section}")
|
294 |
+
elif i == len(sections) - 1:
|
295 |
+
formatted_sections.append(f"### Key Takeaways\n{section}")
|
296 |
+
else:
|
297 |
+
formatted_sections.append(section)
|
298 |
|
299 |
+
return '\n\n'.join(formatted_sections)
|
300 |
+
|
301 |
+
def main():
|
302 |
+
st.title("🧩 AMA Autism")
|
303 |
+
st.write("""
|
304 |
+
Ask questions about autism and get research-based answers from scientific papers.
|
305 |
+
For best results, be specific in your questions.
|
306 |
+
""")
|
307 |
+
|
308 |
+
query = st.text_input("What would you like to know about autism? ✨")
|
309 |
+
|
310 |
+
if query:
|
311 |
+
with st.status("Researching your question...") as status:
|
312 |
+
# Initialize handlers
|
313 |
+
research_fetcher = ResearchFetcher()
|
314 |
+
model_handler = ModelHandler()
|
315 |
+
|
316 |
+
# Fetch papers concurrently
|
317 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
318 |
+
arxiv_future = executor.submit(research_fetcher.fetch_arxiv_papers, query)
|
319 |
+
pubmed_future = executor.submit(research_fetcher.fetch_pubmed_papers, query)
|
320 |
+
|
321 |
+
papers = arxiv_future.result() + pubmed_future.result()
|
322 |
+
|
323 |
+
if not papers:
|
324 |
+
st.warning("No relevant research papers found. Please try a different search term.")
|
325 |
+
return
|
326 |
+
|
327 |
+
# Sort papers by relevance
|
328 |
+
papers.sort(key=lambda x: x.relevance_score, reverse=True)
|
329 |
+
|
330 |
+
# Prepare context from top papers
|
331 |
+
context = "\n".join(
|
332 |
+
TextProcessor.format_paper(paper.title, paper.abstract)
|
333 |
+
for paper in papers[:3]
|
334 |
+
)
|
335 |
+
|
336 |
+
# Generate answer
|
337 |
+
st.write("Analyzing research papers...")
|
338 |
+
answer = model_handler.generate_answer(query, context)
|
339 |
+
|
340 |
+
# Display sources
|
341 |
+
with st.expander("📚 View source papers"):
|
342 |
+
for paper in papers:
|
343 |
+
st.markdown(f"- [{paper.title}]({paper.url}) ({paper.published})")
|
344 |
+
|
345 |
+
st.success("Research analysis complete!")
|
346 |
+
st.markdown(answer)
|
347 |
|
348 |
+
if __name__ == "__main__":
|
349 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|