gitesh-grover commited on
Commit
8af8283
·
verified ·
1 Parent(s): f65755b

Upload 9 files

Browse files
README.md CHANGED
@@ -1,13 +1,36 @@
1
  ---
2
- title: Stable Diffusion Textual Inversion Image Generator
3
- emoji: 🔥
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: Generates an image based on prompt and the concept library
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Textual Inversion Image Generator with optional center focus(background blur)
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # Textual Inversion Image Generator with optional center focus(background blur)
13
+
14
+ ## Description
15
+ This is a simple gradio app that allows you to generate images using textual inversion. An prompt is eneterd by the user and a concept is selected from the dropdown menu. The image is generated using the entered prompt and the selected concept. Currently, there are 5 concepts to choose from. To read more about the concepts, refer https://huggingface.co/sd-concepts-library. The user can optionally select if the background should be blurred or not. Selecting that option generates an image that has a blurred background and the main subject is in focus.
16
+
17
+
18
+ ## How to use
19
+
20
+ 1. Enter your prompt in the text input field
21
+ 2. Select the concept from the dropdown menu
22
+ 3. Click on the generate button
23
+ 4. The image will be generated and displayed on the screen
24
+
25
+ ## How to setup and run the app
26
+
27
+ 1. Clone the repository
28
+ 2. Install dependencies:
29
+ ```bash
30
+ pip install -r requirements.txt
31
+ ```
32
+ 3. Run the app.py file
33
+ ```bash
34
+ python app.py
35
+ ```
36
+ This will start the gradio app on http://127.0.0.1:7860. Open that link in your browser to use the app.
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from textual_inversio_with_blueloss import TextualInversion
3
+
4
+ display_choices = ["minecraft concept art", "dragon born", "birb style", "pool rooms", "matrix"]
5
+ repo_id_embeds=["sd-concepts-library/minecraft-concept-art::with <minecraft-concept-art> concept",
6
+ "sd-concepts-library/dragonborn::with <dragonborn> concept",
7
+ "sd-concepts-library/birb-style::in <birb-style> concept",
8
+ "sd-concepts-library/poolrooms::with <poolrooms>",
9
+ "sd-concepts-library/matrix::in <hatman-matrix> world"
10
+ ]
11
+
12
+ textualInversion = TextualInversion(pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4", repo_id_embeds=repo_id_embeds)
13
+
14
+ def generate_image(prompt, selected_concept, grayscale_image):
15
+ return textualInversion.generate_image(prompt, display_choices.index(selected_concept), grayscale_image=grayscale_image)
16
+
17
+ demo = gr.Interface(
18
+ fn=generate_image,
19
+ inputs=[
20
+ gr.Textbox(label="Enter your prompt"),
21
+ gr.Dropdown(choices=display_choices, label="Select concept", value=display_choices[0]),
22
+ gr.Checkbox(label="Grayscale Image", value=False)
23
+ ],
24
+ outputs=gr.Image(label="Generated Image"),
25
+ title="Textual Inversion Image Generator",
26
+ description="Generate images using textual inversion concepts",
27
+ examples=[["a flying dog", display_choices[0], False]],
28
+ allow_flagging=False
29
+ )
30
+
31
+ # Launch the app
32
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.8.0
2
+ torchvision>=0.9.0
3
+ pytest>=6.0.0
4
+ numpy>=1.19.0
5
+ torchsummary>=1.5.1
6
+ tqdm
7
+ matplotlib>=3.0.0
8
+ diffusers==0.16.1
9
+ # diffusers==0.21.4
10
+ ftfy
11
+ # transformers==4.35.0
12
+ transformers
13
+ accelerate
14
+ safetensors
15
+ Pillow
16
+ huggingface-hub==0.25.2
17
+ gradio
18
+ opencv-python
sd-concepts-library/birb-style_learned_embeds.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2e23a8f2d3628ed77acb8151751ecd4efc4017e8da86bc29af10f855ca308d9
3
+ size 3819
sd-concepts-library/dragonborn_learned_embeds.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78dcbcc13fa0303719ae335097f72413ac3328d8e9da4d637de917add46957b8
3
+ size 3819
sd-concepts-library/matrix_learned_embeds.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b84b50aad5f237f0639cf7d705a66d33b3da5e4e285161fb5084187648f3b0c
3
+ size 3840
sd-concepts-library/minecraft-concept-art_learned_embeds.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af8028909cdbd079194c4100042b96fd39bf65493879c584fd5e7f7984b13383
3
+ size 3819
sd-concepts-library/poolrooms_learned_embeds.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13ac14803186125485b23b1eac11e1bbba83f6c979e8264442d6397656fb4cb0
3
+ size 3819
textual_inversio_with_blueloss.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Import required libraries
2
+ import os
3
+ import torch
4
+ import re
5
+ from tqdm import tqdm
6
+ import PIL
7
+ from PIL import Image
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ from torchvision import transforms as tfms
12
+ from diffusers import StableDiffusionPipeline, AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
13
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
+ from focus_blur_utils import calculate_focus_blur_loss
15
+ # from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+
17
+ class TextualInversion:
18
+ def __init__(self, pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4", repo_id_embeds=["sd-concepts-library/matrix::with <hatman-matrix> concept"]):
19
+ #@markdown `pretrained_model_name_or_path` which Stable Diffusion checkpoint you want to use. This should match the one used for training the embeddings.
20
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
21
+ #@title Load your concept here
22
+ #@markdown Enter the `repo_id` for a concept you like (you can find pre-learned concepts in the public [SD Concepts Library](https://huggingface.co/sd-concepts-library))
23
+ self.repo_id_embeds = [x.split("::")[0].split("/")[-1] for x in repo_id_embeds]
24
+ self.prompts_suffixes = [x.split("::")[1] for x in repo_id_embeds]
25
+
26
+ # Set device
27
+ self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
28
+ if "mps" == self.device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
29
+
30
+ #@title Load the Stable Diffusion pipeline
31
+ # self.pipe = StableDiffusionPipeline.from_pretrained(
32
+ # pretrained_model_name_or_path,
33
+ # torch_dtype=torch.float16
34
+ # ).to(self.device)
35
+
36
+ # Load the autoencoder model which will be used to decode the latents into image space.
37
+ self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
38
+ # Load the tokenizer and text encoder to tokenize and encode the text.
39
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
40
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
41
+ # The UNet model for generating the latents.
42
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
43
+ # The noise scheduler
44
+ self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
45
+
46
+ # To the GPU we go!
47
+ self.vae = self.vae.to(self.device)
48
+ self.text_encoder = self.text_encoder.to(self.device)
49
+ self.unet = self.unet.to(self.device)
50
+
51
+ # Access the token embedding layers
52
+ # Token Embedding Layer
53
+ self.token_emb_layer = self.text_encoder.text_model.embeddings.token_embedding
54
+ # Position Embedding Layer
55
+ self.position_ids = self.text_encoder.text_model.embeddings.position_ids
56
+ self.position_emb_layer = self.text_encoder.text_model.embeddings.position_embedding
57
+
58
+ self.conceptsEmbeddings = []
59
+ for index,repo_id in enumerate(self.repo_id_embeds):
60
+ #@title Load the concept into pipeline
61
+ concept_embed_lib = torch.load("sd-concepts-library/" + self.repo_id_embeds[index] +"_learned_embeds.bin") # load the concept learned embeddings
62
+ print(self.repo_id_embeds[index])
63
+ print(concept_embed_lib.keys())
64
+ if self.repo_id_embeds[index] in concept_embed_lib.keys():
65
+ concept_embed = concept_embed_lib[self.repo_id_embeds[index]] # Read the embedding value using the key i.e. concept_embed_lib['<birb-style>']
66
+ else:
67
+ first_key, concept_embed = next(iter(concept_embed_lib.items())) # Read the first key and the embedding value
68
+
69
+ self.conceptsEmbeddings.append(concept_embed.to(self.device))
70
+ print(f"len(self.conceptsEmbeddings): {len(self.conceptsEmbeddings)}")
71
+
72
+ def _create_4d_causal_attention_mask(
73
+ input_shape: Union[torch.Size, Tuple, List],
74
+ dtype: torch.dtype,
75
+ device: torch.device,
76
+ past_key_values_length: int = 0,
77
+ sliding_window: Optional[int] = None,
78
+ ) -> Optional[torch.Tensor]:
79
+ """
80
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
81
+
82
+ Args:
83
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
84
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
85
+ dtype (`torch.dtype`):
86
+ The torch dtype the created mask shall have.
87
+ device (`int`):
88
+ The torch device the created mask shall have.
89
+ sliding_window (`int`, *optional*):
90
+ If the model uses windowed attention, a sliding window should be passed.
91
+ """
92
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
93
+
94
+ key_value_length = past_key_values_length + input_shape[-1]
95
+ attention_mask = attn_mask_converter.to_causal_4d(
96
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
97
+ )
98
+
99
+ return attention_mask
100
+
101
+ def get_output_embeds(self, input_embeddings):
102
+ # CLIP's text model uses causal mask, so we prepare it here:
103
+ bsz, seq_len = input_embeddings.shape[:2]
104
+ # causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
105
+ # causal_attention_mask = self._create_4d_causal_attention_mask(input_shape=(bsz, seq_len), dtype=input_embeddings.dtype, device=self.device)
106
+ causal_attention_mask = self.text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
107
+
108
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
109
+ # so that it doesn't just return the pooled final predictions:
110
+ encoder_outputs = self.text_encoder.text_model.encoder(
111
+ inputs_embeds=input_embeddings,
112
+ attention_mask=None, # We aren't using an attention mask so that can be None
113
+ causal_attention_mask=causal_attention_mask.to(self.device),
114
+ output_attentions=None,
115
+ output_hidden_states=True, # We want the output embs not the final output
116
+ return_dict=None,
117
+ )
118
+
119
+ # We're interested in the output hidden state only
120
+ output = encoder_outputs[0]
121
+
122
+ # There is a final layer norm we need to pass these through
123
+ output = self.text_encoder.text_model.final_layer_norm(output)
124
+
125
+ # And now they're ready!
126
+ return output
127
+
128
+ def set_timesteps(self, num_inference_steps):
129
+ self.scheduler.set_timesteps(num_inference_steps)
130
+ self.scheduler.timesteps = self.scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
131
+
132
+ def pil_to_latent(self, input_im):
133
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
134
+ with torch.no_grad():
135
+ latent = self.vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(self.device)*2-1) # Note scaling
136
+ return 0.18215 * latent.latent_dist.sample()
137
+
138
+ def latents_to_pil(self, latents):
139
+ # bath of latents -> list of images
140
+ latents = (1 / 0.18215) * latents
141
+ with torch.no_grad():
142
+ image = self.vae.decode(latents).sample
143
+ image = (image / 2 + 0.5).clamp(0, 1)
144
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
145
+ images = (image * 255).round().astype("uint8")
146
+ pil_images = [Image.fromarray(image) for image in images]
147
+ return pil_images
148
+
149
+ def grayscale_loss(self, images):
150
+ """
151
+ Calculate the grayscale loss, which measures how far the image is from being grayscale.
152
+ A grayscale image has R = G = B for each pixel.
153
+
154
+ Args:
155
+ images (torch.Tensor): A tensor of shape (batch_size, 3, H, W) where 3 corresponds to
156
+ the RGB channels of the image.
157
+
158
+ Returns:
159
+ torch.Tensor: A scalar loss value indicating how far the image is from being grayscale.
160
+ """
161
+ # Calculate the absolute difference between the channels
162
+ # images[:, 0] -> Red channel, images[:, 1] -> Green channel, images[:, 2] -> Blue channel
163
+ rg_diff = torch.abs(images[:, 0] - images[:, 1]) # R - G
164
+ gb_diff = torch.abs(images[:, 1] - images[:, 2]) # G - B
165
+ rb_diff = torch.abs(images[:, 0] - images[:, 2]) # R - B
166
+
167
+ # Compute the mean of these differences across the batch and image dimensions
168
+ loss = torch.mean(rg_diff + gb_diff + rb_diff)
169
+
170
+ return loss
171
+
172
+ def blue_loss(self, images):
173
+ # How far are the blue channel values to 0.9:
174
+ # error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
175
+ # Call grayscale loss instead of blue loss
176
+ error = self.grayscale_loss(images)
177
+ return error
178
+
179
+ def update_latents_with_blue_loss(self, latents, noise_pred, sigma, blue_loss_scale=50, print_loss = False):
180
+ # Requires grad on the latents
181
+ latents = latents.detach().requires_grad_()
182
+
183
+ # Get the predicted x0:
184
+ latents_x0 = latents - sigma * noise_pred
185
+ # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
186
+
187
+ # Decode to image space
188
+ denoised_images = self.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
189
+
190
+ # Calculate loss
191
+ loss = self.blue_loss(denoised_images) * blue_loss_scale
192
+
193
+ # # Occasionally print it out
194
+ if print_loss:
195
+ print('loss:', loss.item())
196
+
197
+ # Get gradient
198
+ cond_grad = torch.autograd.grad(loss, latents)[0]
199
+
200
+ # Modify the latents based on this gradient
201
+ latents = latents.detach() - cond_grad * sigma**2
202
+
203
+ return latents
204
+
205
+ def generate_with_embs(self, text_embeddings, generator, max_length, batch_size = 1, consider_blue_loss = False):
206
+ height = 512 # default height of Stable Diffusion
207
+ width = 512 # default width of Stable Diffusion
208
+ num_inference_steps = 50 # Number of denoising steps
209
+ guidance_scale = 7.5 # Scale for classifier-free guidance
210
+
211
+ uncond_input = self.tokenizer(
212
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
213
+ )
214
+ with torch.no_grad():
215
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
216
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
217
+
218
+ # Prep Scheduler
219
+ self.set_timesteps(num_inference_steps)
220
+
221
+ # Prep latents
222
+ latents = torch.randn(
223
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
224
+ generator=generator,
225
+ # device=self.device
226
+ )
227
+ latents = latents.to(self.device)
228
+ latents = latents * self.scheduler.init_noise_sigma
229
+
230
+ # Loop
231
+ for i, t in tqdm(enumerate(self.scheduler.timesteps), total=len(self.scheduler.timesteps)):
232
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
233
+ latent_model_input = torch.cat([latents] * 2)
234
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
235
+
236
+ # predict the noise residual
237
+ with torch.no_grad():
238
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
239
+
240
+ # perform guidance
241
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
242
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
243
+
244
+ if consider_blue_loss:
245
+ print_loss = True if i%10==0 else False
246
+ latents = self.update_latents_with_blue_loss(latents, noise_pred, self.scheduler.sigmas[i], print_loss=print_loss)
247
+
248
+ # compute the previous noisy sample x_t -> x_t-1
249
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
250
+
251
+ return self.latents_to_pil(latents)
252
+
253
+
254
+ def generate_image(self, prompt, concept_index, grayscale_image=False):
255
+ # # Get the index of the selected concept
256
+ # concept_index = self.repo_id_embeds.index(selected_concept)
257
+ prompt_to_send = prompt + " " + self.prompts_suffixes[concept_index]
258
+ print(f"Selected concept_index: {concept_index}.")
259
+ print(f"concept_index: {concept_index} Generating image for concept: {self.repo_id_embeds[concept_index]} with prompt: {prompt_to_send}")
260
+ print(f"Grayscale image: {grayscale_image}")
261
+
262
+ # replace <..> with a placeholder token that can be easily replaced with the embediing after tokenization
263
+ placeholder_text = "gloucestershire " # 33789 is the token id
264
+ prompt_to_send = re.sub(r'<.*?>', placeholder_text, prompt_to_send)
265
+ print(f"prompt after replacing placeholder token: {prompt_to_send}")
266
+ # Tokenize
267
+ text_input = self.tokenizer(prompt_to_send, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
268
+ input_ids = text_input.input_ids.to(self.device)
269
+
270
+ # Get token embeddings
271
+ token_embeddings = self.token_emb_layer(input_ids)
272
+
273
+ # The new embedding - our concept embedding for the special word token
274
+ # replacement_token_embedding = birb_embed['<birb-style>'].to(torch_device)
275
+ replacement_token_embedding = self.conceptsEmbeddings[concept_index].to(self.device)
276
+ print(f"replacement_token_embedding.shape: {replacement_token_embedding.shape} and token_embeddings.shape: {token_embeddings.shape}")
277
+ print(f"torch.where(input_ids[0]==33789): {torch.where(input_ids[0]==33789)}")
278
+ # Replace the placholder token with the concept embedding
279
+ token_embeddings[0, torch.where(input_ids[0]==33789)] = replacement_token_embedding.to(self.device)
280
+ # print(f"If embedding is replaced: {token_embeddings[0, torch.where(input_ids[0]==33789)] == replacement_token_embedding}")
281
+
282
+ B, T, C = token_embeddings.shape
283
+ # Get the position embeddings
284
+ position_embeddings = self.position_emb_layer(self.position_ids[:, :T])
285
+
286
+ # Combine with pos embs
287
+ input_embeddings = token_embeddings + position_embeddings
288
+
289
+ # Feed through to get final output embs
290
+ modified_output_embeddings = self.get_output_embeds(input_embeddings)
291
+
292
+ print(f"manual_seed: {concept_index + 11}")
293
+ generator = torch.manual_seed(concept_index + 11)
294
+ # And generate an image with this:
295
+ result = self.generate_with_embs(modified_output_embeddings, generator=generator, max_length=T, consider_blue_loss=grayscale_image)[0]
296
+
297
+ return result