Fix: Respect `is_causal=False` config in forward to enable bidirectional attention (#37)
Browse files- Fix: Respect `is_causal=False` config in forward to enable bidirectional attention (98540699bd4a2b84b672e00e7302cbc314082bdc)
Co-authored-by: Yihang Wang <[email protected]>
- modeling_qwen.py +4 -4
modeling_qwen.py
CHANGED
|
@@ -350,7 +350,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|
| 350 |
past_key_value: Optional[Cache] = None,
|
| 351 |
output_attentions: bool = False,
|
| 352 |
use_cache: bool = False,
|
| 353 |
-
is_causal: bool =
|
| 354 |
**kwargs,
|
| 355 |
):
|
| 356 |
if "padding_mask" in kwargs:
|
|
@@ -646,7 +646,7 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
| 646 |
past_key_value: Optional[Cache] = None,
|
| 647 |
output_attentions: bool = False,
|
| 648 |
use_cache: bool = False,
|
| 649 |
-
is_causal: bool =
|
| 650 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 651 |
if output_attentions:
|
| 652 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
@@ -965,7 +965,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 965 |
output_hidden_states: Optional[bool] = None,
|
| 966 |
return_dict: Optional[bool] = None,
|
| 967 |
labels: Optional[torch.LongTensor] = None,
|
| 968 |
-
is_causal: Optional[bool] =
|
| 969 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 970 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 971 |
output_hidden_states = (
|
|
@@ -1160,7 +1160,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
| 1160 |
output_attentions: Optional[bool] = None,
|
| 1161 |
output_hidden_states: Optional[bool] = None,
|
| 1162 |
return_dict: Optional[bool] = None,
|
| 1163 |
-
is_causal: Optional[bool] =
|
| 1164 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1165 |
r"""
|
| 1166 |
Args:
|
|
|
|
| 350 |
past_key_value: Optional[Cache] = None,
|
| 351 |
output_attentions: bool = False,
|
| 352 |
use_cache: bool = False,
|
| 353 |
+
is_causal: bool = False,
|
| 354 |
**kwargs,
|
| 355 |
):
|
| 356 |
if "padding_mask" in kwargs:
|
|
|
|
| 646 |
past_key_value: Optional[Cache] = None,
|
| 647 |
output_attentions: bool = False,
|
| 648 |
use_cache: bool = False,
|
| 649 |
+
is_causal: bool = True,
|
| 650 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 651 |
if output_attentions:
|
| 652 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
|
|
| 965 |
output_hidden_states: Optional[bool] = None,
|
| 966 |
return_dict: Optional[bool] = None,
|
| 967 |
labels: Optional[torch.LongTensor] = None,
|
| 968 |
+
is_causal: Optional[bool] = False,
|
| 969 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 970 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 971 |
output_hidden_states = (
|
|
|
|
| 1160 |
output_attentions: Optional[bool] = None,
|
| 1161 |
output_hidden_states: Optional[bool] = None,
|
| 1162 |
return_dict: Optional[bool] = None,
|
| 1163 |
+
is_causal: Optional[bool] = False,
|
| 1164 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1165 |
r"""
|
| 1166 |
Args:
|