Yanisadel commited on
Commit
9ca04b0
·
verified ·
1 Parent(s): 6ed7d0b

Update text_generation.py

Browse files
Files changed (1) hide show
  1. text_generation.py +13 -8
text_generation.py CHANGED
@@ -55,19 +55,24 @@ class TextGenerationPipeline(Pipeline):
55
  truncation=True,
56
  max_length=english_tokens_max_length,
57
  ).input_ids
58
- bio_tokens = self.bio_tokenizer(
59
- dna_sequences,
60
- return_tensors="pt",
61
- padding="max_length",
62
- max_length=bio_tokens_max_length,
63
- truncation=True,
64
- ).input_ids.unsqueeze(0)
 
 
 
65
 
66
  return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
67
 
68
  def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
69
  english_tokens = model_inputs["english_tokens"].clone()
70
- bio_tokens = model_inputs["bio_tokens"].clone()
 
 
71
  projected_bio_embeddings = None
72
 
73
  actual_num_steps = 0
 
55
  truncation=True,
56
  max_length=english_tokens_max_length,
57
  ).input_ids
58
+ if len(dna_sequences) == 0:
59
+ bio_tokens = None
60
+ else:
61
+ bio_tokens = self.bio_tokenizer(
62
+ dna_sequences,
63
+ return_tensors="pt",
64
+ padding="max_length",
65
+ max_length=bio_tokens_max_length,
66
+ truncation=True,
67
+ ).input_ids.unsqueeze(0)
68
 
69
  return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
70
 
71
  def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
72
  english_tokens = model_inputs["english_tokens"].clone()
73
+ bio_tokens = model_inputs["bio_tokens"]
74
+ if bio_tokens is not None:
75
+ bio_tokens = bio_tokens.clone()
76
  projected_bio_embeddings = None
77
 
78
  actual_num_steps = 0