latentnavigation-flux / clip_slider_pipeline.py
multimodalart's picture
Update clip_slider_pipeline.py
0f0144b verified
raw
history blame
30.3 kB
import diffusers
import torch
import random
from tqdm import tqdm
from constants import SUBJECTS, MEDIUMS
from PIL import Image
import math # For acos, sin
# Slerp (Spherical Linear Interpolation) function
def slerp(v0, v1, t, DOT_THRESHOLD=0.9995):
"""
Spherical linear interpolation.
v0, v1: Tensors to interpolate between.
t: Interpolation factor (scalar or tensor).
DOT_THRESHOLD: Threshold for considering vectors collinear.
"""
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=v0.device, dtype=v0.dtype)
# Dot product
dot = torch.sum(v0 * v1 / (torch.norm(v0, dim=-1, keepdim=True) * torch.norm(v1, dim=-1, keepdim=True) + 1e-8), dim=-1, keepdim=True)
# If vectors are too close, use linear interpolation (LERP)
# This also handles t=0 and t=1 correctly if dot is 1.
# Also, if dot is -1 (opposite), omega is pi.
if torch.any(torch.abs(dot) > DOT_THRESHOLD):
# For Slerp, if they are too close, omega is small, sin(omega) is small.
# Fallback to LERP for stability and when vectors are nearly collinear.
# However, the general Slerp formula handles this if dot is clamped.
# Let's use the standard formula but ensure stability.
pass # Continue to Slerp formula with clamping
# Clamp dot to prevent NaN from acos due to floating point errors.
dot = torch.clamp(dot, -1.0, 1.0)
omega = torch.acos(dot) # Angle between vectors
# Get magnitudes for later linear interpolation of magnitude
mag_v0 = torch.norm(v0, dim=-1, keepdim=True)
mag_v1 = torch.norm(v1, dim=-1, keepdim=True)
interpolated_mag = (1 - t) * mag_v0 + t * mag_v1
# Normalize v0 and v1 for pure Slerp on direction
v0_norm = v0 / (mag_v0 + 1e-8)
v1_norm = v1 / (mag_v1 + 1e-8)
# If sin_omega is very small, vectors are nearly collinear.
# LERP on normalized vectors is a good approximation.
# Then re-apply interpolated magnitude.
sin_omega = torch.sin(omega)
# Condition for LERP fallback (nearly collinear)
# Using a small epsilon for sin_omega
use_lerp_fallback = sin_omega.abs() < 1e-5
s0 = torch.sin((1 - t) * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability
s1 = torch.sin(t * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability
# For elements where LERP fallback is needed
s0[use_lerp_fallback] = 1.0 - t
s1[use_lerp_fallback] = t
result_norm = s0 * v0_norm + s1 * v1_norm
result = result_norm * interpolated_mag # Re-apply interpolated magnitude
return result.to(v0.dtype)
class CLIPSlider:
def __init__(
self,
sd_pipe,
device: torch.device,
target_word: str = "",
opposite: str = "",
target_word_2nd: str = "",
opposite_2nd: str = "",
iterations: int = 300,
):
self.device = device
self.pipe = sd_pipe.to(self.device, torch.float16)
self.iterations = iterations
if target_word != "" or opposite != "":
self.avg_diff = self.find_latent_direction(target_word, opposite)
else:
self.avg_diff = None
if target_word_2nd != "" or opposite_2nd != "":
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
else:
self.avg_diff_2nd = None
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
return avg_diff
def generate(self,
prompt = "a photo of a house",
scale = 2.,
scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None
seed = 15,
only_pooler = False,
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
with torch.no_grad():
toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
if self.avg_diff_2nd and normalize_scales:
denominator = abs(scale) + abs(scale_2nd)
scale = scale / denominator
scale_2nd = scale_2nd / denominator
if only_pooler:
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
if self.avg_diff_2nd:
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
else:
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
# weights = torch.sigmoid((weights-0.5)*7)
prompt_embeds = prompt_embeds + (
weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
if self.avg_diff_2nd:
prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
return images
def spectrum(self,
prompt="a photo of a house",
low_scale=-2,
low_scale_2nd=-2,
high_scale=2,
high_scale_2nd=2,
steps=5,
seed=15,
only_pooler=False,
normalize_scales=False,
correlation_weight_factor=1.0,
**pipeline_kwargs
):
images = []
for i in range(steps):
scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1)
image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs)
images.append(image[0])
canvas = Image.new('RGB', (640 * steps, 640))
for i, im in enumerate(images):
canvas.paste(im, (640 * i, 0))
return canvas
class CLIPSliderXL(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
positives2 = []
negatives2 = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
positives2.append(pos2)
negatives2.append(neg2)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
positives2 = torch.cat(positives2, dim=0)
negatives2 = torch.cat(negatives2, dim=0)
diffs2 = positives2 - negatives2
avg_diff2 = diffs2.mean(0, keepdim=True)
return (avg_diff, avg_diff2)
def generate(self,
prompt = "a photo of a house",
scale = 2,
scale_2nd = 2,
seed = 15,
only_pooler = False,
normalize_scales = False,
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
with torch.no_grad():
# toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.to(self.device)
# prompt_embeds = pipe.text_encoder(toks).last_hidden_state
prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders):
tokenizer = tokenizers[i]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
toks = text_inputs.input_ids
prompt_embeds = text_encoder(
toks.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
if self.avg_diff_2nd and normalize_scales:
denominator = abs(scale) + abs(scale_2nd)
scale = scale / denominator
scale_2nd = scale_2nd / denominator
if only_pooler:
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
if self.avg_diff_2nd:
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
else:
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
if i == 0:
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
if self.avg_diff_2nd:
prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
else:
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
if self.avg_diff_2nd:
prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
**pipeline_kwargs).images
return images
class CLIPSliderXL_inv(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
positives2 = []
negatives2 = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
positives2.append(pos2)
negatives2.append(neg2)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
positives2 = torch.cat(positives2, dim=0)
negatives2 = torch.cat(negatives2, dim=0)
diffs2 = positives2 - negatives2
avg_diff2 = diffs2.mean(0, keepdim=True)
return (avg_diff, avg_diff2)
def generate(self,
prompt = "a photo of a house",
scale = 2,
scale_2nd = 2,
seed = 15,
only_pooler = False,
normalize_scales = False,
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
with torch.no_grad():
torch.manual_seed(seed)
images = self.pipe(editing_prompt=prompt,
avg_diff=self.avg_diff, avg_diff_2nd=self.avg_diff_2nd,
scale=scale, scale_2nd=scale_2nd,
**pipeline_kwargs).images
return images
class CLIPSliderFlux(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str,
num_iterations: int = None):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
if num_iterations is not None:
iterations = num_iterations
else:
iterations = self.iterations
with torch.no_grad():
positives = []
negatives = []
for i in tqdm(range(iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt,
padding="max_length",
max_length=self.pipe.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",).input_ids.to(self.device)
neg_toks = self.pipe.tokenizer(neg_prompt,
padding="max_length",
max_length=self.pipe.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",).input_ids.to(self.device)
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
return avg_diff
def generate(self,
prompt = "a photo of a house",
scale = 2.0,
seed = 15,
normalize_scales = False,
avg_diff = None,
avg_diff_2nd = None,
use_slerp: bool = False,
max_strength_for_slerp_endpoint: float = 0.0,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
# Remove slider-specific kwargs before passing to the pipeline
pipeline_kwargs.pop('use_slerp', None)
pipeline_kwargs.pop('max_strength_for_slerp_endpoint', None)
with torch.no_grad():
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds_out = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
original_pooled_prompt_embeds = prompt_embeds_out.pooler_output.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
# For the second text encoder (T5-like for FLUX)
text_inputs_2 = self.pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=512,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
toks_2 = text_inputs_2.input_ids
# This is the non-pooled, sequence output for the second encoder
prompt_embeds_seq_2 = self.pipe.text_encoder_2(toks_2.to(self.device), output_hidden_states=False)[0]
prompt_embeds_seq_2 = prompt_embeds_seq_2.to(dtype=self.pipe.text_encoder_2.dtype, device=self.device)
modified_pooled_embeds = original_pooled_prompt_embeds.clone()
if avg_diff is not None:
if use_slerp and max_strength_for_slerp_endpoint != 0.0:
# Slerp logic
slerp_t_val = 0.0
if max_strength_for_slerp_endpoint != 0:
slerp_t_val = abs(scale) / max_strength_for_slerp_endpoint
slerp_t_val = min(slerp_t_val, 1.0)
if scale == 0:
pass
else:
v0 = original_pooled_prompt_embeds.float()
if scale > 0:
v_end_target = original_pooled_prompt_embeds + max_strength_for_slerp_endpoint * avg_diff
else:
v_end_target = original_pooled_prompt_embeds - max_strength_for_slerp_endpoint * avg_diff
modified_pooled_embeds = slerp(v0, v_end_target.float(), slerp_t_val).to(original_pooled_prompt_embeds.dtype)
else:
modified_pooled_embeds = modified_pooled_embeds + avg_diff * scale
if avg_diff_2nd is not None:
scale_2nd_val = pipeline_kwargs.get("scale_2nd", 0.0)
modified_pooled_embeds += avg_diff_2nd * scale_2nd_val
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds_seq_2,
pooled_prompt_embeds=modified_pooled_embeds,
**pipeline_kwargs).images
return images[0]
def spectrum(self,
prompt="a photo of a house",
low_scale=-2,
low_scale_2nd=-2,
high_scale=2,
high_scale_2nd=2,
steps=5,
seed=15,
normalize_scales=False,
**pipeline_kwargs
):
images = []
for i in range(steps):
scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1)
image = self.generate(prompt, scale, scale_2nd, seed, normalize_scales, **pipeline_kwargs)
images.append(image[0].resize((512,512)))
canvas = Image.new('RGB', (640 * steps, 640))
for i, im in enumerate(images):
canvas.paste(im, (640 * i, 0))
return canvas
class T5SliderFlux(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer_2(pos_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
return_length=False,
return_overflowing_tokens=False,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
neg_toks = self.pipe.tokenizer_2(neg_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
return_length=False,
return_overflowing_tokens=False,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0]
neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0]
positives.append(pos)
negatives.append(neg)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
return avg_diff
def generate(self,
prompt = "a photo of a house",
scale = 2,
scale_2nd = 2,
seed = 15,
only_pooler = False,
normalize_scales = False,
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
with torch.no_grad():
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
# Use pooled output of CLIPTextModel
text_inputs = self.pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=512,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
toks = text_inputs.input_ids
prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
dtype = self.pipe.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
if self.avg_diff_2nd and normalize_scales:
denominator = abs(scale) + abs(scale_2nd)
scale = scale / denominator
scale_2nd = scale_2nd / denominator
if only_pooler:
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
if self.avg_diff_2nd:
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
else:
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2])
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
prompt_embeds = prompt_embeds + (
weights * self.avg_diff * scale)
if self.avg_diff_2nd:
prompt_embeds += (
weights * self.avg_diff_2nd * scale_2nd)
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
**pipeline_kwargs).images
return images