jamesliu1217 commited on
Commit
7761629
·
verified ·
1 Parent(s): 4b47d90

Update src/layers_cache.py

Browse files
Files changed (1) hide show
  1. src/layers_cache.py +2 -2
src/layers_cache.py CHANGED
@@ -146,7 +146,7 @@ class MultiSingleStreamBlockLoraProcessor(nn.Module):
146
  start = i * scaled_cond_size + scaled_block_size
147
  end = (i + 1) * scaled_cond_size + scaled_block_size
148
  mask[start:end, start:end] = 0 # Diagonal blocks
149
- mask = mask * -1e20
150
  mask = mask.to(query.dtype)
151
 
152
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
@@ -305,7 +305,7 @@ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
305
  start = i * scaled_cond_size + scaled_block_size
306
  end = (i + 1) * scaled_cond_size + scaled_block_size
307
  mask[start:end, start:end] = 0 # Diagonal blocks
308
- mask = mask * -1e20
309
  mask = mask.to(query.dtype)
310
 
311
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
 
146
  start = i * scaled_cond_size + scaled_block_size
147
  end = (i + 1) * scaled_cond_size + scaled_block_size
148
  mask[start:end, start:end] = 0 # Diagonal blocks
149
+ mask = mask * -1e10
150
  mask = mask.to(query.dtype)
151
 
152
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
 
305
  start = i * scaled_cond_size + scaled_block_size
306
  end = (i + 1) * scaled_cond_size + scaled_block_size
307
  mask[start:end, start:end] = 0 # Diagonal blocks
308
+ mask = mask * -1e10
309
  mask = mask.to(query.dtype)
310
 
311
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)