Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
Β·
ca066a9
1
Parent(s):
11aea4e
initial push for v1
Browse files- app.py +121 -142
- configs.py +149 -0
- two_stream_shunt_adapter.py +123 -0
app.py
CHANGED
@@ -1,153 +1,132 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
-
import
|
4 |
-
|
5 |
-
# import spaces #[uncomment to use ZeroGPU]
|
6 |
from diffusers import DiffusionPipeline
|
7 |
-
import
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
else
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
)
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
image = pipe(
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
width=width,
|
47 |
-
height=height,
|
48 |
-
generator=generator,
|
49 |
).images[0]
|
50 |
|
51 |
-
return
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
label="Negative prompt",
|
87 |
-
max_lines=1,
|
88 |
-
placeholder="Enter a negative prompt",
|
89 |
-
visible=False,
|
90 |
-
)
|
91 |
-
|
92 |
-
seed = gr.Slider(
|
93 |
-
label="Seed",
|
94 |
-
minimum=0,
|
95 |
-
maximum=MAX_SEED,
|
96 |
-
step=1,
|
97 |
-
value=0,
|
98 |
-
)
|
99 |
-
|
100 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
-
|
102 |
-
with gr.Row():
|
103 |
-
width = gr.Slider(
|
104 |
-
label="Width",
|
105 |
-
minimum=256,
|
106 |
-
maximum=MAX_IMAGE_SIZE,
|
107 |
-
step=32,
|
108 |
-
value=1024, # Replace with defaults that work for your model
|
109 |
-
)
|
110 |
-
|
111 |
-
height = gr.Slider(
|
112 |
-
label="Height",
|
113 |
-
minimum=256,
|
114 |
-
maximum=MAX_IMAGE_SIZE,
|
115 |
-
step=32,
|
116 |
-
value=1024, # Replace with defaults that work for your model
|
117 |
-
)
|
118 |
-
|
119 |
-
with gr.Row():
|
120 |
-
guidance_scale = gr.Slider(
|
121 |
-
label="Guidance scale",
|
122 |
-
minimum=0.0,
|
123 |
-
maximum=10.0,
|
124 |
-
step=0.1,
|
125 |
-
value=0.0, # Replace with defaults that work for your model
|
126 |
-
)
|
127 |
-
|
128 |
-
num_inference_steps = gr.Slider(
|
129 |
-
label="Number of inference steps",
|
130 |
-
minimum=1,
|
131 |
-
maximum=50,
|
132 |
-
step=1,
|
133 |
-
value=2, # Replace with defaults that work for your model
|
134 |
-
)
|
135 |
-
|
136 |
-
gr.Examples(examples=examples, inputs=[prompt])
|
137 |
-
gr.on(
|
138 |
-
triggers=[run_button.click, prompt.submit],
|
139 |
fn=infer,
|
140 |
-
inputs=[
|
141 |
-
|
142 |
-
negative_prompt,
|
143 |
-
seed,
|
144 |
-
randomize_seed,
|
145 |
-
width,
|
146 |
-
height,
|
147 |
-
guidance_scale,
|
148 |
-
num_inference_steps,
|
149 |
-
],
|
150 |
-
outputs=[result, seed],
|
151 |
)
|
152 |
|
153 |
if __name__ == "__main__":
|
|
|
1 |
+
import torch
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
|
|
6 |
from diffusers import DiffusionPipeline
|
7 |
+
from safetensors.torch import load_file
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
from shunt_adapter import TwoStreamShuntAdapter
|
10 |
+
from adapter_config import T5_SHUNT_REPOS
|
11 |
+
|
12 |
+
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
15 |
+
|
16 |
+
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
17 |
+
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
|
18 |
+
|
19 |
+
pipe = DiffusionPipeline.from_pretrained(
|
20 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
21 |
+
torch_dtype=dtype,
|
22 |
+
variant="fp16" if dtype == torch.float16 else None
|
23 |
+
).to(device)
|
24 |
+
|
25 |
+
# βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ
|
26 |
+
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
|
27 |
+
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
|
28 |
+
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
|
29 |
+
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
|
30 |
+
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
|
31 |
+
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
|
32 |
+
|
33 |
+
# βββ Loader βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
34 |
+
def load_adapter(repo, filename, config):
|
35 |
+
path = hf_hub_download(repo_id=repo, filename=filename)
|
36 |
+
model = TwoStreamShuntAdapter(config).to(device).eval()
|
37 |
+
model.load_state_dict(load_file(path, device=device))
|
38 |
+
return model
|
39 |
+
|
40 |
+
# βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ
|
41 |
+
def plot_heat(mat, title):
|
42 |
+
import io
|
43 |
+
fig, ax = plt.subplots(figsize=(6, 3), dpi=100)
|
44 |
+
im = ax.imshow(mat, aspect="auto", cmap="bwr", origin="upper")
|
45 |
+
ax.set_title(title)
|
46 |
+
plt.colorbar(im, ax=ax)
|
47 |
+
buf = io.BytesIO()
|
48 |
+
plt.savefig(buf, format="png", bbox_inches='tight')
|
49 |
+
buf.seek(0)
|
50 |
+
return buf
|
51 |
+
|
52 |
+
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
|
53 |
+
@torch.no_grad()
|
54 |
+
def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
|
55 |
+
t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
|
56 |
+
t5_seq = t5_mod(t5_ids).last_hidden_state
|
57 |
+
|
58 |
+
adapter_l = load_adapter(repo_l, adapter_l_file, config_l)
|
59 |
+
adapter_g = load_adapter(repo_g, adapter_g_file, config_g)
|
60 |
+
|
61 |
+
clip_l_in = torch.randn(t5_seq.shape[0], 77, 768).to(device)
|
62 |
+
clip_g_in = torch.randn(t5_seq.shape[0], 77, 1280).to(device)
|
63 |
+
|
64 |
+
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_l_in)
|
65 |
+
gate_l_scaled = gate_l * gate_prob
|
66 |
+
delta_l_final = delta_l * strength * gate_l_scaled
|
67 |
+
clip_l_mod = clip_l_in + delta_l_final
|
68 |
+
if use_anchor:
|
69 |
+
clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
|
70 |
+
if noise > 0:
|
71 |
+
clip_l_mod += torch.randn_like(clip_l_mod) * noise
|
72 |
+
|
73 |
+
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq, clip_g_in)
|
74 |
+
gate_g_scaled = gate_g * gate_prob
|
75 |
+
delta_g_final = delta_g * strength * gate_g_scaled
|
76 |
+
clip_g_mod = clip_g_in + delta_g_final
|
77 |
+
if use_anchor:
|
78 |
+
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
79 |
+
if noise > 0:
|
80 |
+
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
81 |
+
|
82 |
+
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1).to(dtype)
|
83 |
+
neg_embeds = torch.zeros_like(prompt_embeds)
|
84 |
|
85 |
image = pipe(
|
86 |
+
prompt_embeds=prompt_embeds,
|
87 |
+
negative_prompt_embeds=neg_embeds,
|
88 |
+
num_inference_steps=20,
|
89 |
+
guidance_scale=5.0
|
|
|
|
|
|
|
90 |
).images[0]
|
91 |
|
92 |
+
return (
|
93 |
+
image,
|
94 |
+
plot_heat(delta_l_final.squeeze().cpu().numpy(), "Ξ CLIP-L"),
|
95 |
+
plot_heat(gate_l_scaled.squeeze().cpu().numpy(), "Gate CLIP-L"),
|
96 |
+
plot_heat(delta_g_final.squeeze().cpu().numpy(), "Ξ CLIP-G"),
|
97 |
+
plot_heat(gate_g_scaled.squeeze().cpu().numpy(), "Gate CLIP-G"),
|
98 |
+
f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}",
|
99 |
+
f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}"
|
100 |
+
)
|
101 |
+
|
102 |
+
# βββ Gradio App βββββββββββββββββββββββββββββββββββββββββββββββ
|
103 |
+
with gr.Blocks(title="Dual Adapter T5βCLIP") as demo:
|
104 |
+
gr.Markdown("# π§ Dual Shunt Adapter β’ SDXL Inference")
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column():
|
108 |
+
prompt = gr.Textbox(label="Prompt", value="a futuristic control station")
|
109 |
+
adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter")
|
110 |
+
adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter")
|
111 |
+
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
|
112 |
+
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
|
113 |
+
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
|
114 |
+
use_anchor = gr.Checkbox(label="Use Anchor", value=True)
|
115 |
+
run_btn = gr.Button("Run")
|
116 |
+
|
117 |
+
with gr.Column():
|
118 |
+
out_img = gr.Image(label="Generated Image")
|
119 |
+
delta_l = gr.Image(label="Ξ CLIP-L")
|
120 |
+
gate_l = gr.Image(label="Gate CLIP-L")
|
121 |
+
delta_g = gr.Image(label="Ξ CLIP-G")
|
122 |
+
gate_g = gr.Image(label="Gate CLIP-G")
|
123 |
+
stats_l = gr.Textbox(label="CLIP-L Stats")
|
124 |
+
stats_g = gr.Textbox(label="CLIP-G Stats")
|
125 |
+
|
126 |
+
run_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
fn=infer,
|
128 |
+
inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
|
129 |
+
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
)
|
131 |
|
132 |
if __name__ == "__main__":
|
configs.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
T5_SHUNT_REPOS = {
|
2 |
+
"clip_g": {
|
3 |
+
"models": ["vit-bigG-14", 'flan-t5-base'],
|
4 |
+
"config": {
|
5 |
+
"adapter_id": "003", "name": "DualShuntAdapter-G",
|
6 |
+
"t5": {
|
7 |
+
"model": "google/flan-t5-base",
|
8 |
+
"hidden_size": 768
|
9 |
+
},
|
10 |
+
"clip": {
|
11 |
+
"model": "openai/clip-vit-large-patch14",
|
12 |
+
"hidden_size": 1280
|
13 |
+
},
|
14 |
+
"hidden_size": 1280, # This is the adapter's output size
|
15 |
+
"bottleneck": 640, "heads": 20,
|
16 |
+
"tau_init": 0.1, "max_guidance": 10.0,
|
17 |
+
"proj_layers": 2, "layer_norm": True, "dropout": 0.1,
|
18 |
+
"use_dropout": True, "use_proj_stack": True, "assert_input_dims": True,
|
19 |
+
"routing": {"type": "cross_attention", "enable_causal_mask": False, "bidirectional": True},
|
20 |
+
"version": "v0.3.2"
|
21 |
+
},
|
22 |
+
"repo": "AbstractPhil/t5-flan-base-vit-bigG-14-dual-stream-adapter",
|
23 |
+
"shunts_available": {
|
24 |
+
"shunt_type_name": "DualStreamAdapter-G",
|
25 |
+
"config_file_name": "config.json",
|
26 |
+
"shunt_list": [
|
27 |
+
"t5-flan-vit-bigG-14-dual_shunt_caption.safetensors",
|
28 |
+
"t5-flan-vit-bigG-14-dual_shunt_no_caption_e1.safetensors",
|
29 |
+
"t5-flan-vit-bigG-14-dual_shunt_no_caption_e2.safetensors",
|
30 |
+
"t5-flan-vit-bigG-14-dual_shunt_no_caption_e3.safetensors",
|
31 |
+
"t5-flan-vit-bigG-14-dual_shunt_summarize.safetensors",
|
32 |
+
"dual_shunt_omega_no_caption_e1_step_10000.safetensors",
|
33 |
+
"dual_shunt_omega_no_caption_noised_e1_step_1000.safetensors",
|
34 |
+
"dual_shunt_omega_no_caption_noised_e1_step_4000.safetensors",
|
35 |
+
"dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
|
36 |
+
],
|
37 |
+
}
|
38 |
+
},
|
39 |
+
"clip_l": {
|
40 |
+
"models": ["vit-l-14", 'flan-t5-base'],
|
41 |
+
"config": {
|
42 |
+
"adapter_id": "002",
|
43 |
+
"name": "DualShuntAdapter",
|
44 |
+
"t5": {"model": "google/flan-t5-base", "hidden_size": 768},
|
45 |
+
"clip": {"model": "openai/clip-vit-large-patch14", "hidden_size": 768},
|
46 |
+
"hidden_size": 768, # This is the adapter's output size
|
47 |
+
"bottleneck": 384, "heads": 12,
|
48 |
+
"tau_init": 0.1, "max_guidance": 10.0,
|
49 |
+
"proj_layers": 2, "layer_norm": True, "dropout": 0.1,
|
50 |
+
"use_dropout": True, "use_proj_stack": True, "assert_input_dims": True,
|
51 |
+
"routing": {"type": "cross_attention", "enable_causal_mask": False, "bidirectional": True},
|
52 |
+
"version": "v0.3.2"
|
53 |
+
},
|
54 |
+
"repo": "AbstractPhil/t5-flan-base-vit-l-14-dual-stream-adapter",
|
55 |
+
"shunts_available": {
|
56 |
+
"shunt_type_name": "DualStreamAdapter-L",
|
57 |
+
"config_file_name": "config.json",
|
58 |
+
"shunt_list": [
|
59 |
+
"t5-vit-l-14-dual_shunt_caption.safetensors",
|
60 |
+
"t5-vit-l-14-dual_shunt_no_caption.safetensors",
|
61 |
+
"t5-vit-l-14-dual_shunt_summarize.safetensors",
|
62 |
+
],
|
63 |
+
},
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
# βββ Adapter Configs βββββββββββββββββββββββββββββββββββββββββββββ
|
68 |
+
|
69 |
+
BERT_CONFIGS = {
|
70 |
+
"mobilebert-base-uncased": {
|
71 |
+
"repo_name": "google/mobilebert-uncased",
|
72 |
+
"use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
|
73 |
+
"subfolder": "",
|
74 |
+
},
|
75 |
+
"bert-base-uncased": {
|
76 |
+
"repo_name": "bert-base-uncased",
|
77 |
+
"use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
|
78 |
+
},
|
79 |
+
"bert-large-uncased": {
|
80 |
+
"repo_name": "bert-large-uncased",
|
81 |
+
"use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
|
82 |
+
},
|
83 |
+
"bert-base-cased": {
|
84 |
+
"repo_name": "bert-base-cased",
|
85 |
+
"use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
T5_CONFIGS = {
|
90 |
+
"flan-t5-base": {
|
91 |
+
"repo_name": "google/flan-t5-base",
|
92 |
+
"use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
|
93 |
+
},
|
94 |
+
"t5-small": {
|
95 |
+
"repo_name": "google-t5/t5-small",
|
96 |
+
"use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
|
97 |
+
|
98 |
+
},
|
99 |
+
"t5_small_human_attentive_try2_pass3": {
|
100 |
+
"repo_name": "AbstractPhil/t5_small_human_attentive_try2_pass3",
|
101 |
+
"use_huggingface": True, # defaults to simple loading from HuggingFace, if False, will use repo_name and subfolder
|
102 |
+
# the necessary config is present here for posterity in case it fails to load from HuggingFace.
|
103 |
+
"subfolder": "",
|
104 |
+
"tokenizer": "t5-small",
|
105 |
+
"file_name": "model.safetensors",
|
106 |
+
"config": {
|
107 |
+
"config_file_name": "config.json",
|
108 |
+
"architectures": [
|
109 |
+
"T5ForConditionalGeneration"
|
110 |
+
],
|
111 |
+
"attention_dropout": 0.3,
|
112 |
+
"classifier_dropout": 0.0,
|
113 |
+
"d_ff": 2048,
|
114 |
+
"d_kv": 64,
|
115 |
+
"d_model": 512,
|
116 |
+
"decoder_start_token_id": 0,
|
117 |
+
"dense_act_fn": "relu",
|
118 |
+
"dropout_rate": 0.0, #0.3, # disable for generation
|
119 |
+
"eos_token_id": 1,
|
120 |
+
"feed_forward_proj": "relu",
|
121 |
+
"initializer_factor": 1.0,
|
122 |
+
"is_encoder_decoder": True,
|
123 |
+
"is_gated_act": False,
|
124 |
+
"layer_norm_epsilon": 1e-06,
|
125 |
+
"model_type": "t5",
|
126 |
+
"n_positions": 512,
|
127 |
+
"num_decoder_layers": 6,
|
128 |
+
"num_heads": 8,
|
129 |
+
"num_layers": 6,
|
130 |
+
"output_past": True,
|
131 |
+
"pad_token_id": 0,
|
132 |
+
"relative_attention_max_distance": 128,
|
133 |
+
"relative_attention_num_buckets": 32,
|
134 |
+
"task_specific_params": {
|
135 |
+
"caption": {
|
136 |
+
"early_stopping": True,
|
137 |
+
"length_penalty": 1.0,
|
138 |
+
"max_length": 64,
|
139 |
+
"num_beams": 4,
|
140 |
+
"prefix": "caption: "
|
141 |
+
}
|
142 |
+
},
|
143 |
+
"torch_dtype": "float32",
|
144 |
+
"transformers_version": "4.51.3",
|
145 |
+
"use_cache": True,
|
146 |
+
"vocab_size": 32128
|
147 |
+
}
|
148 |
+
}
|
149 |
+
}
|
two_stream_shunt_adapter.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# βββ Residual Pocket Block βββββββββββββββββββββββββββββββββββ
|
6 |
+
class BottleneckResBlock(nn.Module):
|
7 |
+
def __init__(self, dim, kernel=3, dropout=0.1):
|
8 |
+
super().__init__()
|
9 |
+
self.norm = nn.LayerNorm(dim)
|
10 |
+
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel, padding=kernel // 2, groups=1)
|
11 |
+
self.proj = nn.Sequential(
|
12 |
+
nn.Linear(dim, dim * 2),
|
13 |
+
nn.GELU(),
|
14 |
+
nn.Linear(dim * 2, dim),
|
15 |
+
nn.Dropout(dropout)
|
16 |
+
)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
residual = x
|
20 |
+
x = self.norm(x)
|
21 |
+
x = x.transpose(1, 2)
|
22 |
+
x = self.conv(x).transpose(1, 2)
|
23 |
+
return residual + self.proj(x)
|
24 |
+
|
25 |
+
# βββ Two Stream Shunt Adapter ββββββββββββββββββββββββββββββββββββββ
|
26 |
+
class TwoStreamShuntAdapter(nn.Module):
|
27 |
+
def __init__(self, config: dict):
|
28 |
+
super().__init__()
|
29 |
+
self.config = config
|
30 |
+
self.t5_dim = config["t5"]["hidden_size"]
|
31 |
+
self.clip_dim = config["clip"]["hidden_size"]
|
32 |
+
self.bneck = config["bottleneck"]
|
33 |
+
self.heads = config["heads"]
|
34 |
+
self.tau_init = config["tau_init"]
|
35 |
+
self.max_guidance = config["max_guidance"]
|
36 |
+
|
37 |
+
use_norm = config.get("layer_norm", True)
|
38 |
+
use_do = config.get("use_dropout", True)
|
39 |
+
do_p = config.get("dropout", 0.1)
|
40 |
+
proj_depth = config.get("proj_layers", 2)
|
41 |
+
|
42 |
+
def build_projection(input_dim, output_dim):
|
43 |
+
layers = []
|
44 |
+
last_dim = input_dim
|
45 |
+
if use_norm:
|
46 |
+
layers.append(nn.LayerNorm(last_dim))
|
47 |
+
for i in range(proj_depth):
|
48 |
+
next_dim = self.bneck * (2 if i == 0 and proj_depth > 1 else 1)
|
49 |
+
layers.append(nn.Linear(last_dim, next_dim))
|
50 |
+
layers.append(nn.GELU())
|
51 |
+
if use_do:
|
52 |
+
layers.append(nn.Dropout(do_p))
|
53 |
+
last_dim = next_dim
|
54 |
+
layers.append(nn.Linear(last_dim, output_dim))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
# Projections
|
58 |
+
self.proj_t5 = build_projection(self.t5_dim, self.bneck)
|
59 |
+
self.proj_clip = build_projection(self.clip_dim, self.bneck)
|
60 |
+
|
61 |
+
# Attention
|
62 |
+
self.cross_t2c = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
|
63 |
+
self.cross_c2t = nn.MultiheadAttention(self.bneck, self.heads, batch_first=True, dropout=do_p)
|
64 |
+
self.tau = nn.Parameter(torch.full((self.heads, 1, 1), self.tau_init))
|
65 |
+
|
66 |
+
# Residual Pocket
|
67 |
+
self.pocket_blocks = nn.Sequential(
|
68 |
+
BottleneckResBlock(self.bneck, dropout=do_p),
|
69 |
+
BottleneckResBlock(self.bneck, dropout=do_p)
|
70 |
+
)
|
71 |
+
|
72 |
+
# Fuse
|
73 |
+
self.fuse = nn.Sequential(
|
74 |
+
nn.LayerNorm(2 * self.bneck),
|
75 |
+
nn.Linear(2 * self.bneck, self.bneck * 2),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.Linear(self.bneck * 2, self.bneck)
|
78 |
+
)
|
79 |
+
|
80 |
+
# Output Projections
|
81 |
+
self.anchor_proj = build_projection(self.bneck, self.clip_dim)
|
82 |
+
self.delta_proj = build_projection(self.bneck, self.clip_dim)
|
83 |
+
self.logsig_proj = build_projection(self.bneck, self.clip_dim)
|
84 |
+
|
85 |
+
self.gate_proj = nn.Sequential(
|
86 |
+
nn.LayerNorm(self.bneck),
|
87 |
+
nn.Linear(self.bneck, self.bneck),
|
88 |
+
nn.GELU(),
|
89 |
+
nn.Linear(self.bneck, 1),
|
90 |
+
nn.Tanh(),
|
91 |
+
nn.Sigmoid()
|
92 |
+
)
|
93 |
+
|
94 |
+
self.guidance_proj = nn.Sequential(
|
95 |
+
nn.LayerNorm(self.bneck),
|
96 |
+
nn.Linear(self.bneck, 1),
|
97 |
+
nn.Sigmoid()
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
|
101 |
+
if self.config.get("assert_input_dims", True):
|
102 |
+
assert t5_seq.size(-1) == self.t5_dim
|
103 |
+
assert clip_seq.size(-1) == self.clip_dim
|
104 |
+
|
105 |
+
t5_b = self.proj_t5(t5_seq)
|
106 |
+
clip_b = self.proj_clip(clip_seq)
|
107 |
+
|
108 |
+
t2c, attn_t2c = self.cross_t2c(t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False)
|
109 |
+
c2t, attn_c2t = self.cross_c2t(clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False)
|
110 |
+
|
111 |
+
pocket = self.pocket_blocks(t2c)
|
112 |
+
|
113 |
+
pocket_mean = pocket.mean(1, keepdim=True).expand(-1, clip_b.size(1), -1)
|
114 |
+
h = self.fuse(torch.cat([pocket_mean, c2t], dim=-1))
|
115 |
+
|
116 |
+
anchor = self.anchor_proj(h)
|
117 |
+
delta = self.delta_proj(h) * self.gate_proj(h)
|
118 |
+
log_sigma = self.logsig_proj(h)
|
119 |
+
|
120 |
+
g_tok = self.guidance_proj(h).squeeze(-1)
|
121 |
+
g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
|
122 |
+
|
123 |
+
return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, self.gate_proj(h)
|