Crystalcareai commited on
Commit
833b955
·
verified ·
1 Parent(s): 23c0feb

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +13 -1
modeling_quiet.py CHANGED
@@ -1072,7 +1072,16 @@ class QuietModel(QuietPreTrainedModel):
1072
  )
1073
 
1074
  if attention_mask is None:
1075
- attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=input_ids.device)
 
 
 
 
 
 
 
 
 
1076
 
1077
  if attention_mask.dim() == 2:
1078
  attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
@@ -1880,6 +1889,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1880
  attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
1881
  elif attention_mask.dim() != 4:
1882
  raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len), but got {attention_mask.shape}")
 
 
 
1883
  past_key_values = outputs.past_key_values
1884
  position_ids = position_ids + 1
1885
 
 
1072
  )
1073
 
1074
  if attention_mask is None:
1075
+ if input_ids is not None:
1076
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1077
+ elif inputs_embeds is not None:
1078
+ attention_mask = torch.ones(
1079
+ (batch_size, seq_len),
1080
+ dtype=torch.bool,
1081
+ device=inputs_embeds.device
1082
+ )
1083
+ else:
1084
+ raise ValueError("Either input_ids or inputs_embeds should be provided.")
1085
 
1086
  if attention_mask.dim() == 2:
1087
  attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
 
1889
  attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
1890
  elif attention_mask.dim() != 4:
1891
  raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len), but got {attention_mask.shape}")
1892
+
1893
+ attention_mask = attention_mask.to(dtype=torch.bool)
1894
+
1895
  past_key_values = outputs.past_key_values
1896
  position_ids = position_ids + 1
1897