File size: 3,391 Bytes
f1586e3
5a09d5c
f1586e3
f91cc3b
 
0f8445a
5a09d5c
 
 
 
 
 
f1586e3
 
84e4514
f1586e3
 
f91cc3b
f1586e3
5a09d5c
 
 
 
 
 
 
 
 
 
 
 
f91cc3b
 
 
5a09d5c
 
 
 
13a46cd
f91cc3b
 
 
 
 
99637f2
f1586e3
f91cc3b
5a09d5c
 
 
 
 
 
 
 
f1586e3
5a09d5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1586e3
 
 
0f8445a
5a09d5c
0f8445a
5a09d5c
 
 
0f8445a
5a09d5c
0f8445a
 
 
5a09d5c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import faiss
import os
from datasets import load_from_disk
import torch
import logging
import warnings

# Configure logging
logging.basicConfig(level=logging.WARNING)
warnings.filterwarnings('ignore')

# Title
st.title("🧩 AMA Austim")

# Input: Query
query = st.text_input("Please ask me anything about autism ✨")

@st.cache_resource
def load_rag_components(model_name="facebook/rag-sequence-nq"):
    """Load and cache RAG components to avoid reloading."""
    tokenizer = RagTokenizer.from_pretrained(model_name)
    retriever = RagRetriever.from_pretrained(
        model_name,
        index_name="custom",
        use_dummy_dataset=True  # We'll configure the dataset later
    )
    model = RagSequenceForGeneration.from_pretrained(model_name)
    return tokenizer, retriever, model

# Load or create RAG dataset
def load_rag_dataset(dataset_dir="rag_dataset"):
    if not os.path.exists(dataset_dir):
        with st.spinner("Building initial dataset from autism research papers..."):
            import faiss_index.index as faiss_index_index
            initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
            dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
    
    # Load the dataset and index
    dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
    index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
    
    return dataset, index

# RAG Pipeline
def rag_pipeline(query, dataset, index):
    try:
        # Load cached components
        tokenizer, retriever, model = load_rag_components()
        
        # Configure retriever with our dataset
        retriever.index.dataset = dataset
        retriever.index.index = index
        model.retriever = retriever

        # Generate answer
        inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
        with torch.no_grad():
            generated_ids = model.generate(
                inputs["input_ids"],
                max_length=200,
                min_length=50,
                num_beams=5,
                early_stopping=True
            )
            answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        return answer
    
    except Exception as e:
        st.error(f"An error occurred while processing your query: {str(e)}")
        return None

# Run the app
if query:
    with st.status("Looking for data in the best sources...", expanded=True) as status:
        st.write_stream("Still looking... this may take a while as we look at some prestigious papers...")
        dataset, index = load_rag_dataset()
        st.write_stream("Found the best sources!")
        answer = rag_pipeline(query, dataset, index)
        st.write_stream("Now answering your question...")
        status.update(
            label="Searching complete!", 
            state="complete", 
            expanded=False
        )
    
    if answer:
        st.write("### Answer:")
        st.write_stream(answer)
        st.write("### Retrieved Papers:")
        for i in range(min(5, len(dataset))):
            st.write(f"**Title:** {dataset[i]['title']}")
            st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
            st.write("---")