wakeupmh commited on
Commit
975c327
·
1 Parent(s): 50dc0c2

fix: write

Browse files
Files changed (2) hide show
  1. app.py +27 -14
  2. faiss_index/index.py +8 -5
app.py CHANGED
@@ -6,6 +6,7 @@ from datasets import load_from_disk
6
  import torch
7
  import logging
8
  import warnings
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.WARNING)
@@ -18,14 +19,23 @@ st.title("🧩 AMA Austim")
18
  query = st.text_input("Please ask me anything about autism ✨")
19
 
20
  @st.cache_resource
21
- def load_rag_components(model_name="facebook/rag-sequence-nq"):
22
  """Load and cache RAG components to avoid reloading."""
 
23
  tokenizer = RagTokenizer.from_pretrained(model_name)
24
- retriever = RagRetriever.from_pretrained(
25
- model_name,
26
- index_name="custom",
27
- use_dummy_dataset=True # We'll configure the dataset later
28
- )
 
 
 
 
 
 
 
 
29
  model = RagSequenceForGeneration.from_pretrained(model_name)
30
  return tokenizer, retriever, model
31
 
@@ -37,17 +47,20 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
37
  initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
38
  dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
39
 
 
 
 
40
  # Load the dataset and index
41
- dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
42
- index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
43
 
44
- return dataset, index
45
 
46
  # RAG Pipeline
47
- def rag_pipeline(query, dataset, index):
48
  try:
49
- # Load cached components
50
- tokenizer, retriever, model = load_rag_components()
51
 
52
  # Configure retriever with our dataset
53
  retriever.index.dataset = dataset
@@ -76,9 +89,9 @@ def rag_pipeline(query, dataset, index):
76
  if query:
77
  with st.status("Looking for data in the best sources...", expanded=True) as status:
78
  st.write("Still looking... this may take a while as we look at some prestigious papers...")
79
- dataset, index = load_rag_dataset()
80
  st.write("Found the best sources!")
81
- answer = rag_pipeline(query, dataset, index)
82
  st.write("Now answering your question...")
83
  status.update(
84
  label="Searching complete!",
 
6
  import torch
7
  import logging
8
  import warnings
9
+ from pathlib import Path
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.WARNING)
 
19
  query = st.text_input("Please ask me anything about autism ✨")
20
 
21
  @st.cache_resource
22
+ def load_rag_components(_dataset_path=None, _index_path=None):
23
  """Load and cache RAG components to avoid reloading."""
24
+ model_name = "facebook/rag-sequence-nq"
25
  tokenizer = RagTokenizer.from_pretrained(model_name)
26
+
27
+ retriever_config = {
28
+ "index_name": "custom",
29
+ "use_dummy_dataset": True
30
+ }
31
+
32
+ if _dataset_path and _index_path:
33
+ retriever_config.update({
34
+ "passages_path": _dataset_path,
35
+ "index_path": _index_path
36
+ })
37
+
38
+ retriever = RagRetriever.from_pretrained(model_name, **retriever_config)
39
  model = RagSequenceForGeneration.from_pretrained(model_name)
40
  return tokenizer, retriever, model
41
 
 
47
  initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
48
  dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
49
 
50
+ dataset_path = os.path.join(dataset_dir, "dataset")
51
+ index_path = os.path.join(dataset_dir, "embeddings.faiss")
52
+
53
  # Load the dataset and index
54
+ dataset = load_from_disk(dataset_path)
55
+ index = faiss.read_index(index_path)
56
 
57
+ return dataset, index, dataset_path, index_path
58
 
59
  # RAG Pipeline
60
+ def rag_pipeline(query, dataset, index, dataset_path, index_path):
61
  try:
62
+ # Load cached components with paths
63
+ tokenizer, retriever, model = load_rag_components(dataset_path, index_path)
64
 
65
  # Configure retriever with our dataset
66
  retriever.index.dataset = dataset
 
89
  if query:
90
  with st.status("Looking for data in the best sources...", expanded=True) as status:
91
  st.write("Still looking... this may take a while as we look at some prestigious papers...")
92
+ dataset, index, dataset_path, index_path = load_rag_dataset()
93
  st.write("Found the best sources!")
94
+ answer = rag_pipeline(query, dataset, index, dataset_path, index_path)
95
  st.write("Now answering your question...")
96
  status.update(
97
  label="Searching complete!",
faiss_index/index.py CHANGED
@@ -20,11 +20,12 @@ def fetch_arxiv_papers(query, max_results=10):
20
 
21
  # Build and save dataset with FAISS index
22
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
23
- # Create dataset
24
  dataset = Dataset.from_dict({
25
  "id": [p["id"] for p in papers],
 
26
  "title": [p["title"] for p in papers],
27
- "text": [p["text"] for p in papers],
28
  })
29
 
30
  # Initialize DPR context encoder (same as used by RAG)
@@ -45,9 +46,6 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
45
 
46
  embeddings = np.vstack(embeddings)
47
 
48
- # Add embeddings to dataset
49
- dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
50
-
51
  # Create FAISS index
52
  dimension = embeddings.shape[1] # Should be 768 for DPR
53
  index = faiss.IndexFlatL2(dimension)
@@ -55,7 +53,12 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
55
 
56
  # Save dataset and index
57
  os.makedirs(dataset_dir, exist_ok=True)
 
 
 
58
  dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
 
 
59
  faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
60
 
61
  return dataset_dir
 
20
 
21
  # Build and save dataset with FAISS index
22
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
23
+ # Create dataset with required columns for RAG
24
  dataset = Dataset.from_dict({
25
  "id": [p["id"] for p in papers],
26
+ "text": [p["text"] for p in papers], # RAG expects 'text' field
27
  "title": [p["title"] for p in papers],
28
+ "embeddings": None, # Will be filled later
29
  })
30
 
31
  # Initialize DPR context encoder (same as used by RAG)
 
46
 
47
  embeddings = np.vstack(embeddings)
48
 
 
 
 
49
  # Create FAISS index
50
  dimension = embeddings.shape[1] # Should be 768 for DPR
51
  index = faiss.IndexFlatL2(dimension)
 
53
 
54
  # Save dataset and index
55
  os.makedirs(dataset_dir, exist_ok=True)
56
+
57
+ # Save dataset with embeddings
58
+ dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
59
  dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
60
+
61
+ # Save FAISS index
62
  faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
63
 
64
  return dataset_dir