Spaces:
Sleeping
Sleeping
File size: 3,391 Bytes
f1586e3 5a09d5c f1586e3 f91cc3b 0f8445a 5a09d5c f1586e3 84e4514 f1586e3 f91cc3b f1586e3 5a09d5c f91cc3b 5a09d5c 13a46cd f91cc3b 99637f2 f1586e3 f91cc3b 5a09d5c f1586e3 5a09d5c f1586e3 0f8445a 5a09d5c 0f8445a 5a09d5c 0f8445a 5a09d5c 0f8445a 5a09d5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import faiss
import os
from datasets import load_from_disk
import torch
import logging
import warnings
# Configure logging
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings('ignore')
# Title
st.title("🧩 AMA Austim")
# Input: Query
query = st.text_input("Please ask me anything about autism ✨")
@st.cache_resource
def load_rag_components(model_name="facebook/rag-sequence-nq"):
"""Load and cache RAG components to avoid reloading."""
tokenizer = RagTokenizer.from_pretrained(model_name)
retriever = RagRetriever.from_pretrained(
model_name,
index_name="custom",
use_dummy_dataset=True # We'll configure the dataset later
)
model = RagSequenceForGeneration.from_pretrained(model_name)
return tokenizer, retriever, model
# Load or create RAG dataset
def load_rag_dataset(dataset_dir="rag_dataset"):
if not os.path.exists(dataset_dir):
with st.spinner("Building initial dataset from autism research papers..."):
import faiss_index.index as faiss_index_index
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
# Load the dataset and index
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
return dataset, index
# RAG Pipeline
def rag_pipeline(query, dataset, index):
try:
# Load cached components
tokenizer, retriever, model = load_rag_components()
# Configure retriever with our dataset
retriever.index.dataset = dataset
retriever.index.index = index
model.retriever = retriever
# Generate answer
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
generated_ids = model.generate(
inputs["input_ids"],
max_length=200,
min_length=50,
num_beams=5,
early_stopping=True
)
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return answer
except Exception as e:
st.error(f"An error occurred while processing your query: {str(e)}")
return None
# Run the app
if query:
with st.status("Looking for data in the best sources...", expanded=True) as status:
st.write_stream("Still looking... this may take a while as we look at some prestigious papers...")
dataset, index = load_rag_dataset()
st.write_stream("Found the best sources!")
answer = rag_pipeline(query, dataset, index)
st.write_stream("Now answering your question...")
status.update(
label="Searching complete!",
state="complete",
expanded=False
)
if answer:
st.write("### Answer:")
st.write_stream(answer)
st.write("### Retrieved Papers:")
for i in range(min(5, len(dataset))):
st.write(f"**Title:** {dataset[i]['title']}")
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
st.write("---") |