Yanisadel commited on
Commit
b9b892c
·
verified ·
1 Parent(s): 5b49149

Update text_generation.py

Browse files
Files changed (1) hide show
  1. text_generation.py +2 -2
text_generation.py CHANGED
@@ -66,8 +66,8 @@ class TextGenerationPipeline(Pipeline):
66
  return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
67
 
68
  def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
69
- english_tokens = model_inputs["english_tokens"].clone().to(torch.bfloat16)
70
- bio_tokens = model_inputs["bio_tokens"].clone().to(torch.bfloat16)
71
  projected_bio_embeddings = None
72
 
73
  actual_num_steps = 0
 
66
  return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
67
 
68
  def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
69
+ english_tokens = model_inputs["english_tokens"].clone()
70
+ bio_tokens = model_inputs["bio_tokens"].clone()
71
  projected_bio_embeddings = None
72
 
73
  actual_num_steps = 0