Spaces:
Sleeping
Sleeping
fix: add cache
Browse files
app.py
CHANGED
@@ -1,9 +1,15 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
import faiss
|
4 |
import os
|
5 |
from datasets import load_from_disk
|
6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Title
|
9 |
st.title("🧩 AMA Austim")
|
@@ -11,15 +17,25 @@ st.title("🧩 AMA Austim")
|
|
11 |
# Input: Query
|
12 |
query = st.text_input("Please ask me anything about autism ✨")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Load or create RAG dataset
|
15 |
def load_rag_dataset(dataset_dir="rag_dataset"):
|
16 |
if not os.path.exists(dataset_dir):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
22 |
-
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
23 |
|
24 |
# Load the dataset and index
|
25 |
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
|
@@ -29,46 +45,52 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
|
|
29 |
|
30 |
# RAG Pipeline
|
31 |
def rag_pipeline(query, dataset, index):
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
passages_path=os.path.join("rag_dataset", "dataset"),
|
41 |
-
index_path=os.path.join("rag_dataset", "embeddings.faiss"),
|
42 |
-
use_dummy_dataset=False
|
43 |
-
)
|
44 |
-
|
45 |
-
# Initialize the model with the configured retriever
|
46 |
-
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
|
47 |
-
|
48 |
-
# Generate answer using RAG
|
49 |
-
inputs = tokenizer(query, return_tensors="pt")
|
50 |
-
with torch.no_grad():
|
51 |
-
generated_ids = model.generate(inputs["input_ids"], max_length=200)
|
52 |
-
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
53 |
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# Run the app
|
57 |
if query:
|
58 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
59 |
-
st.
|
60 |
dataset, index = load_rag_dataset()
|
61 |
-
st.
|
|
|
|
|
62 |
status.update(
|
63 |
-
label="
|
64 |
state="complete",
|
65 |
expanded=False
|
66 |
)
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
import faiss
|
4 |
import os
|
5 |
from datasets import load_from_disk
|
6 |
import torch
|
7 |
+
import logging
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logging.basicConfig(level=logging.WARNING)
|
12 |
+
warnings.filterwarnings('ignore')
|
13 |
|
14 |
# Title
|
15 |
st.title("🧩 AMA Austim")
|
|
|
17 |
# Input: Query
|
18 |
query = st.text_input("Please ask me anything about autism ✨")
|
19 |
|
20 |
+
@st.cache_resource
|
21 |
+
def load_rag_components(model_name="facebook/rag-sequence-nq"):
|
22 |
+
"""Load and cache RAG components to avoid reloading."""
|
23 |
+
tokenizer = RagTokenizer.from_pretrained(model_name)
|
24 |
+
retriever = RagRetriever.from_pretrained(
|
25 |
+
model_name,
|
26 |
+
index_name="custom",
|
27 |
+
use_dummy_dataset=True # We'll configure the dataset later
|
28 |
+
)
|
29 |
+
model = RagSequenceForGeneration.from_pretrained(model_name)
|
30 |
+
return tokenizer, retriever, model
|
31 |
+
|
32 |
# Load or create RAG dataset
|
33 |
def load_rag_dataset(dataset_dir="rag_dataset"):
|
34 |
if not os.path.exists(dataset_dir):
|
35 |
+
with st.spinner("Building initial dataset from autism research papers..."):
|
36 |
+
import faiss_index.index as faiss_index_index
|
37 |
+
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
38 |
+
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
|
|
|
|
39 |
|
40 |
# Load the dataset and index
|
41 |
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
|
|
|
45 |
|
46 |
# RAG Pipeline
|
47 |
def rag_pipeline(query, dataset, index):
|
48 |
+
try:
|
49 |
+
# Load cached components
|
50 |
+
tokenizer, retriever, model = load_rag_components()
|
51 |
+
|
52 |
+
# Configure retriever with our dataset
|
53 |
+
retriever.index.dataset = dataset
|
54 |
+
retriever.index.index = index
|
55 |
+
model.retriever = retriever
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
# Generate answer
|
58 |
+
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
59 |
+
with torch.no_grad():
|
60 |
+
generated_ids = model.generate(
|
61 |
+
inputs["input_ids"],
|
62 |
+
max_length=200,
|
63 |
+
min_length=50,
|
64 |
+
num_beams=5,
|
65 |
+
early_stopping=True
|
66 |
+
)
|
67 |
+
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
68 |
+
|
69 |
+
return answer
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
st.error(f"An error occurred while processing your query: {str(e)}")
|
73 |
+
return None
|
74 |
|
75 |
# Run the app
|
76 |
if query:
|
77 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
78 |
+
st.write_stream("Still looking... this may take a while as we look at some prestigious papers...")
|
79 |
dataset, index = load_rag_dataset()
|
80 |
+
st.write_stream("Found the best sources!")
|
81 |
+
answer = rag_pipeline(query, dataset, index)
|
82 |
+
st.write_stream("Now answering your question...")
|
83 |
status.update(
|
84 |
+
label="Searching complete!",
|
85 |
state="complete",
|
86 |
expanded=False
|
87 |
)
|
88 |
+
|
89 |
+
if answer:
|
90 |
+
st.write("### Answer:")
|
91 |
+
st.write_stream(answer)
|
92 |
+
st.write("### Retrieved Papers:")
|
93 |
+
for i in range(min(5, len(dataset))):
|
94 |
+
st.write(f"**Title:** {dataset[i]['title']}")
|
95 |
+
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
|
96 |
+
st.write("---")
|