Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
a4e1cd2
1
Parent(s):
12aa86c
app.py
CHANGED
@@ -4,7 +4,6 @@ import gradio as gr
|
|
4 |
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
6 |
from PIL import Image
|
7 |
-
import spaces
|
8 |
from transformers import T5Tokenizer, T5EncoderModel
|
9 |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
10 |
from safetensors.torch import load_file
|
@@ -13,20 +12,14 @@ from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
|
13 |
from configs import T5_SHUNT_REPOS
|
14 |
|
15 |
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
# T5 Model for semantic understanding
|
20 |
-
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
21 |
-
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
variant="fp16" if dtype == torch.float16 else None,
|
28 |
-
use_safetensors=True
|
29 |
-
).to(device)
|
30 |
|
31 |
# Available schedulers
|
32 |
SCHEDULERS = {
|
@@ -47,6 +40,7 @@ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
|
|
47 |
from safetensors.torch import safe_open
|
48 |
|
49 |
def load_adapter(repo, filename, config):
|
|
|
50 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
51 |
|
52 |
model = TwoStreamShuntAdapter(config).eval()
|
@@ -55,7 +49,7 @@ def load_adapter(repo, filename, config):
|
|
55 |
for key in f.keys():
|
56 |
tensors[key] = f.get_tensor(key)
|
57 |
model.load_state_dict(tensors)
|
58 |
-
|
59 |
return model
|
60 |
|
61 |
# βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ
|
@@ -135,11 +129,34 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
|
|
135 |
}
|
136 |
|
137 |
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
138 |
@spaces.GPU
|
139 |
@torch.no_grad()
|
140 |
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
|
141 |
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
# Set seed for reproducibility
|
144 |
if seed != -1:
|
145 |
torch.manual_seed(seed)
|
@@ -168,8 +185,8 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
168 |
print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
|
169 |
|
170 |
# Load adapters
|
171 |
-
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
|
172 |
-
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
|
173 |
|
174 |
# Apply CLIP-L adapter
|
175 |
if adapter_l is not None:
|
|
|
4 |
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
6 |
from PIL import Image
|
|
|
7 |
from transformers import T5Tokenizer, T5EncoderModel
|
8 |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
9 |
from safetensors.torch import load_file
|
|
|
12 |
from configs import T5_SHUNT_REPOS
|
13 |
|
14 |
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
|
15 |
+
# Don't initialize CUDA here for ZeroGPU compatibility
|
16 |
+
device = None # Will be set inside the GPU function
|
17 |
+
dtype = torch.float16
|
|
|
|
|
|
|
18 |
|
19 |
+
# Don't load models here - will load inside GPU function
|
20 |
+
t5_tok = None
|
21 |
+
t5_mod = None
|
22 |
+
pipe = None
|
|
|
|
|
|
|
23 |
|
24 |
# Available schedulers
|
25 |
SCHEDULERS = {
|
|
|
40 |
from safetensors.torch import safe_open
|
41 |
|
42 |
def load_adapter(repo, filename, config):
|
43 |
+
# Don't initialize device here
|
44 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
45 |
|
46 |
model = TwoStreamShuntAdapter(config).eval()
|
|
|
49 |
for key in f.keys():
|
50 |
tensors[key] = f.get_tensor(key)
|
51 |
model.load_state_dict(tensors)
|
52 |
+
# Device will be set when called from GPU function
|
53 |
return model
|
54 |
|
55 |
# βββ Visualization ββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
129 |
}
|
130 |
|
131 |
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
|
132 |
+
@torch.no_grad()
|
133 |
+
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
|
134 |
+
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
|
135 |
+
|
136 |
+
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββ
|
137 |
@spaces.GPU
|
138 |
@torch.no_grad()
|
139 |
def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
|
140 |
use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
|
141 |
|
142 |
+
# Initialize device and models inside GPU context
|
143 |
+
global t5_tok, t5_mod, pipe
|
144 |
+
device = torch.device("cuda")
|
145 |
+
dtype = torch.float16
|
146 |
+
|
147 |
+
# Load models if not already loaded
|
148 |
+
if t5_tok is None:
|
149 |
+
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
150 |
+
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
|
151 |
+
|
152 |
+
if pipe is None:
|
153 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
154 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
155 |
+
torch_dtype=dtype,
|
156 |
+
variant="fp16",
|
157 |
+
use_safetensors=True
|
158 |
+
).to(device)
|
159 |
+
|
160 |
# Set seed for reproducibility
|
161 |
if seed != -1:
|
162 |
torch.manual_seed(seed)
|
|
|
185 |
print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
|
186 |
|
187 |
# Load adapters
|
188 |
+
adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
|
189 |
+
adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
|
190 |
|
191 |
# Apply CLIP-L adapter
|
192 |
if adapter_l is not None:
|