caesar-one commited on
Commit
9d3ebbc
·
verified ·
1 Parent(s): f622be2

Upload ConstBERT

Browse files
Files changed (2) hide show
  1. colbert_configuration.py +3 -0
  2. modeling.py +12 -9
colbert_configuration.py CHANGED
@@ -158,6 +158,7 @@ class ResourceSettings:
158
  collection: str = DefaultVal(None)
159
  queries: str = DefaultVal(None)
160
  index_name: str = DefaultVal(None)
 
161
 
162
 
163
  @dataclass
@@ -350,6 +351,7 @@ class BaseConfig(CoreConfig):
350
 
351
  return config
352
 
 
353
  try:
354
  checkpoint_path = hf_hub_download(
355
  repo_id=checkpoint_path, filename="artifact.metadata"
@@ -360,6 +362,7 @@ class BaseConfig(CoreConfig):
360
  if os.path.exists(loaded_config_path):
361
  loaded_config, _ = cls.from_path(loaded_config_path)
362
  loaded_config.set("checkpoint", checkpoint_path)
 
363
 
364
  return loaded_config
365
 
 
158
  collection: str = DefaultVal(None)
159
  queries: str = DefaultVal(None)
160
  index_name: str = DefaultVal(None)
161
+ name_or_path: str = DefaultVal(None)
162
 
163
 
164
  @dataclass
 
351
 
352
  return config
353
 
354
+ name_or_path = checkpoint_path
355
  try:
356
  checkpoint_path = hf_hub_download(
357
  repo_id=checkpoint_path, filename="artifact.metadata"
 
362
  if os.path.exists(loaded_config_path):
363
  loaded_config, _ = cls.from_path(loaded_config_path)
364
  loaded_config.set("checkpoint", checkpoint_path)
365
+ loaded_config.set("name_or_path", name_or_path)
366
 
367
  return loaded_config
368
 
modeling.py CHANGED
@@ -1,18 +1,11 @@
1
  import torch.nn as nn
2
  from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
 
3
  import torch
4
  from tqdm import tqdm
5
- from transformers import AutoTokenizer
6
  from .colbert_configuration import ColBERTConfig
7
  from .tokenization_utils import QueryTokenizer, DocTokenizer
8
-
9
- # this is a hack to force huggingface hub to download the tokenizer files
10
- try:
11
- with open("./tokenizer_config.json", "r") as f, open("./tokenizer.json", "r") as f2, open("./vocab.txt", "r") as f3:
12
- pass
13
- except Exception as e:
14
- pass
15
-
16
  class NullContextManager(object):
17
  def __init__(self, dummy_resource=None):
18
  self.dummy_resource = dummy_resource
@@ -70,6 +63,16 @@ class ConstBERT(BertPreTrainedModel):
70
  self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
71
  self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
72
 
 
 
 
 
 
 
 
 
 
 
73
  self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
74
  self.doc_tokenizer = DocTokenizer(colbert_config)
75
  self.amp_manager = MixedPrecisionManager(True)
 
1
  import torch.nn as nn
2
  from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
3
+ from huggingface_hub import hf_hub_download
4
  import torch
5
  from tqdm import tqdm
 
6
  from .colbert_configuration import ColBERTConfig
7
  from .tokenization_utils import QueryTokenizer, DocTokenizer
8
+ import os
 
 
 
 
 
 
 
9
  class NullContextManager(object):
10
  def __init__(self, dummy_resource=None):
11
  self.dummy_resource = dummy_resource
 
63
  self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
64
  self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
65
 
66
+ ## Download required tokenizer files from Hugging Face
67
+ if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer.json")):
68
+ hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer.json")
69
+ if not os.path.exists(os.path.join(colbert_config.name_or_path, "vocab.txt")):
70
+ hf_hub_download(repo_id=colbert_config.name_or_path, filename="vocab.txt")
71
+ if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer_config.json")):
72
+ hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer_config.json")
73
+ if not os.path.exists(os.path.join(colbert_config.name_or_path, "special_tokens_map.json")):
74
+ hf_hub_download(repo_id=colbert_config.name_or_path, filename="special_tokens_map.json")
75
+
76
  self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
77
  self.doc_tokenizer = DocTokenizer(colbert_config)
78
  self.amp_manager = MixedPrecisionManager(True)