ama-autism / app.py
wakeupmh's picture
fix: prompt
67c9f69
raw
history blame
10.8 kB
import streamlit as st
import pandas as pd
import torch
import logging
import os
from transformers import AutoTokenizer, T5ForConditionalGeneration
import arxiv
import requests
import xml.etree.ElementTree as ET
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "google/flan-t5-small" # Using flan-t5-small for better performance
@st.cache_resource
def load_local_model():
"""Load the local Hugging Face model"""
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = T5ForConditionalGeneration.from_pretrained(
MODEL_PATH,
device_map={"": "cpu"}, # Force CPU
torch_dtype=torch.float32
)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None
def clean_text(text):
"""Clean and normalize text content"""
if not text:
return ""
# Remove special characters and normalize spaces
text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text)
text = re.sub(r'\s+', ' ', text)
text = text.replace('’', "'").replace('“', '"').replace('â€', '"')
# Remove any remaining weird characters
text = ''.join(char for char in text if ord(char) < 128)
return text.strip()
def format_paper(title, abstract):
"""Format paper information consistently"""
title = clean_text(title)
abstract = clean_text(abstract)
if len(abstract) > 1000:
abstract = abstract[:997] + "..."
return f"""Title: {title}
Abstract: {abstract}
---"""
def fetch_arxiv_papers(query, max_results=5):
"""Fetch papers from arXiv"""
client = arxiv.Client()
# Always include autism in the search query
search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio"
# Search arXiv
search = arxiv.Search(
query=search_query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
papers = []
for result in client.results(search):
# Only include papers that mention autism in title or abstract
if ('autism' in result.title.lower() or
'asd' in result.title.lower() or
'autism' in result.summary.lower() or
'asd' in result.summary.lower()):
papers.append({
'title': result.title,
'abstract': result.summary,
'url': result.pdf_url,
'published': result.published.strftime("%Y-%m-%d"),
'relevance_score': 1 if 'autism' in result.title.lower() else 0.5
})
return papers
def fetch_pubmed_papers(query, max_results=5):
"""Fetch papers from PubMed"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
# Always include autism in the search term
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
# Search for papers
search_url = f"{base_url}/esearch.fcgi"
search_params = {
'db': 'pubmed',
'term': search_term,
'retmax': max_results,
'sort': 'relevance',
'retmode': 'xml'
}
papers = []
try:
# Get paper IDs
response = requests.get(search_url, params=search_params)
root = ET.fromstring(response.content)
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
if not id_list:
return papers
# Fetch paper details
fetch_url = f"{base_url}/efetch.fcgi"
fetch_params = {
'db': 'pubmed',
'id': ','.join(id_list),
'retmode': 'xml'
}
response = requests.get(fetch_url, params=fetch_params)
articles = ET.fromstring(response.content)
for article in articles.findall('.//PubmedArticle'):
title = article.find('.//ArticleTitle')
abstract = article.find('.//Abstract/AbstractText')
year = article.find('.//PubDate/Year')
pmid = article.find('.//PMID')
if title is not None and abstract is not None:
title_text = title.text.lower()
abstract_text = abstract.text.lower()
# Only include papers that mention autism
if ('autism' in title_text or 'asd' in title_text or
'autism' in abstract_text or 'asd' in abstract_text):
papers.append({
'title': title.text,
'abstract': abstract.text,
'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
'published': year.text if year is not None else 'Unknown',
'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5
})
except Exception as e:
st.error(f"Error fetching PubMed papers: {str(e)}")
return papers
def search_research_papers(query):
"""Search both arXiv and PubMed for papers"""
arxiv_papers = fetch_arxiv_papers(query)
pubmed_papers = fetch_pubmed_papers(query)
# Combine and format papers
all_papers = []
for paper in arxiv_papers + pubmed_papers:
if paper['abstract'] and len(paper['abstract'].strip()) > 0:
# Clean and format the paper content
clean_title = clean_text(paper['title'])
clean_abstract = clean_text(paper['abstract'])
# Check if the paper is actually about autism
if ('autism' in clean_title.lower() or
'asd' in clean_title.lower() or
'autism' in clean_abstract.lower() or
'asd' in clean_abstract.lower()):
formatted_text = format_paper(clean_title, clean_abstract)
all_papers.append({
'title': clean_title,
'text': formatted_text,
'url': paper['url'],
'published': paper['published'],
'relevance_score': paper.get('relevance_score', 0.5)
})
# Sort papers by relevance score and convert to DataFrame
all_papers.sort(key=lambda x: x['relevance_score'], reverse=True)
df = pd.DataFrame(all_papers)
if df.empty:
st.warning("No autism-related papers found. Please try a different search term.")
return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score'])
return df
def generate_answer(question, context, max_length=512):
"""Generate a comprehensive answer using the local model"""
model, tokenizer = load_local_model()
if model is None or tokenizer is None:
return "Error: Could not load the model. Please try again later."
# Clean and format the context
clean_context = clean_text(context)
clean_question = clean_text(question)
# Format the input for T5 (it expects a specific format)
input_text = f"""Generate a comprehensive answer about autism, using the research papers as references to support your explanations.
Question: {clean_question}
Research Papers:
{clean_context}
Instructions:
1. Provide a general explanation about the topic
2. Use the research papers as references, citing them in the format "According to [PAPER TITLE], ..."
3. Integrate research findings naturally into the explanation
4. Keep the focus on being informative and helpful
Answer in a clear, informative way, using the papers as references."""
try:
# T5 expects a specific format for the input
inputs = tokenizer(input_text,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_length=max_length,
min_length=200,
num_beams=5,
length_penalty=1.5,
temperature=0.7,
repetition_penalty=1.2,
early_stopping=True,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.95
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = clean_text(response)
# If response is too short or empty, provide a general overview
if len(response.strip()) < 100:
return """Autism Spectrum Disorder (ASD) is a complex neurodevelopmental condition. Unfortunately, the provided papers don't contain specific information about this aspect of autism.
To get research-based information, try asking more specific questions about:
- Genetics and environmental factors
- Early intervention
- Treatments and therapies
- Neurological development
This will allow us to provide accurate information supported by recent scientific research."""
# Format the response for better readability
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
return formatted_response
except Exception as e:
st.error(f"Error generating response: {str(e)}")
return "Error: Could not generate response. Please try again with a different question."
# Streamlit App
st.title("🧩 AMA Autism")
st.write("""
This app searches through scientific papers to answer your questions about autism.
For best results, be specific in your questions.
""")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers...") as status:
# Search for papers
df = search_research_papers(query)
st.write("Searching for data in PubMed and arXiv...")
st.write(f"Found {len(df)} relevant papers!")
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
# Generate answer
st.write("Generating answer...")
answer = generate_answer(query, context)
# Display paper sources
with st.expander("View source papers"):
for _, paper in df.iterrows():
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
st.success("Answer found!")
st.markdown(answer)