Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import logging | |
| import os | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| import arxiv | |
| import requests | |
| import xml.etree.ElementTree as ET | |
| import re | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Define data paths and constants | |
| DATA_DIR = "/data" if os.path.exists("/data") else "." | |
| DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
| DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
| TOKENIZER_MODEL = "google/flan-t5-small" | |
| SUMMARIZATION_MODEL= "HuggingFaceTB/SmolVLM-256M-Instruct" | |
| # SUMMARIZATION_MODEL="rhaymison/t5-portuguese-small-summarization" | |
| def load_local_model(): | |
| """Load the local Hugging Face model""" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL) | |
| model = T5ForConditionalGeneration.from_pretrained( | |
| SUMMARIZATION_MODEL, | |
| device_map={"": "cpu"}, # Force CPU | |
| torch_dtype=torch.float32 | |
| ) | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return None, None | |
| def clean_text(text): | |
| """Clean and normalize text content""" | |
| if not text: | |
| return "" | |
| # Remove special characters and normalize spaces | |
| text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text) | |
| text = text.replace('’', "'").replace('“', '"').replace('â€', '"') | |
| # Remove any remaining weird characters | |
| text = ''.join(char for char in text if ord(char) < 128) | |
| return text.strip() | |
| def format_paper(title, abstract): | |
| """Format paper information consistently""" | |
| title = clean_text(title) | |
| abstract = clean_text(abstract) | |
| if len(abstract) > 1000: | |
| abstract = abstract[:997] + "..." | |
| return f"""Title: {title} | |
| Abstract: {abstract} | |
| ---""" | |
| def fetch_arxiv_papers(query, max_results=5): | |
| """Fetch papers from arXiv""" | |
| client = arxiv.Client() | |
| # Always include autism in the search query | |
| search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio" | |
| # Search arXiv | |
| search = arxiv.Search( | |
| query=search_query, | |
| max_results=max_results, | |
| sort_by=arxiv.SortCriterion.Relevance | |
| ) | |
| papers = [] | |
| for result in client.results(search): | |
| # Only include papers that mention autism in title or abstract | |
| if ('autism' in result.title.lower() or | |
| 'asd' in result.title.lower() or | |
| 'autism' in result.summary.lower() or | |
| 'asd' in result.summary.lower()): | |
| papers.append({ | |
| 'title': result.title, | |
| 'abstract': result.summary, | |
| 'url': result.pdf_url, | |
| 'published': result.published.strftime("%Y-%m-%d"), | |
| 'relevance_score': 1 if 'autism' in result.title.lower() else 0.5 | |
| }) | |
| return papers | |
| def fetch_pubmed_papers(query, max_results=5): | |
| """Fetch papers from PubMed""" | |
| base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
| # Always include autism in the search term | |
| search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])" | |
| # Search for papers | |
| search_url = f"{base_url}/esearch.fcgi" | |
| search_params = { | |
| 'db': 'pubmed', | |
| 'term': search_term, | |
| 'retmax': max_results, | |
| 'sort': 'relevance', | |
| 'retmode': 'xml' | |
| } | |
| papers = [] | |
| try: | |
| # Get paper IDs | |
| response = requests.get(search_url, params=search_params) | |
| root = ET.fromstring(response.content) | |
| id_list = [id_elem.text for id_elem in root.findall('.//Id')] | |
| if not id_list: | |
| return papers | |
| # Fetch paper details | |
| fetch_url = f"{base_url}/efetch.fcgi" | |
| fetch_params = { | |
| 'db': 'pubmed', | |
| 'id': ','.join(id_list), | |
| 'retmode': 'xml' | |
| } | |
| response = requests.get(fetch_url, params=fetch_params) | |
| articles = ET.fromstring(response.content) | |
| for article in articles.findall('.//PubmedArticle'): | |
| title = article.find('.//ArticleTitle') | |
| abstract = article.find('.//Abstract/AbstractText') | |
| year = article.find('.//PubDate/Year') | |
| pmid = article.find('.//PMID') | |
| if title is not None and abstract is not None: | |
| title_text = title.text.lower() | |
| abstract_text = abstract.text.lower() | |
| # Only include papers that mention autism | |
| if ('autism' in title_text or 'asd' in title_text or | |
| 'autism' in abstract_text or 'asd' in abstract_text): | |
| papers.append({ | |
| 'title': title.text, | |
| 'abstract': abstract.text, | |
| 'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/", | |
| 'published': year.text if year is not None else 'Unknown', | |
| 'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5 | |
| }) | |
| except Exception as e: | |
| st.error(f"Error fetching PubMed papers: {str(e)}") | |
| return papers | |
| def search_research_papers(query): | |
| """Search both arXiv and PubMed for papers""" | |
| arxiv_papers = fetch_arxiv_papers(query) | |
| pubmed_papers = fetch_pubmed_papers(query) | |
| # Combine and format papers | |
| all_papers = [] | |
| for paper in arxiv_papers + pubmed_papers: | |
| if paper['abstract'] and len(paper['abstract'].strip()) > 0: | |
| # Clean and format the paper content | |
| clean_title = clean_text(paper['title']) | |
| clean_abstract = clean_text(paper['abstract']) | |
| # Check if the paper is actually about autism | |
| if ('autism' in clean_title.lower() or | |
| 'asd' in clean_title.lower() or | |
| 'autism' in clean_abstract.lower() or | |
| 'asd' in clean_abstract.lower()): | |
| formatted_text = format_paper(clean_title, clean_abstract) | |
| all_papers.append({ | |
| 'title': clean_title, | |
| 'text': formatted_text, | |
| 'url': paper['url'], | |
| 'published': paper['published'], | |
| 'relevance_score': paper.get('relevance_score', 0.5) | |
| }) | |
| # Sort papers by relevance score and convert to DataFrame | |
| all_papers.sort(key=lambda x: x['relevance_score'], reverse=True) | |
| df = pd.DataFrame(all_papers) | |
| if df.empty: | |
| st.warning("No autism-related papers found. Please try a different search term.") | |
| return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score']) | |
| return df | |
| def generate_answer(question, context, max_length=512): | |
| """Generate a comprehensive answer using the local model""" | |
| model, tokenizer = load_local_model() | |
| if model is None or tokenizer is None: | |
| return "Error: Could not load the model. Please try again later." | |
| # Clean and format the context | |
| clean_context = clean_text(context) | |
| clean_question = clean_text(question) | |
| # Format the input for T5 (it expects a specific format) | |
| input_text = f"""Objective: | |
| Provide a clear, simple, and well-structured answer about autism that is easy to understand for a general audience. Use the provided research papers as references. | |
| Question: {clean_question} | |
| Research Papers: | |
| {clean_context} | |
| Instructions: | |
| Start with a simple definition | |
| - Explain what autism is in a short and clear way, avoiding technical terms. | |
| - Use real-life examples | |
| - Give practical and relatable examples to help illustrate key points. | |
| - Explain research in simple words | |
| - Instead of just citing studies, summarize their key findings in a way that anyone can understand. Example: "A study from X University found that..." | |
| - Avoid complex words | |
| - If a scientific term is needed, provide a short and simple explanation. | |
| - Use clear formatting | |
| - Write in short paragraphs, bullet points, or numbered lists to improve readability. | |
| - Keep a friendly tone | |
| - Make the response engaging and easy to follow, so people without prior knowledge can understand.""" | |
| try: | |
| # T5 expects a specific format for the input | |
| inputs = tokenizer(input_text, | |
| return_tensors="pt", | |
| max_length=1024, | |
| truncation=True, | |
| padding=True) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| min_length=200, | |
| num_beams=3, # Reduzindo para mais variedade | |
| length_penalty=1.2, # Melhor equilíbrio entre concisão e detalhes | |
| temperature=0.8, # Aumentando um pouco para mais fluidez | |
| repetition_penalty=1.2, | |
| early_stopping=True, | |
| no_repeat_ngram_size=2, # Mantendo variação no texto | |
| do_sample=True, | |
| top_k=30, # Reduzindo para respostas mais coerentes | |
| top_p=0.9 # Equilibrando diversidade e precisão | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = clean_text(response) | |
| # If response is too short or empty, provide a general overview | |
| if len(response.strip()) < 100: | |
| return """Autism Spectrum Disorder (ASD) is a complex neurodevelopmental condition. Unfortunately, the provided papers don't contain specific information about this aspect of autism. | |
| To get research-based information, try asking more specific questions about: | |
| - Genetics and environmental factors | |
| - Early intervention | |
| - Treatments and therapies | |
| - Neurological development | |
| This will allow us to provide accurate information supported by recent scientific research.""" | |
| # Format the response for better readability | |
| formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ") | |
| return formatted_response | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| return "Error: Could not generate response. Please try again with a different question." | |
| # Streamlit App | |
| st.title("🧩 AMA Autism") | |
| st.write(""" | |
| This app searches through scientific papers to answer your questions about autism. | |
| For best results, be specific in your questions. | |
| """) | |
| query = st.text_input("Please ask me anything about autism ✨") | |
| if query: | |
| with st.status("Searching for answers...") as status: | |
| # Search for papers | |
| df = search_research_papers(query) | |
| st.write("Searching for data in PubMed and arXiv...") | |
| st.write(f"Found {len(df)} relevant papers!") | |
| # Get relevant context | |
| context = "\n".join([ | |
| f"{text[:1000]}" for text in df['text'].head(3) | |
| ]) | |
| # Generate answer | |
| st.write("Generating answer...") | |
| answer = generate_answer(query, context) | |
| # Display paper sources | |
| with st.expander("View source papers"): | |
| for _, paper in df.iterrows(): | |
| st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})") | |
| st.success("Answer found!") | |
| st.markdown(answer) |