Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -626,11 +626,10 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 626 |
else:
|
| 627 |
# "2" because SDXL always indexes from the penultimate layer.
|
| 628 |
edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
if avg_diff is not None
|
| 632 |
-
|
| 633 |
-
print("SHALOM")
|
| 634 |
normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
|
| 635 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 636 |
if i == 0:
|
|
@@ -639,14 +638,26 @@ class LEditsPPPipelineStableDiffusionXL(
|
|
| 639 |
standard_weights = torch.ones_like(weights)
|
| 640 |
|
| 641 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 642 |
-
edit_concepts_embeds = edit_concepts_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
else:
|
| 644 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 645 |
|
| 646 |
standard_weights = torch.ones_like(weights)
|
| 647 |
|
| 648 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 649 |
-
edit_concepts_embeds = edit_concepts_embeds + (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
edit_prompt_embeds_list.append(edit_concepts_embeds)
|
| 652 |
i+=1
|
|
|
|
| 626 |
else:
|
| 627 |
# "2" because SDXL always indexes from the penultimate layer.
|
| 628 |
edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
if avg_diff is not None:
|
| 632 |
+
|
|
|
|
| 633 |
normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
|
| 634 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 635 |
if i == 0:
|
|
|
|
| 638 |
standard_weights = torch.ones_like(weights)
|
| 639 |
|
| 640 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 641 |
+
edit_concepts_embeds = edit_concepts_embeds + (
|
| 642 |
+
weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
| 643 |
+
|
| 644 |
+
if avg_diff_2nd is not None:
|
| 645 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
|
| 646 |
+
self.pipe.tokenizer.model_max_length,
|
| 647 |
+
1) * scale_2nd)
|
| 648 |
else:
|
| 649 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 650 |
|
| 651 |
standard_weights = torch.ones_like(weights)
|
| 652 |
|
| 653 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 654 |
+
edit_concepts_embeds = edit_concepts_embeds + (
|
| 655 |
+
weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
|
| 656 |
+
if avg_diff_2nd is not None:
|
| 657 |
+
edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
|
| 658 |
+
self.pipe.tokenizer_2.model_max_length,
|
| 659 |
+
1) * scale_2nd)
|
| 660 |
+
|
| 661 |
|
| 662 |
edit_prompt_embeds_list.append(edit_concepts_embeds)
|
| 663 |
i+=1
|