ama-autism / app.py
wakeupmh's picture
fix: emoji
84e4514
raw
history blame
2.77 kB
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss
import os
from datasets import load_from_disk
import torch
# Title
st.title("🧩 AMA Austim")
# Input: Query
query = st.text_input("Please ask me anything about autism ✨")
# Load or create RAG dataset
def load_rag_dataset(dataset_dir="rag_dataset"):
if not os.path.exists(dataset_dir):
# Import the build function from the other file
import faiss_index.index as faiss_index_index
# Fetch some initial papers to build the index
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
# Load the dataset and index
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
return dataset, index
# RAG Pipeline
def rag_pipeline(query, dataset, index):
# Load pre-trained RAG model and configure retriever
model_name = "facebook/rag-sequence-nq"
tokenizer = RagTokenizer.from_pretrained(model_name)
# Configure retriever with correct paths and question encoder
retriever = RagRetriever.from_pretrained(
model_name,
index_name="custom",
passages_path=os.path.join("rag_dataset", "dataset"),
index_path=os.path.join("rag_dataset", "embeddings.faiss"),
use_dummy_dataset=False
)
# Initialize the model with the configured retriever
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
# Generate answer using RAG
inputs = tokenizer(query, return_tensors="pt")
with torch.no_grad():
generated_ids = model.generate(inputs["input_ids"], max_length=200)
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return answer
# Run the app
if query:
with st.status("Looking for data in the best sources...", expanded=True) as status:
st.write("Still looking... this may take a while as we look at some prestigious papers...")
dataset, index = load_rag_dataset()
st.write("Found the best sources!")
status.update(
label="Download complete!",
state="complete",
expanded=False
)
answer = rag_pipeline(query, dataset, index)
st.write("### Answer:")
st.write(answer)
st.write("### Retrieved Papers:")
for i in range(min(5, len(dataset))):
st.write(f"**Title:** {dataset[i]['title']}")
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
st.write("---")