Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -11
modeling_quiet.py
CHANGED
@@ -1072,16 +1072,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
1072 |
)
|
1073 |
|
1074 |
if attention_mask is None:
|
1075 |
-
|
1076 |
-
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
1077 |
-
elif inputs_embeds is not None:
|
1078 |
-
attention_mask = torch.ones(
|
1079 |
-
(batch_size, seq_len),
|
1080 |
-
dtype=torch.bool,
|
1081 |
-
device=inputs_embeds.device
|
1082 |
-
)
|
1083 |
-
else:
|
1084 |
-
raise ValueError("Either input_ids or inputs_embeds should be provided.")
|
1085 |
|
1086 |
if attention_mask.dim() == 2:
|
1087 |
attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
|
@@ -1091,7 +1082,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
1091 |
elif attention_mask.dim() != 4:
|
1092 |
raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len), but got {attention_mask.shape}")
|
1093 |
|
1094 |
-
attention_mask = attention_mask.to(dtype=torch.bool
|
1095 |
|
1096 |
hidden_states = inputs_embeds
|
1097 |
|
|
|
1072 |
)
|
1073 |
|
1074 |
if attention_mask is None:
|
1075 |
+
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1076 |
|
1077 |
if attention_mask.dim() == 2:
|
1078 |
attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
|
|
|
1082 |
elif attention_mask.dim() != 4:
|
1083 |
raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len), but got {attention_mask.shape}")
|
1084 |
|
1085 |
+
attention_mask = attention_mask.to(dtype=torch.bool)
|
1086 |
|
1087 |
hidden_states = inputs_embeds
|
1088 |
|