File size: 2,768 Bytes
f1586e3
0f8445a
f1586e3
f91cc3b
 
0f8445a
f1586e3
 
84e4514
f1586e3
 
f91cc3b
f1586e3
f91cc3b
 
 
13a46cd
db03170
13a46cd
 
 
f91cc3b
13a46cd
f91cc3b
 
 
 
 
99637f2
f1586e3
f91cc3b
 
0f8445a
 
 
 
f91cc3b
0f8445a
f91cc3b
 
0f8445a
 
f91cc3b
0f8445a
 
 
f1586e3
 
 
0f8445a
 
 
f1586e3
99637f2
f1586e3
 
 
0f8445a
 
 
 
 
 
 
 
 
f91cc3b
f1586e3
 
f91cc3b
 
 
 
f1586e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss
import os
from datasets import load_from_disk
import torch

# Title
st.title("🧩 AMA Austim")

# Input: Query
query = st.text_input("Please ask me anything about autism ✨")

# Load or create RAG dataset
def load_rag_dataset(dataset_dir="rag_dataset"):
    if not os.path.exists(dataset_dir):
        # Import the build function from the other file
        import faiss_index.index as faiss_index_index
        
        # Fetch some initial papers to build the index
        initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
        dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
    
    # Load the dataset and index
    dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
    index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
    
    return dataset, index

# RAG Pipeline
def rag_pipeline(query, dataset, index):
    # Load pre-trained RAG model and configure retriever
    model_name = "facebook/rag-sequence-nq"
    tokenizer = RagTokenizer.from_pretrained(model_name)
    
    # Configure retriever with correct paths and question encoder
    retriever = RagRetriever.from_pretrained(
        model_name,
        index_name="custom",
        passages_path=os.path.join("rag_dataset", "dataset"),
        index_path=os.path.join("rag_dataset", "embeddings.faiss"),
        use_dummy_dataset=False
    )
    
    # Initialize the model with the configured retriever
    model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)

    # Generate answer using RAG
    inputs = tokenizer(query, return_tensors="pt")
    with torch.no_grad():
        generated_ids = model.generate(inputs["input_ids"], max_length=200)
        answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return answer

# Run the app
if query:
    with st.status("Looking for data in the best sources...", expanded=True) as status:
        st.write("Still looking... this may take a while as we look at some prestigious papers...")
        dataset, index = load_rag_dataset()
        st.write("Found the best sources!")
        status.update(
            label="Download complete!", 
            state="complete", 
            expanded=False
        )
    answer = rag_pipeline(query, dataset, index)
    st.write("### Answer:")
    st.write(answer)
    st.write("### Retrieved Papers:")
    for i in range(min(5, len(dataset))):
        st.write(f"**Title:** {dataset[i]['title']}")
        st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
        st.write("---")