ama-autism / app.py
wakeupmh's picture
fix: class
8f85101
raw
history blame
12.2 kB
import streamlit as st
import pandas as pd
import torch
import logging
import os
from transformers import AutoTokenizer, T5ForConditionalGeneration
import arxiv
import requests
import xml.etree.ElementTree as ET
import re
from functools import lru_cache
from typing import List, Dict, Optional
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "google/flan-t5-small"
# Constants for better maintainability
MAX_ABSTRACT_LENGTH = 1000
MAX_PAPERS = 5
CACHE_SIZE = 128
@dataclass
class Paper:
title: str
abstract: str
url: str
published: str
relevance_score: float
class TextProcessor:
@staticmethod
def clean_text(text: str) -> str:
"""Clean and normalize text content with improved handling"""
if not text:
return ""
# Improved text cleaning
text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text)
text = re.sub(r'\s+', ' ', text)
text = text.encode('ascii', 'ignore').decode('ascii') # Better character handling
return text.strip()
@staticmethod
def format_paper(title: str, abstract: str) -> str:
"""Format paper information with improved structure"""
title = TextProcessor.clean_text(title)
abstract = TextProcessor.clean_text(abstract)
if len(abstract) > MAX_ABSTRACT_LENGTH:
abstract = abstract[:MAX_ABSTRACT_LENGTH-3] + "..."
return f"""Title: {title}\nAbstract: {abstract}\n---"""
class ResearchFetcher:
def __init__(self):
self.session = requests.Session() # Reuse connection
@lru_cache(maxsize=CACHE_SIZE)
def fetch_arxiv_papers(self, query: str) -> List[Paper]:
"""Fetch papers from arXiv with improved filtering"""
client = arxiv.Client()
search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio"
search = arxiv.Search(
query=search_query,
max_results=MAX_PAPERS,
sort_by=arxiv.SortCriterion.Relevance
)
papers = []
for result in client.results(search):
title_lower = result.title.lower()
summary_lower = result.summary.lower()
if any(term in title_lower or term in summary_lower
for term in ['autism', 'asd']):
papers.append(Paper(
title=result.title,
abstract=result.summary,
url=result.pdf_url,
published=result.published.strftime("%Y-%m-%d"),
relevance_score=1.0 if 'autism' in title_lower else 0.5
))
return papers
@lru_cache(maxsize=CACHE_SIZE)
def fetch_pubmed_papers(self, query: str) -> List[Paper]:
"""Fetch papers from PubMed with improved error handling"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
try:
# Fetch IDs efficiently
response = self.session.get(
f"{base_url}/esearch.fcgi",
params={
'db': 'pubmed',
'term': search_term,
'retmax': MAX_PAPERS,
'sort': 'relevance',
'retmode': 'xml'
},
timeout=10
)
response.raise_for_status()
root = ET.fromstring(response.content)
id_list = root.findall('.//Id')
if not id_list:
return []
# Fetch details in parallel
with ThreadPoolExecutor(max_workers=3) as executor:
paper_futures = [
executor.submit(self._fetch_paper_details, base_url, id_elem.text)
for id_elem in id_list
]
return [paper for future in paper_futures
for paper in [future.result()] if paper is not None]
except Exception as e:
logging.error(f"Error fetching PubMed papers: {str(e)}")
return []
def _fetch_paper_details(self, base_url: str, paper_id: str) -> Optional[Paper]:
"""Fetch individual paper details with timeout"""
try:
response = self.session.get(
f"{base_url}/efetch.fcgi",
params={
'db': 'pubmed',
'id': paper_id,
'retmode': 'xml'
},
timeout=5
)
response.raise_for_status()
article = ET.fromstring(response.content).find('.//PubmedArticle')
if article is None:
return None
title = article.find('.//ArticleTitle')
abstract = article.find('.//Abstract/AbstractText')
year = article.find('.//PubDate/Year')
if title is not None and abstract is not None:
title_text = title.text.lower()
abstract_text = abstract.text.lower()
if any(term in title_text or term in abstract_text
for term in ['autism', 'asd']):
return Paper(
title=title.text,
abstract=abstract.text,
url=f"https://pubmed.ncbi.nlm.nih.gov/{paper_id}/",
published=year.text if year is not None else 'Unknown',
relevance_score=1.0 if any(term in title_text
for term in ['autism', 'asd']) else 0.5
)
except Exception as e:
logging.error(f"Error fetching paper {paper_id}: {str(e)}")
return None
class ModelHandler:
def __init__(self):
self.model = None
self.tokenizer = None
self._initialize_model()
@staticmethod
@st.cache_resource
def _load_model():
"""Load FLAN-T5 Small model with optimized settings"""
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = T5ForConditionalGeneration.from_pretrained(
MODEL_PATH,
device_map={"": "cpu"},
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
return model, tokenizer
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
return None, None
def _initialize_model(self):
"""Initialize model and tokenizer"""
self.model, self.tokenizer = self._load_model()
def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
"""Generate answer with FLAN-T5 optimized parameters"""
if self.model is None or self.tokenizer is None:
return "Error: Model loading failed. Please try again later."
try:
# FLAN-T5 responds better to direct instruction prompts
input_text = f"""Answer the following question about autism using the provided research context.
Research Context:
{context}
Question: {question}
Instructions:
- Be specific and evidence-based
- Use clear, accessible language
- Focus on practical implications
- Cite research when relevant
- Be respectful of neurodiversity
Answer:"""
inputs = self.tokenizer(
input_text,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True
)
with torch.inference_mode():
outputs = self.model.generate(
**inputs,
max_length=max_length,
min_length=100,
num_beams=3,
length_penalty=1.0,
temperature=0.6,
repetition_penalty=1.2,
early_stopping=True,
no_repeat_ngram_size=2,
do_sample=True,
top_k=30,
top_p=0.92
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = TextProcessor.clean_text(response)
if len(response.strip()) < 50:
return self._get_fallback_response()
return self._format_response(response)
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
return "Error: Could not generate response. Please try again."
@staticmethod
def _get_fallback_response() -> str:
"""Provide a structured fallback response"""
return """Based on the available research, I cannot provide a specific answer to your question. Please try:
1. Rephrasing your question to be more specific
2. Asking about:
- Specific behaviors or characteristics
- Intervention strategies
- Research findings
- Support approaches
This will help me provide more accurate, research-based information."""
@staticmethod
def _format_response(response: str) -> str:
"""Format the response for better readability"""
sections = response.split('\n\n')
formatted_sections = []
for i, section in enumerate(sections):
if i == 0:
formatted_sections.append(f"### Overview\n{section}")
elif i == len(sections) - 1:
formatted_sections.append(f"### Key Takeaways\n{section}")
else:
formatted_sections.append(section)
return '\n\n'.join(formatted_sections)
def main():
st.title("🧩 AMA Autism")
st.write("""
Ask questions about autism and get research-based answers from scientific papers.
For best results, be specific in your questions.
""")
query = st.text_input("What would you like to know about autism? ✨")
if query:
with st.status("Researching your question...") as status:
# Initialize handlers
research_fetcher = ResearchFetcher()
model_handler = ModelHandler()
# Fetch papers concurrently
with ThreadPoolExecutor(max_workers=2) as executor:
arxiv_future = executor.submit(research_fetcher.fetch_arxiv_papers, query)
pubmed_future = executor.submit(research_fetcher.fetch_pubmed_papers, query)
papers = arxiv_future.result() + pubmed_future.result()
if not papers:
st.warning("No relevant research papers found. Please try a different search term.")
return
# Sort papers by relevance
papers.sort(key=lambda x: x.relevance_score, reverse=True)
# Prepare context from top papers
context = "\n".join(
TextProcessor.format_paper(paper.title, paper.abstract)
for paper in papers[:3]
)
# Generate answer
st.write("Analyzing research papers...")
answer = model_handler.generate_answer(query, context)
# Display sources
with st.expander("📚 View source papers"):
for paper in papers:
st.markdown(f"- [{paper.title}]({paper.url}) ({paper.published})")
st.success("Research analysis complete!")
st.markdown(answer)
if __name__ == "__main__":
main()