Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import os | |
from datasets import load_from_disk, Dataset | |
import torch | |
import logging | |
import pandas as pd | |
import arxiv | |
import requests | |
import xml.etree.ElementTree as ET | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths and constants | |
DATA_DIR = "/data" if os.path.exists("/data") else "." | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
MODEL_PATH = "t5-small" # Changed to T5-small for better CPU compatibility | |
def load_local_model(): | |
"""Load the local Hugging Face model""" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
device_map={"": "cpu"}, # Force CPU | |
torch_dtype=torch.float32 | |
) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return None, None | |
def fetch_arxiv_papers(query, max_results=5): | |
"""Fetch papers from arXiv""" | |
client = arxiv.Client() | |
# Always include autism in the search query | |
search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio" | |
# Search arXiv | |
search = arxiv.Search( | |
query=search_query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
papers = [] | |
for result in client.results(search): | |
# Only include papers that mention autism in title or abstract | |
if ('autism' in result.title.lower() or | |
'asd' in result.title.lower() or | |
'autism' in result.summary.lower() or | |
'asd' in result.summary.lower()): | |
papers.append({ | |
'title': result.title, | |
'abstract': result.summary, | |
'url': result.pdf_url, | |
'published': result.published.strftime("%Y-%m-%d"), | |
'relevance_score': 1 if 'autism' in result.title.lower() else 0.5 | |
}) | |
return papers | |
def fetch_pubmed_papers(query, max_results=5): | |
"""Fetch papers from PubMed""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
# Always include autism in the search term | |
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])" | |
# Search for papers | |
search_url = f"{base_url}/esearch.fcgi" | |
search_params = { | |
'db': 'pubmed', | |
'term': search_term, | |
'retmax': max_results, | |
'sort': 'relevance', | |
'retmode': 'xml' | |
} | |
papers = [] | |
try: | |
# Get paper IDs | |
response = requests.get(search_url, params=search_params) | |
root = ET.fromstring(response.content) | |
id_list = [id_elem.text for id_elem in root.findall('.//Id')] | |
if not id_list: | |
return papers | |
# Fetch paper details | |
fetch_url = f"{base_url}/efetch.fcgi" | |
fetch_params = { | |
'db': 'pubmed', | |
'id': ','.join(id_list), | |
'retmode': 'xml' | |
} | |
response = requests.get(fetch_url, params=fetch_params) | |
articles = ET.fromstring(response.content) | |
for article in articles.findall('.//PubmedArticle'): | |
title = article.find('.//ArticleTitle') | |
abstract = article.find('.//Abstract/AbstractText') | |
year = article.find('.//PubDate/Year') | |
pmid = article.find('.//PMID') | |
if title is not None and abstract is not None: | |
title_text = title.text.lower() | |
abstract_text = abstract.text.lower() | |
# Only include papers that mention autism | |
if ('autism' in title_text or 'asd' in title_text or | |
'autism' in abstract_text or 'asd' in abstract_text): | |
papers.append({ | |
'title': title.text, | |
'abstract': abstract.text, | |
'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/", | |
'published': year.text if year is not None else 'Unknown', | |
'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5 | |
}) | |
except Exception as e: | |
st.error(f"Error fetching PubMed papers: {str(e)}") | |
return papers | |
def search_research_papers(query): | |
"""Search both arXiv and PubMed for papers""" | |
arxiv_papers = fetch_arxiv_papers(query) | |
pubmed_papers = fetch_pubmed_papers(query) | |
# Combine and format papers | |
all_papers = [] | |
for paper in arxiv_papers + pubmed_papers: | |
if paper['abstract'] and len(paper['abstract'].strip()) > 0: | |
# Check if the paper is actually about autism | |
if ('autism' in paper['title'].lower() or | |
'asd' in paper['title'].lower() or | |
'autism' in paper['abstract'].lower() or | |
'asd' in paper['abstract'].lower()): | |
all_papers.append({ | |
'title': paper['title'], | |
'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}", | |
'url': paper['url'], | |
'published': paper['published'], | |
'relevance_score': paper.get('relevance_score', 0.5) | |
}) | |
# Sort papers by relevance score and convert to DataFrame | |
all_papers.sort(key=lambda x: x['relevance_score'], reverse=True) | |
df = pd.DataFrame(all_papers) | |
if df.empty: | |
st.warning("No autism-related papers found. Please try a different search term.") | |
return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score']) | |
return df | |
def generate_answer(question, context, max_length=512): | |
"""Generate a comprehensive answer using the local model""" | |
model, tokenizer = load_local_model() | |
if model is None or tokenizer is None: | |
return "Error: Could not load the model. Please try again later." | |
# Format the context as a structured query | |
prompt = f"""You are an expert in autism research. Provide a comprehensive answer about autism, incorporating both general knowledge and specific research findings when available. | |
Question: {question} | |
Recent Research Context: | |
{context} | |
Instructions: Provide a detailed response that: | |
1. Starts with a general overview of the topic as it relates to autism | |
2. Incorporates specific findings from the provided research papers when relevant | |
3. Discusses practical implications for individuals with autism and their families | |
4. Mentions any limitations in current understanding | |
If the research papers don't directly address the question, focus on providing general, well-established information about autism while noting what specific research would be helpful.""" | |
try: | |
# Generate response | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
min_length=150, # Increased minimum length for more comprehensive answers | |
num_beams=4, | |
length_penalty=1.5, | |
temperature=0.7, | |
repetition_penalty=1.2, | |
early_stopping=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# If response is too short or empty, provide a general overview | |
if len(response.strip()) < 100: | |
return f"""Here's what we know about autism in relation to your question about {question}: | |
1. General Understanding: | |
- Autism Spectrum Disorder (ASD) is a complex developmental condition | |
- It affects how a person communicates, learns, and interacts with others | |
- Each person with autism has unique strengths and challenges | |
2. Key Aspects: | |
- Communication and social interaction | |
- Repetitive behaviors and specific interests | |
- Sensory sensitivities | |
- Early intervention is important | |
3. Current Research: | |
While the provided research papers don't directly address your specific question, researchers are actively studying various aspects of autism to better understand its causes, characteristics, and effective interventions. | |
For more specific information, try asking about: | |
- Specific symptoms or characteristics | |
- Diagnostic processes | |
- Treatment approaches | |
- Current research in specific areas""" | |
# Format the response for better readability | |
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ") | |
return formatted_response | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return "Error: Could not generate response. Please try again with a different question." | |
# Streamlit App | |
st.title("🧩 AMA Autism") | |
st.write(""" | |
This app searches through scientific papers to answer your questions about autism. | |
For best results, be specific in your questions. | |
""") | |
query = st.text_input("Please ask me anything about autism ✨") | |
if query: | |
with st.status("Searching for answers...") as status: | |
# Search for papers | |
df = search_research_papers(query) | |
st.write("Searching for data in PubMed and arXiv...") | |
st.write(f"Found {len(df)} relevant papers!") | |
# Get relevant context | |
context = "\n".join([ | |
f"{text[:1000]}" for text in df['text'].head(3) | |
]) | |
# Generate answer | |
st.write("Generating answer...") | |
answer = generate_answer(query, context) | |
# Display paper sources | |
with st.expander("View source papers"): | |
for _, paper in df.iterrows(): | |
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})") | |
st.success("Answer found!") | |
st.markdown(answer) |