Aatricks commited on
Commit
dfc5db4
·
verified ·
1 Parent(s): 68bad37

Upload folder using huggingface_hub

Browse files
modules/Attention/AttentionMethods.py CHANGED
@@ -4,9 +4,17 @@ except ImportError:
4
  pass
5
  import torch
6
 
 
 
 
 
 
 
 
 
7
 
8
  def attention_xformers(
9
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False
10
  ) -> torch.Tensor:
11
  """#### Make an attention call using xformers. Fastest attention implementation.
12
 
@@ -20,31 +28,84 @@ def attention_xformers(
20
  #### Returns:
21
  - `torch.Tensor`: The output tensor.
22
  """
23
- b, _, dim_head = q.shape
24
- dim_head //= heads
25
-
26
- q, k, v = map(
27
- lambda t: t.unsqueeze(3)
28
- .reshape(b, -1, heads, dim_head)
29
- .permute(0, 2, 1, 3)
30
- .reshape(b * heads, -1, dim_head)
31
- .contiguous(),
32
- (q, k, v),
33
- )
34
-
35
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
36
-
37
- out = (
38
- out.unsqueeze(0)
39
- .reshape(b, heads, -1, dim_head)
40
- .permute(0, 2, 1, 3)
41
- .reshape(b, -1, heads * dim_head)
42
- )
43
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def attention_pytorch(
47
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False
48
  ) -> torch.Tensor:
49
  """#### Make an attention call using PyTorch.
50
 
@@ -58,19 +119,35 @@ def attention_pytorch(
58
  #### Returns:
59
  - `torch.Tensor`: The output tensor.
60
  """
61
- b, _, dim_head = q.shape
62
- dim_head //= heads
63
- q, k, v = map(
64
- lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
65
- (q, k, v),
66
- )
67
-
68
- out = torch.nn.functional.scaled_dot_product_attention(
69
- q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
70
- )
71
- out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
72
- return out
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def xformers_attention(
76
  q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
 
4
  pass
5
  import torch
6
 
7
+ BROKEN_XFORMERS = False
8
+ try:
9
+ x_vers = xformers.__version__
10
+ # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
11
+ BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
12
+ except:
13
+ pass
14
+
15
 
16
  def attention_xformers(
17
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False, flux=False
18
  ) -> torch.Tensor:
19
  """#### Make an attention call using xformers. Fastest attention implementation.
20
 
 
28
  #### Returns:
29
  - `torch.Tensor`: The output tensor.
30
  """
31
+ if not flux:
32
+ b, _, dim_head = q.shape
33
+ dim_head //= heads
34
+
35
+ q, k, v = map(
36
+ lambda t: t.unsqueeze(3)
37
+ .reshape(b, -1, heads, dim_head)
38
+ .permute(0, 2, 1, 3)
39
+ .reshape(b * heads, -1, dim_head)
40
+ .contiguous(),
41
+ (q, k, v),
42
+ )
43
+
44
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
45
+
46
+ out = (
47
+ out.unsqueeze(0)
48
+ .reshape(b, heads, -1, dim_head)
49
+ .permute(0, 2, 1, 3)
50
+ .reshape(b, -1, heads * dim_head)
51
+ )
52
+ return out
53
+ else:
54
+ if skip_reshape:
55
+ b, _, _, dim_head = q.shape
56
+ else:
57
+ b, _, dim_head = q.shape
58
+ dim_head //= heads
59
+
60
+ disabled_xformers = False
61
+
62
+ if BROKEN_XFORMERS:
63
+ if b * heads > 65535:
64
+ disabled_xformers = True
65
+
66
+ if not disabled_xformers:
67
+ if torch.jit.is_tracing() or torch.jit.is_scripting():
68
+ disabled_xformers = True
69
+
70
+ if disabled_xformers:
71
+ return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
72
+
73
+ if skip_reshape:
74
+ q, k, v = map(
75
+ lambda t: t.reshape(b * heads, -1, dim_head),
76
+ (q, k, v),
77
+ )
78
+ else:
79
+ q, k, v = map(
80
+ lambda t: t.reshape(b, -1, heads, dim_head),
81
+ (q, k, v),
82
+ )
83
+
84
+ if mask is not None:
85
+ pad = 8 - q.shape[1] % 8
86
+ mask_out = torch.empty(
87
+ [q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device
88
+ )
89
+ mask_out[:, :, : mask.shape[-1]] = mask
90
+ mask = mask_out[:, :, : mask.shape[-1]]
91
+
92
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
93
+
94
+ if skip_reshape:
95
+ out = (
96
+ out.unsqueeze(0)
97
+ .reshape(b, heads, -1, dim_head)
98
+ .permute(0, 2, 1, 3)
99
+ .reshape(b, -1, heads * dim_head)
100
+ )
101
+ else:
102
+ out = out.reshape(b, -1, heads * dim_head)
103
+
104
+ return out
105
 
106
 
107
  def attention_pytorch(
108
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False, flux=False
109
  ) -> torch.Tensor:
110
  """#### Make an attention call using PyTorch.
111
 
 
119
  #### Returns:
120
  - `torch.Tensor`: The output tensor.
121
  """
122
+ if not flux:
123
+ b, _, dim_head = q.shape
124
+ dim_head //= heads
125
+ q, k, v = map(
126
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
127
+ (q, k, v),
128
+ )
129
+
130
+ out = torch.nn.functional.scaled_dot_product_attention(
131
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
132
+ )
133
+ out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
134
+ return out
135
+ else:
136
+ if skip_reshape:
137
+ b, _, _, dim_head = q.shape
138
+ else:
139
+ b, _, dim_head = q.shape
140
+ dim_head //= heads
141
+ q, k, v = map(
142
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
143
+ (q, k, v),
144
+ )
145
+
146
+ out = torch.nn.functional.scaled_dot_product_attention(
147
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
148
+ )
149
+ out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
150
+ return out
151
 
152
  def xformers_attention(
153
  q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
modules/BlackForest/Flux.py CHANGED
@@ -29,7 +29,7 @@ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tenso
29
  """
30
  q, k = apply_rope(q, k, pe)
31
  heads = q.shape[1]
32
- x = Attention.optimized_attention(q, k, v, heads, skip_reshape=True)
33
  return x
34
 
35
  # Define the rotary positional encoding (RoPE)
 
29
  """
30
  q, k = apply_rope(q, k, pe)
31
  heads = q.shape[1]
32
+ x = Attention.optimized_attention(q, k, v, heads, skip_reshape=True, flux=True)
33
  return x
34
 
35
  # Define the rotary positional encoding (RoPE)