Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import os | |
| from datasets import load_from_disk, Dataset | |
| import torch | |
| import logging | |
| import pandas as pd | |
| import arxiv | |
| import requests | |
| import xml.etree.ElementTree as ET | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Define data paths and constants | |
| DATA_DIR = "/data" if os.path.exists("/data") else "." | |
| DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
| DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
| MODEL_PATH = "t5-small" # Changed to T5-small for better CPU compatibility | |
| def load_local_model(): | |
| """Load the local Hugging Face model""" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| MODEL_PATH, | |
| device_map={"": "cpu"}, # Force CPU | |
| torch_dtype=torch.float32 | |
| ) | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return None, None | |
| def fetch_arxiv_papers(query, max_results=5): | |
| """Fetch papers from arXiv""" | |
| client = arxiv.Client() | |
| # Always include autism in the search query | |
| search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio" | |
| # Search arXiv | |
| search = arxiv.Search( | |
| query=search_query, | |
| max_results=max_results, | |
| sort_by=arxiv.SortCriterion.Relevance | |
| ) | |
| papers = [] | |
| for result in client.results(search): | |
| # Only include papers that mention autism in title or abstract | |
| if ('autism' in result.title.lower() or | |
| 'asd' in result.title.lower() or | |
| 'autism' in result.summary.lower() or | |
| 'asd' in result.summary.lower()): | |
| papers.append({ | |
| 'title': result.title, | |
| 'abstract': result.summary, | |
| 'url': result.pdf_url, | |
| 'published': result.published.strftime("%Y-%m-%d"), | |
| 'relevance_score': 1 if 'autism' in result.title.lower() else 0.5 | |
| }) | |
| return papers | |
| def fetch_pubmed_papers(query, max_results=5): | |
| """Fetch papers from PubMed""" | |
| base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
| # Always include autism in the search term | |
| search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])" | |
| # Search for papers | |
| search_url = f"{base_url}/esearch.fcgi" | |
| search_params = { | |
| 'db': 'pubmed', | |
| 'term': search_term, | |
| 'retmax': max_results, | |
| 'sort': 'relevance', | |
| 'retmode': 'xml' | |
| } | |
| papers = [] | |
| try: | |
| # Get paper IDs | |
| response = requests.get(search_url, params=search_params) | |
| root = ET.fromstring(response.content) | |
| id_list = [id_elem.text for id_elem in root.findall('.//Id')] | |
| if not id_list: | |
| return papers | |
| # Fetch paper details | |
| fetch_url = f"{base_url}/efetch.fcgi" | |
| fetch_params = { | |
| 'db': 'pubmed', | |
| 'id': ','.join(id_list), | |
| 'retmode': 'xml' | |
| } | |
| response = requests.get(fetch_url, params=fetch_params) | |
| articles = ET.fromstring(response.content) | |
| for article in articles.findall('.//PubmedArticle'): | |
| title = article.find('.//ArticleTitle') | |
| abstract = article.find('.//Abstract/AbstractText') | |
| year = article.find('.//PubDate/Year') | |
| pmid = article.find('.//PMID') | |
| if title is not None and abstract is not None: | |
| title_text = title.text.lower() | |
| abstract_text = abstract.text.lower() | |
| # Only include papers that mention autism | |
| if ('autism' in title_text or 'asd' in title_text or | |
| 'autism' in abstract_text or 'asd' in abstract_text): | |
| papers.append({ | |
| 'title': title.text, | |
| 'abstract': abstract.text, | |
| 'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/", | |
| 'published': year.text if year is not None else 'Unknown', | |
| 'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5 | |
| }) | |
| except Exception as e: | |
| st.error(f"Error fetching PubMed papers: {str(e)}") | |
| return papers | |
| def search_research_papers(query): | |
| """Search both arXiv and PubMed for papers""" | |
| arxiv_papers = fetch_arxiv_papers(query) | |
| pubmed_papers = fetch_pubmed_papers(query) | |
| # Combine and format papers | |
| all_papers = [] | |
| for paper in arxiv_papers + pubmed_papers: | |
| if paper['abstract'] and len(paper['abstract'].strip()) > 0: | |
| # Check if the paper is actually about autism | |
| if ('autism' in paper['title'].lower() or | |
| 'asd' in paper['title'].lower() or | |
| 'autism' in paper['abstract'].lower() or | |
| 'asd' in paper['abstract'].lower()): | |
| all_papers.append({ | |
| 'title': paper['title'], | |
| 'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}", | |
| 'url': paper['url'], | |
| 'published': paper['published'], | |
| 'relevance_score': paper.get('relevance_score', 0.5) | |
| }) | |
| # Sort papers by relevance score and convert to DataFrame | |
| all_papers.sort(key=lambda x: x['relevance_score'], reverse=True) | |
| df = pd.DataFrame(all_papers) | |
| if df.empty: | |
| st.warning("No autism-related papers found. Please try a different search term.") | |
| return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score']) | |
| return df | |
| def generate_answer(question, context, max_length=512): | |
| """Generate a comprehensive answer using the local model""" | |
| model, tokenizer = load_local_model() | |
| if model is None or tokenizer is None: | |
| return "Error: Could not load the model. Please try again later." | |
| # Format the context as a structured query | |
| prompt = f"""You are an expert in autism research. Provide a comprehensive answer about autism, incorporating both general knowledge and specific research findings when available. | |
| Question: {question} | |
| Recent Research Context: | |
| {context} | |
| Instructions: Provide a detailed response that: | |
| 1. Starts with a general overview of the topic as it relates to autism | |
| 2. Incorporates specific findings from the provided research papers when relevant | |
| 3. Discusses practical implications for individuals with autism and their families | |
| 4. Mentions any limitations in current understanding | |
| If the research papers don't directly address the question, focus on providing general, well-established information about autism while noting what specific research would be helpful.""" | |
| try: | |
| # Generate response | |
| inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| min_length=150, # Increased minimum length for more comprehensive answers | |
| num_beams=4, | |
| length_penalty=1.5, | |
| temperature=0.7, | |
| repetition_penalty=1.2, | |
| early_stopping=True | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # If response is too short or empty, provide a general overview | |
| if len(response.strip()) < 100: | |
| return f"""Here's what we know about autism in relation to your question about {question}: | |
| 1. General Understanding: | |
| - Autism Spectrum Disorder (ASD) is a complex developmental condition | |
| - It affects how a person communicates, learns, and interacts with others | |
| - Each person with autism has unique strengths and challenges | |
| 2. Key Aspects: | |
| - Communication and social interaction | |
| - Repetitive behaviors and specific interests | |
| - Sensory sensitivities | |
| - Early intervention is important | |
| 3. Current Research: | |
| While the provided research papers don't directly address your specific question, researchers are actively studying various aspects of autism to better understand its causes, characteristics, and effective interventions. | |
| For more specific information, try asking about: | |
| - Specific symptoms or characteristics | |
| - Diagnostic processes | |
| - Treatment approaches | |
| - Current research in specific areas""" | |
| # Format the response for better readability | |
| formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ") | |
| return formatted_response | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| return "Error: Could not generate response. Please try again with a different question." | |
| # Streamlit App | |
| st.title("🧩 AMA Autism") | |
| st.write(""" | |
| This app searches through scientific papers to answer your questions about autism. | |
| For best results, be specific in your questions. | |
| """) | |
| query = st.text_input("Please ask me anything about autism ✨") | |
| if query: | |
| with st.status("Searching for answers...") as status: | |
| # Search for papers | |
| df = search_research_papers(query) | |
| st.write("Searching for data in PubMed and arXiv...") | |
| st.write(f"Found {len(df)} relevant papers!") | |
| # Get relevant context | |
| context = "\n".join([ | |
| f"{text[:1000]}" for text in df['text'].head(3) | |
| ]) | |
| # Generate answer | |
| st.write("Generating answer...") | |
| answer = generate_answer(query, context) | |
| # Display paper sources | |
| with st.expander("View source papers"): | |
| for _, paper in df.iterrows(): | |
| st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})") | |
| st.success("Answer found!") | |
| st.markdown(answer) |