lhallee commited on
Commit
ed85e42
·
verified ·
1 Parent(s): d354b92

Update modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +17 -17
modeling_esm_plusplus.py CHANGED
@@ -569,24 +569,24 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
569
  to_embed = [seq for seq in sequences if seq not in already_embedded]
570
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
571
  print(f"Embedding {len(to_embed)} new sequences")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
- with torch.no_grad():
574
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
575
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
576
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
577
- x = self.embed(input_ids)
578
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
579
- embeddings = get_embeddings(residue_embeddings, attention_mask)
580
-
581
- for seq, emb in zip(seqs, embeddings):
582
- c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
583
- (seq, emb.cpu().numpy().tobytes()))
584
-
585
- if (i + 1) % 100 == 0:
586
- conn.commit()
587
-
588
- conn.commit()
589
- conn.close()
590
  return None
591
 
592
  embeddings_dict = {}
 
569
  to_embed = [seq for seq in sequences if seq not in already_embedded]
570
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
571
  print(f"Embedding {len(to_embed)} new sequences")
572
+ if len(to_embed) > 0:
573
+ with torch.no_grad():
574
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
575
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
576
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
577
+ x = self.embed(input_ids)
578
+ residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
579
+ embeddings = get_embeddings(residue_embeddings, attention_mask)
580
+
581
+ for seq, emb in zip(seqs, embeddings):
582
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
583
+ (seq, emb.cpu().numpy().tobytes()))
584
+
585
+ if (i + 1) % 100 == 0:
586
+ conn.commit()
587
 
588
+ conn.commit()
589
+ conn.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  return None
591
 
592
  embeddings_dict = {}