ama-autism / _old_app.py
wakeupmh's picture
test: agno
4660a83
raw
history blame
6.64 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
import arxiv
import requests
import xml.etree.ElementTree as ET
from agno.embedder.huggingface import HuggingfaceCustomEmbedder
from agno.vectordb.lancedb import LanceDb, SearchType
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "google/flan-t5-base" # Lighter model
@st.cache_resource
def load_local_model():
"""Load the local Hugging Face model"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32, # Using float32 for CPU compatibility
device_map="auto"
)
return model, tokenizer
def fetch_arxiv_papers(query, max_results=5):
"""Fetch papers from arXiv"""
client = arxiv.Client()
# Clean and prepare the search query
search_query = f"ti:{query} OR abs:{query} AND cat:q-bio"
# Search arXiv
search = arxiv.Search(
query=search_query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
papers = []
for result in client.results(search):
papers.append({
'title': result.title,
'abstract': result.summary,
'url': result.pdf_url,
'published': result.published
})
return papers
def fetch_pubmed_papers(query, max_results=5):
"""Fetch papers from PubMed"""
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
# Search for papers
search_url = f"{base_url}/esearch.fcgi"
search_params = {
'db': 'pubmed',
'term': query,
'retmax': max_results,
'sort': 'relevance',
'retmode': 'xml'
}
papers = []
try:
# Get paper IDs
response = requests.get(search_url, params=search_params)
root = ET.fromstring(response.content)
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
if not id_list:
return papers
# Fetch paper details
fetch_url = f"{base_url}/efetch.fcgi"
fetch_params = {
'db': 'pubmed',
'id': ','.join(id_list),
'retmode': 'xml'
}
response = requests.get(fetch_url, params=fetch_params)
articles = ET.fromstring(response.content)
for article in articles.findall('.//PubmedArticle'):
title = article.find('.//ArticleTitle')
abstract = article.find('.//Abstract/AbstractText')
papers.append({
'title': title.text if title is not None else 'No title available',
'abstract': abstract.text if abstract is not None else 'No abstract available',
'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/",
'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown'
})
except Exception as e:
st.error(f"Error fetching PubMed papers: {str(e)}")
return papers
def search_research_papers(query):
"""Search both arXiv and PubMed for papers"""
arxiv_papers = fetch_arxiv_papers(query)
pubmed_papers = fetch_pubmed_papers(query)
# Combine and format papers
all_papers = []
for paper in arxiv_papers + pubmed_papers:
all_papers.append({
'title': paper['title'],
'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}",
'url': paper['url'],
'published': paper['published']
})
return pd.DataFrame(all_papers)
def generate_answer(question, context, max_length=512):
"""Generate a comprehensive answer using the local model"""
model, tokenizer = load_local_model()
# Format the context as a structured query
prompt = f"""Based on the following research papers about autism, provide a detailed answer:
Question: {question}
Research Context:
{context}
Please analyze:
1. Main findings
2. Research methods
3. Clinical implications
4. Limitations
Answer:"""
# Generate response
inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=4,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Format the response
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
return formatted_response
# Streamlit App
st.title("🧩 AMA Autism")
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers...") as status:
# Search for papers
df = search_research_papers(query)
st.write("Searching for data in PubMed and arXiv...")
st.write("Data found!")
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
# Generate answer
answer = generate_answer(query, context)
st.write("Generating answer...")
status.update(
label="Search complete!", state="complete", expanded=False
)
if answer and not answer.isspace():
st.success("Answer found!")
st.write(answer)
st.write("### Sources used:")
for _, row in df.head(3).iterrows():
st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
st.write(f"**Summary:** {row['text'][:200]}...")
st.write("---")
else:
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
if df.empty:
st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")