Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import os | |
from datasets import load_from_disk, Dataset | |
import torch | |
import logging | |
import pandas as pd | |
import arxiv | |
import requests | |
import xml.etree.ElementTree as ET | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths and constants | |
DATA_DIR = "/data" if os.path.exists("/data") else "." | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
MODEL_PATH = "facebook/bart-large-cnn" # Changed to BART model which is better for summarization | |
def load_local_model(): | |
"""Load the local Hugging Face model""" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=torch.float32, | |
device_map="auto" | |
) | |
return model, tokenizer | |
def fetch_arxiv_papers(query, max_results=5): | |
"""Fetch papers from arXiv""" | |
client = arxiv.Client() | |
# Ensure query includes autism-related terms | |
if 'autism' not in query.lower(): | |
search_query = f"(ti:{query} OR abs:{query}) AND (ti:autism OR abs:autism) AND cat:q-bio" | |
else: | |
search_query = f"(ti:{query} OR abs:{query}) AND cat:q-bio" | |
# Search arXiv | |
search = arxiv.Search( | |
query=search_query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
papers = [] | |
for result in client.results(search): | |
papers.append({ | |
'title': result.title, | |
'abstract': result.summary, | |
'url': result.pdf_url, | |
'published': result.published.strftime("%Y-%m-%d") | |
}) | |
return papers | |
def fetch_pubmed_papers(query, max_results=5): | |
"""Fetch papers from PubMed""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
# Ensure query includes autism-related terms | |
if 'autism' not in query.lower(): | |
search_term = f"({query}) AND (autism[Title/Abstract] OR ASD[Title/Abstract])" | |
else: | |
search_term = query | |
# Search for papers | |
search_url = f"{base_url}/esearch.fcgi" | |
search_params = { | |
'db': 'pubmed', | |
'term': search_term, | |
'retmax': max_results, | |
'sort': 'relevance', | |
'retmode': 'xml' | |
} | |
papers = [] | |
try: | |
# Get paper IDs | |
response = requests.get(search_url, params=search_params) | |
root = ET.fromstring(response.content) | |
id_list = [id_elem.text for id_elem in root.findall('.//Id')] | |
if not id_list: | |
return papers | |
# Fetch paper details | |
fetch_url = f"{base_url}/efetch.fcgi" | |
fetch_params = { | |
'db': 'pubmed', | |
'id': ','.join(id_list), | |
'retmode': 'xml' | |
} | |
response = requests.get(fetch_url, params=fetch_params) | |
articles = ET.fromstring(response.content) | |
for article in articles.findall('.//PubmedArticle'): | |
title = article.find('.//ArticleTitle') | |
abstract = article.find('.//Abstract/AbstractText') | |
year = article.find('.//PubDate/Year') | |
pmid = article.find('.//PMID') | |
if title is not None and abstract is not None: | |
papers.append({ | |
'title': title.text, | |
'abstract': abstract.text, | |
'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/", | |
'published': year.text if year is not None else 'Unknown' | |
}) | |
except Exception as e: | |
st.error(f"Error fetching PubMed papers: {str(e)}") | |
return papers | |
def search_research_papers(query): | |
"""Search both arXiv and PubMed for papers""" | |
arxiv_papers = fetch_arxiv_papers(query) | |
pubmed_papers = fetch_pubmed_papers(query) | |
# Combine and format papers | |
all_papers = [] | |
for paper in arxiv_papers + pubmed_papers: | |
if paper['abstract'] and len(paper['abstract'].strip()) > 0: | |
all_papers.append({ | |
'title': paper['title'], | |
'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}", | |
'url': paper['url'], | |
'published': paper['published'] | |
}) | |
return pd.DataFrame(all_papers) | |
def generate_answer(question, context, max_length=512): | |
"""Generate a comprehensive answer using the local model""" | |
model, tokenizer = load_local_model() | |
# Format the context as a structured query | |
prompt = f"""Summarize the following research about autism and answer the question. | |
Research Context: | |
{context} | |
Question: {question} | |
Provide a detailed answer that includes: | |
1. Main findings from the research | |
2. Research methods used | |
3. Clinical implications | |
4. Limitations of the studies | |
If the research doesn't address the question directly, explain what information is missing.""" | |
# Generate response | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
min_length=200, # Ensure longer responses | |
num_beams=5, | |
length_penalty=2.0, # Encourage even longer responses | |
temperature=0.7, | |
no_repeat_ngram_size=3, | |
repetition_penalty=1.3, | |
early_stopping=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# If response is too short or empty, provide a fallback message | |
if len(response.strip()) < 100: | |
return """I apologize, but I couldn't generate a specific answer from the research papers provided. | |
This might be because: | |
1. The research papers don't directly address your question | |
2. The context needs more specific information | |
3. The question might need to be more specific | |
Please try rephrasing your question or ask about a more specific aspect of autism.""" | |
# Format the response for better readability | |
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ") | |
return formatted_response | |
# Streamlit App | |
st.title("🧩 AMA Autism") | |
st.write(""" | |
This app searches through scientific papers to answer your questions about autism. | |
For best results, be specific in your questions. | |
""") | |
query = st.text_input("Please ask me anything about autism >") | |
if query: | |
with st.status("Searching for answers...") as status: | |
# Search for papers | |
df = search_research_papers(query) | |
st.write("Searching for data in PubMed and arXiv...") | |
st.write(f"Found {len(df)} relevant papers!") | |
# Get relevant context | |
context = "\n".join([ | |
f"{text[:1000]}" for text in df['text'].head(3) | |
]) | |
# Generate answer | |
st.write("Generating answer...") | |
answer = generate_answer(query, context) | |
# Display paper sources | |
with st.expander("View source papers"): | |
for _, paper in df.iterrows(): | |
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})") | |
st.success("Answer found!") | |
st.markdown(answer) |