Upload ConstBERT
Browse files- colbert_configuration.py +3 -0
- modeling.py +12 -9
colbert_configuration.py
CHANGED
@@ -158,6 +158,7 @@ class ResourceSettings:
|
|
158 |
collection: str = DefaultVal(None)
|
159 |
queries: str = DefaultVal(None)
|
160 |
index_name: str = DefaultVal(None)
|
|
|
161 |
|
162 |
|
163 |
@dataclass
|
@@ -350,6 +351,7 @@ class BaseConfig(CoreConfig):
|
|
350 |
|
351 |
return config
|
352 |
|
|
|
353 |
try:
|
354 |
checkpoint_path = hf_hub_download(
|
355 |
repo_id=checkpoint_path, filename="artifact.metadata"
|
@@ -360,6 +362,7 @@ class BaseConfig(CoreConfig):
|
|
360 |
if os.path.exists(loaded_config_path):
|
361 |
loaded_config, _ = cls.from_path(loaded_config_path)
|
362 |
loaded_config.set("checkpoint", checkpoint_path)
|
|
|
363 |
|
364 |
return loaded_config
|
365 |
|
|
|
158 |
collection: str = DefaultVal(None)
|
159 |
queries: str = DefaultVal(None)
|
160 |
index_name: str = DefaultVal(None)
|
161 |
+
name_or_path: str = DefaultVal(None)
|
162 |
|
163 |
|
164 |
@dataclass
|
|
|
351 |
|
352 |
return config
|
353 |
|
354 |
+
name_or_path = checkpoint_path
|
355 |
try:
|
356 |
checkpoint_path = hf_hub_download(
|
357 |
repo_id=checkpoint_path, filename="artifact.metadata"
|
|
|
362 |
if os.path.exists(loaded_config_path):
|
363 |
loaded_config, _ = cls.from_path(loaded_config_path)
|
364 |
loaded_config.set("checkpoint", checkpoint_path)
|
365 |
+
loaded_config.set("name_or_path", name_or_path)
|
366 |
|
367 |
return loaded_config
|
368 |
|
modeling.py
CHANGED
@@ -1,18 +1,11 @@
|
|
1 |
import torch.nn as nn
|
2 |
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
|
|
|
3 |
import torch
|
4 |
from tqdm import tqdm
|
5 |
-
from transformers import AutoTokenizer
|
6 |
from .colbert_configuration import ColBERTConfig
|
7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
8 |
-
|
9 |
-
# this is a hack to force huggingface hub to download the tokenizer files
|
10 |
-
try:
|
11 |
-
with open("./tokenizer_config.json", "r") as f, open("./tokenizer.json", "r") as f2, open("./vocab.txt", "r") as f3:
|
12 |
-
pass
|
13 |
-
except Exception as e:
|
14 |
-
pass
|
15 |
-
|
16 |
class NullContextManager(object):
|
17 |
def __init__(self, dummy_resource=None):
|
18 |
self.dummy_resource = dummy_resource
|
@@ -70,6 +63,16 @@ class ConstBERT(BertPreTrainedModel):
|
|
70 |
self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
|
71 |
self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
|
74 |
self.doc_tokenizer = DocTokenizer(colbert_config)
|
75 |
self.amp_manager = MixedPrecisionManager(True)
|
|
|
1 |
import torch.nn as nn
|
2 |
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
import torch
|
5 |
from tqdm import tqdm
|
|
|
6 |
from .colbert_configuration import ColBERTConfig
|
7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
8 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class NullContextManager(object):
|
10 |
def __init__(self, dummy_resource=None):
|
11 |
self.dummy_resource = dummy_resource
|
|
|
63 |
self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False)
|
64 |
self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False)
|
65 |
|
66 |
+
## Download required tokenizer files from Hugging Face
|
67 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer.json")):
|
68 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer.json")
|
69 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "vocab.txt")):
|
70 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="vocab.txt")
|
71 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer_config.json")):
|
72 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer_config.json")
|
73 |
+
if not os.path.exists(os.path.join(colbert_config.name_or_path, "special_tokens_map.json")):
|
74 |
+
hf_hub_download(repo_id=colbert_config.name_or_path, filename="special_tokens_map.json")
|
75 |
+
|
76 |
self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose)
|
77 |
self.doc_tokenizer = DocTokenizer(colbert_config)
|
78 |
self.amp_manager = MixedPrecisionManager(True)
|