Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
import faiss | |
import os | |
from datasets import load_from_disk | |
import torch | |
import logging | |
import warnings | |
# Configure logging | |
logging.basicConfig(level=logging.WARNING) | |
warnings.filterwarnings('ignore') | |
# Title | |
st.title("🧩 AMA Austim") | |
# Input: Query | |
query = st.text_input("Please ask me anything about autism ✨") | |
def load_rag_components(model_name="facebook/rag-sequence-nq"): | |
"""Load and cache RAG components to avoid reloading.""" | |
tokenizer = RagTokenizer.from_pretrained(model_name) | |
retriever = RagRetriever.from_pretrained( | |
model_name, | |
index_name="custom", | |
use_dummy_dataset=True # We'll configure the dataset later | |
) | |
model = RagSequenceForGeneration.from_pretrained(model_name) | |
return tokenizer, retriever, model | |
# Load or create RAG dataset | |
def load_rag_dataset(dataset_dir="rag_dataset"): | |
if not os.path.exists(dataset_dir): | |
with st.spinner("Building initial dataset from autism research papers..."): | |
import faiss_index.index as faiss_index_index | |
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100) | |
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir) | |
# Load the dataset and index | |
dataset = load_from_disk(os.path.join(dataset_dir, "dataset")) | |
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss")) | |
return dataset, index | |
# RAG Pipeline | |
def rag_pipeline(query, dataset, index): | |
try: | |
# Load cached components | |
tokenizer, retriever, model = load_rag_components() | |
# Configure retriever with our dataset | |
retriever.index.dataset = dataset | |
retriever.index.index = index | |
model.retriever = retriever | |
# Generate answer | |
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
inputs["input_ids"], | |
max_length=200, | |
min_length=50, | |
num_beams=5, | |
early_stopping=True | |
) | |
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return answer | |
except Exception as e: | |
st.error(f"An error occurred while processing your query: {str(e)}") | |
return None | |
# Run the app | |
if query: | |
with st.status("Looking for data in the best sources...", expanded=True) as status: | |
st.write_stream("Still looking... this may take a while as we look at some prestigious papers...") | |
dataset, index = load_rag_dataset() | |
st.write_stream("Found the best sources!") | |
answer = rag_pipeline(query, dataset, index) | |
st.write_stream("Now answering your question...") | |
status.update( | |
label="Searching complete!", | |
state="complete", | |
expanded=False | |
) | |
if answer: | |
st.write("### Answer:") | |
st.write_stream(answer) | |
st.write("### Retrieved Papers:") | |
for i in range(min(5, len(dataset))): | |
st.write(f"**Title:** {dataset[i]['title']}") | |
st.write(f"**Summary:** {dataset[i]['text'][:200]}...") | |
st.write("---") |