crumb commited on
Commit
ad1e7a3
·
1 Parent(s): 4bba6a9

Upload model

Browse files
Files changed (1) hide show
  1. modeling_gzipembed.py +13 -0
modeling_gzipembed.py CHANGED
@@ -39,6 +39,19 @@ class GZIPEmbeddingModel(PreTrainedModel):
39
  x = x.to(self.reduction_head.dtype).to(self.reduction_head.device)
40
  return self.reduction_head(x)
41
  return x if not return_tensor else torch.tensor(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def normalize(self, x):
44
  x = ''.join([char for char in x.lower() if char in "abcdefghijklmnopqrstuvwxyz "])
 
39
  x = x.to(self.reduction_head.dtype).to(self.reduction_head.device)
40
  return self.reduction_head(x)
41
  return x if not return_tensor else torch.tensor(x)
42
+
43
+ def encode(self, sentences, batch_size=32, **kwargs):
44
+ """
45
+ Returns a list of embeddings for the given sentences.
46
+ Args:
47
+ sentences (`List[str]`): List of sentences to encode
48
+ batch_size (`int`): Batch size for the encoding
49
+
50
+ Returns:
51
+ `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
52
+ """
53
+ x = self.forward(sentences, num_procs=batch_size, return_tensor=False)
54
+ return [torch.tensor(i) for i in x]
55
 
56
  def normalize(self, x):
57
  x = ''.join([char for char in x.lower() if char in "abcdefghijklmnopqrstuvwxyz "])