Spaces:
Runtime error
Runtime error
File size: 5,010 Bytes
051fd3e 1e5ce4d ca066a9 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d ca066a9 1e5ce4d 051fd3e ca066a9 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d ca066a9 051fd3e ca066a9 051fd3e ca066a9 051fd3e ca066a9 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d ca066a9 051fd3e ca066a9 051fd3e 1e5ce4d 051fd3e 1e5ce4d ca066a9 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e 1e5ce4d 051fd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# adapter_v2.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import torch, math
import torch.nn as nn
import torch.nn.functional as F
# βββ Residual pocket block ββββββββββββββββββββββββββββββββββββββββββββββββ
class PocketBlock(nn.Module):
def __init__(self, dim, kernel=3, dropout=0.1):
super().__init__()
self.body = nn.Sequential(
nn.LayerNorm(dim),
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
nn.GELU(),
nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
nn.Dropout(dropout),
)
def forward(self, x):
y = self.body(x.transpose(1, 2)).transpose(1, 2)
return x + y
# βββ adapter ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TwoStreamShuntAdapter(nn.Module):
"""T5-seq β bottleneck β CLIP-seq β anchor / delta / Ο β¦"""
def __init__(self, cfg: dict):
super().__init__()
self.cfg = cfg
hid_t5 = cfg["t5"]["hidden_size"]
hid_clip = cfg["clip"]["hidden_size"]
bneck = cfg["bottleneck"]
heads = cfg["heads"]
proj_layers = cfg.get("proj_layers", 2)
use_norm = cfg.get("layer_norm", True)
p_drop = cfg.get("dropout", 0.1)
pocket_depth = cfg.get("pocket_depth", 2)
# helper ----------------------------------------------------------------
def proj(in_d, out_d):
layers, d = [], in_d
for i in range(proj_layers):
if use_norm:
layers.append(nn.LayerNorm(d))
layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2),
nn.GELU()]
if p_drop: layers.append(nn.Dropout(p_drop))
d = bneck
return nn.Sequential(*layers)
# projections -----------------------------------------------------------
self.t5_in = proj(hid_t5, bneck)
self.clip_in = proj(hid_clip, bneck)
# bidirectional cross-attention ----------------------------------------
self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
self.tau = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0)))
# pocket stack ----------------------------------------------------------
self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)])
# fuse bottleneck β bneck ----------------------------------------------
self.fuse = nn.Sequential(
nn.LayerNorm(bneck * 2),
nn.Linear(bneck * 2, bneck * 2),
nn.GELU(),
nn.Linear(bneck * 2, bneck)
)
# head projections ------------------------------------------------------
self.anchor_out = proj(bneck, hid_clip)
self.delta_out = proj(bneck, hid_clip)
self.sigma_out = proj(bneck, hid_clip) # log Ο
self.gate_guid_proj = nn.Sequential(
nn.LayerNorm(bneck),
nn.Linear(bneck, bneck),
nn.GELU(),
nn.Linear(bneck, 2), # [:, :, 0] β gate, [:, :, 1] β g_pred
)
self.max_guidance = cfg.get("max_guidance", 2.0)
# --- forward --------------------------------------------------------------
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
assert t5_seq.size(-1) == self.cfg["t5"]["hidden_size"]
assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"]
t5_b = self.t5_in(t5_seq)
clip_b = self.clip_in(clip_seq)
t2c, attn_t2c = self.attn_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
c2t, attn_c2t = self.attn_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
p = self.pocket(t2c)
z = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1)
h = self.fuse(z)
anchor = self.anchor_out(h)
delta = self.delta_out(h)
log_sigma = self.sigma_out(h)
gate_and_g = self.gate_guid_proj(h)
gate = torch.sigmoid(gate_and_g[..., 0:1])
g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True),
0, self.max_guidance)
return (anchor, delta, log_sigma,
attn_t2c, attn_c2t,
self.tau,
g_pred,
gate)
|