leejunhyeok iamwyldecat commited on
Commit
1c6e6f5
·
verified ·
1 Parent(s): a3eb76a

refactor: print flash attn usage log only once (#4)

Browse files

- refactor: print flash attn usage log only once (c013e88166f88ab0d45328cb2654914802b127e6)


Co-authored-by: Jeesoo Lee <[email protected]>

Files changed (1) hide show
  1. modeling_motif.py +2 -2
modeling_motif.py CHANGED
@@ -472,8 +472,6 @@ class MotifFlashAttention2(MotifAttention):
472
 
473
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
474
 
475
- logger.info(f'flash attention is used {not self._flash_attn_uses_top_left_mask}')
476
-
477
  def _reshape_heads(self, tensor, batch_size, seq_len):
478
  """2-way head split tensor reshape"""
479
  return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
@@ -981,6 +979,8 @@ class MotifModel(MotifPreTrainedModel):
981
  self.gradient_checkpointing = False
982
  self.post_init()
983
 
 
 
984
  def get_input_embeddings(self):
985
  return self.embed_tokens
986
 
 
472
 
473
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
474
 
 
 
475
  def _reshape_heads(self, tensor, batch_size, seq_len):
476
  """2-way head split tensor reshape"""
477
  return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
 
979
  self.gradient_checkpointing = False
980
  self.post_init()
981
 
982
+ logger.info(f'Using flash_attn: {is_flash_attn_greater_or_equal_2_10()}')
983
+
984
  def get_input_embeddings(self):
985
  return self.embed_tokens
986