Spaces:
Runtime error
Runtime error
Merge pull request #30 from LightricksResearch/fix-no-flash-attention
Browse filesmodel: fix flash attention enabling - do not check device type at this point
xora/models/transformers/attention.py
CHANGED
|
@@ -179,15 +179,14 @@ class BasicTransformerBlock(nn.Module):
|
|
| 179 |
self._chunk_size = None
|
| 180 |
self._chunk_dim = 0
|
| 181 |
|
| 182 |
-
def set_use_tpu_flash_attention(self
|
| 183 |
r"""
|
| 184 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 185 |
attention kernel.
|
| 186 |
"""
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
self.attn2.set_use_tpu_flash_attention(device)
|
| 191 |
|
| 192 |
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 193 |
# Sets chunk feed-forward
|
|
@@ -508,12 +507,11 @@ class Attention(nn.Module):
|
|
| 508 |
processor = AttnProcessor2_0()
|
| 509 |
self.set_processor(processor)
|
| 510 |
|
| 511 |
-
def set_use_tpu_flash_attention(self
|
| 512 |
r"""
|
| 513 |
Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
|
| 514 |
"""
|
| 515 |
-
|
| 516 |
-
self.use_tpu_flash_attention = True
|
| 517 |
|
| 518 |
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 519 |
r"""
|
|
|
|
| 179 |
self._chunk_size = None
|
| 180 |
self._chunk_dim = 0
|
| 181 |
|
| 182 |
+
def set_use_tpu_flash_attention(self):
|
| 183 |
r"""
|
| 184 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 185 |
attention kernel.
|
| 186 |
"""
|
| 187 |
+
self.use_tpu_flash_attention = True
|
| 188 |
+
self.attn1.set_use_tpu_flash_attention()
|
| 189 |
+
self.attn2.set_use_tpu_flash_attention()
|
|
|
|
| 190 |
|
| 191 |
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 192 |
# Sets chunk feed-forward
|
|
|
|
| 507 |
processor = AttnProcessor2_0()
|
| 508 |
self.set_processor(processor)
|
| 509 |
|
| 510 |
+
def set_use_tpu_flash_attention(self):
|
| 511 |
r"""
|
| 512 |
Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
|
| 513 |
"""
|
| 514 |
+
self.use_tpu_flash_attention = True
|
|
|
|
| 515 |
|
| 516 |
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 517 |
r"""
|
xora/models/transformers/transformer3d.py
CHANGED
|
@@ -160,13 +160,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
| 160 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 161 |
attention kernel.
|
| 162 |
"""
|
| 163 |
-
logger.info("
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
for block in self.transformer_blocks:
|
| 169 |
-
block.set_use_tpu_flash_attention(self.device.type)
|
| 170 |
|
| 171 |
def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
|
| 172 |
def _basic_init(module):
|
|
|
|
| 160 |
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 161 |
attention kernel.
|
| 162 |
"""
|
| 163 |
+
logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
|
| 164 |
+
self.use_tpu_flash_attention = True
|
| 165 |
+
# push config down to the attention modules
|
| 166 |
+
for block in self.transformer_blocks:
|
| 167 |
+
block.set_use_tpu_flash_attention()
|
|
|
|
|
|
|
| 168 |
|
| 169 |
def initialize(self, embedding_std: float, mode: Literal["xora", "legacy"]):
|
| 170 |
def _basic_init(module):
|