|
""" |
|
Computes embeddings |
|
""" |
|
|
|
|
|
import unittest |
|
from sentence_transformers import SentenceTransformer |
|
import numpy as np |
|
|
|
class ComputeMultiProcessTest(unittest.TestCase): |
|
def setUp(self): |
|
self.model = SentenceTransformer('paraphrase-distilroberta-base-v1') |
|
|
|
def test_multi_gpu_encode(self): |
|
|
|
pool = self.model.start_multi_process_pool(['cpu', 'cpu']) |
|
|
|
sentences = ["This is sentence {}".format(i) for i in range(1000)] |
|
|
|
|
|
emb = self.model.encode_multi_process(sentences, pool, chunk_size=50) |
|
assert emb.shape == (len(sentences), 768) |
|
|
|
emb_normal = self.model.encode(sentences) |
|
|
|
|
|
diff = np.max(np.abs(emb - emb_normal)) |
|
print("Max multi proc diff", diff) |
|
assert diff < 0.001 |
|
|
|
|
|
|
|
|
|
|