ama-autism / app.py
wakeupmh's picture
fix: response
97889da
raw
history blame
7.37 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
import arxiv
import requests
import xml.etree.ElementTree as ET
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "facebook/bart-large-cnn" # Changed to BART model which is better for summarization
@st.cache_resource
def load_local_model():
"""Load the local Hugging Face model"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32,
device_map="auto"
)
return model, tokenizer
def fetch_arxiv_papers(query, max_results=5):
"""Fetch papers from arXiv"""
client = arxiv.Client()
# Ensure query includes autism-related terms
if 'autism' not in query.lower():
search_query = f"(ti:{query} OR abs:{query}) AND (ti:autism OR abs:autism) AND cat:q-bio"
else:
search_query = f"(ti:{query} OR abs:{query}) AND cat:q-bio"
# Search arXiv
search = arxiv.Search(
query=search_query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
papers = []
for result in client.results(search):
papers.append({
'title': result.title,
'abstract': result.summary,
'url': result.pdf_url,
'published': result.published.strftime("%Y-%m-%d")
})
return papers
def fetch_pubmed_papers(query, max_results=5):
"""Fetch papers from PubMed"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
# Ensure query includes autism-related terms
if 'autism' not in query.lower():
search_term = f"({query}) AND (autism[Title/Abstract] OR ASD[Title/Abstract])"
else:
search_term = query
# Search for papers
search_url = f"{base_url}/esearch.fcgi"
search_params = {
'db': 'pubmed',
'term': search_term,
'retmax': max_results,
'sort': 'relevance',
'retmode': 'xml'
}
papers = []
try:
# Get paper IDs
response = requests.get(search_url, params=search_params)
root = ET.fromstring(response.content)
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
if not id_list:
return papers
# Fetch paper details
fetch_url = f"{base_url}/efetch.fcgi"
fetch_params = {
'db': 'pubmed',
'id': ','.join(id_list),
'retmode': 'xml'
}
response = requests.get(fetch_url, params=fetch_params)
articles = ET.fromstring(response.content)
for article in articles.findall('.//PubmedArticle'):
title = article.find('.//ArticleTitle')
abstract = article.find('.//Abstract/AbstractText')
year = article.find('.//PubDate/Year')
pmid = article.find('.//PMID')
if title is not None and abstract is not None:
papers.append({
'title': title.text,
'abstract': abstract.text,
'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
'published': year.text if year is not None else 'Unknown'
})
except Exception as e:
st.error(f"Error fetching PubMed papers: {str(e)}")
return papers
def search_research_papers(query):
"""Search both arXiv and PubMed for papers"""
arxiv_papers = fetch_arxiv_papers(query)
pubmed_papers = fetch_pubmed_papers(query)
# Combine and format papers
all_papers = []
for paper in arxiv_papers + pubmed_papers:
if paper['abstract'] and len(paper['abstract'].strip()) > 0:
all_papers.append({
'title': paper['title'],
'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}",
'url': paper['url'],
'published': paper['published']
})
return pd.DataFrame(all_papers)
def generate_answer(question, context, max_length=512):
"""Generate a comprehensive answer using the local model"""
model, tokenizer = load_local_model()
# Format the context as a structured query
prompt = f"""Summarize the following research about autism and answer the question.
Research Context:
{context}
Question: {question}
Provide a detailed answer that includes:
1. Main findings from the research
2. Research methods used
3. Clinical implications
4. Limitations of the studies
If the research doesn't address the question directly, explain what information is missing."""
# Generate response
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_length=max_length,
min_length=200, # Ensure longer responses
num_beams=5,
length_penalty=2.0, # Encourage even longer responses
temperature=0.7,
no_repeat_ngram_size=3,
repetition_penalty=1.3,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# If response is too short or empty, provide a fallback message
if len(response.strip()) < 100:
return """I apologize, but I couldn't generate a specific answer from the research papers provided.
This might be because:
1. The research papers don't directly address your question
2. The context needs more specific information
3. The question might need to be more specific
Please try rephrasing your question or ask about a more specific aspect of autism."""
# Format the response for better readability
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
return formatted_response
# Streamlit App
st.title("🧩 AMA Autism")
st.write("""
This app searches through scientific papers to answer your questions about autism.
For best results, be specific in your questions.
""")
query = st.text_input("Please ask me anything about autism >")
if query:
with st.status("Searching for answers...") as status:
# Search for papers
df = search_research_papers(query)
st.write("Searching for data in PubMed and arXiv...")
st.write(f"Found {len(df)} relevant papers!")
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
# Generate answer
st.write("Generating answer...")
answer = generate_answer(query, context)
# Display paper sources
with st.expander("View source papers"):
for _, paper in df.iterrows():
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
st.success("Answer found!")
st.markdown(answer)