AbstractPhil commited on
Commit
a4e1cd2
Β·
1 Parent(s): 12aa86c
Files changed (1) hide show
  1. app.py +34 -17
app.py CHANGED
@@ -4,7 +4,6 @@ import gradio as gr
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from PIL import Image
7
- import spaces
8
  from transformers import T5Tokenizer, T5EncoderModel
9
  from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
10
  from safetensors.torch import load_file
@@ -13,20 +12,14 @@ from two_stream_shunt_adapter import TwoStreamShuntAdapter
13
  from configs import T5_SHUNT_REPOS
14
 
15
  # ─── Device & Model Setup ─────────────────────────────────────
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
-
19
- # T5 Model for semantic understanding
20
- t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
21
- t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
22
 
23
- # SDXL Pipeline with proper text encoders
24
- pipe = StableDiffusionXLPipeline.from_pretrained(
25
- "stabilityai/stable-diffusion-xl-base-1.0",
26
- torch_dtype=dtype,
27
- variant="fp16" if dtype == torch.float16 else None,
28
- use_safetensors=True
29
- ).to(device)
30
 
31
  # Available schedulers
32
  SCHEDULERS = {
@@ -47,6 +40,7 @@ config_g = T5_SHUNT_REPOS["clip_g"]["config"]
47
  from safetensors.torch import safe_open
48
 
49
  def load_adapter(repo, filename, config):
 
50
  path = hf_hub_download(repo_id=repo, filename=filename)
51
 
52
  model = TwoStreamShuntAdapter(config).eval()
@@ -55,7 +49,7 @@ def load_adapter(repo, filename, config):
55
  for key in f.keys():
56
  tensors[key] = f.get_tensor(key)
57
  model.load_state_dict(tensors)
58
- model.to(device)
59
  return model
60
 
61
  # ─── Visualization ────────────────────────────────────────────
@@ -135,11 +129,34 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
135
  }
136
 
137
  # ─── Inference ────────────────────────────────────────────────
 
 
 
 
 
138
  @spaces.GPU
139
  @torch.no_grad()
140
  def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
141
  use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # Set seed for reproducibility
144
  if seed != -1:
145
  torch.manual_seed(seed)
@@ -168,8 +185,8 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
168
  print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
169
 
170
  # Load adapters
171
- adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
172
- adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
173
 
174
  # Apply CLIP-L adapter
175
  if adapter_l is not None:
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from PIL import Image
 
7
  from transformers import T5Tokenizer, T5EncoderModel
8
  from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
9
  from safetensors.torch import load_file
 
12
  from configs import T5_SHUNT_REPOS
13
 
14
  # ─── Device & Model Setup ─────────────────────────────────────
15
+ # Don't initialize CUDA here for ZeroGPU compatibility
16
+ device = None # Will be set inside the GPU function
17
+ dtype = torch.float16
 
 
 
18
 
19
+ # Don't load models here - will load inside GPU function
20
+ t5_tok = None
21
+ t5_mod = None
22
+ pipe = None
 
 
 
23
 
24
  # Available schedulers
25
  SCHEDULERS = {
 
40
  from safetensors.torch import safe_open
41
 
42
  def load_adapter(repo, filename, config):
43
+ # Don't initialize device here
44
  path = hf_hub_download(repo_id=repo, filename=filename)
45
 
46
  model = TwoStreamShuntAdapter(config).eval()
 
49
  for key in f.keys():
50
  tensors[key] = f.get_tensor(key)
51
  model.load_state_dict(tensors)
52
+ # Device will be set when called from GPU function
53
  return model
54
 
55
  # ─── Visualization ────────────────────────────────────────────
 
129
  }
130
 
131
  # ─── Inference ────────────────────────────────────────────────
132
+ @torch.no_grad()
133
+ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
134
+ use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
135
+
136
+ # ─── Inference ────────────────────────────────────────────
137
  @spaces.GPU
138
  @torch.no_grad()
139
  def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob,
140
  use_anchor, steps, cfg_scale, scheduler_name, width, height, seed):
141
 
142
+ # Initialize device and models inside GPU context
143
+ global t5_tok, t5_mod, pipe
144
+ device = torch.device("cuda")
145
+ dtype = torch.float16
146
+
147
+ # Load models if not already loaded
148
+ if t5_tok is None:
149
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
150
+ t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
151
+
152
+ if pipe is None:
153
+ pipe = StableDiffusionXLPipeline.from_pretrained(
154
+ "stabilityai/stable-diffusion-xl-base-1.0",
155
+ torch_dtype=dtype,
156
+ variant="fp16",
157
+ use_safetensors=True
158
+ ).to(device)
159
+
160
  # Set seed for reproducibility
161
  if seed != -1:
162
  torch.manual_seed(seed)
 
185
  print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
186
 
187
  # Load adapters
188
+ adapter_l = load_adapter(repo_l, adapter_l_file, config_l).to(device) if adapter_l_file else None
189
+ adapter_g = load_adapter(repo_g, adapter_g_file, config_g).to(device) if adapter_g_file else None
190
 
191
  # Apply CLIP-L adapter
192
  if adapter_l is not None: