Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import os | |
from datasets import load_from_disk, Dataset | |
import torch | |
import logging | |
import pandas as pd | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths | |
DATA_DIR = "/data" if os.path.exists("/data") else "." | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
# Cache models and dataset | |
def load_models(): | |
model_name = "google/flan-t5-small" # Lighter model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
device_map='auto', | |
max_memory={'cpu': '1GB'} | |
) | |
return tokenizer, model | |
# Cache for 1 hour | |
def load_dataset(query): | |
# Always fetch fresh results for the specific query | |
with st.spinner("Searching research papers from arXiv and PubMed..."): | |
import faiss_index.index as idx | |
# Ensure both autism and the query terms are included | |
if 'autism' not in query.lower(): | |
search_query = f"autism {query}" | |
else: | |
search_query = query | |
papers = idx.fetch_papers(search_query, max_results=25) # This now fetches from both sources | |
if not papers: | |
st.warning("No relevant papers found. Please try rephrasing your question.") | |
return pd.DataFrame(columns=['title', 'text', 'url', 'published']) | |
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR) | |
# Load and convert to pandas for easier handling | |
dataset = load_from_disk(DATASET_PATH) | |
df = pd.DataFrame({ | |
'title': dataset['title'], | |
'text': dataset['text'], | |
'url': [p['url'] for p in papers], | |
'published': [p['published'] for p in papers] | |
}) | |
return df | |
def generate_answer(question, context, max_length=300): | |
tokenizer, model = load_models() | |
# Enhanced prompt for more detailed and structured answers | |
prompt = f"""Based on scientific research about autism, provide a comprehensive and structured summary answering the following question. | |
Include the following aspects when relevant: | |
1. Main findings and conclusions | |
2. Supporting evidence or research methods | |
3. Clinical implications or practical applications | |
4. Any limitations or areas needing further research | |
Use clear headings and bullet points when appropriate to organize the information. | |
If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.' | |
Question: {question} | |
Context: {context} | |
Detailed summary:""" | |
# Optimize input processing | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
num_beams=4, | |
temperature=0.8, | |
top_p=0.9, | |
repetition_penalty=1.3, | |
length_penalty=1.2, | |
early_stopping=True | |
) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clear GPU memory if possible | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Enhanced answer validation and formatting | |
if not answer or answer.isspace() or "cannot find" in answer.lower(): | |
return "I cannot find specific information about this topic in the autism research papers." | |
# Format the answer with proper line breaks and structure | |
formatted_answer = answer.replace(". ", ".\n").replace("• ", "\n• ") | |
return formatted_answer | |
# Streamlit App | |
st.title("🧩 AMA Autism") | |
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.") | |
query = st.text_input("Please ask me anything about autism ✨") | |
if query: | |
with st.status("Searching for answers...") as status: | |
# Load dataset | |
df = load_dataset(query) | |
st.write("Searching for data in PubMed and arXiv...") | |
# Get relevant context | |
context = "\n".join([ | |
f"{text[:1000]}" for text in df['text'].head(3) | |
]) | |
st.write("Data found!") | |
# Generate answer | |
answer = generate_answer(query, context) | |
st.write("Generating answer...") | |
status.update( | |
label="Search complete!", state="complete", expanded=False | |
) | |
if answer and not answer.isspace(): | |
st.success("Answer found!") | |
st.write(answer) | |
st.write("### Sources used:") | |
for _, row in df.head(3).iterrows(): | |
st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})") | |
st.write(f"**Summary:** {row['text'][:200]}...") | |
st.write("---") | |
else: | |
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.") | |
if df.empty: | |
st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.") |