| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import re | 
					
					
						
						| 
							 | 
						from tqdm import tqdm | 
					
					
						
						| 
							 | 
						import PIL | 
					
					
						
						| 
							 | 
						from PIL import Image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from typing import List, Optional, Tuple, Union | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from torchvision import transforms as tfms | 
					
					
						
						| 
							 | 
						from diffusers import StableDiffusionPipeline, AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel | 
					
					
						
						| 
							 | 
						from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class TextualInversion: | 
					
					
						
						| 
							 | 
						    def __init__(self, pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4", repo_id_embeds=["sd-concepts-library/matrix::with <hatman-matrix> concept"]): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.pretrained_model_name_or_path = pretrained_model_name_or_path | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						        self.repo_id_embeds = [x.split("::")[0].split("/")[-1] for x in repo_id_embeds] | 
					
					
						
						| 
							 | 
						        self.prompts_suffixes = [x.split("::")[1] for x in repo_id_embeds] | 
					
					
						
						| 
							 | 
						        | 
					
					
						
						| 
							 | 
						        | 
					
					
						
						| 
							 | 
						        self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | 
					
					
						
						| 
							 | 
						        if "mps" == self.device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | 
					
					
						
						| 
							 | 
						        self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.vae = self.vae.to(self.device) | 
					
					
						
						| 
							 | 
						        self.text_encoder = self.text_encoder.to(self.device) | 
					
					
						
						| 
							 | 
						        self.unet = self.unet.to(self.device) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.token_emb_layer = self.text_encoder.text_model.embeddings.token_embedding  | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.position_ids = self.text_encoder.text_model.embeddings.position_ids | 
					
					
						
						| 
							 | 
						        self.position_emb_layer = self.text_encoder.text_model.embeddings.position_embedding | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.conceptsEmbeddings = [] | 
					
					
						
						| 
							 | 
						        for index,repo_id in enumerate(self.repo_id_embeds): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            concept_embed_lib = torch.load("sd-concepts-library/" + self.repo_id_embeds[index] +"_learned_embeds.bin")  | 
					
					
						
						| 
							 | 
						            print(self.repo_id_embeds[index]) | 
					
					
						
						| 
							 | 
						            print(concept_embed_lib.keys()) | 
					
					
						
						| 
							 | 
						            if self.repo_id_embeds[index] in concept_embed_lib.keys(): | 
					
					
						
						| 
							 | 
						                concept_embed = concept_embed_lib[self.repo_id_embeds[index]]  | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                first_key, concept_embed = next(iter(concept_embed_lib.items()))  | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.conceptsEmbeddings.append(concept_embed.to(self.device)) | 
					
					
						
						| 
							 | 
						        print(f"len(self.conceptsEmbeddings): {len(self.conceptsEmbeddings)}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _create_4d_causal_attention_mask( | 
					
					
						
						| 
							 | 
						        input_shape: Union[torch.Size, Tuple, List], | 
					
					
						
						| 
							 | 
						        dtype: torch.dtype, | 
					
					
						
						| 
							 | 
						        device: torch.device, | 
					
					
						
						| 
							 | 
						        past_key_values_length: int = 0, | 
					
					
						
						| 
							 | 
						        sliding_window: Optional[int] = None, | 
					
					
						
						| 
							 | 
						    ) -> Optional[torch.Tensor]: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            input_shape (`tuple(int)` or `list(int)` or `torch.Size`): | 
					
					
						
						| 
							 | 
						                The input shape should be a tuple that defines `(batch_size, query_length)`. | 
					
					
						
						| 
							 | 
						            dtype (`torch.dtype`): | 
					
					
						
						| 
							 | 
						                The torch dtype the created mask shall have. | 
					
					
						
						| 
							 | 
						            device (`int`): | 
					
					
						
						| 
							 | 
						                The torch device the created mask shall have. | 
					
					
						
						| 
							 | 
						            sliding_window (`int`, *optional*): | 
					
					
						
						| 
							 | 
						                If the model uses windowed attention, a sliding window should be passed. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        key_value_length = past_key_values_length + input_shape[-1] | 
					
					
						
						| 
							 | 
						        attention_mask = attn_mask_converter.to_causal_4d( | 
					
					
						
						| 
							 | 
						            input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return attention_mask  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_output_embeds(self, input_embeddings): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        bsz, seq_len = input_embeddings.shape[:2] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        causal_attention_mask = self.text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        encoder_outputs = self.text_encoder.text_model.encoder( | 
					
					
						
						| 
							 | 
						            inputs_embeds=input_embeddings, | 
					
					
						
						| 
							 | 
						            attention_mask=None,  | 
					
					
						
						| 
							 | 
						            causal_attention_mask=causal_attention_mask.to(self.device), | 
					
					
						
						| 
							 | 
						            output_attentions=None, | 
					
					
						
						| 
							 | 
						            output_hidden_states=True,  | 
					
					
						
						| 
							 | 
						            return_dict=None, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        output = encoder_outputs[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        output = self.text_encoder.text_model.final_layer_norm(output) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return output | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def set_timesteps(self, num_inference_steps): | 
					
					
						
						| 
							 | 
						        self.scheduler.set_timesteps(num_inference_steps) | 
					
					
						
						| 
							 | 
						        self.scheduler.timesteps = self.scheduler.timesteps.to(torch.float32)  | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    def pil_to_latent(self, input_im): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            latent = self.vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(self.device)*2-1)  | 
					
					
						
						| 
							 | 
						        return 0.18215 * latent.latent_dist.sample() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def latents_to_pil(self, latents): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents = (1 / 0.18215) * latents | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            image = self.vae.decode(latents).sample | 
					
					
						
						| 
							 | 
						        image = (image / 2 + 0.5).clamp(0, 1) | 
					
					
						
						| 
							 | 
						        image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | 
					
					
						
						| 
							 | 
						        images = (image * 255).round().astype("uint8") | 
					
					
						
						| 
							 | 
						        pil_images = [Image.fromarray(image) for image in images] | 
					
					
						
						| 
							 | 
						        return pil_images | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def grayscale_loss(self, images): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Calculate the grayscale loss, which measures how far the image is from being grayscale. | 
					
					
						
						| 
							 | 
						        A grayscale image has R = G = B for each pixel. | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            images (torch.Tensor): A tensor of shape (batch_size, 3, H, W) where 3 corresponds to  | 
					
					
						
						| 
							 | 
						                                the RGB channels of the image. | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: A scalar loss value indicating how far the image is from being grayscale. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        rg_diff = torch.abs(images[:, 0] - images[:, 1])   | 
					
					
						
						| 
							 | 
						        gb_diff = torch.abs(images[:, 1] - images[:, 2])   | 
					
					
						
						| 
							 | 
						        rb_diff = torch.abs(images[:, 0] - images[:, 2])   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        loss = torch.mean(rg_diff + gb_diff + rb_diff) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return loss | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def blue_loss(self, images): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        error = self.grayscale_loss(images) | 
					
					
						
						| 
							 | 
						        return error | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def update_latents_with_blue_loss(self, latents, noise_pred, sigma, blue_loss_scale=50, print_loss = False): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents = latents.detach().requires_grad_() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents_x0 = latents - sigma * noise_pred | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        denoised_images = self.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        loss = self.blue_loss(denoised_images) * blue_loss_scale | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if print_loss: | 
					
					
						
						| 
							 | 
						            print('loss:', loss.item()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        cond_grad = torch.autograd.grad(loss, latents)[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents = latents.detach() - cond_grad * sigma**2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return latents | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def generate_with_embs(self, text_embeddings, generator, max_length, batch_size = 1, consider_blue_loss = False): | 
					
					
						
						| 
							 | 
						        height = 512                         | 
					
					
						
						| 
							 | 
						        width = 512                          | 
					
					
						
						| 
							 | 
						        num_inference_steps = 50             | 
					
					
						
						| 
							 | 
						        guidance_scale = 7.5                 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        uncond_input = self.tokenizer( | 
					
					
						
						| 
							 | 
						        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | 
					
					
						
						| 
							 | 
						        text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.set_timesteps(num_inference_steps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents = torch.randn( | 
					
					
						
						| 
							 | 
						        (batch_size, self.unet.in_channels, height // 8, width // 8), | 
					
					
						
						| 
							 | 
						        generator=generator, | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        latents = latents.to(self.device) | 
					
					
						
						| 
							 | 
						        latents = latents * self.scheduler.init_noise_sigma | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        for i, t in tqdm(enumerate(self.scheduler.timesteps), total=len(self.scheduler.timesteps)): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            latent_model_input = torch.cat([latents] * 2) | 
					
					
						
						| 
							 | 
						            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 
					
					
						
						| 
							 | 
						            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if consider_blue_loss: | 
					
					
						
						| 
							 | 
						                print_loss = True if i%10==0 else False | 
					
					
						
						| 
							 | 
						                latents = self.update_latents_with_blue_loss(latents, noise_pred, self.scheduler.sigmas[i], print_loss=print_loss) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            latents = self.scheduler.step(noise_pred, t, latents).prev_sample | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return self.latents_to_pil(latents) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    def generate_image(self, prompt, concept_index, grayscale_image=False): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        prompt_to_send =  prompt + " " + self.prompts_suffixes[concept_index] | 
					
					
						
						| 
							 | 
						        print(f"Selected concept_index: {concept_index}.") | 
					
					
						
						| 
							 | 
						        print(f"concept_index: {concept_index} Generating image for concept: {self.repo_id_embeds[concept_index]} with prompt: {prompt_to_send}") | 
					
					
						
						| 
							 | 
						        print(f"Grayscale image: {grayscale_image}") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        placeholder_text = "gloucestershire "  | 
					
					
						
						| 
							 | 
						        prompt_to_send = re.sub(r'<.*?>', placeholder_text, prompt_to_send) | 
					
					
						
						| 
							 | 
						        print(f"prompt after replacing placeholder token: {prompt_to_send}") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        text_input = self.tokenizer(prompt_to_send, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") | 
					
					
						
						| 
							 | 
						        input_ids = text_input.input_ids.to(self.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        token_embeddings = self.token_emb_layer(input_ids) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        replacement_token_embedding = self.conceptsEmbeddings[concept_index].to(self.device) | 
					
					
						
						| 
							 | 
						        print(f"replacement_token_embedding.shape: {replacement_token_embedding.shape} and token_embeddings.shape: {token_embeddings.shape}") | 
					
					
						
						| 
							 | 
						        print(f"torch.where(input_ids[0]==33789): {torch.where(input_ids[0]==33789)}") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        token_embeddings[0, torch.where(input_ids[0]==33789)] = replacement_token_embedding.to(self.device) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        B, T, C = token_embeddings.shape | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        position_embeddings = self.position_emb_layer(self.position_ids[:, :T]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        input_embeddings = token_embeddings + position_embeddings | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        modified_output_embeddings = self.get_output_embeds(input_embeddings) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        print(f"manual_seed: {concept_index + 11}") | 
					
					
						
						| 
							 | 
						        generator = torch.manual_seed(concept_index + 11) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        result = self.generate_with_embs(modified_output_embeddings, generator=generator, max_length=T, consider_blue_loss=grayscale_image)[0] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return result | 
					
					
						
						| 
							 | 
						
 |