import streamlit as st from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline import os from datasets import load_from_disk, Dataset import torch import logging import pandas as pd import arxiv import requests import xml.etree.ElementTree as ET from agno.embedder.huggingface import HuggingfaceCustomEmbedder from agno.vectordb.lancedb import LanceDb, SearchType # Configure logging logging.basicConfig(level=logging.INFO) # Define data paths and constants DATA_DIR = "/data" if os.path.exists("/data") else "." DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") DATASET_PATH = os.path.join(DATASET_DIR, "dataset") MODEL_PATH = "google/flan-t5-base" # Lighter model @st.cache_resource def load_local_model(): """Load the local Hugging Face model""" tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_PATH, torch_dtype=torch.float32, # Using float32 for CPU compatibility device_map="auto" ) return model, tokenizer def fetch_arxiv_papers(query, max_results=5): """Fetch papers from arXiv""" client = arxiv.Client() # Clean and prepare the search query search_query = f"ti:{query} OR abs:{query} AND cat:q-bio" # Search arXiv search = arxiv.Search( query=search_query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance ) papers = [] for result in client.results(search): papers.append({ 'title': result.title, 'abstract': result.summary, 'url': result.pdf_url, 'published': result.published }) return papers def fetch_pubmed_papers(query, max_results=5): """Fetch papers from PubMed""" base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" # Search for papers search_url = f"{base_url}/esearch.fcgi" search_params = { 'db': 'pubmed', 'term': query, 'retmax': max_results, 'sort': 'relevance', 'retmode': 'xml' } papers = [] try: # Get paper IDs response = requests.get(search_url, params=search_params) root = ET.fromstring(response.content) id_list = [id_elem.text for id_elem in root.findall('.//Id')] if not id_list: return papers # Fetch paper details fetch_url = f"{base_url}/efetch.fcgi" fetch_params = { 'db': 'pubmed', 'id': ','.join(id_list), 'retmode': 'xml' } response = requests.get(fetch_url, params=fetch_params) articles = ET.fromstring(response.content) for article in articles.findall('.//PubmedArticle'): title = article.find('.//ArticleTitle') abstract = article.find('.//Abstract/AbstractText') papers.append({ 'title': title.text if title is not None else 'No title available', 'abstract': abstract.text if abstract is not None else 'No abstract available', 'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/", 'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown' }) except Exception as e: st.error(f"Error fetching PubMed papers: {str(e)}") return papers def search_research_papers(query): """Search both arXiv and PubMed for papers""" arxiv_papers = fetch_arxiv_papers(query) pubmed_papers = fetch_pubmed_papers(query) # Combine and format papers all_papers = [] for paper in arxiv_papers + pubmed_papers: all_papers.append({ 'title': paper['title'], 'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}", 'url': paper['url'], 'published': paper['published'] }) return pd.DataFrame(all_papers) def generate_answer(question, context, max_length=512): """Generate a comprehensive answer using the local model""" model, tokenizer = load_local_model() # Format the context as a structured query prompt = f"""Based on the following research papers about autism, provide a detailed answer: Question: {question} Research Context: {context} Please analyze: 1. Main findings 2. Research methods 3. Clinical implications 4. Limitations Answer:""" # Generate response inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True) with torch.inference_mode(): outputs = model.generate( **inputs, max_length=max_length, num_beams=4, temperature=0.7, top_p=0.9, repetition_penalty=1.2, early_stopping=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Format the response formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ") return formatted_response # Streamlit App st.title("🧩 AMA Autism") st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.") query = st.text_input("Please ask me anything about autism ✨") if query: with st.status("Searching for answers...") as status: # Search for papers df = search_research_papers(query) st.write("Searching for data in PubMed and arXiv...") st.write("Data found!") # Get relevant context context = "\n".join([ f"{text[:1000]}" for text in df['text'].head(3) ]) # Generate answer answer = generate_answer(query, context) st.write("Generating answer...") status.update( label="Search complete!", state="complete", expanded=False ) if answer and not answer.isspace(): st.success("Answer found!") st.write(answer) st.write("### Sources used:") for _, row in df.head(3).iterrows(): st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})") st.write(f"**Summary:** {row['text'][:200]}...") st.write("---") else: st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.") if df.empty: st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")