Spaces:
Runtime error
Runtime error
Update two_stream_shunt_adapter.py
Browse files- two_stream_shunt_adapter.py +81 -89
two_stream_shunt_adapter.py
CHANGED
@@ -1,123 +1,115 @@
|
|
1 |
-
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
-
|
6 |
-
|
|
|
7 |
def __init__(self, dim, kernel=3, dropout=0.1):
|
8 |
super().__init__()
|
9 |
-
self.
|
10 |
-
|
11 |
-
|
12 |
-
nn.Linear(dim, dim * 2),
|
13 |
nn.GELU(),
|
14 |
-
nn.
|
15 |
-
nn.Dropout(dropout)
|
16 |
)
|
17 |
|
18 |
def forward(self, x):
|
19 |
-
|
20 |
-
x
|
21 |
-
|
22 |
-
x = self.conv(x).transpose(1, 2)
|
23 |
-
return residual + self.proj(x)
|
24 |
|
25 |
-
# βββ
|
26 |
class TwoStreamShuntAdapter(nn.Module):
|
27 |
-
|
|
|
|
|
28 |
super().__init__()
|
29 |
-
self.
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
layers.append(nn.Linear(last_dim, next_dim))
|
50 |
-
layers.append(nn.GELU())
|
51 |
-
if use_do:
|
52 |
-
layers.append(nn.Dropout(do_p))
|
53 |
-
last_dim = next_dim
|
54 |
-
layers.append(nn.Linear(last_dim, output_dim))
|
55 |
return nn.Sequential(*layers)
|
56 |
|
57 |
-
#
|
58 |
-
self.
|
59 |
-
self.
|
60 |
|
61 |
-
#
|
62 |
-
self.
|
63 |
-
self.
|
64 |
-
self.tau
|
65 |
|
66 |
-
#
|
67 |
-
self.
|
68 |
-
BottleneckResBlock(self.bneck, dropout=do_p),
|
69 |
-
BottleneckResBlock(self.bneck, dropout=do_p)
|
70 |
-
)
|
71 |
|
72 |
-
#
|
73 |
self.fuse = nn.Sequential(
|
74 |
-
nn.LayerNorm(
|
75 |
-
nn.Linear(
|
76 |
nn.GELU(),
|
77 |
-
nn.Linear(
|
78 |
)
|
79 |
|
80 |
-
#
|
81 |
-
self.
|
82 |
-
self.
|
83 |
-
self.
|
84 |
|
85 |
-
self.
|
86 |
-
nn.LayerNorm(
|
87 |
-
nn.Linear(
|
88 |
nn.GELU(),
|
89 |
-
nn.Linear(
|
90 |
-
nn.Tanh(),
|
91 |
-
nn.Sigmoid()
|
92 |
)
|
93 |
|
94 |
-
self.
|
95 |
-
nn.LayerNorm(self.bneck),
|
96 |
-
nn.Linear(self.bneck, 1),
|
97 |
-
nn.Sigmoid()
|
98 |
-
)
|
99 |
|
|
|
100 |
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
|
101 |
-
|
102 |
-
|
103 |
-
assert clip_seq.size(-1) == self.clip_dim
|
104 |
|
105 |
-
t5_b = self.
|
106 |
-
clip_b = self.
|
107 |
|
108 |
-
t2c, attn_t2c = self.
|
109 |
-
c2t, attn_c2t = self.
|
110 |
|
111 |
-
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
|
116 |
-
|
117 |
-
delta = self.delta_proj(h) * self.gate_proj(h)
|
118 |
-
log_sigma = self.logsig_proj(h)
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
122 |
|
123 |
-
return anchor, delta, log_sigma,
|
|
|
|
|
|
|
|
|
|
1 |
+
# adapter_v2.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
2 |
+
import torch, math
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
|
6 |
+
|
7 |
+
# βββ Residual pocket block ββββββββββββββββββββββββββββββββββββββββββββββββ
|
8 |
+
class PocketBlock(nn.Module):
|
9 |
def __init__(self, dim, kernel=3, dropout=0.1):
|
10 |
super().__init__()
|
11 |
+
self.body = nn.Sequential(
|
12 |
+
nn.LayerNorm(dim),
|
13 |
+
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
|
|
|
14 |
nn.GELU(),
|
15 |
+
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
|
16 |
+
nn.Dropout(dropout),
|
17 |
)
|
18 |
|
19 |
def forward(self, x):
|
20 |
+
y = self.body(x.transpose(1, 2)).transpose(1, 2)
|
21 |
+
return x + y
|
22 |
+
|
|
|
|
|
23 |
|
24 |
+
# βββ adapter ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
25 |
class TwoStreamShuntAdapter(nn.Module):
|
26 |
+
"""T5-seq β bottleneck β CLIP-seq β anchor / delta / Ο β¦"""
|
27 |
+
|
28 |
+
def __init__(self, cfg: dict):
|
29 |
super().__init__()
|
30 |
+
self.cfg = cfg
|
31 |
+
hid_t5 = cfg["t5"]["hidden_size"]
|
32 |
+
hid_clip = cfg["clip"]["hidden_size"]
|
33 |
+
bneck = cfg["bottleneck"]
|
34 |
+
heads = cfg["heads"]
|
35 |
+
proj_layers = cfg.get("proj_layers", 2)
|
36 |
+
use_norm = cfg.get("layer_norm", True)
|
37 |
+
p_drop = cfg.get("dropout", 0.1)
|
38 |
+
pocket_depth = cfg.get("pocket_depth", 2)
|
39 |
+
|
40 |
+
# helper ----------------------------------------------------------------
|
41 |
+
def proj(in_d, out_d):
|
42 |
+
layers, d = [], in_d
|
43 |
+
for i in range(proj_layers):
|
44 |
+
if use_norm:
|
45 |
+
layers.append(nn.LayerNorm(d))
|
46 |
+
layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2),
|
47 |
+
nn.GELU()]
|
48 |
+
if p_drop: layers.append(nn.Dropout(p_drop))
|
49 |
+
d = bneck
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
return nn.Sequential(*layers)
|
51 |
|
52 |
+
# projections -----------------------------------------------------------
|
53 |
+
self.t5_in = proj(hid_t5, bneck)
|
54 |
+
self.clip_in = proj(hid_clip, bneck)
|
55 |
|
56 |
+
# bidirectional cross-attention ----------------------------------------
|
57 |
+
self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
|
58 |
+
self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
|
59 |
+
self.tau = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0)))
|
60 |
|
61 |
+
# pocket stack ----------------------------------------------------------
|
62 |
+
self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)])
|
|
|
|
|
|
|
63 |
|
64 |
+
# fuse bottleneck β bneck ----------------------------------------------
|
65 |
self.fuse = nn.Sequential(
|
66 |
+
nn.LayerNorm(bneck * 2),
|
67 |
+
nn.Linear(bneck * 2, bneck * 2),
|
68 |
nn.GELU(),
|
69 |
+
nn.Linear(bneck * 2, bneck)
|
70 |
)
|
71 |
|
72 |
+
# head projections ------------------------------------------------------
|
73 |
+
self.anchor_out = proj(bneck, hid_clip)
|
74 |
+
self.delta_out = proj(bneck, hid_clip)
|
75 |
+
self.sigma_out = proj(bneck, hid_clip) # log Ο
|
76 |
|
77 |
+
self.gate_guid_proj = nn.Sequential(
|
78 |
+
nn.LayerNorm(bneck),
|
79 |
+
nn.Linear(bneck, bneck),
|
80 |
nn.GELU(),
|
81 |
+
nn.Linear(bneck, 2), # [:, :, 0] β gate, [:, :, 1] β g_pred
|
|
|
|
|
82 |
)
|
83 |
|
84 |
+
self.max_guidance = cfg.get("max_guidance", 2.0)
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
# --- forward --------------------------------------------------------------
|
87 |
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
|
88 |
+
assert t5_seq.size(-1) == self.cfg["t5"]["hidden_size"]
|
89 |
+
assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"]
|
|
|
90 |
|
91 |
+
t5_b = self.t5_in(t5_seq)
|
92 |
+
clip_b = self.clip_in(clip_seq)
|
93 |
|
94 |
+
t2c, attn_t2c = self.attn_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
|
95 |
+
c2t, attn_c2t = self.attn_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
|
96 |
|
97 |
+
p = self.pocket(t2c)
|
98 |
+
z = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1)
|
99 |
+
h = self.fuse(z)
|
100 |
|
101 |
+
anchor = self.anchor_out(h)
|
102 |
+
delta = self.delta_out(h)
|
103 |
|
104 |
+
log_sigma = self.sigma_out(h)
|
|
|
|
|
105 |
|
106 |
+
gate_and_g = self.gate_guid_proj(h)
|
107 |
+
gate = torch.sigmoid(gate_and_g[..., 0:1])
|
108 |
+
g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True),
|
109 |
+
0, self.max_guidance)
|
110 |
|
111 |
+
return (anchor, delta, log_sigma,
|
112 |
+
attn_t2c, attn_c2t,
|
113 |
+
self.tau,
|
114 |
+
g_pred,
|
115 |
+
gate)
|