|
|
|
import os |
|
import torch |
|
import re |
|
from tqdm import tqdm |
|
import PIL |
|
from PIL import Image |
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
from torchvision import transforms as tfms |
|
from diffusers import StableDiffusionPipeline, AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel |
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
|
|
class TextualInversion: |
|
def __init__(self, pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4", repo_id_embeds=["sd-concepts-library/matrix::with <hatman-matrix> concept"]): |
|
|
|
self.pretrained_model_name_or_path = pretrained_model_name_or_path |
|
|
|
|
|
self.repo_id_embeds = [x.split("::")[0].split("/")[-1] for x in repo_id_embeds] |
|
self.prompts_suffixes = [x.split("::")[1] for x in repo_id_embeds] |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
if "mps" == self.device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") |
|
|
|
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) |
|
|
|
|
|
self.vae = self.vae.to(self.device) |
|
self.text_encoder = self.text_encoder.to(self.device) |
|
self.unet = self.unet.to(self.device) |
|
|
|
|
|
|
|
self.token_emb_layer = self.text_encoder.text_model.embeddings.token_embedding |
|
|
|
self.position_ids = self.text_encoder.text_model.embeddings.position_ids |
|
self.position_emb_layer = self.text_encoder.text_model.embeddings.position_embedding |
|
|
|
self.conceptsEmbeddings = [] |
|
for index,repo_id in enumerate(self.repo_id_embeds): |
|
|
|
concept_embed_lib = torch.load("sd-concepts-library/" + self.repo_id_embeds[index] +"_learned_embeds.bin") |
|
print(self.repo_id_embeds[index]) |
|
print(concept_embed_lib.keys()) |
|
if self.repo_id_embeds[index] in concept_embed_lib.keys(): |
|
concept_embed = concept_embed_lib[self.repo_id_embeds[index]] |
|
else: |
|
first_key, concept_embed = next(iter(concept_embed_lib.items())) |
|
|
|
self.conceptsEmbeddings.append(concept_embed.to(self.device)) |
|
print(f"len(self.conceptsEmbeddings): {len(self.conceptsEmbeddings)}") |
|
|
|
def _create_4d_causal_attention_mask( |
|
input_shape: Union[torch.Size, Tuple, List], |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
past_key_values_length: int = 0, |
|
sliding_window: Optional[int] = None, |
|
) -> Optional[torch.Tensor]: |
|
""" |
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` |
|
|
|
Args: |
|
input_shape (`tuple(int)` or `list(int)` or `torch.Size`): |
|
The input shape should be a tuple that defines `(batch_size, query_length)`. |
|
dtype (`torch.dtype`): |
|
The torch dtype the created mask shall have. |
|
device (`int`): |
|
The torch device the created mask shall have. |
|
sliding_window (`int`, *optional*): |
|
If the model uses windowed attention, a sliding window should be passed. |
|
""" |
|
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) |
|
|
|
key_value_length = past_key_values_length + input_shape[-1] |
|
attention_mask = attn_mask_converter.to_causal_4d( |
|
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device |
|
) |
|
|
|
return attention_mask |
|
|
|
def get_output_embeds(self, input_embeddings): |
|
|
|
bsz, seq_len = input_embeddings.shape[:2] |
|
|
|
|
|
causal_attention_mask = self.text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype) |
|
|
|
|
|
|
|
encoder_outputs = self.text_encoder.text_model.encoder( |
|
inputs_embeds=input_embeddings, |
|
attention_mask=None, |
|
causal_attention_mask=causal_attention_mask.to(self.device), |
|
output_attentions=None, |
|
output_hidden_states=True, |
|
return_dict=None, |
|
) |
|
|
|
|
|
output = encoder_outputs[0] |
|
|
|
|
|
output = self.text_encoder.text_model.final_layer_norm(output) |
|
|
|
|
|
return output |
|
|
|
def set_timesteps(self, num_inference_steps): |
|
self.scheduler.set_timesteps(num_inference_steps) |
|
self.scheduler.timesteps = self.scheduler.timesteps.to(torch.float32) |
|
|
|
def pil_to_latent(self, input_im): |
|
|
|
with torch.no_grad(): |
|
latent = self.vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(self.device)*2-1) |
|
return 0.18215 * latent.latent_dist.sample() |
|
|
|
def latents_to_pil(self, latents): |
|
|
|
latents = (1 / 0.18215) * latents |
|
with torch.no_grad(): |
|
image = self.vae.decode(latents).sample |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() |
|
images = (image * 255).round().astype("uint8") |
|
pil_images = [Image.fromarray(image) for image in images] |
|
return pil_images |
|
|
|
def grayscale_loss(self, images): |
|
""" |
|
Calculate the grayscale loss, which measures how far the image is from being grayscale. |
|
A grayscale image has R = G = B for each pixel. |
|
|
|
Args: |
|
images (torch.Tensor): A tensor of shape (batch_size, 3, H, W) where 3 corresponds to |
|
the RGB channels of the image. |
|
|
|
Returns: |
|
torch.Tensor: A scalar loss value indicating how far the image is from being grayscale. |
|
""" |
|
|
|
|
|
rg_diff = torch.abs(images[:, 0] - images[:, 1]) |
|
gb_diff = torch.abs(images[:, 1] - images[:, 2]) |
|
rb_diff = torch.abs(images[:, 0] - images[:, 2]) |
|
|
|
|
|
loss = torch.mean(rg_diff + gb_diff + rb_diff) |
|
|
|
return loss |
|
|
|
def blue_loss(self, images): |
|
|
|
|
|
|
|
error = self.grayscale_loss(images) |
|
return error |
|
|
|
def update_latents_with_blue_loss(self, latents, noise_pred, sigma, blue_loss_scale=50, print_loss = False): |
|
|
|
latents = latents.detach().requires_grad_() |
|
|
|
|
|
latents_x0 = latents - sigma * noise_pred |
|
|
|
|
|
|
|
denoised_images = self.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 |
|
|
|
|
|
loss = self.blue_loss(denoised_images) * blue_loss_scale |
|
|
|
|
|
if print_loss: |
|
print('loss:', loss.item()) |
|
|
|
|
|
cond_grad = torch.autograd.grad(loss, latents)[0] |
|
|
|
|
|
latents = latents.detach() - cond_grad * sigma**2 |
|
|
|
return latents |
|
|
|
def generate_with_embs(self, text_embeddings, generator, max_length, batch_size = 1, consider_blue_loss = False): |
|
height = 512 |
|
width = 512 |
|
num_inference_steps = 50 |
|
guidance_scale = 7.5 |
|
|
|
uncond_input = self.tokenizer( |
|
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" |
|
) |
|
with torch.no_grad(): |
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] |
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
|
|
self.set_timesteps(num_inference_steps) |
|
|
|
|
|
latents = torch.randn( |
|
(batch_size, self.unet.in_channels, height // 8, width // 8), |
|
generator=generator, |
|
|
|
) |
|
latents = latents.to(self.device) |
|
latents = latents * self.scheduler.init_noise_sigma |
|
|
|
|
|
for i, t in tqdm(enumerate(self.scheduler.timesteps), total=len(self.scheduler.timesteps)): |
|
|
|
latent_model_input = torch.cat([latents] * 2) |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
with torch.no_grad(): |
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if consider_blue_loss: |
|
print_loss = True if i%10==0 else False |
|
latents = self.update_latents_with_blue_loss(latents, noise_pred, self.scheduler.sigmas[i], print_loss=print_loss) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
return self.latents_to_pil(latents) |
|
|
|
|
|
def generate_image(self, prompt, concept_index, grayscale_image=False): |
|
|
|
|
|
prompt_to_send = prompt + " " + self.prompts_suffixes[concept_index] |
|
print(f"Selected concept_index: {concept_index}.") |
|
print(f"concept_index: {concept_index} Generating image for concept: {self.repo_id_embeds[concept_index]} with prompt: {prompt_to_send}") |
|
print(f"Grayscale image: {grayscale_image}") |
|
|
|
|
|
placeholder_text = "gloucestershire " |
|
prompt_to_send = re.sub(r'<.*?>', placeholder_text, prompt_to_send) |
|
print(f"prompt after replacing placeholder token: {prompt_to_send}") |
|
|
|
text_input = self.tokenizer(prompt_to_send, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
input_ids = text_input.input_ids.to(self.device) |
|
|
|
|
|
token_embeddings = self.token_emb_layer(input_ids) |
|
|
|
|
|
|
|
replacement_token_embedding = self.conceptsEmbeddings[concept_index].to(self.device) |
|
print(f"replacement_token_embedding.shape: {replacement_token_embedding.shape} and token_embeddings.shape: {token_embeddings.shape}") |
|
print(f"torch.where(input_ids[0]==33789): {torch.where(input_ids[0]==33789)}") |
|
|
|
token_embeddings[0, torch.where(input_ids[0]==33789)] = replacement_token_embedding.to(self.device) |
|
|
|
|
|
B, T, C = token_embeddings.shape |
|
|
|
position_embeddings = self.position_emb_layer(self.position_ids[:, :T]) |
|
|
|
|
|
input_embeddings = token_embeddings + position_embeddings |
|
|
|
|
|
modified_output_embeddings = self.get_output_embeds(input_embeddings) |
|
|
|
print(f"manual_seed: {concept_index + 11}") |
|
generator = torch.manual_seed(concept_index + 11) |
|
|
|
result = self.generate_with_embs(modified_output_embeddings, generator=generator, max_length=T, consider_blue_loss=grayscale_image)[0] |
|
|
|
return result |
|
|