ama-autism / app.py
wakeupmh's picture
feat: add tokenizer model and summarization model
03e43ae
raw
history blame
11.1 kB
import streamlit as st
import pandas as pd
import torch
import logging
import os
from transformers import AutoTokenizer, T5ForConditionalGeneration
import arxiv
import requests
import xml.etree.ElementTree as ET
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
TOKENIZER_MODEL = "google/flan-t5-small"
SUMMARIZATION_MODEL= "Falconsai/text_summarization"
@st.cache_resource
def load_local_model():
"""Load the local Hugging Face model"""
try:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL)
model = T5ForConditionalGeneration.from_pretrained(
SUMMARIZATION_MODEL,
device_map={"": "cpu"}, # Force CPU
torch_dtype=torch.float32
)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None
def clean_text(text):
"""Clean and normalize text content"""
if not text:
return ""
# Remove special characters and normalize spaces
text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text)
text = re.sub(r'\s+', ' ', text)
text = text.replace('’', "'").replace('“', '"').replace('â€', '"')
# Remove any remaining weird characters
text = ''.join(char for char in text if ord(char) < 128)
return text.strip()
def format_paper(title, abstract):
"""Format paper information consistently"""
title = clean_text(title)
abstract = clean_text(abstract)
if len(abstract) > 1000:
abstract = abstract[:997] + "..."
return f"""Title: {title}
Abstract: {abstract}
---"""
def fetch_arxiv_papers(query, max_results=5):
"""Fetch papers from arXiv"""
client = arxiv.Client()
# Always include autism in the search query
search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio"
# Search arXiv
search = arxiv.Search(
query=search_query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
papers = []
for result in client.results(search):
# Only include papers that mention autism in title or abstract
if ('autism' in result.title.lower() or
'asd' in result.title.lower() or
'autism' in result.summary.lower() or
'asd' in result.summary.lower()):
papers.append({
'title': result.title,
'abstract': result.summary,
'url': result.pdf_url,
'published': result.published.strftime("%Y-%m-%d"),
'relevance_score': 1 if 'autism' in result.title.lower() else 0.5
})
return papers
def fetch_pubmed_papers(query, max_results=5):
"""Fetch papers from PubMed"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
# Always include autism in the search term
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
# Search for papers
search_url = f"{base_url}/esearch.fcgi"
search_params = {
'db': 'pubmed',
'term': search_term,
'retmax': max_results,
'sort': 'relevance',
'retmode': 'xml'
}
papers = []
try:
# Get paper IDs
response = requests.get(search_url, params=search_params)
root = ET.fromstring(response.content)
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
if not id_list:
return papers
# Fetch paper details
fetch_url = f"{base_url}/efetch.fcgi"
fetch_params = {
'db': 'pubmed',
'id': ','.join(id_list),
'retmode': 'xml'
}
response = requests.get(fetch_url, params=fetch_params)
articles = ET.fromstring(response.content)
for article in articles.findall('.//PubmedArticle'):
title = article.find('.//ArticleTitle')
abstract = article.find('.//Abstract/AbstractText')
year = article.find('.//PubDate/Year')
pmid = article.find('.//PMID')
if title is not None and abstract is not None:
title_text = title.text.lower()
abstract_text = abstract.text.lower()
# Only include papers that mention autism
if ('autism' in title_text or 'asd' in title_text or
'autism' in abstract_text or 'asd' in abstract_text):
papers.append({
'title': title.text,
'abstract': abstract.text,
'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
'published': year.text if year is not None else 'Unknown',
'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5
})
except Exception as e:
st.error(f"Error fetching PubMed papers: {str(e)}")
return papers
def search_research_papers(query):
"""Search both arXiv and PubMed for papers"""
arxiv_papers = fetch_arxiv_papers(query)
pubmed_papers = fetch_pubmed_papers(query)
# Combine and format papers
all_papers = []
for paper in arxiv_papers + pubmed_papers:
if paper['abstract'] and len(paper['abstract'].strip()) > 0:
# Clean and format the paper content
clean_title = clean_text(paper['title'])
clean_abstract = clean_text(paper['abstract'])
# Check if the paper is actually about autism
if ('autism' in clean_title.lower() or
'asd' in clean_title.lower() or
'autism' in clean_abstract.lower() or
'asd' in clean_abstract.lower()):
formatted_text = format_paper(clean_title, clean_abstract)
all_papers.append({
'title': clean_title,
'text': formatted_text,
'url': paper['url'],
'published': paper['published'],
'relevance_score': paper.get('relevance_score', 0.5)
})
# Sort papers by relevance score and convert to DataFrame
all_papers.sort(key=lambda x: x['relevance_score'], reverse=True)
df = pd.DataFrame(all_papers)
if df.empty:
st.warning("No autism-related papers found. Please try a different search term.")
return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score'])
return df
def generate_answer(question, context, max_length=512):
"""Generate a comprehensive answer using the local model"""
model, tokenizer = load_local_model()
if model is None or tokenizer is None:
return "Error: Could not load the model. Please try again later."
# Clean and format the context
clean_context = clean_text(context)
clean_question = clean_text(question)
# Format the input for T5 (it expects a specific format)
input_text = f"""Generate a detailed and well-structured answer about autism, using the provided research papers as references to support your explanations.
Question: {clean_question}
Research Papers:
{clean_context}
Instructions:
1. Begin with a clear and concise overview of autism, explaining its key characteristics and significance.
2. Use the research papers to support your explanation, citing them in the format: "According to [PAPER TITLE], ...".
3. Integrate findings from the papers naturally into your response, ensuring the information is accurate and relevant.
4. Focus on providing informative, helpful, and easy-to-understand insights.
Write your answer in a professional and accessible tone, ensuring it is well-organized and grounded in the provided research."""
try:
# T5 expects a specific format for the input
inputs = tokenizer(input_text,
return_tensors="pt",
max_length=1024,
truncation=True,
padding=True)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_length=max_length,
min_length=200,
num_beams=5,
length_penalty=1.5,
temperature=0.7,
repetition_penalty=1.2,
early_stopping=True,
no_repeat_ngram_size=3,
do_sample=True,
top_k=50,
top_p=0.95
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = clean_text(response)
# If response is too short or empty, provide a general overview
if len(response.strip()) < 100:
return """Autism Spectrum Disorder (ASD) is a complex neurodevelopmental condition. Unfortunately, the provided papers don't contain specific information about this aspect of autism.
To get research-based information, try asking more specific questions about:
- Genetics and environmental factors
- Early intervention
- Treatments and therapies
- Neurological development
This will allow us to provide accurate information supported by recent scientific research."""
# Format the response for better readability
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
return formatted_response
except Exception as e:
st.error(f"Error generating response: {str(e)}")
return "Error: Could not generate response. Please try again with a different question."
# Streamlit App
st.title("🧩 AMA Autism")
st.write("""
This app searches through scientific papers to answer your questions about autism.
For best results, be specific in your questions.
""")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers...") as status:
# Search for papers
df = search_research_papers(query)
st.write("Searching for data in PubMed and arXiv...")
st.write(f"Found {len(df)} relevant papers!")
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
# Generate answer
st.write("Generating answer...")
answer = generate_answer(query, context)
# Display paper sources
with st.expander("View source papers"):
for _, paper in df.iterrows():
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
st.success("Answer found!")
st.markdown(answer)