Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +17 -17
modeling_esm_plusplus.py
CHANGED
@@ -569,24 +569,24 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
569 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
570 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
571 |
print(f"Embedding {len(to_embed)} new sequences")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
|
573 |
-
|
574 |
-
|
575 |
-
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
576 |
-
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
577 |
-
x = self.embed(input_ids)
|
578 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
579 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
580 |
-
|
581 |
-
for seq, emb in zip(seqs, embeddings):
|
582 |
-
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
583 |
-
(seq, emb.cpu().numpy().tobytes()))
|
584 |
-
|
585 |
-
if (i + 1) % 100 == 0:
|
586 |
-
conn.commit()
|
587 |
-
|
588 |
-
conn.commit()
|
589 |
-
conn.close()
|
590 |
return None
|
591 |
|
592 |
embeddings_dict = {}
|
|
|
569 |
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
570 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
571 |
print(f"Embedding {len(to_embed)} new sequences")
|
572 |
+
if len(to_embed) > 0:
|
573 |
+
with torch.no_grad():
|
574 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
575 |
+
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
576 |
+
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
577 |
+
x = self.embed(input_ids)
|
578 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
579 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
580 |
+
|
581 |
+
for seq, emb in zip(seqs, embeddings):
|
582 |
+
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
583 |
+
(seq, emb.cpu().numpy().tobytes()))
|
584 |
+
|
585 |
+
if (i + 1) % 100 == 0:
|
586 |
+
conn.commit()
|
587 |
|
588 |
+
conn.commit()
|
589 |
+
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
590 |
return None
|
591 |
|
592 |
embeddings_dict = {}
|