Spaces:
Running
on
Zero
Running
on
Zero
Update src/layers_cache.py
Browse files- src/layers_cache.py +2 -2
src/layers_cache.py
CHANGED
@@ -146,7 +146,7 @@ class MultiSingleStreamBlockLoraProcessor(nn.Module):
|
|
146 |
start = i * scaled_cond_size + scaled_block_size
|
147 |
end = (i + 1) * scaled_cond_size + scaled_block_size
|
148 |
mask[start:end, start:end] = 0 # Diagonal blocks
|
149 |
-
mask = mask * -
|
150 |
mask = mask.to(query.dtype)
|
151 |
|
152 |
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
|
@@ -305,7 +305,7 @@ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
|
|
305 |
start = i * scaled_cond_size + scaled_block_size
|
306 |
end = (i + 1) * scaled_cond_size + scaled_block_size
|
307 |
mask[start:end, start:end] = 0 # Diagonal blocks
|
308 |
-
mask = mask * -
|
309 |
mask = mask.to(query.dtype)
|
310 |
|
311 |
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
|
|
|
146 |
start = i * scaled_cond_size + scaled_block_size
|
147 |
end = (i + 1) * scaled_cond_size + scaled_block_size
|
148 |
mask[start:end, start:end] = 0 # Diagonal blocks
|
149 |
+
mask = mask * -1e10
|
150 |
mask = mask.to(query.dtype)
|
151 |
|
152 |
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
|
|
|
305 |
start = i * scaled_cond_size + scaled_block_size
|
306 |
end = (i + 1) * scaled_cond_size + scaled_block_size
|
307 |
mask[start:end, start:end] = 0 # Diagonal blocks
|
308 |
+
mask = mask * -1e10
|
309 |
mask = mask.to(query.dtype)
|
310 |
|
311 |
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
|