Update modeling_quiet.py
Browse files- modeling_quiet.py +20 -44
modeling_quiet.py
CHANGED
@@ -1071,27 +1071,18 @@ class QuietModel(QuietPreTrainedModel):
|
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
|
1074 |
-
if
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
elif attention_mask is None or attention_mask.dim() == 2:
|
1087 |
-
# 4d mask is passed through the layers
|
1088 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
1089 |
-
attention_mask,
|
1090 |
-
(batch_size, seq_length),
|
1091 |
-
inputs_embeds,
|
1092 |
-
past_key_values_length,
|
1093 |
-
sliding_window=self.config.sliding_window,
|
1094 |
-
)
|
1095 |
|
1096 |
hidden_states = inputs_embeds
|
1097 |
|
@@ -1883,29 +1874,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1883 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1884 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1885 |
|
1886 |
-
if
|
1887 |
-
|
1888 |
-
|
1889 |
-
|
1890 |
-
|
1891 |
-
|
1892 |
-
|
1893 |
-
|
1894 |
-
# because eye isn't implemented for BF16, we need to handle the case
|
1895 |
-
if not attention_mask.dtype == torch.bfloat16:
|
1896 |
-
new_attention = torch.eye(
|
1897 |
-
seq_len, dtype=attention_mask.dtype, device=attention_mask.device
|
1898 |
-
)
|
1899 |
-
else:
|
1900 |
-
new_attention = torch.eye(
|
1901 |
-
seq_len, dtype=torch.float32, device=attention_mask.device
|
1902 |
-
).to(attention_mask.dtype)
|
1903 |
-
|
1904 |
-
new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
|
1905 |
-
new_attention = new_attention * original_attention
|
1906 |
-
new_attention[new_attention == 0] = attention_mask.min()
|
1907 |
-
new_attention[new_attention == 1] = attention_mask.max()
|
1908 |
-
attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
|
1909 |
past_key_values = outputs.past_key_values
|
1910 |
position_ids = position_ids + 1
|
1911 |
|
|
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
|
1074 |
+
if attention_mask is None:
|
1075 |
+
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=input_ids.device)
|
1076 |
+
|
1077 |
+
if attention_mask.dim() == 2:
|
1078 |
+
attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
|
1079 |
+
attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
|
1080 |
+
elif attention_mask.dim() == 3:
|
1081 |
+
attention_mask = attention_mask.unsqueeze(1)
|
1082 |
+
elif attention_mask.dim() != 4:
|
1083 |
+
raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len), but got {attention_mask.shape}")
|
1084 |
+
|
1085 |
+
attention_mask = attention_mask.to(dtype=torch.bool, device=input_ids.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1086 |
|
1087 |
hidden_states = inputs_embeds
|
1088 |
|
|
|
1874 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1875 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1876 |
|
1877 |
+
if attention_mask is not None:
|
1878 |
+
if attention_mask.dim() == 2:
|
1879 |
+
attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
|
1880 |
+
attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
|
1881 |
+
elif attention_mask.dim() != 4:
|
1882 |
+
raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len), but got {attention_mask.shape}")
|
1883 |
+
|
1884 |
+
attention_mask = attention_mask.to(dtype=torch.bool, device=input_ids.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1885 |
past_key_values = outputs.past_key_values
|
1886 |
position_ids = position_ids + 1
|
1887 |
|