Upload ConstBERT
Browse files- modeling.py +2 -2
modeling.py
CHANGED
@@ -177,7 +177,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
177 |
|
178 |
return D
|
179 |
|
180 |
-
def
|
181 |
if type(queries) == str:
|
182 |
queries = [queries]
|
183 |
if bsize:
|
@@ -188,7 +188,7 @@ class ConstBERT(BertPreTrainedModel):
|
|
188 |
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
|
189 |
return self.query(input_ids, attention_mask)
|
190 |
|
191 |
-
def
|
192 |
if type(docs) == str:
|
193 |
docs = [docs]
|
194 |
assert keep_dims in [True, False, 'flatten']
|
|
|
177 |
|
178 |
return D
|
179 |
|
180 |
+
def encode_queries(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False):
|
181 |
if type(queries) == str:
|
182 |
queries = [queries]
|
183 |
if bsize:
|
|
|
188 |
input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search)
|
189 |
return self.query(input_ids, attention_mask)
|
190 |
|
191 |
+
def encode_documents(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False):
|
192 |
if type(docs) == str:
|
193 |
docs = [docs]
|
194 |
assert keep_dims in [True, False, 'flatten']
|