Vittorio Pippi commited on
Commit
6caadef
·
1 Parent(s): 9828a5d

Fix stopping criteria window size calculation to prevent index errors

Browse files
Files changed (1) hide show
  1. modeling_emuru.py +1 -1
modeling_emuru.py CHANGED
@@ -252,7 +252,7 @@ class Emuru(PreTrainedModel):
252
 
253
  if stopping_criteria == 'latent':
254
  similarity = torch.nn.functional.cosine_similarity(canvas_sequence, pad_token, dim=-1)
255
- windows = (similarity > self.padding_token_threshold).unfold(1, stopping_after, 1)
256
  window_sums = windows.to(torch.int).sum(dim=2)
257
 
258
  for i in range(similarity.size(0)):
 
252
 
253
  if stopping_criteria == 'latent':
254
  similarity = torch.nn.functional.cosine_similarity(canvas_sequence, pad_token, dim=-1)
255
+ windows = (similarity > self.padding_token_threshold).unfold(1, min(stopping_after, similarity.size(-1)), 1)
256
  window_sums = windows.to(torch.int).sum(dim=2)
257
 
258
  for i in range(similarity.size(0)):