omer11a commited on
Commit
86a22a6
·
1 Parent(s): bdfcbd1

Load model to CPU

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -6,11 +6,14 @@ import numpy as np
6
  from PIL import Image, ImageDraw
7
 
8
  from diffusers import DDIMScheduler
 
9
  from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
10
- from injection_utils import register_attention_editor_diffusers
11
  from bounded_attention import BoundedAttention
12
  from pytorch_lightning import seed_everything
13
 
 
 
14
  MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
15
  RESOLUTION = 256
16
  MIN_SIZE = 0.01
@@ -111,6 +114,7 @@ FOOTNOTE = """
111
 
112
 
113
  def inference(
 
114
  boxes,
115
  prompts,
116
  subject_token_indices,
@@ -131,10 +135,7 @@ def inference(
131
  raise gr.Error("cuda is not available")
132
 
133
  device = torch.device("cuda")
134
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
135
- model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16).to(device)
136
- model.unet.set_default_attn_processor()
137
- model.enable_sequential_cpu_offload()
138
 
139
  seed_everything(seed)
140
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
@@ -159,11 +160,15 @@ def inference(
159
  )
160
 
161
  register_attention_editor_diffusers(model, editor)
162
- return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
 
 
 
163
 
164
 
165
  @spaces.GPU(duration=300)
166
  def generate(
 
167
  prompt,
168
  subject_token_indices,
169
  filter_token_indices,
@@ -193,7 +198,7 @@ def generate(
193
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
194
 
195
  images = inference(
196
- boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
197
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
198
  num_iterations, loss_threshold, num_guidance_steps, seed)
199
 
@@ -249,6 +254,11 @@ def clear(batch_size):
249
 
250
  def main():
251
  nltk.download("averaged_perceptron_tagger")
 
 
 
 
 
252
 
253
  with gr.Blocks(
254
  css=CSS,
@@ -320,7 +330,7 @@ def main():
320
  )
321
 
322
  generate_image_button.click(
323
- fn=generate,
324
  inputs=[
325
  prompt, subject_token_indices, filter_token_indices, num_tokens,
326
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
 
6
  from PIL import Image, ImageDraw
7
 
8
  from diffusers import DDIMScheduler
9
+ from diffusers.models.attention_processor import AttnProcessor2_0
10
  from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
11
+ from injection_utils import register_attention_editor_diffusers, unregister_attention_editor_diffusers
12
  from bounded_attention import BoundedAttention
13
  from pytorch_lightning import seed_everything
14
 
15
+ from functools import partial
16
+
17
  MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
18
  RESOLUTION = 256
19
  MIN_SIZE = 0.01
 
114
 
115
 
116
  def inference(
117
+ model,
118
  boxes,
119
  prompts,
120
  subject_token_indices,
 
135
  raise gr.Error("cuda is not available")
136
 
137
  device = torch.device("cuda")
138
+ model.to(device)
 
 
 
139
 
140
  seed_everything(seed)
141
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
 
160
  )
161
 
162
  register_attention_editor_diffusers(model, editor)
163
+ images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
164
+ unregister_attention_editor_diffusers(model)
165
+ model.to(torch.device("cpu"))
166
+ return images
167
 
168
 
169
  @spaces.GPU(duration=300)
170
  def generate(
171
+ model,
172
  prompt,
173
  subject_token_indices,
174
  filter_token_indices,
 
198
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
199
 
200
  images = inference(
201
+ model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
202
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
203
  num_iterations, loss_threshold, num_guidance_steps, seed)
204
 
 
254
 
255
  def main():
256
  nltk.download("averaged_perceptron_tagger")
257
+
258
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
259
+ model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16).to(device)
260
+ model.unet.set_attn_processor(AttnProcessor2_0())
261
+ model.enable_sequential_cpu_offload()
262
 
263
  with gr.Blocks(
264
  css=CSS,
 
330
  )
331
 
332
  generate_image_button.click(
333
+ fn=partial(generate, model),
334
  inputs=[
335
  prompt, subject_token_indices, filter_token_indices, num_tokens,
336
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,