File size: 1,895 Bytes
f1586e3
5a09d5c
f1586e3
f91cc3b
f68ac31
0f8445a
5a09d5c
 
 
8108db5
f1586e3
f68ac31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99637f2
f1586e3
f68ac31
 
 
 
 
 
 
 
 
 
 
8108db5
f68ac31
 
f1586e3
f68ac31
 
 
f1586e3
 
f68ac31
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import faiss
import os
from datasets import load_from_disk
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Cache models and dataset
@st.cache_resource  # Cache models in memory
def load_models():
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
    retriever = RagRetriever.from_pretrained(
        "facebook/rag-sequence-nq",
        index_name="custom",
        passages_path="/data/rag_dataset/dataset",
        index_path="/data/rag_dataset/embeddings.faiss"
    )
    model = RagSequenceForGeneration.from_pretrained(
        "facebook/rag-sequence-nq",
        retriever=retriever,
        device_map="auto"
    )
    return tokenizer, retriever, model

@st.cache_data  # Cache dataset on disk
def load_dataset():
    return load_from_disk("/data/rag_dataset/dataset")

# RAG Pipeline
def rag_pipeline(query, dataset, index):
    tokenizer, retriever, model = load_models()
    inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            max_length=200,
            min_length=50,
            num_beams=5,
            early_stopping=True,
            no_repeat_ngram_size=3
        )
        answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    return answer

# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers..."):
        dataset = load_dataset()
        answer = rag_pipeline(query, dataset, index=None)
        if answer:
            st.success("Answer found!")
            st.write(answer)
        else:
            st.error("Failed to generate an answer.")