yhzx233 commited on
Commit
7ddb63f
·
1 Parent(s): 50730e1

fix: cur_len mismatch

Browse files
Files changed (1) hide show
  1. modeling_asteroid.py +1 -0
modeling_asteroid.py CHANGED
@@ -85,6 +85,7 @@ class CustomMixin(GenerationMixin):
85
  needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
86
  tf_inputs = input_ids[:]
87
  input_ids = input_ids[:, :-(channels - 1)]
 
88
  model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
89
  base_length = input_ids.shape[1]
90
  model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
 
85
  needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
86
  tf_inputs = input_ids[:]
87
  input_ids = input_ids[:, :-(channels - 1)]
88
+ cur_len = input_ids.shape[1]
89
  model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
90
  base_length = input_ids.shape[1]
91
  model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)