Spaces:
Running
on
Zero
Running
on
Zero
Update ledits/pipeline_leditspp_stable_diffusion_xl.py
Browse files
ledits/pipeline_leditspp_stable_diffusion_xl.py
CHANGED
|
@@ -613,11 +613,10 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 613 |
else:
|
| 614 |
# "2" because SDXL always indexes from the penultimate layer.
|
| 615 |
edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
if avg_diff is not None
|
| 619 |
-
|
| 620 |
-
print("SHALOM")
|
| 621 |
normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
|
| 622 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 623 |
if i == 0:
|
|
@@ -626,14 +625,26 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 626 |
standard_weights = torch.ones_like(weights)
|
| 627 |
|
| 628 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 629 |
-
edit_concepts_embeds = edit_concepts_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
else:
|
| 631 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 632 |
|
| 633 |
standard_weights = torch.ones_like(weights)
|
| 634 |
|
| 635 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 636 |
-
edit_concepts_embeds = edit_concepts_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
|
| 638 |
edit_prompt_embeds_list.append(edit_concepts_embeds)
|
| 639 |
i+=1
|
|
|
|
| 613 |
else:
|
| 614 |
# "2" because SDXL always indexes from the penultimate layer.
|
| 615 |
edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
if avg_diff is not None:
|
| 619 |
+
|
|
|
|
| 620 |
normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
|
| 621 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 622 |
if i == 0:
|
|
|
|
| 625 |
standard_weights = torch.ones_like(weights)
|
| 626 |
|
| 627 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 628 |
+
edit_concepts_embeds = edit_concepts_embeds + (
|
| 629 |
+
weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
| 630 |
+
|
| 631 |
+
if avg_diff_2nd is not None:
|
| 632 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
|
| 633 |
+
self.pipe.tokenizer.model_max_length,
|
| 634 |
+
1) * scale_2nd)
|
| 635 |
else:
|
| 636 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 637 |
|
| 638 |
standard_weights = torch.ones_like(weights)
|
| 639 |
|
| 640 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 641 |
+
edit_concepts_embeds = edit_concepts_embeds + (
|
| 642 |
+
weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
| 643 |
+
if avg_diff_2nd is not None:
|
| 644 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
|
| 645 |
+
self.pipe.tokenizer_2.model_max_length,
|
| 646 |
+
1) * scale_2nd)
|
| 647 |
+
|
| 648 |
|
| 649 |
edit_prompt_embeds_list.append(edit_concepts_embeds)
|
| 650 |
i+=1
|