fix: cur_len mismatch
Browse files- modeling_asteroid.py +1 -0
modeling_asteroid.py
CHANGED
@@ -85,6 +85,7 @@ class CustomMixin(GenerationMixin):
|
|
85 |
needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
86 |
tf_inputs = input_ids[:]
|
87 |
input_ids = input_ids[:, :-(channels - 1)]
|
|
|
88 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
|
89 |
base_length = input_ids.shape[1]
|
90 |
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
85 |
needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
86 |
tf_inputs = input_ids[:]
|
87 |
input_ids = input_ids[:, :-(channels - 1)]
|
88 |
+
cur_len = input_ids.shape[1]
|
89 |
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
|
90 |
base_length = input_ids.shape[1]
|
91 |
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|