Update modeling_quiet.py
Browse files- modeling_quiet.py +13 -1
modeling_quiet.py
CHANGED
@@ -1072,7 +1072,16 @@ class QuietModel(QuietPreTrainedModel):
|
|
1072 |
)
|
1073 |
|
1074 |
if attention_mask is None:
|
1075 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1076 |
|
1077 |
if attention_mask.dim() == 2:
|
1078 |
attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
|
@@ -1880,6 +1889,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1880 |
attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
|
1881 |
elif attention_mask.dim() != 4:
|
1882 |
raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len), but got {attention_mask.shape}")
|
|
|
|
|
|
|
1883 |
past_key_values = outputs.past_key_values
|
1884 |
position_ids = position_ids + 1
|
1885 |
|
|
|
1072 |
)
|
1073 |
|
1074 |
if attention_mask is None:
|
1075 |
+
if input_ids is not None:
|
1076 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
1077 |
+
elif inputs_embeds is not None:
|
1078 |
+
attention_mask = torch.ones(
|
1079 |
+
(batch_size, seq_len),
|
1080 |
+
dtype=torch.bool,
|
1081 |
+
device=inputs_embeds.device
|
1082 |
+
)
|
1083 |
+
else:
|
1084 |
+
raise ValueError("Either input_ids or inputs_embeds should be provided.")
|
1085 |
|
1086 |
if attention_mask.dim() == 2:
|
1087 |
attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
|
|
|
1889 |
attention_mask = attention_mask.expand(batch_size, 1, seq_len, seq_len)
|
1890 |
elif attention_mask.dim() != 4:
|
1891 |
raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len), but got {attention_mask.shape}")
|
1892 |
+
|
1893 |
+
attention_mask = attention_mask.to(dtype=torch.bool)
|
1894 |
+
|
1895 |
past_key_values = outputs.past_key_values
|
1896 |
position_ids = position_ids + 1
|
1897 |
|