orionweller commited on
Commit
6b90dc3
·
1 Parent(s): e61608e
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -5,7 +5,7 @@ import glob
5
  import tqdm
6
  import torch
7
  import torch.nn.functional as F
8
- from transformers import AutoTokenizer, AutoModel
9
  from peft import PeftModel
10
  from tevatron.retriever.searcher import FaissFlatSearcher
11
  import logging
@@ -20,6 +20,8 @@ import peft
20
  import faiss
21
  import sys
22
 
 
 
23
  # Set up logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
@@ -176,12 +178,15 @@ def load_corpus_lookups(dataset_name):
176
  global corpus_lookups
177
  corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
178
  index_files = glob.glob(corpus_path)
 
 
179
 
180
  corpus_lookups[dataset_name] = []
181
  for file in index_files:
182
  with open(file, 'rb') as f:
183
  _, p_lookup = pickle.load(f)
184
  corpus_lookups[dataset_name] += p_lookup
 
185
  logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
186
  logger.info(f"Sample corpus lookup entry: {corpus_lookups[dataset_name][0]}")
187
 
 
5
  import tqdm
6
  import torch
7
  import torch.nn.functional as F
8
+ from transformers import AutoTokenizer, AutoModel, set_seed
9
  from peft import PeftModel
10
  from tevatron.retriever.searcher import FaissFlatSearcher
11
  import logging
 
20
  import faiss
21
  import sys
22
 
23
+ set_seed(42)
24
+
25
  # Set up logging
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
 
178
  global corpus_lookups
179
  corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
180
  index_files = glob.glob(corpus_path)
181
+ # sort them
182
+ index_files.sort(key=lambda x: int(x.split('.')[-2]))
183
 
184
  corpus_lookups[dataset_name] = []
185
  for file in index_files:
186
  with open(file, 'rb') as f:
187
  _, p_lookup = pickle.load(f)
188
  corpus_lookups[dataset_name] += p_lookup
189
+
190
  logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
191
  logger.info(f"Sample corpus lookup entry: {corpus_lookups[dataset_name][0]}")
192