Spaces:
Sleeping
Sleeping
File size: 2,768 Bytes
f1586e3 0f8445a f1586e3 f91cc3b 0f8445a f1586e3 84e4514 f1586e3 f91cc3b f1586e3 f91cc3b 13a46cd db03170 13a46cd f91cc3b 13a46cd f91cc3b 99637f2 f1586e3 f91cc3b 0f8445a f91cc3b 0f8445a f91cc3b 0f8445a f91cc3b 0f8445a f1586e3 0f8445a f1586e3 99637f2 f1586e3 0f8445a f91cc3b f1586e3 f91cc3b f1586e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss
import os
from datasets import load_from_disk
import torch
# Title
st.title("🧩 AMA Austim")
# Input: Query
query = st.text_input("Please ask me anything about autism ✨")
# Load or create RAG dataset
def load_rag_dataset(dataset_dir="rag_dataset"):
if not os.path.exists(dataset_dir):
# Import the build function from the other file
import faiss_index.index as faiss_index_index
# Fetch some initial papers to build the index
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
# Load the dataset and index
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
return dataset, index
# RAG Pipeline
def rag_pipeline(query, dataset, index):
# Load pre-trained RAG model and configure retriever
model_name = "facebook/rag-sequence-nq"
tokenizer = RagTokenizer.from_pretrained(model_name)
# Configure retriever with correct paths and question encoder
retriever = RagRetriever.from_pretrained(
model_name,
index_name="custom",
passages_path=os.path.join("rag_dataset", "dataset"),
index_path=os.path.join("rag_dataset", "embeddings.faiss"),
use_dummy_dataset=False
)
# Initialize the model with the configured retriever
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
# Generate answer using RAG
inputs = tokenizer(query, return_tensors="pt")
with torch.no_grad():
generated_ids = model.generate(inputs["input_ids"], max_length=200)
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return answer
# Run the app
if query:
with st.status("Looking for data in the best sources...", expanded=True) as status:
st.write("Still looking... this may take a while as we look at some prestigious papers...")
dataset, index = load_rag_dataset()
st.write("Found the best sources!")
status.update(
label="Download complete!",
state="complete",
expanded=False
)
answer = rag_pipeline(query, dataset, index)
st.write("### Answer:")
st.write(answer)
st.write("### Retrieved Papers:")
for i in range(min(5, len(dataset))):
st.write(f"**Title:** {dataset[i]['title']}")
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
st.write("---") |