Spaces:
Runtime error
Runtime error
Allow repeated inference
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
|
|
10 |
from injection_utils import regiter_attention_editor_diffusers
|
11 |
from bounded_attention import BoundedAttention
|
12 |
from pytorch_lightning import seed_everything
|
|
|
13 |
|
14 |
from functools import partial
|
15 |
|
@@ -40,26 +41,45 @@ def inference(
|
|
40 |
):
|
41 |
seed_everything(seed)
|
42 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
max_guidance_iter=num_guidance_steps
|
54 |
-
max_guidance_iter_per_step=num_iterations
|
55 |
-
start_step_size=init_step_size
|
56 |
-
|
57 |
-
loss_stopping_value=loss_threshold
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
|
64 |
|
65 |
|
|
|
10 |
from injection_utils import regiter_attention_editor_diffusers
|
11 |
from bounded_attention import BoundedAttention
|
12 |
from pytorch_lightning import seed_everything
|
13 |
+
from torch_kmeans import KMeans
|
14 |
|
15 |
from functools import partial
|
16 |
|
|
|
41 |
):
|
42 |
seed_everything(seed)
|
43 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
44 |
+
eos_token_index = num_tokens + 1
|
45 |
+
|
46 |
+
if hasattr(model, 'editor'):
|
47 |
+
editor.boxes = boxes
|
48 |
+
editor.prompts = prompts
|
49 |
+
editor.subject_token_indices = subject_token_indices
|
50 |
+
editor.filter_token_indices = filter_token_indices
|
51 |
+
editor.eos_token_index = eos_token_index
|
52 |
+
editor.cross_loss_coef = cross_loss_scale
|
53 |
+
editor.self_loss_coef = self_loss_scale
|
54 |
+
editor.max_guidance_iter = num_guidance_steps
|
55 |
+
editor.max_guidance_iter_per_step = num_iterations
|
56 |
+
editor.start_step_size = init_step_size
|
57 |
+
self.step_size_coef = (final_step_size - init_step_size) / num_guidance_steps
|
58 |
+
editor.loss_stopping_value = loss_threshold
|
59 |
+
num_clusters = len(boxes) * num_clusters_per_subject
|
60 |
+
self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
|
61 |
+
|
62 |
+
else:
|
63 |
+
editor = BoundedAttention(
|
64 |
+
boxes,
|
65 |
+
prompts,
|
66 |
+
subject_token_indices,
|
67 |
+
list(range(70, 82)),
|
68 |
+
list(range(70, 82)),
|
69 |
+
filter_token_indices=filter_token_indices,
|
70 |
+
eos_token_index=eos_token_index,
|
71 |
+
cross_loss_coef=cross_loss_scale,
|
72 |
+
self_loss_coef=self_loss_scale,
|
73 |
+
max_guidance_iter=num_guidance_steps,
|
74 |
+
max_guidance_iter_per_step=num_iterations,
|
75 |
+
start_step_size=init_step_size,
|
76 |
+
end_step_size=final_step_size,
|
77 |
+
loss_stopping_value=loss_threshold,
|
78 |
+
num_clusters_per_box=num_clusters_per_subject,
|
79 |
+
)
|
80 |
+
|
81 |
+
regiter_attention_editor_diffusers(model, editor)
|
82 |
+
|
83 |
return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
|
84 |
|
85 |
|