caesar-one commited on
Commit
5219c44
·
verified ·
1 Parent(s): 624a902

Upload ConstBERT

Browse files
Files changed (1) hide show
  1. modeling.py +4 -0
modeling.py CHANGED
@@ -178,6 +178,8 @@ class ConstBERT(BertPreTrainedModel):
178
  return D
179
 
180
  def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
 
 
181
  if bsize:
182
  batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
183
  batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
@@ -187,6 +189,8 @@ class ConstBERT(BertPreTrainedModel):
187
  return self.query(input_ids, attention_mask)
188
 
189
  def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
 
 
190
  assert keep_dims in [True, False, 'flatten']
191
 
192
  if bsize:
 
178
  return D
179
 
180
  def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
181
+ if type(queries) == str:
182
+ queries = [queries]
183
  if bsize:
184
  batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
185
  batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
 
189
  return self.query(input_ids, attention_mask)
190
 
191
  def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
192
+ if type(docs) == str:
193
+ docs = [docs]
194
  assert keep_dims in [True, False, 'flatten']
195
 
196
  if bsize: