Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
import os | |
from datasets import load_from_disk, Dataset | |
import torch | |
import logging | |
import pandas as pd | |
import arxiv | |
import requests | |
import xml.etree.ElementTree as ET | |
from agno.embedder.huggingface import HuggingfaceCustomEmbedder | |
from agno.vectordb.lancedb import LanceDb, SearchType | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths and constants | |
DATA_DIR = "/data" if os.path.exists("/data") else "." | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
MODEL_PATH = "google/flan-t5-base" # Lighter model | |
def load_local_model(): | |
"""Load the local Hugging Face model""" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=torch.float32, # Using float32 for CPU compatibility | |
device_map="auto" | |
) | |
return model, tokenizer | |
def fetch_arxiv_papers(query, max_results=5): | |
"""Fetch papers from arXiv""" | |
client = arxiv.Client() | |
# Clean and prepare the search query | |
search_query = f"ti:{query} OR abs:{query} AND cat:q-bio" | |
# Search arXiv | |
search = arxiv.Search( | |
query=search_query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
papers = [] | |
for result in client.results(search): | |
papers.append({ | |
'title': result.title, | |
'abstract': result.summary, | |
'url': result.pdf_url, | |
'published': result.published | |
}) | |
return papers | |
def fetch_pubmed_papers(query, max_results=5): | |
"""Fetch papers from PubMed""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
# Search for papers | |
search_url = f"{base_url}/esearch.fcgi" | |
search_params = { | |
'db': 'pubmed', | |
'term': query, | |
'retmax': max_results, | |
'sort': 'relevance', | |
'retmode': 'xml' | |
} | |
papers = [] | |
try: | |
# Get paper IDs | |
response = requests.get(search_url, params=search_params) | |
root = ET.fromstring(response.content) | |
id_list = [id_elem.text for id_elem in root.findall('.//Id')] | |
if not id_list: | |
return papers | |
# Fetch paper details | |
fetch_url = f"{base_url}/efetch.fcgi" | |
fetch_params = { | |
'db': 'pubmed', | |
'id': ','.join(id_list), | |
'retmode': 'xml' | |
} | |
response = requests.get(fetch_url, params=fetch_params) | |
articles = ET.fromstring(response.content) | |
for article in articles.findall('.//PubmedArticle'): | |
title = article.find('.//ArticleTitle') | |
abstract = article.find('.//Abstract/AbstractText') | |
papers.append({ | |
'title': title.text if title is not None else 'No title available', | |
'abstract': abstract.text if abstract is not None else 'No abstract available', | |
'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/", | |
'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown' | |
}) | |
except Exception as e: | |
st.error(f"Error fetching PubMed papers: {str(e)}") | |
return papers | |
def search_research_papers(query): | |
"""Search both arXiv and PubMed for papers""" | |
arxiv_papers = fetch_arxiv_papers(query) | |
pubmed_papers = fetch_pubmed_papers(query) | |
# Combine and format papers | |
all_papers = [] | |
for paper in arxiv_papers + pubmed_papers: | |
all_papers.append({ | |
'title': paper['title'], | |
'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}", | |
'url': paper['url'], | |
'published': paper['published'] | |
}) | |
return pd.DataFrame(all_papers) | |
def generate_answer(question, context, max_length=512): | |
"""Generate a comprehensive answer using the local model""" | |
model, tokenizer = load_local_model() | |
# Format the context as a structured query | |
prompt = f"""Based on the following research papers about autism, provide a detailed answer: | |
Question: {question} | |
Research Context: | |
{context} | |
Please analyze: | |
1. Main findings | |
2. Research methods | |
3. Clinical implications | |
4. Limitations | |
Answer:""" | |
# Generate response | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
num_beams=4, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
early_stopping=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Format the response | |
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ") | |
return formatted_response | |
# Streamlit App | |
st.title("🧩 AMA Autism") | |
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.") | |
query = st.text_input("Please ask me anything about autism ✨") | |
if query: | |
with st.status("Searching for answers...") as status: | |
# Search for papers | |
df = search_research_papers(query) | |
st.write("Searching for data in PubMed and arXiv...") | |
st.write("Data found!") | |
# Get relevant context | |
context = "\n".join([ | |
f"{text[:1000]}" for text in df['text'].head(3) | |
]) | |
# Generate answer | |
answer = generate_answer(query, context) | |
st.write("Generating answer...") | |
status.update( | |
label="Search complete!", state="complete", expanded=False | |
) | |
if answer and not answer.isspace(): | |
st.success("Answer found!") | |
st.write(answer) | |
st.write("### Sources used:") | |
for _, row in df.head(3).iterrows(): | |
st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})") | |
st.write(f"**Summary:** {row['text'][:200]}...") | |
st.write("---") | |
else: | |
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.") | |
if df.empty: | |
st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.") |