wakeupmh commited on
Commit
5a09d5c
·
1 Parent(s): 84e4514

fix: add cache

Browse files
Files changed (1) hide show
  1. app.py +62 -40
app.py CHANGED
@@ -1,9 +1,15 @@
1
  import streamlit as st
2
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
3
  import faiss
4
  import os
5
  from datasets import load_from_disk
6
  import torch
 
 
 
 
 
 
7
 
8
  # Title
9
  st.title("🧩 AMA Austim")
@@ -11,15 +17,25 @@ st.title("🧩 AMA Austim")
11
  # Input: Query
12
  query = st.text_input("Please ask me anything about autism ✨")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Load or create RAG dataset
15
  def load_rag_dataset(dataset_dir="rag_dataset"):
16
  if not os.path.exists(dataset_dir):
17
- # Import the build function from the other file
18
- import faiss_index.index as faiss_index_index
19
-
20
- # Fetch some initial papers to build the index
21
- initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
22
- dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
23
 
24
  # Load the dataset and index
25
  dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
@@ -29,46 +45,52 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
29
 
30
  # RAG Pipeline
31
  def rag_pipeline(query, dataset, index):
32
- # Load pre-trained RAG model and configure retriever
33
- model_name = "facebook/rag-sequence-nq"
34
- tokenizer = RagTokenizer.from_pretrained(model_name)
35
-
36
- # Configure retriever with correct paths and question encoder
37
- retriever = RagRetriever.from_pretrained(
38
- model_name,
39
- index_name="custom",
40
- passages_path=os.path.join("rag_dataset", "dataset"),
41
- index_path=os.path.join("rag_dataset", "embeddings.faiss"),
42
- use_dummy_dataset=False
43
- )
44
-
45
- # Initialize the model with the configured retriever
46
- model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
47
-
48
- # Generate answer using RAG
49
- inputs = tokenizer(query, return_tensors="pt")
50
- with torch.no_grad():
51
- generated_ids = model.generate(inputs["input_ids"], max_length=200)
52
- answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
 
54
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Run the app
57
  if query:
58
  with st.status("Looking for data in the best sources...", expanded=True) as status:
59
- st.write("Still looking... this may take a while as we look at some prestigious papers...")
60
  dataset, index = load_rag_dataset()
61
- st.write("Found the best sources!")
 
 
62
  status.update(
63
- label="Download complete!",
64
  state="complete",
65
  expanded=False
66
  )
67
- answer = rag_pipeline(query, dataset, index)
68
- st.write("### Answer:")
69
- st.write(answer)
70
- st.write("### Retrieved Papers:")
71
- for i in range(min(5, len(dataset))):
72
- st.write(f"**Title:** {dataset[i]['title']}")
73
- st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
74
- st.write("---")
 
 
1
  import streamlit as st
2
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import faiss
4
  import os
5
  from datasets import load_from_disk
6
  import torch
7
+ import logging
8
+ import warnings
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.WARNING)
12
+ warnings.filterwarnings('ignore')
13
 
14
  # Title
15
  st.title("🧩 AMA Austim")
 
17
  # Input: Query
18
  query = st.text_input("Please ask me anything about autism ✨")
19
 
20
+ @st.cache_resource
21
+ def load_rag_components(model_name="facebook/rag-sequence-nq"):
22
+ """Load and cache RAG components to avoid reloading."""
23
+ tokenizer = RagTokenizer.from_pretrained(model_name)
24
+ retriever = RagRetriever.from_pretrained(
25
+ model_name,
26
+ index_name="custom",
27
+ use_dummy_dataset=True # We'll configure the dataset later
28
+ )
29
+ model = RagSequenceForGeneration.from_pretrained(model_name)
30
+ return tokenizer, retriever, model
31
+
32
  # Load or create RAG dataset
33
  def load_rag_dataset(dataset_dir="rag_dataset"):
34
  if not os.path.exists(dataset_dir):
35
+ with st.spinner("Building initial dataset from autism research papers..."):
36
+ import faiss_index.index as faiss_index_index
37
+ initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
38
+ dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
 
 
39
 
40
  # Load the dataset and index
41
  dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
 
45
 
46
  # RAG Pipeline
47
  def rag_pipeline(query, dataset, index):
48
+ try:
49
+ # Load cached components
50
+ tokenizer, retriever, model = load_rag_components()
51
+
52
+ # Configure retriever with our dataset
53
+ retriever.index.dataset = dataset
54
+ retriever.index.index = index
55
+ model.retriever = retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Generate answer
58
+ inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
59
+ with torch.no_grad():
60
+ generated_ids = model.generate(
61
+ inputs["input_ids"],
62
+ max_length=200,
63
+ min_length=50,
64
+ num_beams=5,
65
+ early_stopping=True
66
+ )
67
+ answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
68
+
69
+ return answer
70
+
71
+ except Exception as e:
72
+ st.error(f"An error occurred while processing your query: {str(e)}")
73
+ return None
74
 
75
  # Run the app
76
  if query:
77
  with st.status("Looking for data in the best sources...", expanded=True) as status:
78
+ st.write_stream("Still looking... this may take a while as we look at some prestigious papers...")
79
  dataset, index = load_rag_dataset()
80
+ st.write_stream("Found the best sources!")
81
+ answer = rag_pipeline(query, dataset, index)
82
+ st.write_stream("Now answering your question...")
83
  status.update(
84
+ label="Searching complete!",
85
  state="complete",
86
  expanded=False
87
  )
88
+
89
+ if answer:
90
+ st.write("### Answer:")
91
+ st.write_stream(answer)
92
+ st.write("### Retrieved Papers:")
93
+ for i in range(min(5, len(dataset))):
94
+ st.write(f"**Title:** {dataset[i]['title']}")
95
+ st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
96
+ st.write("---")