AbstractPhil commited on
Commit
ca066a9
Β·
1 Parent(s): 11aea4e

initial push for v1

Browse files
Files changed (3) hide show
  1. app.py +121 -142
  2. configs.py +149 -0
  3. two_stream_shunt_adapter.py +123 -0
app.py CHANGED
@@ -1,153 +1,132 @@
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
- import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
  ).images[0]
50
 
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
 
1
+ import torch
2
  import gradio as gr
3
  import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from transformers import T5Tokenizer, T5EncoderModel
 
6
  from diffusers import DiffusionPipeline
7
+ from safetensors.torch import load_file
8
+ from huggingface_hub import hf_hub_download
9
+ from shunt_adapter import TwoStreamShuntAdapter
10
+ from adapter_config import T5_SHUNT_REPOS
11
+
12
+ # ─── Device & Model Setup ─────────────────────────────────────
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
+
16
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
17
+ t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
18
+
19
+ pipe = DiffusionPipeline.from_pretrained(
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ torch_dtype=dtype,
22
+ variant="fp16" if dtype == torch.float16 else None
23
+ ).to(device)
24
+
25
+ # ─── Adapter Configs ──────────────────────────────────────────
26
+ clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
27
+ clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
28
+ repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
29
+ repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
30
+ config_l = T5_SHUNT_REPOS["clip_l"]["config"]
31
+ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
32
+
33
+ # ─── Loader ───────────────────────────────────────────────────
34
+ def load_adapter(repo, filename, config):
35
+ path = hf_hub_download(repo_id=repo, filename=filename)
36
+ model = TwoStreamShuntAdapter(config).to(device).eval()
37
+ model.load_state_dict(load_file(path, device=device))
38
+ return model
39
+
40
+ # ─── Visualization ────────────────────────────────────────────
41
+ def plot_heat(mat, title):
42
+ import io
43
+ fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
44
+ im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
45
+ ax.set_title(title)
46
+ plt.colorbar(im, ax=ax)
47
+ buf = io.BytesIO()
48
+ plt.savefig(buf, format="png", bbox_inches='tight')
49
+ buf.seek(0)
50
+ return buf
51
+
52
+ # ─── Inference ────────────────────────────────────────────────
53
+ @torch.no_grad()
54
+ def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
55
+ t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
56
+ t5_seq = t5_mod(t5_ids).last_hidden_state
57
+
58
+ adapter_l = load_adapter(repo_l, adapter_l_file, config_l)
59
+ adapter_g = load_adapter(repo_g, adapter_g_file, config_g)
60
+
61
+ clip_l_in = torch.randn(t5_seq.shape[0], 77, 768).to(device)
62
+ clip_g_in = torch.randn(t5_seq.shape[0], 77, 1280).to(device)
63
+
64
+ anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in)
65
+ gate_l_scaled = gate_l * gate_prob
66
+ delta_l_final = delta_l * strength * gate_l_scaled
67
+ clip_l_mod = clip_l_in + delta_l_final
68
+ if use_anchor:
69
+ clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
70
+ if noise > 0:
71
+ clip_l_mod += torch.randn_like(clip_l_mod) * noise
72
+
73
+ anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in)
74
+ gate_g_scaled = gate_g * gate_prob
75
+ delta_g_final = delta_g * strength * gate_g_scaled
76
+ clip_g_mod = clip_g_in + delta_g_final
77
+ if use_anchor:
78
+ clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
79
+ if noise > 0:
80
+ clip_g_mod += torch.randn_like(clip_g_mod) * noise
81
+
82
+ prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
83
+ neg_embeds = torch.zeros_like(prompt_embeds)
84
 
85
  image = pipe(
86
+ prompt_embeds=prompt_embeds,
87
+ negative_prompt_embeds=neg_embeds,
88
+ num_inference_steps=20,
89
+ guidance_scale=5.0
 
 
 
90
  ).images[0]
91
 
92
+ return (
93
+ image,
94
+ plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ” CLIP-L"),
95
+ plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
96
+ plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ” CLIP-G"),
97
+ plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
98
+ f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο„_l: {tau_l.mean().item():.3f}",
99
+ f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο„_g: {tau_g.mean().item():.3f}"
100
+ )
101
+
102
+ # ─── Gradio App ───────────────────────────────────────────────
103
+ with gr.Blocks(title="Dual Adapter T5β†’CLIP") as demo:
104
+ gr.Markdown("# 🧠 Dual Shunt Adapter β€’ SDXL Inference")
105
+
106
+ with gr.Row():
107
+ with gr.Column():
108
+ prompt = gr.Textbox(label="Prompt", value="a futuristic control station")
109
+ adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter")
110
+ adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter")
111
+ strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
112
+ noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
113
+ gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
114
+ use_anchor = gr.Checkbox(label="Use Anchor", value=True)
115
+ run_btn = gr.Button("Run")
116
+
117
+ with gr.Column():
118
+ out_img = gr.Image(label="Generated Image")
119
+ delta_l = gr.Image(label="Ξ” CLIP-L")
120
+ gate_l = gr.Image(label="Gate CLIP-L")
121
+ delta_g = gr.Image(label="Ξ” CLIP-G")
122
+ gate_g = gr.Image(label="Gate CLIP-G")
123
+ stats_l = gr.Textbox(label="CLIP-L Stats")
124
+ stats_g = gr.Textbox(label="CLIP-G Stats")
125
+
126
+ run_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  fn=infer,
128
+ inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
129
+ outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
 
 
 
 
 
 
 
 
 
130
  )
131
 
132
  if __name__ == "__main__":
configs.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ T5_SHUNT_REPOS = {
2
+ "clip_g": {
3
+ "models": ["vit-bigG-14", 'flan-t5-base'],
4
+ "config": {
5
+ "adapter_id": "003", "name": "DualShuntAdapter-G",
6
+ "t5": {
7
+ "model": "google/flan-t5-base",
8
+ "hidden_size": 768
9
+ },
10
+ "clip": {
11
+ "model": "openai/clip-vit-large-patch14",
12
+ "hidden_size": 1280
13
+ },
14
+ "hidden_size": 1280, # This is the adapter's output size
15
+ "bottleneck": 640, "heads": 20,
16
+ "tau_init": 0.1, "max_guidance": 10.0,
17
+ "proj_layers": 2, "layer_norm": True, "dropout": 0.1,
18
+ "use_dropout": True, "use_proj_stack": True, "assert_input_dims": True,
19
+ "routing": {"type": "cross_attention", "enable_causal_mask": False, "bidirectional": True},
20
+ "version": "v0.3.2"
21
+ },
22
+ "repo": "AbstractPhil/t5-flan-base-vit-bigG-14-dual-stream-adapter",
23
+ "shunts_available": {
24
+ "shunt_type_name": "DualStreamAdapter-G",
25
+ "config_file_name": "config.json",
26
+ "shunt_list": [
27
+ "t5-flan-vit-bigG-14-dual_shunt_caption.safetensors",
28
+ "t5-flan-vit-bigG-14-dual_shunt_no_caption_e1.safetensors",
29
+ "t5-flan-vit-bigG-14-dual_shunt_no_caption_e2.safetensors",
30
+ "t5-flan-vit-bigG-14-dual_shunt_no_caption_e3.safetensors",
31
+ "t5-flan-vit-bigG-14-dual_shunt_summarize.safetensors",
32
+ "dual_shunt_omega_no_caption_e1_step_10000.safetensors",
33
+ "dual_shunt_omega_no_caption_noised_e1_step_1000.safetensors",
34
+ "dual_shunt_omega_no_caption_noised_e1_step_4000.safetensors",
35
+ "dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
36
+ ],
37
+ }
38
+ },
39
+ "clip_l": {
40
+ "models": ["vit-l-14", 'flan-t5-base'],
41
+ "config": {
42
+ "adapter_id": "002",
43
+ "name": "DualShuntAdapter",
44
+ "t5": {"model": "google/flan-t5-base", "hidden_size": 768},
45
+ "clip": {"model": "openai/clip-vit-large-patch14", "hidden_size": 768},
46
+ "hidden_size": 768, # This is the adapter's output size
47
+ "bottleneck": 384, "heads": 12,
48
+ "tau_init": 0.1, "max_guidance": 10.0,
49
+ "proj_layers": 2, "layer_norm": True, "dropout": 0.1,
50
+ "use_dropout": True, "use_proj_stack": True, "assert_input_dims": True,
51
+ "routing": {"type": "cross_attention", "enable_causal_mask": False, "bidirectional": True},
52
+ "version": "v0.3.2"
53
+ },
54
+ "repo": "AbstractPhil/t5-flan-base-vit-l-14-dual-stream-adapter",
55
+ "shunts_available": {
56
+ "shunt_type_name": "DualStreamAdapter-L",
57
+ "config_file_name": "config.json",
58
+ "shunt_list": [
59
+ "t5-vit-l-14-dual_shunt_caption.safetensors",
60
+ "t5-vit-l-14-dual_shunt_no_caption.safetensors",
61
+ "t5-vit-l-14-dual_shunt_summarize.safetensors",
62
+ ],
63
+ },
64
+ }
65
+ }
66
+
67
+ # ─── Adapter Configs ─────────────────────────────────────────────
68
+
69
+ BERT_CONFIGS = {
70
+ "mobilebert-base-uncased": {
71
+ "repo_name": "google/mobilebert-uncased",
72
+ "use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
73
+ "subfolder": "",
74
+ },
75
+ "bert-base-uncased": {
76
+ "repo_name": "bert-base-uncased",
77
+ "use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
78
+ },
79
+ "bert-large-uncased": {
80
+ "repo_name": "bert-large-uncased",
81
+ "use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
82
+ },
83
+ "bert-base-cased": {
84
+ "repo_name": "bert-base-cased",
85
+ "use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
86
+ }
87
+ }
88
+
89
+ T5_CONFIGS = {
90
+ "flan-t5-base": {
91
+ "repo_name": "google/flan-t5-base",
92
+ "use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
93
+ },
94
+ "t5-small": {
95
+ "repo_name": "google-t5/t5-small",
96
+ "use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
97
+
98
+ },
99
+ "t5_small_human_attentive_try2_pass3": {
100
+ "repo_name": "AbstractPhil/t5_small_human_attentive_try2_pass3",
101
+ "use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
102
+ # the necessary config is present here for posterity in case it fails to load from HuggingFace.
103
+ "subfolder": "",
104
+ "tokenizer": "t5-small",
105
+ "file_name": "model.safetensors",
106
+ "config": {
107
+ "config_file_name": "config.json",
108
+ "architectures": [
109
+ "T5ForConditionalGeneration"
110
+ ],
111
+ "attention_dropout": 0.3,
112
+ "classifier_dropout": 0.0,
113
+ "d_ff": 2048,
114
+ "d_kv": 64,
115
+ "d_model": 512,
116
+ "decoder_start_token_id": 0,
117
+ "dense_act_fn": "relu",
118
+ "dropout_rate": 0.0, #0.3, # disable for generation
119
+ "eos_token_id": 1,
120
+ "feed_forward_proj": "relu",
121
+ "initializer_factor": 1.0,
122
+ "is_encoder_decoder": True,
123
+ "is_gated_act": False,
124
+ "layer_norm_epsilon": 1e-06,
125
+ "model_type": "t5",
126
+ "n_positions": 512,
127
+ "num_decoder_layers": 6,
128
+ "num_heads": 8,
129
+ "num_layers": 6,
130
+ "output_past": True,
131
+ "pad_token_id": 0,
132
+ "relative_attention_max_distance": 128,
133
+ "relative_attention_num_buckets": 32,
134
+ "task_specific_params": {
135
+ "caption": {
136
+ "early_stopping": True,
137
+ "length_penalty": 1.0,
138
+ "max_length": 64,
139
+ "num_beams": 4,
140
+ "prefix": "caption: "
141
+ }
142
+ },
143
+ "torch_dtype": "float32",
144
+ "transformers_version": "4.51.3",
145
+ "use_cache": True,
146
+ "vocab_size": 32128
147
+ }
148
+ }
149
+ }
two_stream_shunt_adapter.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # ─── Residual Pocket Block ───────────────────────────────────
6
+ class BottleneckResBlock(nn.Module):
7
+ def __init__(self, dim, kernel=3, dropout=0.1):
8
+ super().__init__()
9
+ self.norm = nn.LayerNorm(dim)
10
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
11
+ self.proj = nn.Sequential(
12
+ nn.Linear(dim, dim * 2),
13
+ nn.GELU(),
14
+ nn.Linear(dim * 2, dim),
15
+ nn.Dropout(dropout)
16
+ )
17
+
18
+ def forward(self, x):
19
+ residual = x
20
+ x = self.norm(x)
21
+ x = x.transpose(1, 2)
22
+ x = self.conv(x).transpose(1, 2)
23
+ return residual + self.proj(x)
24
+
25
+ # ─── Two Stream Shunt Adapter ──────────────────────────────────────
26
+ class TwoStreamShuntAdapter(nn.Module):
27
+ def __init__(self, config: dict):
28
+ super().__init__()
29
+ self.config = config
30
+ self.t5_dim = config["t5"]["hidden_size"]
31
+ self.clip_dim = config["clip"]["hidden_size"]
32
+ self.bneck = config["bottleneck"]
33
+ self.heads = config["heads"]
34
+ self.tau_init = config["tau_init"]
35
+ self.max_guidance = config["max_guidance"]
36
+
37
+ use_norm = config.get("layer_norm", True)
38
+ use_do = config.get("use_dropout", True)
39
+ do_p = config.get("dropout", 0.1)
40
+ proj_depth = config.get("proj_layers", 2)
41
+
42
+ def build_projection(input_dim, output_dim):
43
+ layers = []
44
+ last_dim = input_dim
45
+ if use_norm:
46
+ layers.append(nn.LayerNorm(last_dim))
47
+ for i in range(proj_depth):
48
+ next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
49
+ layers.append(nn.Linear(last_dim, next_dim))
50
+ layers.append(nn.GELU())
51
+ if use_do:
52
+ layers.append(nn.Dropout(do_p))
53
+ last_dim = next_dim
54
+ layers.append(nn.Linear(last_dim, output_dim))
55
+ return nn.Sequential(*layers)
56
+
57
+ # Projections
58
+ self.proj_t5 = build_projection(self.t5_dim, self.bneck)
59
+ self.proj_clip = build_projection(self.clip_dim, self.bneck)
60
+
61
+ # Attention
62
+ self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
63
+ self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
64
+ self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
65
+
66
+ # Residual Pocket
67
+ self.pocket_blocks = nn.Sequential(
68
+ BottleneckResBlock(self.bneck, dropout=do_p),
69
+ BottleneckResBlock(self.bneck, dropout=do_p)
70
+ )
71
+
72
+ # Fuse
73
+ self.fuse = nn.Sequential(
74
+ nn.LayerNorm(2 * self.bneck),
75
+ nn.Linear(2 * self.bneck, self.bneck * 2),
76
+ nn.GELU(),
77
+ nn.Linear(self.bneck * 2, self.bneck)
78
+ )
79
+
80
+ # Output Projections
81
+ self.anchor_proj = build_projection(self.bneck, self.clip_dim)
82
+ self.delta_proj = build_projection(self.bneck, self.clip_dim)
83
+ self.logsig_proj = build_projection(self.bneck, self.clip_dim)
84
+
85
+ self.gate_proj = nn.Sequential(
86
+ nn.LayerNorm(self.bneck),
87
+ nn.Linear(self.bneck, self.bneck),
88
+ nn.GELU(),
89
+ nn.Linear(self.bneck, 1),
90
+ nn.Tanh(),
91
+ nn.Sigmoid()
92
+ )
93
+
94
+ self.guidance_proj = nn.Sequential(
95
+ nn.LayerNorm(self.bneck),
96
+ nn.Linear(self.bneck, 1),
97
+ nn.Sigmoid()
98
+ )
99
+
100
+ def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
101
+ if self.config.get("assert_input_dims", True):
102
+ assert t5_seq.size(-1) == self.t5_dim
103
+ assert clip_seq.size(-1) == self.clip_dim
104
+
105
+ t5_b = self.proj_t5(t5_seq)
106
+ clip_b = self.proj_clip(clip_seq)
107
+
108
+ t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
109
+ c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
110
+
111
+ pocket = self.pocket_blocks(t2c)
112
+
113
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
114
+ h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
115
+
116
+ anchor = self.anchor_proj(h)
117
+ delta = self.delta_proj(h) * self.gate_proj(h)
118
+ log_sigma = self.logsig_proj(h)
119
+
120
+ g_tok = self.guidance_proj(h).squeeze(-1)
121
+ g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
122
+
123
+ return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)