Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +17 -17
modeling_esm_plusplus.py
CHANGED
|
@@ -569,24 +569,24 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 569 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
| 570 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 571 |
print(f"Embedding {len(to_embed)} new sequences")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 576 |
-
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 577 |
-
x = self.embed(input_ids)
|
| 578 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
| 579 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 580 |
-
|
| 581 |
-
for seq, emb in zip(seqs, embeddings):
|
| 582 |
-
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 583 |
-
(seq, emb.cpu().numpy().tobytes()))
|
| 584 |
-
|
| 585 |
-
if (i + 1) % 100 == 0:
|
| 586 |
-
conn.commit()
|
| 587 |
-
|
| 588 |
-
conn.commit()
|
| 589 |
-
conn.close()
|
| 590 |
return None
|
| 591 |
|
| 592 |
embeddings_dict = {}
|
|
|
|
| 569 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
| 570 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 571 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 572 |
+
if len(to_embed) > 0:
|
| 573 |
+
with torch.no_grad():
|
| 574 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 575 |
+
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 576 |
+
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 577 |
+
x = self.embed(input_ids)
|
| 578 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
| 579 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 580 |
+
|
| 581 |
+
for seq, emb in zip(seqs, embeddings):
|
| 582 |
+
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 583 |
+
(seq, emb.cpu().numpy().tobytes()))
|
| 584 |
+
|
| 585 |
+
if (i + 1) % 100 == 0:
|
| 586 |
+
conn.commit()
|
| 587 |
|
| 588 |
+
conn.commit()
|
| 589 |
+
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
return None
|
| 591 |
|
| 592 |
embeddings_dict = {}
|