refactor: print flash attn usage log only once (#4)
Browse files- refactor: print flash attn usage log only once (c013e88166f88ab0d45328cb2654914802b127e6)
Co-authored-by: Jeesoo Lee <[email protected]>
- modeling_motif.py +2 -2
modeling_motif.py
CHANGED
|
@@ -472,8 +472,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 472 |
|
| 473 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 474 |
|
| 475 |
-
logger.info(f'flash attention is used {not self._flash_attn_uses_top_left_mask}')
|
| 476 |
-
|
| 477 |
def _reshape_heads(self, tensor, batch_size, seq_len):
|
| 478 |
"""2-way head split tensor reshape"""
|
| 479 |
return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
|
|
@@ -981,6 +979,8 @@ class MotifModel(MotifPreTrainedModel):
|
|
| 981 |
self.gradient_checkpointing = False
|
| 982 |
self.post_init()
|
| 983 |
|
|
|
|
|
|
|
| 984 |
def get_input_embeddings(self):
|
| 985 |
return self.embed_tokens
|
| 986 |
|
|
|
|
| 472 |
|
| 473 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 474 |
|
|
|
|
|
|
|
| 475 |
def _reshape_heads(self, tensor, batch_size, seq_len):
|
| 476 |
"""2-way head split tensor reshape"""
|
| 477 |
return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
|
|
|
|
| 979 |
self.gradient_checkpointing = False
|
| 980 |
self.post_init()
|
| 981 |
|
| 982 |
+
logger.info(f'Using flash_attn: {is_flash_attn_greater_or_equal_2_10()}')
|
| 983 |
+
|
| 984 |
def get_input_embeddings(self):
|
| 985 |
return self.embed_tokens
|
| 986 |
|