Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
import faiss | |
import os | |
from datasets import load_from_disk | |
import torch | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Cache models and dataset | |
# Cache models in memory | |
def load_models(): | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
retriever = RagRetriever.from_pretrained( | |
"facebook/rag-sequence-nq", | |
index_name="custom", | |
passages_path="/data/rag_dataset/dataset", | |
index_path="/data/rag_dataset/embeddings.faiss" | |
) | |
model = RagSequenceForGeneration.from_pretrained( | |
"facebook/rag-sequence-nq", | |
retriever=retriever, | |
device_map="auto" | |
) | |
return tokenizer, retriever, model | |
# Cache dataset on disk | |
def load_dataset(): | |
return load_from_disk("/data/rag_dataset/dataset") | |
# RAG Pipeline | |
def rag_pipeline(query, dataset, index): | |
tokenizer, retriever, model = load_models() | |
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_length=200, | |
min_length=50, | |
num_beams=5, | |
early_stopping=True, | |
no_repeat_ngram_size=3 | |
) | |
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
return answer | |
# Streamlit App | |
st.title("🧩 AMA Autism") | |
query = st.text_input("Please ask me anything about autism ✨") | |
if query: | |
with st.status("Searching for answers..."): | |
dataset = load_dataset() | |
answer = rag_pipeline(query, dataset, index=None) | |
if answer: | |
st.success("Answer found!") | |
st.write(answer) | |
else: | |
st.error("Failed to generate an answer.") |