wakeupmh commited on
Commit
a226fb9
·
1 Parent(s): 8108db5

fix: hf format

Browse files
Files changed (2) hide show
  1. app.py +17 -14
  2. faiss_index/index.py +29 -13
app.py CHANGED
@@ -38,29 +38,32 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
38
  dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
39
 
40
  # Load the dataset and index
41
- dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
42
- index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
43
- return dataset, index
 
 
 
 
 
 
44
  except Exception as e:
45
  st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
46
- return None, None
47
 
48
  # RAG Pipeline
49
- def rag_pipeline(query, dataset, index):
50
  try:
51
  # Initialize components
52
  model_name = "facebook/rag-sequence-nq"
53
  tokenizer = RagTokenizer.from_pretrained(model_name)
54
 
55
- # Convert dataset to passages format
56
- passages = prepare_rag_passages(dataset)
57
-
58
- # Initialize retriever with passages
59
  retriever = RagRetriever.from_pretrained(
60
  model_name,
61
  index_name="custom",
62
- passages=passages,
63
- index=index
64
  )
65
 
66
  # Initialize model with retriever
@@ -93,15 +96,15 @@ if query:
93
  with st.status("Looking for data in the best sources...", expanded=True) as status:
94
  try:
95
  st.write("Still looking... this may take a while as we look at some prestigious papers...")
96
- dataset, index = load_rag_dataset()
97
 
98
- if dataset is None or index is None:
99
  st.error("Failed to load or create the dataset.")
100
  status.update(label="Error loading data", state="error")
101
  else:
102
  st.write("Found the best sources!")
103
  st.write("Now answering your question...")
104
- answer = rag_pipeline(query, dataset, index)
105
 
106
  if answer:
107
  status.update(label="Search complete!", state="complete", expanded=False)
 
38
  dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
39
 
40
  # Load the dataset and index
41
+ dataset_path = os.path.join(dataset_dir, "dataset")
42
+ index_path = os.path.join(dataset_dir, "embeddings.faiss")
43
+
44
+ dataset = load_from_disk(dataset_path)
45
+
46
+ # Add FAISS index back to dataset
47
+ dataset.load_faiss_index('embeddings', index_path)
48
+
49
+ return dataset, dataset_path, index_path
50
  except Exception as e:
51
  st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
52
+ return None, None, None
53
 
54
  # RAG Pipeline
55
+ def rag_pipeline(query, dataset, dataset_path, index_path):
56
  try:
57
  # Initialize components
58
  model_name = "facebook/rag-sequence-nq"
59
  tokenizer = RagTokenizer.from_pretrained(model_name)
60
 
61
+ # Initialize retriever with correct paths
 
 
 
62
  retriever = RagRetriever.from_pretrained(
63
  model_name,
64
  index_name="custom",
65
+ passages_path=dataset_path,
66
+ index_path=index_path
67
  )
68
 
69
  # Initialize model with retriever
 
96
  with st.status("Looking for data in the best sources...", expanded=True) as status:
97
  try:
98
  st.write("Still looking... this may take a while as we look at some prestigious papers...")
99
+ dataset, dataset_path, index_path = load_rag_dataset()
100
 
101
+ if dataset is None:
102
  st.error("Failed to load or create the dataset.")
103
  status.update(label="Error loading data", state="error")
104
  else:
105
  st.write("Found the best sources!")
106
  st.write("Now answering your question...")
107
+ answer = rag_pipeline(query, dataset, dataset_path, index_path)
108
 
109
  if answer:
110
  status.update(label="Search complete!", state="complete", expanded=False)
faiss_index/index.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
  import torch
8
  import logging
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
@@ -24,9 +25,9 @@ def fetch_arxiv_papers(query, max_results=10):
24
 
25
  for i, result in enumerate(results):
26
  papers.append({
27
- "id": str(i), # Unique identifier
28
- "text": result.summary, # Main content for embedding
29
- "title": result.title, # Title for display
30
  })
31
 
32
  logging.info(f"Fetched {len(papers)} papers from arXiv")
@@ -39,16 +40,16 @@ def fetch_arxiv_papers(query, max_results=10):
39
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
40
  """Build and save dataset with FAISS index for RAG"""
41
  try:
42
- # Create dataset with required fields for RAG
43
  dataset = Dataset.from_dict({
44
  "id": [p["id"] for p in papers],
45
- "text": [p["text"] for p in papers], # Main content field
46
- "title": [p["title"] for p in papers], # Additional metadata
47
  })
48
 
49
  logging.info(f"Created dataset with {len(dataset)} papers")
50
 
51
- # Initialize DPR encoder (same as used by RAG)
52
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
53
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
54
 
@@ -74,17 +75,32 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
74
  embeddings = np.vstack(embeddings)
75
  logging.info(f"Created embeddings with shape {embeddings.shape}")
76
 
77
- # Create FAISS index (L2 distance)
78
- dimension = embeddings.shape[1] # Should be 768 for DPR
79
  index = faiss.IndexFlatL2(dimension)
80
  index.add(embeddings.astype(np.float32))
81
 
82
- # Save dataset and index
83
  os.makedirs(dataset_dir, exist_ok=True)
84
- dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
85
- faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
86
 
87
- logging.info(f"Saved dataset and index to {dataset_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  return dataset_dir
89
 
90
  except Exception as e:
 
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
  import torch
8
  import logging
9
+ from datasets.utils.file_utils import DownloadConfig
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
 
25
 
26
  for i, result in enumerate(results):
27
  papers.append({
28
+ "id": str(i),
29
+ "text": result.summary,
30
+ "title": result.title,
31
  })
32
 
33
  logging.info(f"Fetched {len(papers)} papers from arXiv")
 
40
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
41
  """Build and save dataset with FAISS index for RAG"""
42
  try:
43
+ # Create dataset
44
  dataset = Dataset.from_dict({
45
  "id": [p["id"] for p in papers],
46
+ "text": [p["text"] for p in papers],
47
+ "title": [p["title"] for p in papers],
48
  })
49
 
50
  logging.info(f"Created dataset with {len(dataset)} papers")
51
 
52
+ # Initialize DPR encoder
53
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
54
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
55
 
 
75
  embeddings = np.vstack(embeddings)
76
  logging.info(f"Created embeddings with shape {embeddings.shape}")
77
 
78
+ # Create FAISS index
79
+ dimension = embeddings.shape[1]
80
  index = faiss.IndexFlatL2(dimension)
81
  index.add(embeddings.astype(np.float32))
82
 
83
+ # Save everything
84
  os.makedirs(dataset_dir, exist_ok=True)
 
 
85
 
86
+ # Add embeddings to dataset
87
+ dataset = dataset.add_faiss_index(
88
+ column='embeddings',
89
+ custom_index=index,
90
+ device=0 if torch.cuda.is_available() else -1
91
+ )
92
+ dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
93
+
94
+ # Save dataset and index
95
+ dataset_path = os.path.join(dataset_dir, "dataset")
96
+ index_path = os.path.join(dataset_dir, "embeddings.faiss")
97
+
98
+ dataset.save_to_disk(dataset_path)
99
+ dataset.get_index('embeddings').save(index_path)
100
+
101
+ logging.info(f"Saved dataset to {dataset_path}")
102
+ logging.info(f"Saved index to {index_path}")
103
+
104
  return dataset_dir
105
 
106
  except Exception as e: