Update text_generation.py
Browse files- text_generation.py +13 -8
text_generation.py
CHANGED
@@ -55,19 +55,24 @@ class TextGenerationPipeline(Pipeline):
|
|
55 |
truncation=True,
|
56 |
max_length=english_tokens_max_length,
|
57 |
).input_ids
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
65 |
|
66 |
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
|
67 |
|
68 |
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
|
69 |
english_tokens = model_inputs["english_tokens"].clone()
|
70 |
-
bio_tokens = model_inputs["bio_tokens"]
|
|
|
|
|
71 |
projected_bio_embeddings = None
|
72 |
|
73 |
actual_num_steps = 0
|
|
|
55 |
truncation=True,
|
56 |
max_length=english_tokens_max_length,
|
57 |
).input_ids
|
58 |
+
if len(dna_sequences) == 0:
|
59 |
+
bio_tokens = None
|
60 |
+
else:
|
61 |
+
bio_tokens = self.bio_tokenizer(
|
62 |
+
dna_sequences,
|
63 |
+
return_tensors="pt",
|
64 |
+
padding="max_length",
|
65 |
+
max_length=bio_tokens_max_length,
|
66 |
+
truncation=True,
|
67 |
+
).input_ids.unsqueeze(0)
|
68 |
|
69 |
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
|
70 |
|
71 |
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
|
72 |
english_tokens = model_inputs["english_tokens"].clone()
|
73 |
+
bio_tokens = model_inputs["bio_tokens"]
|
74 |
+
if bio_tokens is not None:
|
75 |
+
bio_tokens = bio_tokens.clone()
|
76 |
projected_bio_embeddings = None
|
77 |
|
78 |
actual_num_steps = 0
|