Spaces:
Runtime error
Runtime error
# adapter_v2.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import torch, math | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# βββ Residual pocket block ββββββββββββββββββββββββββββββββββββββββββββββββ | |
class PocketBlock(nn.Module): | |
def __init__(self, dim, kernel=3, dropout=0.1): | |
super().__init__() | |
self.body = nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1), | |
nn.GELU(), | |
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
y = self.body(x.transpose(1, 2)).transpose(1, 2) | |
return x + y | |
# βββ adapter ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
class TwoStreamShuntAdapter(nn.Module): | |
"""T5-seq β bottleneck β CLIP-seq β anchor / delta / Ο β¦""" | |
def __init__(self, cfg: dict): | |
super().__init__() | |
self.cfg = cfg | |
hid_t5 = cfg["t5"]["hidden_size"] | |
hid_clip = cfg["clip"]["hidden_size"] | |
bneck = cfg["bottleneck"] | |
heads = cfg["heads"] | |
proj_layers = cfg.get("proj_layers", 2) | |
use_norm = cfg.get("layer_norm", True) | |
p_drop = cfg.get("dropout", 0.1) | |
pocket_depth = cfg.get("pocket_depth", 2) | |
# helper ---------------------------------------------------------------- | |
def proj(in_d, out_d): | |
layers, d = [], in_d | |
for i in range(proj_layers): | |
if use_norm: | |
layers.append(nn.LayerNorm(d)) | |
layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2), | |
nn.GELU()] | |
if p_drop: layers.append(nn.Dropout(p_drop)) | |
d = bneck | |
return nn.Sequential(*layers) | |
# projections ----------------------------------------------------------- | |
self.t5_in = proj(hid_t5, bneck) | |
self.clip_in = proj(hid_clip, bneck) | |
# bidirectional cross-attention ---------------------------------------- | |
self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop) | |
self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop) | |
self.tau = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0))) | |
# pocket stack ---------------------------------------------------------- | |
self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)]) | |
# fuse bottleneck β bneck ---------------------------------------------- | |
self.fuse = nn.Sequential( | |
nn.LayerNorm(bneck * 2), | |
nn.Linear(bneck * 2, bneck * 2), | |
nn.GELU(), | |
nn.Linear(bneck * 2, bneck) | |
) | |
# head projections ------------------------------------------------------ | |
self.anchor_out = proj(bneck, hid_clip) | |
self.delta_out = proj(bneck, hid_clip) | |
self.sigma_out = proj(bneck, hid_clip) # log Ο | |
self.gate_guid_proj = nn.Sequential( | |
nn.LayerNorm(bneck), | |
nn.Linear(bneck, bneck), | |
nn.GELU(), | |
nn.Linear(bneck, 2), # [:, :, 0] β gate, [:, :, 1] β g_pred | |
) | |
self.max_guidance = cfg.get("max_guidance", 2.0) | |
# --- forward -------------------------------------------------------------- | |
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor): | |
assert t5_seq.size(-1) == self.cfg["t5"]["hidden_size"] | |
assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"] | |
t5_b = self.t5_in(t5_seq) | |
clip_b = self.clip_in(clip_seq) | |
t2c, attn_t2c = self.attn_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False) | |
c2t, attn_c2t = self.attn_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False) | |
p = self.pocket(t2c) | |
z = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1) | |
h = self.fuse(z) | |
anchor = self.anchor_out(h) | |
delta = self.delta_out(h) | |
log_sigma = self.sigma_out(h) | |
gate_and_g = self.gate_guid_proj(h) | |
gate = torch.sigmoid(gate_and_g[..., 0:1]) | |
g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True), | |
0, self.max_guidance) | |
return (anchor, delta, log_sigma, | |
attn_t2c, attn_c2t, | |
self.tau, | |
g_pred, | |
gate) | |