shunt-adapter-testing / two_stream_shunt_adapter.py
AbstractPhil's picture
Update two_stream_shunt_adapter.py
051fd3e verified
raw
history blame
5.01 kB
# adapter_v2.py ────────────────────────────────────────────────────────────
import torch, math
import torch.nn as nn
import torch.nn.functional as F
# ─── Residual pocket block ────────────────────────────────────────────────
class PocketBlock(nn.Module):
def __init__(self, dim, kernel=3, dropout=0.1):
super().__init__()
self.body = nn.Sequential(
nn.LayerNorm(dim),
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
nn.GELU(),
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
nn.Dropout(dropout),
)
def forward(self, x):
y = self.body(x.transpose(1, 2)).transpose(1, 2)
return x + y
# ─── adapter ──────────────────────────────────────────────────────────────
class TwoStreamShuntAdapter(nn.Module):
"""T5-seq βž” bottleneck ⇄ CLIP-seq β†’ anchor / delta / Οƒ …"""
def __init__(self, cfg: dict):
super().__init__()
self.cfg = cfg
hid_t5 = cfg["t5"]["hidden_size"]
hid_clip = cfg["clip"]["hidden_size"]
bneck = cfg["bottleneck"]
heads = cfg["heads"]
proj_layers = cfg.get("proj_layers", 2)
use_norm = cfg.get("layer_norm", True)
p_drop = cfg.get("dropout", 0.1)
pocket_depth = cfg.get("pocket_depth", 2)
# helper ----------------------------------------------------------------
def proj(in_d, out_d):
layers, d = [], in_d
for i in range(proj_layers):
if use_norm:
layers.append(nn.LayerNorm(d))
layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2),
nn.GELU()]
if p_drop: layers.append(nn.Dropout(p_drop))
d = bneck
return nn.Sequential(*layers)
# projections -----------------------------------------------------------
self.t5_in = proj(hid_t5, bneck)
self.clip_in = proj(hid_clip, bneck)
# bidirectional cross-attention ----------------------------------------
self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
self.tau = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0)))
# pocket stack ----------------------------------------------------------
self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)])
# fuse bottleneck β†’ bneck ----------------------------------------------
self.fuse = nn.Sequential(
nn.LayerNorm(bneck * 2),
nn.Linear(bneck * 2, bneck * 2),
nn.GELU(),
nn.Linear(bneck * 2, bneck)
)
# head projections ------------------------------------------------------
self.anchor_out = proj(bneck, hid_clip)
self.delta_out = proj(bneck, hid_clip)
self.sigma_out = proj(bneck, hid_clip) # log Οƒ
self.gate_guid_proj = nn.Sequential(
nn.LayerNorm(bneck),
nn.Linear(bneck, bneck),
nn.GELU(),
nn.Linear(bneck, 2), # [:, :, 0] β†’ gate, [:, :, 1] β†’ g_pred
)
self.max_guidance = cfg.get("max_guidance", 2.0)
# --- forward --------------------------------------------------------------
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
assert t5_seq.size(-1) == self.cfg["t5"]["hidden_size"]
assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"]
t5_b = self.t5_in(t5_seq)
clip_b = self.clip_in(clip_seq)
t2c, attn_t2c = self.attn_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
c2t, attn_c2t = self.attn_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
p = self.pocket(t2c)
z = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1)
h = self.fuse(z)
anchor = self.anchor_out(h)
delta = self.delta_out(h)
log_sigma = self.sigma_out(h)
gate_and_g = self.gate_guid_proj(h)
gate = torch.sigmoid(gate_and_g[..., 0:1])
g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True),
0, self.max_guidance)
return (anchor, delta, log_sigma,
attn_t2c, attn_c2t,
self.tau,
g_pred,
gate)