Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -11,12 +11,7 @@ from huggingface_hub import hf_hub_download
|
|
11 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
12 |
from configs import T5_SHUNT_REPOS
|
13 |
|
14 |
-
# βββ
|
15 |
-
# Don't initialize CUDA here for ZeroGPU compatibility
|
16 |
-
device = None # Will be set inside the GPU function
|
17 |
-
dtype = torch.float16
|
18 |
-
|
19 |
-
# Don't load models here - will load inside GPU function
|
20 |
t5_tok = None
|
21 |
t5_mod = None
|
22 |
pipe = None
|
@@ -36,11 +31,9 @@ repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
|
|
36 |
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
|
37 |
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
|
38 |
|
39 |
-
# βββ
|
40 |
-
|
41 |
-
|
42 |
-
def load_adapter(repo, filename, config):
|
43 |
-
# Don't initialize device here
|
44 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
45 |
|
46 |
model = TwoStreamShuntAdapter(config).eval()
|
@@ -49,75 +42,67 @@ def load_adapter(repo, filename, config):
|
|
49 |
for key in f.keys():
|
50 |
tensors[key] = f.get_tensor(key)
|
51 |
model.load_state_dict(tensors)
|
52 |
-
|
53 |
-
return model
|
54 |
|
55 |
-
# βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ
|
56 |
def plot_heat(mat, title):
|
|
|
57 |
import io
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
buf = io.BytesIO()
|
63 |
-
plt.savefig(buf, format="png", bbox_inches='tight')
|
64 |
buf.seek(0)
|
|
|
65 |
plt.close(fig)
|
66 |
-
return
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
"""Generate proper CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
|
71 |
|
72 |
# Tokenize for both encoders
|
73 |
tokens_l = pipe.tokenizer(
|
74 |
-
prompt,
|
75 |
-
padding="max_length",
|
76 |
-
max_length=77,
|
77 |
-
truncation=True,
|
78 |
-
return_tensors="pt"
|
79 |
).input_ids.to(device)
|
80 |
|
81 |
tokens_g = pipe.tokenizer_2(
|
82 |
-
prompt,
|
83 |
-
padding="max_length",
|
84 |
-
max_length=77,
|
85 |
-
truncation=True,
|
86 |
-
return_tensors="pt"
|
87 |
).input_ids.to(device)
|
88 |
|
89 |
-
# Negative prompts
|
90 |
neg_tokens_l = pipe.tokenizer(
|
91 |
-
negative_prompt,
|
92 |
-
padding="max_length",
|
93 |
-
max_length=77,
|
94 |
-
truncation=True,
|
95 |
-
return_tensors="pt"
|
96 |
).input_ids.to(device)
|
97 |
|
98 |
neg_tokens_g = pipe.tokenizer_2(
|
99 |
-
negative_prompt,
|
100 |
-
padding="max_length",
|
101 |
-
max_length=77,
|
102 |
-
truncation=True,
|
103 |
-
return_tensors="pt"
|
104 |
).input_ids.to(device)
|
105 |
|
106 |
with torch.no_grad():
|
107 |
-
# CLIP-L
|
108 |
clip_l_embeds = pipe.text_encoder(tokens_l)[0]
|
109 |
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
|
110 |
|
111 |
-
# CLIP-G
|
112 |
clip_g_output = pipe.text_encoder_2(tokens_g)
|
113 |
clip_g_embeds = clip_g_output[1] # sequence embeddings
|
|
|
114 |
|
115 |
neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
|
116 |
-
neg_clip_g_embeds = neg_clip_g_output[1]
|
117 |
-
|
118 |
-
# Pooled embeddings for SDXL
|
119 |
-
pooled_embeds = clip_g_output[0] # pooled embeddings
|
120 |
-
neg_pooled_embeds = neg_clip_g_output[0] # pooled embeddings
|
121 |
|
122 |
return {
|
123 |
"clip_l": clip_l_embeds,
|
@@ -128,23 +113,16 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
|
|
128 |
"neg_pooled": neg_pooled_embeds
|
129 |
}
|
130 |
|
131 |
-
# βββ Inference
|
132 |
-
@torch.no_grad()
|
133 |
-
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
|
134 |
-
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
|
135 |
-
|
136 |
-
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββ
|
137 |
@spaces.GPU
|
138 |
-
@torch.no_grad()
|
139 |
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
|
140 |
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
|
141 |
|
142 |
-
# Initialize device and models inside GPU context
|
143 |
global t5_tok, t5_mod, pipe
|
144 |
device = torch.device("cuda")
|
145 |
dtype = torch.float16
|
146 |
|
147 |
-
#
|
148 |
if t5_tok is None:
|
149 |
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
150 |
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
|
@@ -157,7 +135,7 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
157 |
use_safetensors=True
|
158 |
).to(device)
|
159 |
|
160 |
-
# Set seed
|
161 |
if seed != -1:
|
162 |
torch.manual_seed(seed)
|
163 |
np.random.seed(seed)
|
@@ -166,67 +144,62 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
166 |
if scheduler_name in SCHEDULERS:
|
167 |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
|
168 |
|
169 |
-
# Get T5 embeddings
|
170 |
t5_ids = t5_tok(
|
171 |
-
prompt,
|
172 |
-
return_tensors="pt",
|
173 |
-
padding="max_length",
|
174 |
-
max_length=77,
|
175 |
-
truncation=True
|
176 |
).input_ids.to(device)
|
177 |
t5_seq = t5_mod(t5_ids).last_hidden_state
|
178 |
|
179 |
-
# Get
|
180 |
clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
|
181 |
|
182 |
-
#
|
183 |
-
|
184 |
-
|
185 |
-
print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
|
186 |
-
|
187 |
-
# Load adapters
|
188 |
-
adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
|
189 |
-
adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
|
190 |
|
191 |
# Apply CLIP-L adapter
|
192 |
if adapter_l is not None:
|
193 |
-
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(
|
|
|
|
|
194 |
gate_l_scaled = gate_l * gate_prob
|
195 |
delta_l_final = delta_l * strength * gate_l_scaled
|
196 |
-
clip_l_mod = clip_embeds["clip_l"] + delta_l_final
|
|
|
197 |
if use_anchor:
|
198 |
-
clip_l_mod = clip_l_mod * (1 - gate_l_scaled) + anchor_l * gate_l_scaled
|
199 |
if noise > 0:
|
200 |
clip_l_mod += torch.randn_like(clip_l_mod) * noise
|
201 |
else:
|
202 |
clip_l_mod = clip_embeds["clip_l"]
|
203 |
delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
|
204 |
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
|
205 |
-
g_pred_l = torch.tensor(0.0)
|
206 |
-
tau_l = torch.tensor(0.0)
|
207 |
|
208 |
# Apply CLIP-G adapter
|
209 |
if adapter_g is not None:
|
210 |
-
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(
|
|
|
|
|
211 |
gate_g_scaled = gate_g * gate_prob
|
212 |
delta_g_final = delta_g * strength * gate_g_scaled
|
213 |
-
clip_g_mod = clip_embeds["clip_g"] + delta_g_final
|
|
|
214 |
if use_anchor:
|
215 |
-
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
216 |
if noise > 0:
|
217 |
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
218 |
else:
|
219 |
clip_g_mod = clip_embeds["clip_g"]
|
220 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
221 |
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
|
222 |
-
g_pred_g = torch.tensor(0.0)
|
223 |
-
tau_g = torch.tensor(0.0)
|
224 |
|
225 |
-
# Combine embeddings
|
226 |
-
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
227 |
-
neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1)
|
228 |
|
229 |
-
# Generate image
|
230 |
image = pipe(
|
231 |
prompt_embeds=prompt_embeds,
|
232 |
pooled_prompt_embeds=clip_embeds["pooled"],
|
@@ -236,69 +209,72 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
236 |
guidance_scale=cfg_scale,
|
237 |
width=width,
|
238 |
height=height,
|
239 |
-
num_images_per_prompt=1,
|
240 |
generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
|
241 |
).images[0]
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
)
|
|
|
|
|
252 |
|
253 |
# βββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββ
|
254 |
-
|
255 |
-
gr.
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
with gr.
|
260 |
-
|
261 |
-
|
262 |
-
gr.Markdown("### Prompts")
|
263 |
prompt = gr.Textbox(
|
264 |
-
label="Prompt",
|
265 |
value="a futuristic control station with holographic displays",
|
266 |
-
lines=3
|
|
|
267 |
)
|
268 |
negative_prompt = gr.Textbox(
|
269 |
label="Negative Prompt",
|
270 |
value="blurry, low quality, distorted",
|
271 |
-
lines=2
|
|
|
272 |
)
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
gr.Markdown("### Adapters")
|
277 |
adapter_l = gr.Dropdown(
|
278 |
-
choices=["None"] + clip_l_opts,
|
279 |
label="CLIP-L (768d) Adapter",
|
280 |
-
value="None"
|
|
|
281 |
)
|
282 |
adapter_g = gr.Dropdown(
|
283 |
-
choices=["None"] + clip_g_opts,
|
284 |
-
label="CLIP-G (1280d) Adapter",
|
285 |
-
value="None"
|
|
|
286 |
)
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
gr.Markdown("### Adapter Controls")
|
291 |
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
|
292 |
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
|
293 |
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
|
294 |
-
use_anchor = gr.Checkbox(label="Use Anchor", value=True)
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
gr.Markdown("### Generation Settings")
|
299 |
with gr.Row():
|
300 |
-
steps = gr.Slider(1,
|
301 |
-
cfg_scale = gr.Slider(1.0,
|
302 |
|
303 |
scheduler_name = gr.Dropdown(
|
304 |
choices=list(SCHEDULERS.keys()),
|
@@ -311,57 +287,48 @@ with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
|
|
311 |
height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
|
312 |
|
313 |
seed = gr.Number(value=-1, label="Seed (-1 for random)")
|
|
|
|
|
314 |
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
# Visualizations
|
324 |
-
with gr.Group():
|
325 |
-
gr.Markdown("### Adapter Visualizations")
|
326 |
with gr.Row():
|
327 |
-
|
328 |
-
|
329 |
with gr.Row():
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
gr.
|
336 |
-
|
337 |
-
stats_g = gr.Textbox(label="CLIP-G Stats", interactive=False)
|
338 |
-
|
339 |
-
# Event handlers
|
340 |
-
def process_adapters(adapter_l_val, adapter_g_val):
|
341 |
-
# Convert "None" back to None for processing
|
342 |
-
adapter_l_processed = None if adapter_l_val == "None" else adapter_l_val
|
343 |
-
adapter_g_processed = None if adapter_g_val == "None" else adapter_g_val
|
344 |
-
return adapter_l_processed, adapter_g_processed
|
345 |
-
|
346 |
-
def run_inference(*args):
|
347 |
-
# Process adapter selections
|
348 |
-
adapter_l_processed, adapter_g_processed = process_adapters(args[2], args[3])
|
349 |
|
350 |
-
#
|
351 |
-
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
354 |
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
-
|
358 |
-
fn=run_inference,
|
359 |
-
inputs=[
|
360 |
-
prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
|
361 |
-
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
|
362 |
-
],
|
363 |
-
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
|
364 |
-
)
|
365 |
|
|
|
366 |
if __name__ == "__main__":
|
|
|
367 |
demo.launch()
|
|
|
11 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
12 |
from configs import T5_SHUNT_REPOS
|
13 |
|
14 |
+
# βββ Global Variables βββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
15 |
t5_tok = None
|
16 |
t5_mod = None
|
17 |
pipe = None
|
|
|
31 |
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
|
32 |
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
|
33 |
|
34 |
+
# βββ Helper Functions βββββββββββββββββββββββββββββββββββββββββ
|
35 |
+
def load_adapter(repo, filename, config, device):
|
36 |
+
from safetensors.torch import safe_open
|
|
|
|
|
37 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
38 |
|
39 |
model = TwoStreamShuntAdapter(config).eval()
|
|
|
42 |
for key in f.keys():
|
43 |
tensors[key] = f.get_tensor(key)
|
44 |
model.load_state_dict(tensors)
|
45 |
+
return model.to(device)
|
|
|
46 |
|
|
|
47 |
def plot_heat(mat, title):
|
48 |
+
"""Create heatmap visualization with proper shape handling"""
|
49 |
import io
|
50 |
+
|
51 |
+
# Ensure we have a 2D array for visualization
|
52 |
+
if len(mat.shape) == 1:
|
53 |
+
mat = mat.reshape(1, -1)
|
54 |
+
elif len(mat.shape) == 3:
|
55 |
+
mat = mat.mean(axis=0)
|
56 |
+
elif len(mat.shape) > 3:
|
57 |
+
mat = mat.reshape(-1, mat.shape[-1])
|
58 |
+
|
59 |
+
fig, ax = plt.subplots(figsize=(8, 4), dpi=100)
|
60 |
+
im = ax.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
|
61 |
+
ax.set_title(title, fontsize=12, fontweight='bold')
|
62 |
+
ax.set_xlabel("Token Position")
|
63 |
+
ax.set_ylabel("Feature Dimension")
|
64 |
+
plt.colorbar(im, ax=ax, shrink=0.8)
|
65 |
+
|
66 |
buf = io.BytesIO()
|
67 |
+
plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
|
68 |
buf.seek(0)
|
69 |
+
pil_image = Image.open(buf)
|
70 |
plt.close(fig)
|
71 |
+
return pil_image
|
72 |
|
73 |
+
def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
|
74 |
+
"""Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
|
|
|
75 |
|
76 |
# Tokenize for both encoders
|
77 |
tokens_l = pipe.tokenizer(
|
78 |
+
prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
|
|
|
|
|
|
|
|
79 |
).input_ids.to(device)
|
80 |
|
81 |
tokens_g = pipe.tokenizer_2(
|
82 |
+
prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
|
|
|
|
|
|
|
|
83 |
).input_ids.to(device)
|
84 |
|
|
|
85 |
neg_tokens_l = pipe.tokenizer(
|
86 |
+
negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
|
|
|
|
|
|
|
|
87 |
).input_ids.to(device)
|
88 |
|
89 |
neg_tokens_g = pipe.tokenizer_2(
|
90 |
+
negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
|
|
|
|
|
|
|
|
91 |
).input_ids.to(device)
|
92 |
|
93 |
with torch.no_grad():
|
94 |
+
# CLIP-L: [0] = sequence, [1] = pooled
|
95 |
clip_l_embeds = pipe.text_encoder(tokens_l)[0]
|
96 |
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
|
97 |
|
98 |
+
# CLIP-G: [0] = pooled, [1] = sequence (different from CLIP-L!)
|
99 |
clip_g_output = pipe.text_encoder_2(tokens_g)
|
100 |
clip_g_embeds = clip_g_output[1] # sequence embeddings
|
101 |
+
pooled_embeds = clip_g_output[0] # pooled embeddings
|
102 |
|
103 |
neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g)
|
104 |
+
neg_clip_g_embeds = neg_clip_g_output[1]
|
105 |
+
neg_pooled_embeds = neg_clip_g_output[0]
|
|
|
|
|
|
|
106 |
|
107 |
return {
|
108 |
"clip_l": clip_l_embeds,
|
|
|
113 |
"neg_pooled": neg_pooled_embeds
|
114 |
}
|
115 |
|
116 |
+
# βββ Main Inference Function ββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
117 |
@spaces.GPU
|
|
|
118 |
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
|
119 |
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
|
120 |
|
|
|
121 |
global t5_tok, t5_mod, pipe
|
122 |
device = torch.device("cuda")
|
123 |
dtype = torch.float16
|
124 |
|
125 |
+
# Initialize models
|
126 |
if t5_tok is None:
|
127 |
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
128 |
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
|
|
|
135 |
use_safetensors=True
|
136 |
).to(device)
|
137 |
|
138 |
+
# Set seed
|
139 |
if seed != -1:
|
140 |
torch.manual_seed(seed)
|
141 |
np.random.seed(seed)
|
|
|
144 |
if scheduler_name in SCHEDULERS:
|
145 |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
|
146 |
|
147 |
+
# Get T5 embeddings
|
148 |
t5_ids = t5_tok(
|
149 |
+
prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
|
|
|
|
|
|
|
|
|
150 |
).input_ids.to(device)
|
151 |
t5_seq = t5_mod(t5_ids).last_hidden_state
|
152 |
|
153 |
+
# Get CLIP embeddings
|
154 |
clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
|
155 |
|
156 |
+
# Load and apply adapters
|
157 |
+
adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None
|
158 |
+
adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
# Apply CLIP-L adapter
|
161 |
if adapter_l is not None:
|
162 |
+
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(
|
163 |
+
t5_seq.float(), clip_embeds["clip_l"].float()
|
164 |
+
)
|
165 |
gate_l_scaled = gate_l * gate_prob
|
166 |
delta_l_final = delta_l * strength * gate_l_scaled
|
167 |
+
clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
|
168 |
+
|
169 |
if use_anchor:
|
170 |
+
clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
|
171 |
if noise > 0:
|
172 |
clip_l_mod += torch.randn_like(clip_l_mod) * noise
|
173 |
else:
|
174 |
clip_l_mod = clip_embeds["clip_l"]
|
175 |
delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
|
176 |
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
|
177 |
+
g_pred_l, tau_l = torch.tensor(0.0), torch.tensor(0.0)
|
|
|
178 |
|
179 |
# Apply CLIP-G adapter
|
180 |
if adapter_g is not None:
|
181 |
+
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(
|
182 |
+
t5_seq.float(), clip_embeds["clip_g"].float()
|
183 |
+
)
|
184 |
gate_g_scaled = gate_g * gate_prob
|
185 |
delta_g_final = delta_g * strength * gate_g_scaled
|
186 |
+
clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
|
187 |
+
|
188 |
if use_anchor:
|
189 |
+
clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
|
190 |
if noise > 0:
|
191 |
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
192 |
else:
|
193 |
clip_g_mod = clip_embeds["clip_g"]
|
194 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
195 |
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
|
196 |
+
g_pred_g, tau_g = torch.tensor(0.0), torch.tensor(0.0)
|
|
|
197 |
|
198 |
+
# Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
|
199 |
+
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
200 |
+
neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1)
|
201 |
|
202 |
+
# Generate image
|
203 |
image = pipe(
|
204 |
prompt_embeds=prompt_embeds,
|
205 |
pooled_prompt_embeds=clip_embeds["pooled"],
|
|
|
209 |
guidance_scale=cfg_scale,
|
210 |
width=width,
|
211 |
height=height,
|
212 |
+
num_images_per_prompt=1,
|
213 |
generator=torch.Generator(device=device).manual_seed(seed) if seed != -1 else None
|
214 |
).images[0]
|
215 |
|
216 |
+
# Create visualizations
|
217 |
+
delta_l_viz = plot_heat(delta_l_final.squeeze().cpu().numpy(), "CLIP-L Delta Values")
|
218 |
+
gate_l_viz = plot_heat(gate_l_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-L Gate Activations")
|
219 |
+
delta_g_viz = plot_heat(delta_g_final.squeeze().cpu().numpy(), "CLIP-G Delta Values")
|
220 |
+
gate_g_viz = plot_heat(gate_g_scaled.squeeze().cpu().numpy().mean(axis=-1, keepdims=True), "CLIP-G Gate Activations")
|
221 |
+
|
222 |
+
# Statistics
|
223 |
+
stats_l = f"g_pred_l: {g_pred_l.mean().item():.3f}, Ο_l: {tau_l.mean().item():.3f}"
|
224 |
+
stats_g = f"g_pred_g: {g_pred_g.mean().item():.3f}, Ο_g: {tau_g.mean().item():.3f}"
|
225 |
+
|
226 |
+
return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g
|
227 |
|
228 |
# βββ Gradio Interface βββββββββββββββββββββββββββββββββββββββββ
|
229 |
+
def create_interface():
|
230 |
+
with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
|
231 |
+
gr.Markdown("# π§ SDXL Dual Shunt Adapter")
|
232 |
+
gr.Markdown("*Enhance SDXL generation using T5 semantic understanding to modify CLIP embeddings*")
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column(scale=1):
|
236 |
+
# Prompts
|
237 |
+
gr.Markdown("### π Prompts")
|
238 |
prompt = gr.Textbox(
|
239 |
+
label="Prompt",
|
240 |
value="a futuristic control station with holographic displays",
|
241 |
+
lines=3,
|
242 |
+
placeholder="Describe what you want to generate..."
|
243 |
)
|
244 |
negative_prompt = gr.Textbox(
|
245 |
label="Negative Prompt",
|
246 |
value="blurry, low quality, distorted",
|
247 |
+
lines=2,
|
248 |
+
placeholder="Describe what you want to avoid..."
|
249 |
)
|
250 |
+
|
251 |
+
# Adapters
|
252 |
+
gr.Markdown("### βοΈ Adapters")
|
|
|
253 |
adapter_l = gr.Dropdown(
|
254 |
+
choices=["None"] + clip_l_opts,
|
255 |
label="CLIP-L (768d) Adapter",
|
256 |
+
value="None",
|
257 |
+
info="Choose adapter for CLIP-L embeddings"
|
258 |
)
|
259 |
adapter_g = gr.Dropdown(
|
260 |
+
choices=["None"] + clip_g_opts,
|
261 |
+
label="CLIP-G (1280d) Adapter",
|
262 |
+
value="None",
|
263 |
+
info="Choose adapter for CLIP-G embeddings"
|
264 |
)
|
265 |
+
|
266 |
+
# Controls
|
267 |
+
gr.Markdown("### ποΈ Adapter Controls")
|
|
|
268 |
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
|
269 |
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
|
270 |
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
|
271 |
+
use_anchor = gr.Checkbox(label="Use Anchor Points", value=True)
|
272 |
+
|
273 |
+
# Generation Settings
|
274 |
+
gr.Markdown("### π¨ Generation Settings")
|
|
|
275 |
with gr.Row():
|
276 |
+
steps = gr.Slider(1, 50, value=25, step=1, label="Steps")
|
277 |
+
cfg_scale = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG Scale")
|
278 |
|
279 |
scheduler_name = gr.Dropdown(
|
280 |
choices=list(SCHEDULERS.keys()),
|
|
|
287 |
height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
|
288 |
|
289 |
seed = gr.Number(value=-1, label="Seed (-1 for random)")
|
290 |
+
|
291 |
+
generate_btn = gr.Button("π Generate Image", variant="primary", size="lg")
|
292 |
|
293 |
+
with gr.Column(scale=1):
|
294 |
+
# Output
|
295 |
+
gr.Markdown("### πΌοΈ Generated Image")
|
296 |
+
output_image = gr.Image(label="Result", height=400, show_label=False)
|
297 |
+
|
298 |
+
# Visualizations
|
299 |
+
gr.Markdown("### π Adapter Analysis")
|
|
|
|
|
|
|
|
|
300 |
with gr.Row():
|
301 |
+
delta_l_img = gr.Image(label="CLIP-L Deltas", height=200)
|
302 |
+
gate_l_img = gr.Image(label="CLIP-L Gates", height=200)
|
303 |
with gr.Row():
|
304 |
+
delta_g_img = gr.Image(label="CLIP-G Deltas", height=200)
|
305 |
+
gate_g_img = gr.Image(label="CLIP-G Gates", height=200)
|
306 |
+
|
307 |
+
# Statistics
|
308 |
+
gr.Markdown("### π Statistics")
|
309 |
+
stats_l_text = gr.Textbox(label="CLIP-L Metrics", interactive=False)
|
310 |
+
stats_g_text = gr.Textbox(label="CLIP-G Metrics", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
+
# Event handler
|
313 |
+
def run_generation(*args):
|
314 |
+
# Process adapter selections
|
315 |
+
processed_args = list(args)
|
316 |
+
processed_args[2] = None if args[2] == "None" else args[2] # adapter_l
|
317 |
+
processed_args[3] = None if args[3] == "None" else args[3] # adapter_g
|
318 |
+
return infer(*processed_args)
|
319 |
|
320 |
+
generate_btn.click(
|
321 |
+
fn=run_generation,
|
322 |
+
inputs=[
|
323 |
+
prompt, negative_prompt, adapter_l, adapter_g, strength, noise, gate_prob,
|
324 |
+
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed
|
325 |
+
],
|
326 |
+
outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text]
|
327 |
+
)
|
328 |
|
329 |
+
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
+
# βββ Launch ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
332 |
if __name__ == "__main__":
|
333 |
+
demo = create_interface()
|
334 |
demo.launch()
|