AbstractPhil commited on
Commit
051fd3e
Β·
verified Β·
1 Parent(s): e543e33

Update two_stream_shunt_adapter.py

Browse files
Files changed (1) hide show
  1. two_stream_shunt_adapter.py +81 -89
two_stream_shunt_adapter.py CHANGED
@@ -1,123 +1,115 @@
1
- import torch
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- # ─── Residual Pocket Block ───────────────────────────────────
6
- class BottleneckResBlock(nn.Module):
 
7
  def __init__(self, dim, kernel=3, dropout=0.1):
8
  super().__init__()
9
- self.norm = nn.LayerNorm(dim)
10
- self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
11
- self.proj = nn.Sequential(
12
- nn.Linear(dim, dim * 2),
13
  nn.GELU(),
14
- nn.Linear(dim * 2, dim),
15
- nn.Dropout(dropout)
16
  )
17
 
18
  def forward(self, x):
19
- residual = x
20
- x = self.norm(x)
21
- x = x.transpose(1, 2)
22
- x = self.conv(x).transpose(1, 2)
23
- return residual + self.proj(x)
24
 
25
- # ─── Two Stream Shunt Adapter ──────────────────────────────────────
26
  class TwoStreamShuntAdapter(nn.Module):
27
- def __init__(self, config: dict):
 
 
28
  super().__init__()
29
- self.config = config
30
- self.t5_dim = config["t5"]["hidden_size"]
31
- self.clip_dim = config["clip"]["hidden_size"]
32
- self.bneck = config["bottleneck"]
33
- self.heads = config["heads"]
34
- self.tau_init = config["tau_init"]
35
- self.max_guidance = config["max_guidance"]
36
-
37
- use_norm = config.get("layer_norm", True)
38
- use_do = config.get("use_dropout", True)
39
- do_p = config.get("dropout", 0.1)
40
- proj_depth = config.get("proj_layers", 2)
41
-
42
- def build_projection(input_dim, output_dim):
43
- layers = []
44
- last_dim = input_dim
45
- if use_norm:
46
- layers.append(nn.LayerNorm(last_dim))
47
- for i in range(proj_depth):
48
- next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
49
- layers.append(nn.Linear(last_dim, next_dim))
50
- layers.append(nn.GELU())
51
- if use_do:
52
- layers.append(nn.Dropout(do_p))
53
- last_dim = next_dim
54
- layers.append(nn.Linear(last_dim, output_dim))
55
  return nn.Sequential(*layers)
56
 
57
- # Projections
58
- self.proj_t5 = build_projection(self.t5_dim, self.bneck)
59
- self.proj_clip = build_projection(self.clip_dim, self.bneck)
60
 
61
- # Attention
62
- self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
63
- self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
64
- self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
65
 
66
- # Residual Pocket
67
- self.pocket_blocks = nn.Sequential(
68
- BottleneckResBlock(self.bneck, dropout=do_p),
69
- BottleneckResBlock(self.bneck, dropout=do_p)
70
- )
71
 
72
- # Fuse
73
  self.fuse = nn.Sequential(
74
- nn.LayerNorm(2 * self.bneck),
75
- nn.Linear(2 * self.bneck, self.bneck * 2),
76
  nn.GELU(),
77
- nn.Linear(self.bneck * 2, self.bneck)
78
  )
79
 
80
- # Output Projections
81
- self.anchor_proj = build_projection(self.bneck, self.clip_dim)
82
- self.delta_proj = build_projection(self.bneck, self.clip_dim)
83
- self.logsig_proj = build_projection(self.bneck, self.clip_dim)
84
 
85
- self.gate_proj = nn.Sequential(
86
- nn.LayerNorm(self.bneck),
87
- nn.Linear(self.bneck, self.bneck),
88
  nn.GELU(),
89
- nn.Linear(self.bneck, 1),
90
- nn.Tanh(),
91
- nn.Sigmoid()
92
  )
93
 
94
- self.guidance_proj = nn.Sequential(
95
- nn.LayerNorm(self.bneck),
96
- nn.Linear(self.bneck, 1),
97
- nn.Sigmoid()
98
- )
99
 
 
100
  def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
101
- if self.config.get("assert_input_dims", True):
102
- assert t5_seq.size(-1) == self.t5_dim
103
- assert clip_seq.size(-1) == self.clip_dim
104
 
105
- t5_b = self.proj_t5(t5_seq)
106
- clip_b = self.proj_clip(clip_seq)
107
 
108
- t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
109
- c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
110
 
111
- pocket = self.pocket_blocks(t2c)
 
 
112
 
113
- pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
114
- h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
115
 
116
- anchor = self.anchor_proj(h)
117
- delta = self.delta_proj(h) * self.gate_proj(h)
118
- log_sigma = self.logsig_proj(h)
119
 
120
- g_tok = self.guidance_proj(h).squeeze(-1)
121
- g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
 
 
122
 
123
- return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)
 
 
 
 
 
1
+ # adapter_v2.py ────────────────────────────────────────────────────────────
2
+ import torch, math
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
 
6
+
7
+ # ─── Residual pocket block ────────────────────────────────────────────────
8
+ class PocketBlock(nn.Module):
9
  def __init__(self, dim, kernel=3, dropout=0.1):
10
  super().__init__()
11
+ self.body = nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
 
14
  nn.GELU(),
15
+ nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
16
+ nn.Dropout(dropout),
17
  )
18
 
19
  def forward(self, x):
20
+ y = self.body(x.transpose(1, 2)).transpose(1, 2)
21
+ return x + y
22
+
 
 
23
 
24
+ # ─── adapter ──────────────────────────────────────────────────────────────
25
  class TwoStreamShuntAdapter(nn.Module):
26
+ """T5-seq βž” bottleneck ⇄ CLIP-seq β†’ anchor / delta / Οƒ …"""
27
+
28
+ def __init__(self, cfg: dict):
29
  super().__init__()
30
+ self.cfg = cfg
31
+ hid_t5 = cfg["t5"]["hidden_size"]
32
+ hid_clip = cfg["clip"]["hidden_size"]
33
+ bneck = cfg["bottleneck"]
34
+ heads = cfg["heads"]
35
+ proj_layers = cfg.get("proj_layers", 2)
36
+ use_norm = cfg.get("layer_norm", True)
37
+ p_drop = cfg.get("dropout", 0.1)
38
+ pocket_depth = cfg.get("pocket_depth", 2)
39
+
40
+ # helper ----------------------------------------------------------------
41
+ def proj(in_d, out_d):
42
+ layers, d = [], in_d
43
+ for i in range(proj_layers):
44
+ if use_norm:
45
+ layers.append(nn.LayerNorm(d))
46
+ layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2),
47
+ nn.GELU()]
48
+ if p_drop: layers.append(nn.Dropout(p_drop))
49
+ d = bneck
 
 
 
 
 
 
50
  return nn.Sequential(*layers)
51
 
52
+ # projections -----------------------------------------------------------
53
+ self.t5_in = proj(hid_t5, bneck)
54
+ self.clip_in = proj(hid_clip, bneck)
55
 
56
+ # bidirectional cross-attention ----------------------------------------
57
+ self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
58
+ self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
59
+ self.tau = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0)))
60
 
61
+ # pocket stack ----------------------------------------------------------
62
+ self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)])
 
 
 
63
 
64
+ # fuse bottleneck β†’ bneck ----------------------------------------------
65
  self.fuse = nn.Sequential(
66
+ nn.LayerNorm(bneck * 2),
67
+ nn.Linear(bneck * 2, bneck * 2),
68
  nn.GELU(),
69
+ nn.Linear(bneck * 2, bneck)
70
  )
71
 
72
+ # head projections ------------------------------------------------------
73
+ self.anchor_out = proj(bneck, hid_clip)
74
+ self.delta_out = proj(bneck, hid_clip)
75
+ self.sigma_out = proj(bneck, hid_clip) # log Οƒ
76
 
77
+ self.gate_guid_proj = nn.Sequential(
78
+ nn.LayerNorm(bneck),
79
+ nn.Linear(bneck, bneck),
80
  nn.GELU(),
81
+ nn.Linear(bneck, 2), # [:, :, 0] β†’ gate, [:, :, 1] β†’ g_pred
 
 
82
  )
83
 
84
+ self.max_guidance = cfg.get("max_guidance", 2.0)
 
 
 
 
85
 
86
+ # --- forward --------------------------------------------------------------
87
  def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
88
+ assert t5_seq.size(-1) == self.cfg["t5"]["hidden_size"]
89
+ assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"]
 
90
 
91
+ t5_b = self.t5_in(t5_seq)
92
+ clip_b = self.clip_in(clip_seq)
93
 
94
+ t2c, attn_t2c = self.attn_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
95
+ c2t, attn_c2t = self.attn_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
96
 
97
+ p = self.pocket(t2c)
98
+ z = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1)
99
+ h = self.fuse(z)
100
 
101
+ anchor = self.anchor_out(h)
102
+ delta = self.delta_out(h)
103
 
104
+ log_sigma = self.sigma_out(h)
 
 
105
 
106
+ gate_and_g = self.gate_guid_proj(h)
107
+ gate = torch.sigmoid(gate_and_g[..., 0:1])
108
+ g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True),
109
+ 0, self.max_guidance)
110
 
111
+ return (anchor, delta, log_sigma,
112
+ attn_t2c, attn_c2t,
113
+ self.tau,
114
+ g_pred,
115
+ gate)