omer11a commited on
Commit
e7e4657
·
1 Parent(s): 4b19f84

Creates model inside GPU

Browse files
Files changed (1) hide show
  1. app.py +26 -46
app.py CHANGED
@@ -10,7 +10,6 @@ from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
10
  from injection_utils import regiter_attention_editor_diffusers
11
  from bounded_attention import BoundedAttention
12
  from pytorch_lightning import seed_everything
13
- from torch_kmeans import KMeans
14
 
15
  from functools import partial
16
 
@@ -21,7 +20,6 @@ COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
21
 
22
 
23
  def inference(
24
- model,
25
  boxes,
26
  prompts,
27
  subject_token_indices,
@@ -42,55 +40,41 @@ def inference(
42
  raise gr.Error("cuda is not available")
43
 
44
  device = torch.device("cuda")
45
- model = model.to(device=device, dtype=torch.float16)
 
 
 
 
46
 
47
  seed_everything(seed)
48
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
49
  eos_token_index = num_tokens + 1
50
 
51
- if hasattr(model, 'editor'):
52
- editor.boxes = boxes
53
- editor.prompts = prompts
54
- editor.subject_token_indices = subject_token_indices
55
- editor.filter_token_indices = filter_token_indices
56
- editor.eos_token_index = eos_token_index
57
- editor.cross_loss_coef = cross_loss_scale
58
- editor.self_loss_coef = self_loss_scale
59
- editor.max_guidance_iter = num_guidance_steps
60
- editor.max_guidance_iter_per_step = num_iterations
61
- editor.start_step_size = init_step_size
62
- self.step_size_coef = (final_step_size - init_step_size) / num_guidance_steps
63
- editor.loss_stopping_value = loss_threshold
64
- num_clusters = len(boxes) * num_clusters_per_subject
65
- self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
66
-
67
- else:
68
- editor = BoundedAttention(
69
- boxes,
70
- prompts,
71
- subject_token_indices,
72
- list(range(70, 82)),
73
- list(range(70, 82)),
74
- filter_token_indices=filter_token_indices,
75
- eos_token_index=eos_token_index,
76
- cross_loss_coef=cross_loss_scale,
77
- self_loss_coef=self_loss_scale,
78
- max_guidance_iter=num_guidance_steps,
79
- max_guidance_iter_per_step=num_iterations,
80
- start_step_size=init_step_size,
81
- end_step_size=final_step_size,
82
- loss_stopping_value=loss_threshold,
83
- num_clusters_per_box=num_clusters_per_subject,
84
- )
85
-
86
- regiter_attention_editor_diffusers(model, editor)
87
 
88
  return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
89
 
90
 
91
  @spaces.GPU
92
  def generate(
93
- model,
94
  prompt,
95
  subject_token_indices,
96
  filter_token_indices,
@@ -120,7 +104,7 @@ def generate(
120
  prompts = [prompt.strip('.').strip(',').strip()] * batch_size
121
 
122
  images = inference(
123
- model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
124
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
125
  num_iterations, loss_threshold, num_guidance_steps, seed)
126
 
@@ -214,10 +198,6 @@ def main():
214
  }
215
  """
216
 
217
- model_path = "stabilityai/stable-diffusion-xl-base-1.0"
218
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
219
- model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler)
220
-
221
  nltk.download('averaged_perceptron_tagger')
222
 
223
  with gr.Blocks(
@@ -328,7 +308,7 @@ def main():
328
  )
329
 
330
  generate_image_button.click(
331
- fn=partial(generate, model),
332
  inputs=[
333
  prompt, subject_token_indices, filter_token_indices, num_tokens,
334
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
 
10
  from injection_utils import regiter_attention_editor_diffusers
11
  from bounded_attention import BoundedAttention
12
  from pytorch_lightning import seed_everything
 
13
 
14
  from functools import partial
15
 
 
20
 
21
 
22
  def inference(
 
23
  boxes,
24
  prompts,
25
  subject_token_indices,
 
40
  raise gr.Error("cuda is not available")
41
 
42
  device = torch.device("cuda")
43
+ model_path = "stabilityai/stable-diffusion-xl-base-1.0"
44
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
45
+ model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler, device=device, torch_dtype=torch.float16)
46
+ model.unet.set_default_attn_processor()
47
+ model.enable_sequential_cpu_offload()
48
 
49
  seed_everything(seed)
50
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
51
  eos_token_index = num_tokens + 1
52
 
53
+ editor = BoundedAttention(
54
+ boxes,
55
+ prompts,
56
+ subject_token_indices,
57
+ list(range(70, 82)),
58
+ list(range(70, 82)),
59
+ filter_token_indices=filter_token_indices,
60
+ eos_token_index=eos_token_index,
61
+ cross_loss_coef=cross_loss_scale,
62
+ self_loss_coef=self_loss_scale,
63
+ max_guidance_iter=num_guidance_steps,
64
+ max_guidance_iter_per_step=num_iterations,
65
+ start_step_size=init_step_size,
66
+ end_step_size=final_step_size,
67
+ loss_stopping_value=loss_threshold,
68
+ num_clusters_per_box=num_clusters_per_subject,
69
+ )
70
+
71
+ regiter_attention_editor_diffusers(model, editor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
74
 
75
 
76
  @spaces.GPU
77
  def generate(
 
78
  prompt,
79
  subject_token_indices,
80
  filter_token_indices,
 
104
  prompts = [prompt.strip('.').strip(',').strip()] * batch_size
105
 
106
  images = inference(
107
+ boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
108
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
109
  num_iterations, loss_threshold, num_guidance_steps, seed)
110
 
 
198
  }
199
  """
200
 
 
 
 
 
201
  nltk.download('averaged_perceptron_tagger')
202
 
203
  with gr.Blocks(
 
308
  )
309
 
310
  generate_image_button.click(
311
+ fn=generate,
312
  inputs=[
313
  prompt, subject_token_indices, filter_token_indices, num_tokens,
314
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,