ama-autism / app.py
wakeupmh's picture
fix: rag
f68ac31
raw
history blame
1.9 kB
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import faiss
import os
from datasets import load_from_disk
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
# Cache models and dataset
@st.cache_resource # Cache models in memory
def load_models():
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="custom",
passages_path="/data/rag_dataset/dataset",
index_path="/data/rag_dataset/embeddings.faiss"
)
model = RagSequenceForGeneration.from_pretrained(
"facebook/rag-sequence-nq",
retriever=retriever,
device_map="auto"
)
return tokenizer, retriever, model
@st.cache_data # Cache dataset on disk
def load_dataset():
return load_from_disk("/data/rag_dataset/dataset")
# RAG Pipeline
def rag_pipeline(query, dataset, index):
tokenizer, retriever, model = load_models()
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=200,
min_length=50,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=3
)
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return answer
# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers..."):
dataset = load_dataset()
answer = rag_pipeline(query, dataset, index=None)
if answer:
st.success("Answer found!")
st.write(answer)
else:
st.error("Failed to generate an answer.")