Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
import logging | |
import os | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
import arxiv | |
import requests | |
import xml.etree.ElementTree as ET | |
import re | |
from functools import lru_cache | |
from typing import List, Dict, Optional | |
from dataclasses import dataclass | |
from concurrent.futures import ThreadPoolExecutor | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths and constants | |
DATA_DIR = "/data" if os.path.exists("/data") else "." | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
MODEL_PATH = "google/flan-t5-small" | |
# Constants for better maintainability | |
MAX_ABSTRACT_LENGTH = 1000 | |
MAX_PAPERS = 5 | |
CACHE_SIZE = 128 | |
class Paper: | |
title: str | |
abstract: str | |
url: str | |
published: str | |
relevance_score: float | |
class TextProcessor: | |
def clean_text(text: str) -> str: | |
"""Clean and normalize text content with improved handling""" | |
if not text: | |
return "" | |
# Improved text cleaning | |
text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text) | |
text = re.sub(r'\s+', ' ', text) | |
text = text.encode('ascii', 'ignore').decode('ascii') # Better character handling | |
return text.strip() | |
def format_paper(title: str, abstract: str) -> str: | |
"""Format paper information with improved structure""" | |
title = TextProcessor.clean_text(title) | |
abstract = TextProcessor.clean_text(abstract) | |
if len(abstract) > MAX_ABSTRACT_LENGTH: | |
abstract = abstract[:MAX_ABSTRACT_LENGTH-3] + "..." | |
return f"""Title: {title}\nAbstract: {abstract}\n---""" | |
class ResearchFetcher: | |
def __init__(self): | |
self.session = requests.Session() # Reuse connection | |
def fetch_arxiv_papers(self, query: str) -> List[Paper]: | |
"""Fetch papers from arXiv with improved filtering""" | |
client = arxiv.Client() | |
search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio" | |
search = arxiv.Search( | |
query=search_query, | |
max_results=MAX_PAPERS, | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
papers = [] | |
for result in client.results(search): | |
title_lower = result.title.lower() | |
summary_lower = result.summary.lower() | |
if any(term in title_lower or term in summary_lower | |
for term in ['autism', 'asd']): | |
papers.append(Paper( | |
title=result.title, | |
abstract=result.summary, | |
url=result.pdf_url, | |
published=result.published.strftime("%Y-%m-%d"), | |
relevance_score=1.0 if 'autism' in title_lower else 0.5 | |
)) | |
return papers | |
def fetch_pubmed_papers(self, query: str) -> List[Paper]: | |
"""Fetch papers from PubMed with improved error handling""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])" | |
try: | |
# Fetch IDs efficiently | |
response = self.session.get( | |
f"{base_url}/esearch.fcgi", | |
params={ | |
'db': 'pubmed', | |
'term': search_term, | |
'retmax': MAX_PAPERS, | |
'sort': 'relevance', | |
'retmode': 'xml' | |
}, | |
timeout=10 | |
) | |
response.raise_for_status() | |
root = ET.fromstring(response.content) | |
id_list = root.findall('.//Id') | |
if not id_list: | |
return [] | |
# Fetch details in parallel | |
with ThreadPoolExecutor(max_workers=3) as executor: | |
paper_futures = [ | |
executor.submit(self._fetch_paper_details, base_url, id_elem.text) | |
for id_elem in id_list | |
] | |
return [paper for future in paper_futures | |
for paper in [future.result()] if paper is not None] | |
except Exception as e: | |
logging.error(f"Error fetching PubMed papers: {str(e)}") | |
return [] | |
def _fetch_paper_details(self, base_url: str, paper_id: str) -> Optional[Paper]: | |
"""Fetch individual paper details with timeout""" | |
try: | |
response = self.session.get( | |
f"{base_url}/efetch.fcgi", | |
params={ | |
'db': 'pubmed', | |
'id': paper_id, | |
'retmode': 'xml' | |
}, | |
timeout=5 | |
) | |
response.raise_for_status() | |
article = ET.fromstring(response.content).find('.//PubmedArticle') | |
if article is None: | |
return None | |
title = article.find('.//ArticleTitle') | |
abstract = article.find('.//Abstract/AbstractText') | |
year = article.find('.//PubDate/Year') | |
if title is not None and abstract is not None: | |
title_text = title.text.lower() | |
abstract_text = abstract.text.lower() | |
if any(term in title_text or term in abstract_text | |
for term in ['autism', 'asd']): | |
return Paper( | |
title=title.text, | |
abstract=abstract.text, | |
url=f"https://pubmed.ncbi.nlm.nih.gov/{paper_id}/", | |
published=year.text if year is not None else 'Unknown', | |
relevance_score=1.0 if any(term in title_text | |
for term in ['autism', 'asd']) else 0.5 | |
) | |
except Exception as e: | |
logging.error(f"Error fetching paper {paper_id}: {str(e)}") | |
return None | |
class ModelHandler: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
def load_model(self): | |
"""Load FLAN-T5 Small model with optimized settings""" | |
if self.model is None: | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
self.model = T5ForConditionalGeneration.from_pretrained( | |
MODEL_PATH, | |
device_map={"": "cpu"}, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True | |
) | |
return True | |
except Exception as e: | |
logging.error(f"Error loading model: {str(e)}") | |
return False | |
return True | |
def generate_answer(self, question: str, context: str, max_length: int = 512) -> str: | |
"""Generate answer with FLAN-T5 optimized parameters""" | |
if not self.load_model(): | |
return "Error: Model loading failed. Please try again later." | |
try: | |
# FLAN-T5 responds better to direct instruction prompts | |
input_text = f"""Answer the following question about autism using the provided research context. | |
Research Context: | |
{context} | |
Question: {question} | |
Instructions: | |
- Be specific and evidence-based | |
- Use clear, accessible language | |
- Focus on practical implications | |
- Cite research when relevant | |
- Be respectful of neurodiversity | |
Answer:""" | |
inputs = self.tokenizer( | |
input_text, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True, | |
padding=True | |
) | |
with torch.inference_mode(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=max_length, | |
min_length=100, # Reduzido para FLAN-T5 Small | |
num_beams=3, # Ajustado para melhor performance | |
length_penalty=1.0, # Mais neutro para respostas concisas | |
temperature=0.6, # Mais determinístico | |
repetition_penalty=1.2, | |
early_stopping=True, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_k=30, | |
top_p=0.92 | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = TextProcessor.clean_text(response) | |
if len(response.strip()) < 50: # Ajustado para respostas mais curtas do FLAN-T5 | |
return self._get_fallback_response() | |
return self._format_response(response) | |
except Exception as e: | |
logging.error(f"Error generating response: {str(e)}") | |
return "Error: Could not generate response. Please try again." | |
def _get_fallback_response() -> str: | |
"""Provide a structured fallback response""" | |
return """Based on the available research, I cannot provide a specific answer to your question. However, I can suggest: | |
1. Try rephrasing your question to focus on specific aspects of autism | |
2. Consider asking about: | |
- Specific behaviors or characteristics | |
- Intervention strategies | |
- Research findings | |
- Support approaches | |
This will help me provide more accurate, research-based information.""" | |
def _format_response(response: str) -> str: | |
"""Format the response for better readability""" | |
# Add section headers | |
sections = response.split('\n\n') | |
formatted_sections = [] | |
for i, section in enumerate(sections): | |
if i == 0: | |
formatted_sections.append(f"### Overview\n{section}") | |
elif i == len(sections) - 1: | |
formatted_sections.append(f"### Key Takeaways\n{section}") | |
else: | |
formatted_sections.append(section) | |
return '\n\n'.join(formatted_sections) | |
def main(): | |
st.title("🧩 AMA Autism") | |
st.write(""" | |
Ask questions about autism and get research-based answers from scientific papers. | |
For best results, be specific in your questions. | |
""") | |
query = st.text_input("What would you like to know about autism? ✨") | |
if query: | |
with st.status("Researching your question...") as status: | |
# Initialize handlers | |
research_fetcher = ResearchFetcher() | |
model_handler = ModelHandler() | |
# Fetch papers concurrently | |
with ThreadPoolExecutor(max_workers=2) as executor: | |
arxiv_future = executor.submit(research_fetcher.fetch_arxiv_papers, query) | |
pubmed_future = executor.submit(research_fetcher.fetch_pubmed_papers, query) | |
papers = arxiv_future.result() + pubmed_future.result() | |
if not papers: | |
st.warning("No relevant research papers found. Please try a different search term.") | |
return | |
# Sort papers by relevance | |
papers.sort(key=lambda x: x.relevance_score, reverse=True) | |
# Prepare context from top papers | |
context = "\n".join( | |
TextProcessor.format_paper(paper.title, paper.abstract) | |
for paper in papers[:3] | |
) | |
# Generate answer | |
st.write("Analyzing research papers...") | |
answer = model_handler.generate_answer(query, context) | |
# Display sources | |
with st.expander("📚 View source papers"): | |
for paper in papers: | |
st.markdown(f"- [{paper.title}]({paper.url}) ({paper.published})") | |
st.success("Research analysis complete!") | |
st.markdown(answer) | |
if __name__ == "__main__": | |
main() |