Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoder, DPRQuestionEncoderTokenizer | |
import faiss | |
import os | |
from datasets import load_from_disk | |
import torch | |
# Title | |
st.title("🧩 AMA Austim") | |
# Input: Query | |
query = st.text_input("Please ask me anything about autism ✨") | |
# Load or create RAG dataset | |
def load_rag_dataset(dataset_dir="rag_dataset"): | |
if not os.path.exists(dataset_dir): | |
# Import the build function from the other file | |
import faiss_index.index as faiss_index_index | |
# Fetch some initial papers to build the index | |
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100) | |
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir) | |
# Load the dataset and index | |
dataset = load_from_disk(os.path.join(dataset_dir, "dataset")) | |
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss")) | |
return dataset, index | |
# RAG Pipeline | |
def rag_pipeline(query, dataset, index): | |
# Load pre-trained RAG model and configure retriever | |
model_name = "facebook/rag-sequence-nq" | |
tokenizer = RagTokenizer.from_pretrained(model_name) | |
# Configure retriever with correct paths and question encoder | |
retriever = RagRetriever.from_pretrained( | |
model_name, | |
index_name="custom", | |
passages_path=os.path.join("rag_dataset", "dataset"), | |
index_path=os.path.join("rag_dataset", "embeddings.faiss"), | |
use_dummy_dataset=False | |
) | |
# Initialize the model with the configured retriever | |
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever) | |
# Generate answer using RAG | |
inputs = tokenizer(query, return_tensors="pt") | |
with torch.no_grad(): | |
generated_ids = model.generate(inputs["input_ids"], max_length=200) | |
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return answer | |
# Run the app | |
if query: | |
with st.status("Looking for data in the best sources...", expanded=True) as status: | |
st.write("Still looking... this may take a while as we look at some prestigious papers...") | |
dataset, index = load_rag_dataset() | |
st.write("Found the best sources!") | |
status.update( | |
label="Download complete!", | |
state="complete", | |
expanded=False | |
) | |
answer = rag_pipeline(query, dataset, index) | |
st.write("### Answer:") | |
st.write(answer) | |
st.write("### Retrieved Papers:") | |
for i in range(min(5, len(dataset))): | |
st.write(f"**Title:** {dataset[i]['title']}") | |
st.write(f"**Summary:** {dataset[i]['text'][:200]}...") | |
st.write("---") |