Update modeling_custom.py
Browse files- modeling_custom.py +1 -1
modeling_custom.py
CHANGED
@@ -410,7 +410,7 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
|
|
410 |
|
411 |
attention_dropout = self.config.attention_dropout if self.training else 0.0
|
412 |
|
413 |
-
#TODO: Compute attention
|
414 |
attn_weights = ...
|
415 |
|
416 |
#TODO: Reshape outputs before projection
|
|
|
410 |
|
411 |
attention_dropout = self.config.attention_dropout if self.training else 0.0
|
412 |
|
413 |
+
#TODO: Compute attention with _flash_attention_forward
|
414 |
attn_weights = ...
|
415 |
|
416 |
#TODO: Reshape outputs before projection
|