Upload ConstBERT
Browse files- modeling.py +4 -0
modeling.py
CHANGED
@@ -178,6 +178,8 @@ class ConstBERT(BertPreTrainedModel):
|
|
178 |
return D
|
179 |
|
180 |
def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
|
|
|
|
|
181 |
if bsize:
|
182 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
183 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
@@ -187,6 +189,8 @@ class ConstBERT(BertPreTrainedModel):
|
|
187 |
return self.query(input_ids, attention_mask)
|
188 |
|
189 |
def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
|
|
|
|
|
190 |
assert keep_dims in [True, False, 'flatten']
|
191 |
|
192 |
if bsize:
|
|
|
178 |
return D
|
179 |
|
180 |
def encode_query(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
|
181 |
+
if type(queries) == str:
|
182 |
+
queries = [queries]
|
183 |
if bsize:
|
184 |
batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search)
|
185 |
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
|
|
|
189 |
return self.query(input_ids, attention_mask)
|
190 |
|
191 |
def encode_document(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
|
192 |
+
if type(docs) == str:
|
193 |
+
docs = [docs]
|
194 |
assert keep_dims in [True, False, 'flatten']
|
195 |
|
196 |
if bsize:
|