# adapter_v2.py ──────────────────────────────────────────────────────────── import torch, math import torch.nn as nn import torch.nn.functional as F # ─── Residual pocket block ──────────────────────────────────────────────── class PocketBlock(nn.Module): def __init__(self, dim, kernel=3, dropout=0.1): super().__init__() self.body = nn.Sequential( nn.LayerNorm(dim), nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1), nn.GELU(), nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1), nn.Dropout(dropout), ) def forward(self, x): y = self.body(x.transpose(1, 2)).transpose(1, 2) return x + y # ─── adapter ────────────────────────────────────────────────────────────── class TwoStreamShuntAdapter(nn.Module): """T5-seq ➔ bottleneck ⇄ CLIP-seq → anchor / delta / σ …""" def __init__(self, cfg: dict): super().__init__() self.cfg = cfg hid_t5 = cfg["t5"]["hidden_size"] hid_clip = cfg["clip"]["hidden_size"] bneck = cfg["bottleneck"] heads = cfg["heads"] proj_layers = cfg.get("proj_layers", 2) use_norm = cfg.get("layer_norm", True) p_drop = cfg.get("dropout", 0.1) pocket_depth = cfg.get("pocket_depth", 2) # helper ---------------------------------------------------------------- def proj(in_d, out_d): layers, d = [], in_d for i in range(proj_layers): if use_norm: layers.append(nn.LayerNorm(d)) layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2), nn.GELU()] if p_drop: layers.append(nn.Dropout(p_drop)) d = bneck return nn.Sequential(*layers) # projections ----------------------------------------------------------- self.t5_in = proj(hid_t5, bneck) self.clip_in = proj(hid_clip, bneck) # bidirectional cross-attention ---------------------------------------- self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop) self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop) self.tau = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0))) # pocket stack ---------------------------------------------------------- self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)]) # fuse bottleneck → bneck ---------------------------------------------- self.fuse = nn.Sequential( nn.LayerNorm(bneck * 2), nn.Linear(bneck * 2, bneck * 2), nn.GELU(), nn.Linear(bneck * 2, bneck) ) # head projections ------------------------------------------------------ self.anchor_out = proj(bneck, hid_clip) self.delta_out = proj(bneck, hid_clip) self.sigma_out = proj(bneck, hid_clip) # log σ self.gate_guid_proj = nn.Sequential( nn.LayerNorm(bneck), nn.Linear(bneck, bneck), nn.GELU(), nn.Linear(bneck, 2), # [:, :, 0] → gate, [:, :, 1] → g_pred ) self.max_guidance = cfg.get("max_guidance", 2.0) # --- forward -------------------------------------------------------------- def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor): assert t5_seq.size(-1) == self.cfg["t5"]["hidden_size"] assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"] t5_b = self.t5_in(t5_seq) clip_b = self.clip_in(clip_seq) t2c, attn_t2c = self.attn_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False) c2t, attn_c2t = self.attn_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False) p = self.pocket(t2c) z = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1) h = self.fuse(z) anchor = self.anchor_out(h) delta = self.delta_out(h) log_sigma = self.sigma_out(h) gate_and_g = self.gate_guid_proj(h) gate = torch.sigmoid(gate_and_g[..., 0:1]) g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True), 0, self.max_guidance) return (anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate)