Spaces:
Running
on
Zero
Running
on
Zero
import diffusers | |
import torch | |
import random | |
from tqdm import tqdm | |
from constants import SUBJECTS, MEDIUMS | |
from PIL import Image | |
import math # For acos, sin | |
# Slerp (Spherical Linear Interpolation) function | |
def slerp(v0, v1, t, DOT_THRESHOLD=0.9995): | |
""" | |
Spherical linear interpolation. | |
v0, v1: Tensors to interpolate between. | |
t: Interpolation factor (scalar or tensor). | |
DOT_THRESHOLD: Threshold for considering vectors collinear. | |
""" | |
if not isinstance(t, torch.Tensor): | |
t = torch.tensor(t, device=v0.device, dtype=v0.dtype) | |
# Dot product | |
dot = torch.sum(v0 * v1 / (torch.norm(v0, dim=-1, keepdim=True) * torch.norm(v1, dim=-1, keepdim=True) + 1e-8), dim=-1, keepdim=True) | |
# If vectors are too close, use linear interpolation (LERP) | |
# This also handles t=0 and t=1 correctly if dot is 1. | |
# Also, if dot is -1 (opposite), omega is pi. | |
if torch.any(torch.abs(dot) > DOT_THRESHOLD): | |
# For Slerp, if they are too close, omega is small, sin(omega) is small. | |
# Fallback to LERP for stability and when vectors are nearly collinear. | |
# However, the general Slerp formula handles this if dot is clamped. | |
# Let's use the standard formula but ensure stability. | |
pass # Continue to Slerp formula with clamping | |
# Clamp dot to prevent NaN from acos due to floating point errors. | |
dot = torch.clamp(dot, -1.0, 1.0) | |
omega = torch.acos(dot) # Angle between vectors | |
# Get magnitudes for later linear interpolation of magnitude | |
mag_v0 = torch.norm(v0, dim=-1, keepdim=True) | |
mag_v1 = torch.norm(v1, dim=-1, keepdim=True) | |
interpolated_mag = (1 - t) * mag_v0 + t * mag_v1 | |
# Normalize v0 and v1 for pure Slerp on direction | |
v0_norm = v0 / (mag_v0 + 1e-8) | |
v1_norm = v1 / (mag_v1 + 1e-8) | |
# If sin_omega is very small, vectors are nearly collinear. | |
# LERP on normalized vectors is a good approximation. | |
# Then re-apply interpolated magnitude. | |
sin_omega = torch.sin(omega) | |
# Condition for LERP fallback (nearly collinear) | |
# Using a small epsilon for sin_omega | |
use_lerp_fallback = sin_omega.abs() < 1e-5 | |
s0 = torch.sin((1 - t) * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability | |
s1 = torch.sin(t * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability | |
# For elements where LERP fallback is needed | |
s0[use_lerp_fallback] = 1.0 - t | |
s1[use_lerp_fallback] = t | |
result_norm = s0 * v0_norm + s1 * v1_norm | |
result = result_norm * interpolated_mag # Re-apply interpolated magnitude | |
return result.to(v0.dtype) | |
class CLIPSlider: | |
def __init__( | |
self, | |
sd_pipe, | |
device: torch.device, | |
target_word: str = "", | |
opposite: str = "", | |
target_word_2nd: str = "", | |
opposite_2nd: str = "", | |
iterations: int = 300, | |
): | |
self.device = device | |
self.pipe = sd_pipe.to(self.device, torch.float16) | |
self.iterations = iterations | |
if target_word != "" or opposite != "": | |
self.avg_diff = self.find_latent_direction(target_word, opposite) | |
else: | |
self.avg_diff = None | |
if target_word_2nd != "" or opposite_2nd != "": | |
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd) | |
else: | |
self.avg_diff_2nd = None | |
def find_latent_direction(self, | |
target_word:str, | |
opposite:str): | |
# lets identify a latent direction by taking differences between opposites | |
# target_word = "happy" | |
# opposite = "sad" | |
with torch.no_grad(): | |
positives = [] | |
negatives = [] | |
for i in tqdm(range(self.iterations)): | |
medium = random.choice(MEDIUMS) | |
subject = random.choice(SUBJECTS) | |
pos_prompt = f"a {medium} of a {target_word} {subject}" | |
neg_prompt = f"a {medium} of a {opposite} {subject}" | |
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
pos = self.pipe.text_encoder(pos_toks).pooler_output | |
neg = self.pipe.text_encoder(neg_toks).pooler_output | |
positives.append(pos) | |
negatives.append(neg) | |
positives = torch.cat(positives, dim=0) | |
negatives = torch.cat(negatives, dim=0) | |
diffs = positives - negatives | |
avg_diff = diffs.mean(0, keepdim=True) | |
return avg_diff | |
def generate(self, | |
prompt = "a photo of a house", | |
scale = 2., | |
scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None | |
seed = 15, | |
only_pooler = False, | |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None | |
correlation_weight_factor = 1.0, | |
**pipeline_kwargs | |
): | |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
# if pooler token only [-4,4] work well | |
with torch.no_grad(): | |
toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state | |
if self.avg_diff_2nd and normalize_scales: | |
denominator = abs(scale) + abs(scale_2nd) | |
scale = scale / denominator | |
scale_2nd = scale_2nd / denominator | |
if only_pooler: | |
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale | |
if self.avg_diff_2nd: | |
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd | |
else: | |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) | |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T | |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) | |
standard_weights = torch.ones_like(weights) | |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
# weights = torch.sigmoid((weights-0.5)*7) | |
prompt_embeds = prompt_embeds + ( | |
weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) | |
if self.avg_diff_2nd: | |
prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd | |
torch.manual_seed(seed) | |
images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images | |
return images | |
def spectrum(self, | |
prompt="a photo of a house", | |
low_scale=-2, | |
low_scale_2nd=-2, | |
high_scale=2, | |
high_scale_2nd=2, | |
steps=5, | |
seed=15, | |
only_pooler=False, | |
normalize_scales=False, | |
correlation_weight_factor=1.0, | |
**pipeline_kwargs | |
): | |
images = [] | |
for i in range(steps): | |
scale = low_scale + (high_scale - low_scale) * i / (steps - 1) | |
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) | |
image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs) | |
images.append(image[0]) | |
canvas = Image.new('RGB', (640 * steps, 640)) | |
for i, im in enumerate(images): | |
canvas.paste(im, (640 * i, 0)) | |
return canvas | |
class CLIPSliderXL(CLIPSlider): | |
def find_latent_direction(self, | |
target_word:str, | |
opposite:str): | |
# lets identify a latent direction by taking differences between opposites | |
# target_word = "happy" | |
# opposite = "sad" | |
with torch.no_grad(): | |
positives = [] | |
negatives = [] | |
positives2 = [] | |
negatives2 = [] | |
for i in tqdm(range(self.iterations)): | |
medium = random.choice(MEDIUMS) | |
subject = random.choice(SUBJECTS) | |
pos_prompt = f"a {medium} of a {target_word} {subject}" | |
neg_prompt = f"a {medium} of a {opposite} {subject}" | |
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
pos = self.pipe.text_encoder(pos_toks).pooler_output | |
neg = self.pipe.text_encoder(neg_toks).pooler_output | |
positives.append(pos) | |
negatives.append(neg) | |
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds | |
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds | |
positives2.append(pos2) | |
negatives2.append(neg2) | |
positives = torch.cat(positives, dim=0) | |
negatives = torch.cat(negatives, dim=0) | |
diffs = positives - negatives | |
avg_diff = diffs.mean(0, keepdim=True) | |
positives2 = torch.cat(positives2, dim=0) | |
negatives2 = torch.cat(negatives2, dim=0) | |
diffs2 = positives2 - negatives2 | |
avg_diff2 = diffs2.mean(0, keepdim=True) | |
return (avg_diff, avg_diff2) | |
def generate(self, | |
prompt = "a photo of a house", | |
scale = 2, | |
scale_2nd = 2, | |
seed = 15, | |
only_pooler = False, | |
normalize_scales = False, | |
correlation_weight_factor = 1.0, | |
**pipeline_kwargs | |
): | |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
# if pooler token only [-4,4] work well | |
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2] | |
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2] | |
with torch.no_grad(): | |
# toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.to(self.device) | |
# prompt_embeds = pipe.text_encoder(toks).last_hidden_state | |
prompt_embeds_list = [] | |
for i, text_encoder in enumerate(text_encoders): | |
tokenizer = tokenizers[i] | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
toks = text_inputs.input_ids | |
prompt_embeds = text_encoder( | |
toks.to(text_encoder.device), | |
output_hidden_states=True, | |
) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
pooled_prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
if self.avg_diff_2nd and normalize_scales: | |
denominator = abs(scale) + abs(scale_2nd) | |
scale = scale / denominator | |
scale_2nd = scale_2nd / denominator | |
if only_pooler: | |
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale | |
if self.avg_diff_2nd: | |
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd | |
else: | |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) | |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T | |
if i == 0: | |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) | |
standard_weights = torch.ones_like(weights) | |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) | |
if self.avg_diff_2nd: | |
prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd) | |
else: | |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280) | |
standard_weights = torch.ones_like(weights) | |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale) | |
if self.avg_diff_2nd: | |
prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd) | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
prompt_embeds_list.append(prompt_embeds) | |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) | |
torch.manual_seed(seed) | |
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, | |
**pipeline_kwargs).images | |
return images | |
class CLIPSliderXL_inv(CLIPSlider): | |
def find_latent_direction(self, | |
target_word:str, | |
opposite:str): | |
# lets identify a latent direction by taking differences between opposites | |
# target_word = "happy" | |
# opposite = "sad" | |
with torch.no_grad(): | |
positives = [] | |
negatives = [] | |
positives2 = [] | |
negatives2 = [] | |
for i in tqdm(range(self.iterations)): | |
medium = random.choice(MEDIUMS) | |
subject = random.choice(SUBJECTS) | |
pos_prompt = f"a {medium} of a {target_word} {subject}" | |
neg_prompt = f"a {medium} of a {opposite} {subject}" | |
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
pos = self.pipe.text_encoder(pos_toks).pooler_output | |
neg = self.pipe.text_encoder(neg_toks).pooler_output | |
positives.append(pos) | |
negatives.append(neg) | |
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds | |
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds | |
positives2.append(pos2) | |
negatives2.append(neg2) | |
positives = torch.cat(positives, dim=0) | |
negatives = torch.cat(negatives, dim=0) | |
diffs = positives - negatives | |
avg_diff = diffs.mean(0, keepdim=True) | |
positives2 = torch.cat(positives2, dim=0) | |
negatives2 = torch.cat(negatives2, dim=0) | |
diffs2 = positives2 - negatives2 | |
avg_diff2 = diffs2.mean(0, keepdim=True) | |
return (avg_diff, avg_diff2) | |
def generate(self, | |
prompt = "a photo of a house", | |
scale = 2, | |
scale_2nd = 2, | |
seed = 15, | |
only_pooler = False, | |
normalize_scales = False, | |
correlation_weight_factor = 1.0, | |
**pipeline_kwargs | |
): | |
with torch.no_grad(): | |
torch.manual_seed(seed) | |
images = self.pipe(editing_prompt=prompt, | |
avg_diff=self.avg_diff, avg_diff_2nd=self.avg_diff_2nd, | |
scale=scale, scale_2nd=scale_2nd, | |
**pipeline_kwargs).images | |
return images | |
class CLIPSliderFlux(CLIPSlider): | |
def find_latent_direction(self, | |
target_word:str, | |
opposite:str, | |
num_iterations: int = None): | |
# lets identify a latent direction by taking differences between opposites | |
# target_word = "happy" | |
# opposite = "sad" | |
if num_iterations is not None: | |
iterations = num_iterations | |
else: | |
iterations = self.iterations | |
with torch.no_grad(): | |
positives = [] | |
negatives = [] | |
for i in tqdm(range(iterations)): | |
medium = random.choice(MEDIUMS) | |
subject = random.choice(SUBJECTS) | |
pos_prompt = f"a {medium} of a {target_word} {subject}" | |
neg_prompt = f"a {medium} of a {opposite} {subject}" | |
pos_toks = self.pipe.tokenizer(pos_prompt, | |
padding="max_length", | |
max_length=self.pipe.tokenizer_max_length, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt",).input_ids.to(self.device) | |
neg_toks = self.pipe.tokenizer(neg_prompt, | |
padding="max_length", | |
max_length=self.pipe.tokenizer_max_length, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt",).input_ids.to(self.device) | |
pos = self.pipe.text_encoder(pos_toks).pooler_output | |
neg = self.pipe.text_encoder(neg_toks).pooler_output | |
positives.append(pos) | |
negatives.append(neg) | |
positives = torch.cat(positives, dim=0) | |
negatives = torch.cat(negatives, dim=0) | |
diffs = positives - negatives | |
avg_diff = diffs.mean(0, keepdim=True) | |
return avg_diff | |
def generate(self, | |
prompt = "a photo of a house", | |
scale = 2.0, | |
seed = 15, | |
normalize_scales = False, | |
avg_diff = None, | |
avg_diff_2nd = None, | |
use_slerp: bool = False, | |
max_strength_for_slerp_endpoint: float = 0.0, | |
**pipeline_kwargs | |
): | |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
# if pooler token only [-4,4] work well | |
# Remove slider-specific kwargs before passing to the pipeline | |
pipeline_kwargs.pop('use_slerp', None) | |
pipeline_kwargs.pop('max_strength_for_slerp_endpoint', None) | |
with torch.no_grad(): | |
text_inputs = self.pipe.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds_out = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False) | |
original_pooled_prompt_embeds = prompt_embeds_out.pooler_output.to(dtype=self.pipe.text_encoder.dtype, device=self.device) | |
# For the second text encoder (T5-like for FLUX) | |
text_inputs_2 = self.pipe.tokenizer_2( | |
prompt, | |
padding="max_length", | |
max_length=512, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
toks_2 = text_inputs_2.input_ids | |
# This is the non-pooled, sequence output for the second encoder | |
prompt_embeds_seq_2 = self.pipe.text_encoder_2(toks_2.to(self.device), output_hidden_states=False)[0] | |
prompt_embeds_seq_2 = prompt_embeds_seq_2.to(dtype=self.pipe.text_encoder_2.dtype, device=self.device) | |
modified_pooled_embeds = original_pooled_prompt_embeds.clone() | |
if avg_diff is not None: | |
if use_slerp and max_strength_for_slerp_endpoint != 0.0: | |
# Slerp logic | |
slerp_t_val = 0.0 | |
if max_strength_for_slerp_endpoint != 0: | |
slerp_t_val = abs(scale) / max_strength_for_slerp_endpoint | |
slerp_t_val = min(slerp_t_val, 1.0) | |
if scale == 0: | |
pass | |
else: | |
v0 = original_pooled_prompt_embeds.float() | |
if scale > 0: | |
v_end_target = original_pooled_prompt_embeds + max_strength_for_slerp_endpoint * avg_diff | |
else: | |
v_end_target = original_pooled_prompt_embeds - max_strength_for_slerp_endpoint * avg_diff | |
modified_pooled_embeds = slerp(v0, v_end_target.float(), slerp_t_val).to(original_pooled_prompt_embeds.dtype) | |
else: | |
modified_pooled_embeds = modified_pooled_embeds + avg_diff * scale | |
if avg_diff_2nd is not None: | |
scale_2nd_val = pipeline_kwargs.get("scale_2nd", 0.0) | |
modified_pooled_embeds += avg_diff_2nd * scale_2nd_val | |
torch.manual_seed(seed) | |
images = self.pipe(prompt_embeds=prompt_embeds_seq_2, | |
pooled_prompt_embeds=modified_pooled_embeds, | |
**pipeline_kwargs).images | |
return images[0] | |
def spectrum(self, | |
prompt="a photo of a house", | |
low_scale=-2, | |
low_scale_2nd=-2, | |
high_scale=2, | |
high_scale_2nd=2, | |
steps=5, | |
seed=15, | |
normalize_scales=False, | |
**pipeline_kwargs | |
): | |
images = [] | |
for i in range(steps): | |
scale = low_scale + (high_scale - low_scale) * i / (steps - 1) | |
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) | |
image = self.generate(prompt, scale, scale_2nd, seed, normalize_scales, **pipeline_kwargs) | |
images.append(image[0].resize((512,512))) | |
canvas = Image.new('RGB', (640 * steps, 640)) | |
for i, im in enumerate(images): | |
canvas.paste(im, (640 * i, 0)) | |
return canvas | |
class T5SliderFlux(CLIPSlider): | |
def find_latent_direction(self, | |
target_word:str, | |
opposite:str): | |
# lets identify a latent direction by taking differences between opposites | |
# target_word = "happy" | |
# opposite = "sad" | |
with torch.no_grad(): | |
positives = [] | |
negatives = [] | |
for i in tqdm(range(self.iterations)): | |
medium = random.choice(MEDIUMS) | |
subject = random.choice(SUBJECTS) | |
pos_prompt = f"a {medium} of a {target_word} {subject}" | |
neg_prompt = f"a {medium} of a {opposite} {subject}" | |
pos_toks = self.pipe.tokenizer_2(pos_prompt, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
neg_toks = self.pipe.tokenizer_2(neg_prompt, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0] | |
neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0] | |
positives.append(pos) | |
negatives.append(neg) | |
positives = torch.cat(positives, dim=0) | |
negatives = torch.cat(negatives, dim=0) | |
diffs = positives - negatives | |
avg_diff = diffs.mean(0, keepdim=True) | |
return avg_diff | |
def generate(self, | |
prompt = "a photo of a house", | |
scale = 2, | |
scale_2nd = 2, | |
seed = 15, | |
only_pooler = False, | |
normalize_scales = False, | |
correlation_weight_factor = 1.0, | |
**pipeline_kwargs | |
): | |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
# if pooler token only [-4,4] work well | |
with torch.no_grad(): | |
text_inputs = self.pipe.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_overflowing_tokens=False, | |
return_length=False, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False) | |
# Use pooled output of CLIPTextModel | |
prompt_embeds = prompt_embeds.pooler_output | |
pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device) | |
# Use pooled output of CLIPTextModel | |
text_inputs = self.pipe.tokenizer_2( | |
prompt, | |
padding="max_length", | |
max_length=512, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
toks = text_inputs.input_ids | |
prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0] | |
dtype = self.pipe.text_encoder_2.dtype | |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) | |
if self.avg_diff_2nd and normalize_scales: | |
denominator = abs(scale) + abs(scale_2nd) | |
scale = scale / denominator | |
scale_2nd = scale_2nd / denominator | |
if only_pooler: | |
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale | |
if self.avg_diff_2nd: | |
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd | |
else: | |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) | |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T | |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2]) | |
standard_weights = torch.ones_like(weights) | |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
prompt_embeds = prompt_embeds + ( | |
weights * self.avg_diff * scale) | |
if self.avg_diff_2nd: | |
prompt_embeds += ( | |
weights * self.avg_diff_2nd * scale_2nd) | |
torch.manual_seed(seed) | |
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, | |
**pipeline_kwargs).images | |
return images |