Upload model
Browse files- modeling_gzipembed.py +7 -1
modeling_gzipembed.py
CHANGED
@@ -17,6 +17,12 @@ class GZIPEmbeddingModel(PreTrainedModel):
|
|
17 |
self.dummy_parameter = torch.nn.Parameter(torch.ones(1))
|
18 |
|
19 |
def forward(self, prompt, num_procs=16):
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
if type(prompt) == str:
|
21 |
prompt = [prompt]
|
22 |
x = []
|
@@ -24,7 +30,7 @@ class GZIPEmbeddingModel(PreTrainedModel):
|
|
24 |
ncd = [0] * len(self.config.corpus)
|
25 |
with multiprocessing.Pool(num_procs) as pool:
|
26 |
data = enumerate(self.config.corpus)
|
27 |
-
results = pool.map(
|
28 |
for i,row in results:
|
29 |
ncd[i]=row
|
30 |
x.append(ncd)
|
|
|
17 |
self.dummy_parameter = torch.nn.Parameter(torch.ones(1))
|
18 |
|
19 |
def forward(self, prompt, num_procs=16):
|
20 |
+
global calculate_ncd_row
|
21 |
+
global p
|
22 |
+
def calculate_ncd_row(data_row):
|
23 |
+
i = data_row[0]
|
24 |
+
row = model.ncd(data_row[1], prompt)
|
25 |
+
return i, row
|
26 |
if type(prompt) == str:
|
27 |
prompt = [prompt]
|
28 |
x = []
|
|
|
30 |
ncd = [0] * len(self.config.corpus)
|
31 |
with multiprocessing.Pool(num_procs) as pool:
|
32 |
data = enumerate(self.config.corpus)
|
33 |
+
results = pool.map(calculate_ncd_row,(data,p))
|
34 |
for i,row in results:
|
35 |
ncd[i]=row
|
36 |
x.append(ncd)
|