omer11a commited on
Commit
137c79d
·
1 Parent(s): fa4d90e

Allow repeated inference

Browse files
Files changed (1) hide show
  1. app.py +40 -20
app.py CHANGED
@@ -10,6 +10,7 @@ from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
10
  from injection_utils import regiter_attention_editor_diffusers
11
  from bounded_attention import BoundedAttention
12
  from pytorch_lightning import seed_everything
 
13
 
14
  from functools import partial
15
 
@@ -40,26 +41,45 @@ def inference(
40
  ):
41
  seed_everything(seed)
42
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
43
- editor = BoundedAttention(
44
- boxes,
45
- prompts,
46
- subject_token_indices,
47
- list(range(70, 82)),
48
- list(range(70, 82)),
49
- eos_token_index=num_tokens + 1,
50
- cross_loss_coef=cross_loss_scale,
51
- self_loss_coef=self_loss_scale,
52
- filter_token_indices=filter_token_indices,
53
- max_guidance_iter=num_guidance_steps,
54
- max_guidance_iter_per_step=num_iterations,
55
- start_step_size=init_step_size,
56
- end_step_size=final_step_size,
57
- loss_stopping_value=loss_threshold,
58
- num_clusters_per_box=num_clusters_per_subject,
59
- debug=False,
60
- )
61
-
62
- regiter_attention_editor_diffusers(model, editor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
64
 
65
 
 
10
  from injection_utils import regiter_attention_editor_diffusers
11
  from bounded_attention import BoundedAttention
12
  from pytorch_lightning import seed_everything
13
+ from torch_kmeans import KMeans
14
 
15
  from functools import partial
16
 
 
41
  ):
42
  seed_everything(seed)
43
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
44
+ eos_token_index = num_tokens + 1
45
+
46
+ if hasattr(model, 'editor'):
47
+ editor.boxes = boxes
48
+ editor.prompts = prompts
49
+ editor.subject_token_indices = subject_token_indices
50
+ editor.filter_token_indices = filter_token_indices
51
+ editor.eos_token_index = eos_token_index
52
+ editor.cross_loss_coef = cross_loss_scale
53
+ editor.self_loss_coef = self_loss_scale
54
+ editor.max_guidance_iter = num_guidance_steps
55
+ editor.max_guidance_iter_per_step = num_iterations
56
+ editor.start_step_size = init_step_size
57
+ self.step_size_coef = (final_step_size - init_step_size) / num_guidance_steps
58
+ editor.loss_stopping_value = loss_threshold
59
+ num_clusters = len(boxes) * num_clusters_per_subject
60
+ self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
61
+
62
+ else:
63
+ editor = BoundedAttention(
64
+ boxes,
65
+ prompts,
66
+ subject_token_indices,
67
+ list(range(70, 82)),
68
+ list(range(70, 82)),
69
+ filter_token_indices=filter_token_indices,
70
+ eos_token_index=eos_token_index,
71
+ cross_loss_coef=cross_loss_scale,
72
+ self_loss_coef=self_loss_scale,
73
+ max_guidance_iter=num_guidance_steps,
74
+ max_guidance_iter_per_step=num_iterations,
75
+ start_step_size=init_step_size,
76
+ end_step_size=final_step_size,
77
+ loss_stopping_value=loss_threshold,
78
+ num_clusters_per_box=num_clusters_per_subject,
79
+ )
80
+
81
+ regiter_attention_editor_diffusers(model, editor)
82
+
83
  return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
84
 
85