caesar-one commited on
Commit
39157e5
·
verified ·
1 Parent(s): 9d3ebbc

Upload ConstBERT

Browse files
Files changed (1) hide show
  1. modeling.py +5 -3
modeling.py CHANGED
@@ -6,6 +6,8 @@ from tqdm import tqdm
6
  from .colbert_configuration import ColBERTConfig
7
  from .tokenization_utils import QueryTokenizer, DocTokenizer
8
  import os
 
 
9
  class NullContextManager(object):
10
  def __init__(self, dummy_resource=None):
11
  self.dummy_resource = dummy_resource
@@ -54,7 +56,7 @@ class ConstBERT(BertPreTrainedModel):
54
  """
55
  _keys_to_ignore_on_load_unexpected = [r"cls"]
56
 
57
- def __init__(self, config, colbert_config, verbose:int = 3):
58
  super().__init__(config)
59
 
60
  self.config = config
@@ -175,7 +177,7 @@ class ConstBERT(BertPreTrainedModel):
175
 
176
  return D
177
 
178
- def queryFromText(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
179
  if bsize:
180
  batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
181
  batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
@@ -184,7 +186,7 @@ class ConstBERT(BertPreTrainedModel):
184
  input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
185
  return self.query(input_ids, attention_mask)
186
 
187
- def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
188
  assert keep_dims in [True, False, 'flatten']
189
 
190
  if bsize:
 
6
  from .colbert_configuration import ColBERTConfig
7
  from .tokenization_utils import QueryTokenizer, DocTokenizer
8
  import os
9
+
10
+
11
  class NullContextManager(object):
12
  def __init__(self, dummy_resource=None):
13
  self.dummy_resource = dummy_resource
 
56
  """
57
  _keys_to_ignore_on_load_unexpected = [r"cls"]
58
 
59
+ def __init__(self, config, colbert_config, verbose:int = 0):
60
  super().__init__(config)
61
 
62
  self.config = config
 
177
 
178
  return D
179
 
180
+ def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
181
  if bsize:
182
  batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
183
  batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
 
186
  input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
187
  return self.query(input_ids, attention_mask)
188
 
189
+ def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
190
  assert keep_dims in [True, False, 'flatten']
191
 
192
  if bsize: