File size: 5,010 Bytes
051fd3e
 
1e5ce4d
 
ca066a9
051fd3e
 
 
1e5ce4d
 
051fd3e
 
 
1e5ce4d
051fd3e
 
1e5ce4d
ca066a9
1e5ce4d
051fd3e
 
 
ca066a9
051fd3e
1e5ce4d
051fd3e
 
 
1e5ce4d
051fd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5ce4d
ca066a9
051fd3e
 
 
ca066a9
051fd3e
 
 
 
ca066a9
051fd3e
 
ca066a9
051fd3e
1e5ce4d
051fd3e
 
1e5ce4d
051fd3e
1e5ce4d
ca066a9
051fd3e
 
 
 
ca066a9
051fd3e
 
 
1e5ce4d
051fd3e
1e5ce4d
ca066a9
051fd3e
1e5ce4d
051fd3e
1e5ce4d
051fd3e
 
1e5ce4d
051fd3e
 
1e5ce4d
051fd3e
 
1e5ce4d
051fd3e
 
 
1e5ce4d
051fd3e
 
1e5ce4d
051fd3e
1e5ce4d
051fd3e
 
 
 
1e5ce4d
051fd3e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# adapter_v2.py ────────────────────────────────────────────────────────────
import torch, math
import torch.nn as nn
import torch.nn.functional as F


# ─── Residual pocket block ────────────────────────────────────────────────
class PocketBlock(nn.Module):
    def __init__(self, dim, kernel=3, dropout=0.1):
        super().__init__()
        self.body = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
            nn.GELU(),
            nn.Conv1d(dim, dim, kernel, padding=kernel // 2, groups=1),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        y = self.body(x.transpose(1, 2)).transpose(1, 2)
        return x + y


# ─── adapter ──────────────────────────────────────────────────────────────
class TwoStreamShuntAdapter(nn.Module):
    """T5-seq  βž” bottleneck  ⇄  CLIP-seq  β†’  anchor / delta / Οƒ …"""

    def __init__(self, cfg: dict):
        super().__init__()
        self.cfg       = cfg
        hid_t5         = cfg["t5"]["hidden_size"]
        hid_clip       = cfg["clip"]["hidden_size"]
        bneck          = cfg["bottleneck"]
        heads          = cfg["heads"]
        proj_layers    = cfg.get("proj_layers", 2)
        use_norm       = cfg.get("layer_norm", True)
        p_drop         = cfg.get("dropout", 0.1)
        pocket_depth   = cfg.get("pocket_depth", 2)

        # helper ----------------------------------------------------------------
        def proj(in_d, out_d):
            layers, d = [], in_d
            for i in range(proj_layers):
                if use_norm:
                    layers.append(nn.LayerNorm(d))
                layers += [nn.Linear(d, bneck if i == proj_layers - 1 else bneck * 2),
                           nn.GELU()]
                if p_drop: layers.append(nn.Dropout(p_drop))
                d = bneck
            return nn.Sequential(*layers)

        # projections -----------------------------------------------------------
        self.t5_in   = proj(hid_t5,   bneck)
        self.clip_in = proj(hid_clip, bneck)

        # bidirectional cross-attention ----------------------------------------
        self.attn_t2c = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
        self.attn_c2t = nn.MultiheadAttention(bneck, heads, batch_first=True, dropout=p_drop)
        self.tau      = nn.Parameter(torch.full((heads, 1, 1), cfg.get("tau_init", 1.0)))

        # pocket stack ----------------------------------------------------------
        self.pocket = nn.Sequential(*[PocketBlock(bneck, dropout=p_drop) for _ in range(pocket_depth)])

        # fuse bottleneck β†’ bneck ----------------------------------------------
        self.fuse = nn.Sequential(
            nn.LayerNorm(bneck * 2),
            nn.Linear(bneck * 2, bneck * 2),
            nn.GELU(),
            nn.Linear(bneck * 2, bneck)
        )

        # head projections ------------------------------------------------------
        self.anchor_out = proj(bneck, hid_clip)
        self.delta_out  = proj(bneck, hid_clip)
        self.sigma_out  = proj(bneck, hid_clip)    # log Οƒ

        self.gate_guid_proj = nn.Sequential(
            nn.LayerNorm(bneck),
            nn.Linear(bneck, bneck),
            nn.GELU(),
            nn.Linear(bneck, 2),                   # [:, :, 0] β†’ gate, [:, :, 1] β†’ g_pred
        )

        self.max_guidance = cfg.get("max_guidance", 2.0)

    # --- forward --------------------------------------------------------------
    def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
        assert t5_seq.size(-1)   == self.cfg["t5"]["hidden_size"]
        assert clip_seq.size(-1) == self.cfg["clip"]["hidden_size"]

        t5_b   = self.t5_in(t5_seq)
        clip_b = self.clip_in(clip_seq)

        t2c, attn_t2c = self.attn_t2c(t5_b,  clip_b, clip_b, need_weights=True, average_attn_weights=False)
        c2t, attn_c2t = self.attn_c2t(clip_b, t5_b,  t5_b,  need_weights=True, average_attn_weights=False)

        p   = self.pocket(t2c)
        z   = torch.cat([p.mean(1, keepdim=True).expand_as(c2t), c2t], dim=-1)
        h   = self.fuse(z)

        anchor    = self.anchor_out(h)
        delta     = self.delta_out(h)

        log_sigma = self.sigma_out(h)

        gate_and_g = self.gate_guid_proj(h)
        gate   = torch.sigmoid(gate_and_g[..., 0:1])
        g_pred = torch.clamp(gate_and_g[..., 1:2].mean(1, keepdim=True),
                             0, self.max_guidance)

        return (anchor, delta, log_sigma,
                attn_t2c, attn_c2t,
                self.tau,
                g_pred,
                gate)