ama-autism / app.py
wakeupmh's picture
fix: add cache
5a09d5c
raw
history blame
3.39 kB
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import faiss
import os
from datasets import load_from_disk
import torch
import logging
import warnings
# Configure logging
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings('ignore')
# Title
st.title("🧩 AMA Austim")
# Input: Query
query = st.text_input("Please ask me anything about autism ✨")
@st.cache_resource
def load_rag_components(model_name="facebook/rag-sequence-nq"):
"""Load and cache RAG components to avoid reloading."""
tokenizer = RagTokenizer.from_pretrained(model_name)
retriever = RagRetriever.from_pretrained(
model_name,
index_name="custom",
use_dummy_dataset=True # We'll configure the dataset later
)
model = RagSequenceForGeneration.from_pretrained(model_name)
return tokenizer, retriever, model
# Load or create RAG dataset
def load_rag_dataset(dataset_dir="rag_dataset"):
if not os.path.exists(dataset_dir):
with st.spinner("Building initial dataset from autism research papers..."):
import faiss_index.index as faiss_index_index
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
# Load the dataset and index
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
return dataset, index
# RAG Pipeline
def rag_pipeline(query, dataset, index):
try:
# Load cached components
tokenizer, retriever, model = load_rag_components()
# Configure retriever with our dataset
retriever.index.dataset = dataset
retriever.index.index = index
model.retriever = retriever
# Generate answer
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
generated_ids = model.generate(
inputs["input_ids"],
max_length=200,
min_length=50,
num_beams=5,
early_stopping=True
)
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return answer
except Exception as e:
st.error(f"An error occurred while processing your query: {str(e)}")
return None
# Run the app
if query:
with st.status("Looking for data in the best sources...", expanded=True) as status:
st.write_stream("Still looking... this may take a while as we look at some prestigious papers...")
dataset, index = load_rag_dataset()
st.write_stream("Found the best sources!")
answer = rag_pipeline(query, dataset, index)
st.write_stream("Now answering your question...")
status.update(
label="Searching complete!",
state="complete",
expanded=False
)
if answer:
st.write("### Answer:")
st.write_stream(answer)
st.write("### Retrieved Papers:")
for i in range(min(5, len(dataset))):
st.write(f"**Title:** {dataset[i]['title']}")
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
st.write("---")