Crystalcareai commited on
Commit
d33b844
·
verified ·
1 Parent(s): 21d94a3

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +20 -44
modeling_quiet.py CHANGED
@@ -1071,27 +1071,18 @@ class QuietModel(QuietPreTrainedModel):
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
 
1074
- if self._attn_implementation == "flash_attention_2":
1075
- # 2d mask is passed through the layers
1076
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1077
- elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1078
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1079
- # the manual implementation that requires a 4D causal mask in all cases.
1080
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1081
- attention_mask,
1082
- (batch_size, seq_length),
1083
- inputs_embeds,
1084
- past_key_values_length,
1085
- )
1086
- elif attention_mask is None or attention_mask.dim() == 2:
1087
- # 4d mask is passed through the layers
1088
- attention_mask = _prepare_4d_causal_attention_mask(
1089
- attention_mask,
1090
- (batch_size, seq_length),
1091
- inputs_embeds,
1092
- past_key_values_length,
1093
- sliding_window=self.config.sliding_window,
1094
- )
1095
 
1096
  hidden_states = inputs_embeds
1097
 
@@ -1883,29 +1874,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
1883
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1884
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1885
 
1886
- if len(attention_mask.shape) == 2:
1887
- breakpoint()
1888
- else:
1889
- original_attention = attention_mask[..., :attention_mask.shape[-2]]
1890
- if self.use_upper_triangular:
1891
- new_attention = original_attention
1892
- else:
1893
- original_attention = original_attention == attention_mask.max()
1894
- # because eye isn't implemented for BF16, we need to handle the case
1895
- if not attention_mask.dtype == torch.bfloat16:
1896
- new_attention = torch.eye(
1897
- seq_len, dtype=attention_mask.dtype, device=attention_mask.device
1898
- )
1899
- else:
1900
- new_attention = torch.eye(
1901
- seq_len, dtype=torch.float32, device=attention_mask.device
1902
- ).to(attention_mask.dtype)
1903
-
1904
- new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
1905
- new_attention = new_attention * original_attention
1906
- new_attention[new_attention == 0] = attention_mask.min()
1907
- new_attention[new_attention == 1] = attention_mask.max()
1908
- attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
1909
  past_key_values = outputs.past_key_values
1910
  position_ids = position_ids + 1
1911
 
 
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
 
1074
+ if attention_mask is None:
1075
+ attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=input_ids.device)
1076
+
1077
+ if attention_mask.dim() == 2:
1078
+ attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
1079
+ attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
1080
+ elif attention_mask.dim() == 3:
1081
+ attention_mask = attention_mask.unsqueeze(1)
1082
+ elif attention_mask.dim() != 4:
1083
+ raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len), but got {attention_mask.shape}")
1084
+
1085
+ attention_mask = attention_mask.to(dtype=torch.bool, device=input_ids.device)
 
 
 
 
 
 
 
 
 
1086
 
1087
  hidden_states = inputs_embeds
1088
 
 
1874
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1875
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1876
 
1877
+ if attention_mask is not None:
1878
+ if attention_mask.dim() == 2:
1879
+ attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
1880
+ attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
1881
+ elif attention_mask.dim() != 4:
1882
+ raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len), but got {attention_mask.shape}")
1883
+
1884
+ attention_mask = attention_mask.to(dtype=torch.bool, device=input_ids.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1885
  past_key_values = outputs.past_key_values
1886
  position_ids = position_ids + 1
1887