Upload ConstBERT
Browse files- modeling.py +5 -3
modeling.py
CHANGED
@@ -6,6 +6,8 @@ from tqdm import tqdm
|
|
6 |
from .colbert_configuration import ColBERTConfig
|
7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
8 |
import os
|
|
|
|
|
9 |
class NullContextManager(object):
|
10 |
def __init__(self, dummy_resource=None):
|
11 |
self.dummy_resource = dummy_resource
|
@@ -54,7 +56,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
54 |
"""
|
55 |
_keys_to_ignore_on_load_unexpected = [r"cls"]
|
56 |
|
57 |
-
def __init__(self, config, colbert_config, verbose:int =
|
58 |
super().__init__(config)
|
59 |
|
60 |
self.config = config
|
@@ -175,7 +177,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
175 |
|
176 |
return D
|
177 |
|
178 |
-
def
|
179 |
if bsize:
|
180 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
181 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
@@ -184,7 +186,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
184 |
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
|
185 |
return self.query(input_ids, attention_mask)
|
186 |
|
187 |
-
def
|
188 |
assert keep_dims in [True, False, 'flatten']
|
189 |
|
190 |
if bsize:
|
|
|
6 |
from .colbert_configuration import ColBERTConfig
|
7 |
from .tokenization_utils import QueryTokenizer, DocTokenizer
|
8 |
import os
|
9 |
+
|
10 |
+
|
11 |
class NullContextManager(object):
|
12 |
def __init__(self, dummy_resource=None):
|
13 |
self.dummy_resource = dummy_resource
|
|
|
56 |
"""
|
57 |
_keys_to_ignore_on_load_unexpected = [r"cls"]
|
58 |
|
59 |
+
def __init__(self, config, colbert_config, verbose:int = 0):
|
60 |
super().__init__(config)
|
61 |
|
62 |
self.config = config
|
|
|
177 |
|
178 |
return D
|
179 |
|
180 |
+
def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
|
181 |
if bsize:
|
182 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
183 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
|
|
186 |
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
|
187 |
return self.query(input_ids, attention_mask)
|
188 |
|
189 |
+
def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
|
190 |
assert keep_dims in [True, False, 'flatten']
|
191 |
|
192 |
if bsize:
|