Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
import logging | |
import os | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
import arxiv | |
import requests | |
import xml.etree.ElementTree as ET | |
import re | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths and constants | |
DATA_DIR = "/data" if os.path.exists("/data") else "." | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
TOKENIZER_MODEL = "google/flan-t5-small" | |
SUMMARIZATION_MODEL= "HuggingFaceTB/SmolVLM-256M-Instruct" | |
# SUMMARIZATION_MODEL="rhaymison/t5-portuguese-small-summarization" | |
def load_local_model(): | |
"""Load the local Hugging Face model""" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL) | |
model = T5ForConditionalGeneration.from_pretrained( | |
SUMMARIZATION_MODEL, | |
device_map={"": "cpu"}, # Force CPU | |
torch_dtype=torch.float32 | |
) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return None, None | |
def clean_text(text): | |
"""Clean and normalize text content""" | |
if not text: | |
return "" | |
# Remove special characters and normalize spaces | |
text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text) | |
text = re.sub(r'\s+', ' ', text) | |
text = text.replace('’', "'").replace('“', '"').replace('â€', '"') | |
# Remove any remaining weird characters | |
text = ''.join(char for char in text if ord(char) < 128) | |
return text.strip() | |
def format_paper(title, abstract): | |
"""Format paper information consistently""" | |
title = clean_text(title) | |
abstract = clean_text(abstract) | |
if len(abstract) > 1000: | |
abstract = abstract[:997] + "..." | |
return f"""Title: {title} | |
Abstract: {abstract} | |
---""" | |
def fetch_arxiv_papers(query, max_results=5): | |
"""Fetch papers from arXiv""" | |
client = arxiv.Client() | |
# Always include autism in the search query | |
search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio" | |
# Search arXiv | |
search = arxiv.Search( | |
query=search_query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
papers = [] | |
for result in client.results(search): | |
# Only include papers that mention autism in title or abstract | |
if ('autism' in result.title.lower() or | |
'asd' in result.title.lower() or | |
'autism' in result.summary.lower() or | |
'asd' in result.summary.lower()): | |
papers.append({ | |
'title': result.title, | |
'abstract': result.summary, | |
'url': result.pdf_url, | |
'published': result.published.strftime("%Y-%m-%d"), | |
'relevance_score': 1 if 'autism' in result.title.lower() else 0.5 | |
}) | |
return papers | |
def fetch_pubmed_papers(query, max_results=5): | |
"""Fetch papers from PubMed""" | |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
# Always include autism in the search term | |
search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])" | |
# Search for papers | |
search_url = f"{base_url}/esearch.fcgi" | |
search_params = { | |
'db': 'pubmed', | |
'term': search_term, | |
'retmax': max_results, | |
'sort': 'relevance', | |
'retmode': 'xml' | |
} | |
papers = [] | |
try: | |
# Get paper IDs | |
response = requests.get(search_url, params=search_params) | |
root = ET.fromstring(response.content) | |
id_list = [id_elem.text for id_elem in root.findall('.//Id')] | |
if not id_list: | |
return papers | |
# Fetch paper details | |
fetch_url = f"{base_url}/efetch.fcgi" | |
fetch_params = { | |
'db': 'pubmed', | |
'id': ','.join(id_list), | |
'retmode': 'xml' | |
} | |
response = requests.get(fetch_url, params=fetch_params) | |
articles = ET.fromstring(response.content) | |
for article in articles.findall('.//PubmedArticle'): | |
title = article.find('.//ArticleTitle') | |
abstract = article.find('.//Abstract/AbstractText') | |
year = article.find('.//PubDate/Year') | |
pmid = article.find('.//PMID') | |
if title is not None and abstract is not None: | |
title_text = title.text.lower() | |
abstract_text = abstract.text.lower() | |
# Only include papers that mention autism | |
if ('autism' in title_text or 'asd' in title_text or | |
'autism' in abstract_text or 'asd' in abstract_text): | |
papers.append({ | |
'title': title.text, | |
'abstract': abstract.text, | |
'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/", | |
'published': year.text if year is not None else 'Unknown', | |
'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5 | |
}) | |
except Exception as e: | |
st.error(f"Error fetching PubMed papers: {str(e)}") | |
return papers | |
def search_research_papers(query): | |
"""Search both arXiv and PubMed for papers""" | |
arxiv_papers = fetch_arxiv_papers(query) | |
pubmed_papers = fetch_pubmed_papers(query) | |
# Combine and format papers | |
all_papers = [] | |
for paper in arxiv_papers + pubmed_papers: | |
if paper['abstract'] and len(paper['abstract'].strip()) > 0: | |
# Clean and format the paper content | |
clean_title = clean_text(paper['title']) | |
clean_abstract = clean_text(paper['abstract']) | |
# Check if the paper is actually about autism | |
if ('autism' in clean_title.lower() or | |
'asd' in clean_title.lower() or | |
'autism' in clean_abstract.lower() or | |
'asd' in clean_abstract.lower()): | |
formatted_text = format_paper(clean_title, clean_abstract) | |
all_papers.append({ | |
'title': clean_title, | |
'text': formatted_text, | |
'url': paper['url'], | |
'published': paper['published'], | |
'relevance_score': paper.get('relevance_score', 0.5) | |
}) | |
# Sort papers by relevance score and convert to DataFrame | |
all_papers.sort(key=lambda x: x['relevance_score'], reverse=True) | |
df = pd.DataFrame(all_papers) | |
if df.empty: | |
st.warning("No autism-related papers found. Please try a different search term.") | |
return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score']) | |
return df | |
def generate_answer(question, context, max_length=512): | |
"""Generate a comprehensive answer using the local model""" | |
model, tokenizer = load_local_model() | |
if model is None or tokenizer is None: | |
return "Error: Could not load the model. Please try again later." | |
# Clean and format the context | |
clean_context = clean_text(context) | |
clean_question = clean_text(question) | |
# Format the input for T5 (it expects a specific format) | |
input_text = f"""Objective: | |
Provide a clear, simple, and well-structured answer about autism that is easy to understand for a general audience. Use the provided research papers as references. | |
Question: {clean_question} | |
Research Papers: | |
{clean_context} | |
Instructions: | |
Start with a simple definition | |
- Explain what autism is in a short and clear way, avoiding technical terms. | |
- Use real-life examples | |
- Give practical and relatable examples to help illustrate key points. | |
- Explain research in simple words | |
- Instead of just citing studies, summarize their key findings in a way that anyone can understand. Example: "A study from X University found that..." | |
- Avoid complex words | |
- If a scientific term is needed, provide a short and simple explanation. | |
- Use clear formatting | |
- Write in short paragraphs, bullet points, or numbered lists to improve readability. | |
- Keep a friendly tone | |
- Make the response engaging and easy to follow, so people without prior knowledge can understand.""" | |
try: | |
# T5 expects a specific format for the input | |
inputs = tokenizer(input_text, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True, | |
padding=True) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
min_length=200, | |
num_beams=3, # Reduzindo para mais variedade | |
length_penalty=1.2, # Melhor equilíbrio entre concisão e detalhes | |
temperature=0.8, # Aumentando um pouco para mais fluidez | |
repetition_penalty=1.2, | |
early_stopping=True, | |
no_repeat_ngram_size=2, # Mantendo variação no texto | |
do_sample=True, | |
top_k=30, # Reduzindo para respostas mais coerentes | |
top_p=0.9 # Equilibrando diversidade e precisão | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = clean_text(response) | |
# If response is too short or empty, provide a general overview | |
if len(response.strip()) < 100: | |
return """Autism Spectrum Disorder (ASD) is a complex neurodevelopmental condition. Unfortunately, the provided papers don't contain specific information about this aspect of autism. | |
To get research-based information, try asking more specific questions about: | |
- Genetics and environmental factors | |
- Early intervention | |
- Treatments and therapies | |
- Neurological development | |
This will allow us to provide accurate information supported by recent scientific research.""" | |
# Format the response for better readability | |
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ") | |
return formatted_response | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return "Error: Could not generate response. Please try again with a different question." | |
# Streamlit App | |
st.title("🧩 AMA Autism") | |
st.write(""" | |
This app searches through scientific papers to answer your questions about autism. | |
For best results, be specific in your questions. | |
""") | |
query = st.text_input("Please ask me anything about autism ✨") | |
if query: | |
with st.status("Searching for answers...") as status: | |
# Search for papers | |
df = search_research_papers(query) | |
st.write("Searching for data in PubMed and arXiv...") | |
st.write(f"Found {len(df)} relevant papers!") | |
# Get relevant context | |
context = "\n".join([ | |
f"{text[:1000]}" for text in df['text'].head(3) | |
]) | |
# Generate answer | |
st.write("Generating answer...") | |
answer = generate_answer(query, context) | |
# Display paper sources | |
with st.expander("View source papers"): | |
for _, paper in df.iterrows(): | |
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})") | |
st.success("Answer found!") | |
st.markdown(answer) |