ama-autism / app.py
wakeupmh's picture
refactor: improve prompt
cb9a068
raw
history blame
5.37 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
# Cache models and dataset
@st.cache_resource
def load_models():
model_name = "google/flan-t5-small" # Lighter model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map='auto',
max_memory={'cpu': '1GB'}
)
return tokenizer, model
@st.cache_data(ttl=3600) # Cache for 1 hour
def load_dataset(query):
# Always fetch fresh results for the specific query
with st.spinner("Searching research papers from arXiv and PubMed..."):
import faiss_index.index as idx
# Ensure both autism and the query terms are included
if 'autism' not in query.lower():
search_query = f"autism {query}"
else:
search_query = query
papers = idx.fetch_papers(search_query, max_results=25) # This now fetches from both sources
if not papers:
st.warning("No relevant papers found. Please try rephrasing your question.")
return pd.DataFrame(columns=['title', 'text', 'url', 'published'])
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
# Load and convert to pandas for easier handling
dataset = load_from_disk(DATASET_PATH)
df = pd.DataFrame({
'title': dataset['title'],
'text': dataset['text'],
'url': [p['url'] for p in papers],
'published': [p['published'] for p in papers]
})
return df
def generate_answer(question, context, max_length=300):
tokenizer, model = load_models()
# Enhanced prompt for more detailed and structured answers
prompt = f"""Based on scientific research about autism, provide a comprehensive and structured summary answering the following question.
Include the following aspects when relevant:
1. Main findings and conclusions
2. Supporting evidence or research methods
3. Clinical implications or practical applications
4. Any limitations or areas needing further research
Use clear headings and bullet points when appropriate to organize the information.
If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.'
Question: {question}
Context: {context}
Detailed summary:"""
# Optimize input processing
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=4,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.3,
length_penalty=1.2,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clear GPU memory if possible
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Enhanced answer validation and formatting
if not answer or answer.isspace() or "cannot find" in answer.lower():
return "I cannot find specific information about this topic in the autism research papers."
# Format the answer with proper line breaks and structure
formatted_answer = answer.replace(". ", ".\n").replace("• ", "\n• ")
return formatted_answer
# Streamlit App
st.title("🧩 AMA Autism")
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers...") as status:
# Load dataset
df = load_dataset(query)
st.write("Searching for data in PubMed and arXiv...")
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
st.write("Data found!")
# Generate answer
answer = generate_answer(query, context)
st.write("Generating answer...")
status.update(
label="Search complete!", state="complete", expanded=False
)
if answer and not answer.isspace():
st.success("Answer found!")
st.write(answer)
st.write("### Sources used:")
for _, row in df.head(3).iterrows():
st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
st.write(f"**Summary:** {row['text'][:200]}...")
st.write("---")
else:
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
if df.empty:
st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")