anshuln commited on
Commit
6e13e89
·
verified ·
1 Parent(s): 2d30cbe

Delete src/image_generation.py

Browse files
Files changed (1) hide show
  1. src/image_generation.py +0 -366
src/image_generation.py DELETED
@@ -1,366 +0,0 @@
1
-
2
- # Copyright (c) Facebook, Inc. and its affiliates.
3
- # Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
4
- from transformers import pipeline
5
- import torchvision
6
- from PIL import Image
7
- from models.t2i_pipeline import StableDiffusionPipelineSpatialAware
8
- import torchvision.io as vision_io
9
- import torch.nn.functional as F
10
- import torch
11
- import tqdm
12
- import numpy as np
13
- import cv2
14
- import warnings
15
- import time
16
- import tempfile
17
- import argparse
18
- import glob
19
- import multiprocessing as mp
20
- import os
21
- import random
22
-
23
- # fmt: off
24
- import sys
25
- sys.path.insert(1, os.path.join(sys.path[0], '..'))
26
- # fmt: on
27
-
28
-
29
- warnings.filterwarnings("ignore")
30
-
31
- # constants
32
- WINDOW_NAME = "demo"
33
-
34
-
35
- def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None,
36
- fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
37
- '''
38
- Main function that calls the image diffusion model
39
- latent: input_noise from where it starts the generation
40
- get_latents: if True, returns the latents for each frame
41
- '''
42
-
43
- image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
44
- frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil',
45
- fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0]
46
- torch.save(image, "img.pt")
47
-
48
- if get_latents:
49
- video_latents = pipe(overall_prompt, latents=latents,
50
- num_inference_steps=num_inference_steps, output_type="latent").images
51
- torch.save(video_latents, "img_latents.pt")
52
- return image, video_latents
53
-
54
- return image
55
-
56
-
57
- def save_frames(path):
58
- video, audio, video_info = vision_io.read_video(
59
- f"demo3/{path}.mp4", pts_unit='sec')
60
-
61
- # Number of frames
62
- num_frames = video.size(0)
63
-
64
- # Save each frame
65
- os.makedirs(f"demo3/{path}", exist_ok=True)
66
- for i in range(num_frames):
67
- frame = video[i, :, :, :].numpy()
68
- # Convert from C x H x W to H x W x C and from torch tensor to PIL Image
69
- # frame = frame.permute(1, 2, 0).numpy()
70
- img = Image.fromarray(frame.astype('uint8'))
71
- img.save(f"demo3/{path}/frame_{i:04d}.png")
72
-
73
-
74
- def create_boxes():
75
- img_width = 96
76
- img_height = 96
77
-
78
- # initialize bboxes list
79
- sbboxes = []
80
-
81
- # object dimensions
82
- for object_size in [20, 30, 40, 50, 60]:
83
- obj_width, obj_height = object_size, object_size
84
-
85
- # starting position
86
- start_x = 3
87
- start_y = 4
88
-
89
- # calculate total size occupied by the objects in the grid
90
- total_obj_width = 3 * obj_width
91
- total_obj_height = 3 * obj_height
92
-
93
- # determine horizontal and vertical spacings
94
- spacing_horizontal = (img_width - total_obj_width - start_x) // 2
95
- spacing_vertical = (img_height - total_obj_height - start_y) // 2
96
-
97
- for i in range(3):
98
- for j in range(3):
99
- x_start = start_x + i * (obj_width + spacing_horizontal)
100
- y_start = start_y + j * (obj_height + spacing_vertical)
101
- # Corrected to img_width to include the last pixel
102
- x_end = min(x_start + obj_width, img_width)
103
- # Corrected to img_height to include the last pixel
104
- y_end = min(y_start + obj_height, img_height)
105
- sbboxes.append([x_start, y_start, x_end, y_end])
106
-
107
- mask_id = 0
108
- masks_list = []
109
-
110
- for sbbox in sbboxes:
111
- smask = torch.zeros(1, 1, 96, 96)
112
- smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0
113
- masks_list.append(smask)
114
- # torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images
115
- mask_id += 1
116
-
117
- return masks_list
118
-
119
-
120
- def objects_list():
121
- objects_settings = [
122
- ("apple", "on a table"),
123
- ("ball", "in a park"),
124
- ("cat", "on a couch"),
125
- ("dog", "in a backyard"),
126
- ("elephant", "in a jungle"),
127
- ("fountain pen", "on a desk"),
128
- ("guitar", "on a stage"),
129
- ("helicopter", "in the sky"),
130
- ("island", "in the sea"),
131
- ("jar", "on a shelf"),
132
- ("kite", "in the sky"),
133
- ("lamp", "in a room"),
134
- ("motorbike", "on a road"),
135
- ("notebook", "on a table"),
136
- ("owl", "on a tree"),
137
- ("piano", "in a hall"),
138
- ("queen", "in a castle"),
139
- ("robot", "in a lab"),
140
- ("snake", "in a forest"),
141
- ("tent", "in the mountains"),
142
- ("umbrella", "on a beach"),
143
- ("violin", "in an orchestra"),
144
- ("wheel", "in a garage"),
145
- ("xylophone", "in a music class"),
146
- ("yacht", "in a marina"),
147
- ("zebra", "in a savannah"),
148
- ("aeroplane", "in the clouds"),
149
- ("bridge", "over a river"),
150
- ("computer", "in an office"),
151
- ("dragon", "in a cave"),
152
- ("egg", "in a nest"),
153
- ("flower", "in a garden"),
154
- ("globe", "in a library"),
155
- ("hat", "on a rack"),
156
- ("ice cube", "in a glass"),
157
- ("jewelry", "in a box"),
158
- ("kangaroo", "in a desert"),
159
- ("lion", "in a den"),
160
- ("mug", "on a counter"),
161
- ("nest", "on a branch"),
162
- ("octopus", "in the ocean"),
163
- ("parrot", "in a rainforest"),
164
- ("quilt", "on a bed"),
165
- ("rose", "in a vase"),
166
- ("ship", "in a dock"),
167
- ("train", "on the tracks"),
168
- ("utensils", "in a kitchen"),
169
- ("vase", "on a window sill"),
170
- ("watch", "in a store"),
171
- ("x-ray", "in a hospital"),
172
- ("yarn", "in a basket"),
173
- ("zeppelin", "above a city"),
174
- ]
175
- objects_settings.extend([
176
- ("muffin", "on a bakery shelf"),
177
- ("notebook", "on a student's desk"),
178
- ("owl", "in a tree"),
179
- ("piano", "in a concert hall"),
180
- ("quill", "on parchment"),
181
- ("robot", "in a factory"),
182
- ("snake", "in the grass"),
183
- ("telescope", "in an observatory"),
184
- ("umbrella", "at the beach"),
185
- ("violin", "in an orchestra"),
186
- ("whale", "in the ocean"),
187
- ("xylophone", "in a music store"),
188
- ("yacht", "in a marina"),
189
- ("zebra", "on a savanna"),
190
-
191
- # Kitchen items
192
- ("spoon", "in a drawer"),
193
- ("plate", "in a cupboard"),
194
- ("cup", "on a shelf"),
195
- ("frying pan", "on a stove"),
196
- ("jar", "in the refrigerator"),
197
-
198
- # Office items
199
- ("computer", "in an office"),
200
- ("printer", "by a desk"),
201
- ("chair", "around a conference table"),
202
- ("lamp", "on a workbench"),
203
- ("calendar", "on a wall"),
204
-
205
- # Outdoor items
206
- ("bicycle", "on a street"),
207
- ("tent", "in a campsite"),
208
- ("fire", "in a fireplace"),
209
- ("mountain", "in the distance"),
210
- ("river", "through the woods"),
211
-
212
-
213
- # and so on ...
214
- ])
215
-
216
- # To expedite the generation, you can combine themes and objects:
217
-
218
- themes = [
219
- ("wild animals", ["tiger", "lion", "cheetah",
220
- "giraffe", "hippopotamus"], "in the wild"),
221
- ("household items", ["sofa", "tv", "clock",
222
- "vase", "photo frame"], "in a living room"),
223
- ("clothes", ["shirt", "pants", "shoes",
224
- "hat", "jacket"], "in a wardrobe"),
225
- ("musical instruments", ["drum", "trumpet",
226
- "harp", "saxophone", "tuba"], "in a band"),
227
- ("cosmic entities", ["planet", "star",
228
- "comet", "nebula", "asteroid"], "in space"),
229
- # ... add more themes
230
- ]
231
-
232
- # Using the themes to extend our list
233
- for theme_name, theme_objects, theme_location in themes:
234
- for theme_object in theme_objects:
235
- objects_settings.append((theme_object, theme_location))
236
-
237
- # Sports equipment
238
- objects_settings.extend([
239
- ("basketball", "on a court"),
240
- ("golf ball", "on a golf course"),
241
- ("tennis racket", "on a tennis court"),
242
- ("baseball bat", "in a stadium"),
243
- ("hockey stick", "on an ice rink"),
244
- ("football", "on a field"),
245
- ("skateboard", "in a skatepark"),
246
- ("boxing gloves", "in a boxing ring"),
247
- ("ski", "on a snowy slope"),
248
- ("surfboard", "on a beach shore"),
249
- ])
250
-
251
- # Toys and games
252
- objects_settings.extend([
253
- ("teddy bear", "on a child's bed"),
254
- ("doll", "in a toy store"),
255
- ("toy car", "on a carpet"),
256
- ("board game", "on a table"),
257
- ("yo-yo", "in a child's hand"),
258
- ("kite", "in the sky on a windy day"),
259
- ("Lego bricks", "on a construction table"),
260
- ("jigsaw puzzle", "partially completed"),
261
- ("rubik's cube", "on a shelf"),
262
- ("action figure", "on display"),
263
- ])
264
-
265
- # Transportation
266
- objects_settings.extend([
267
- ("bus", "at a bus stop"),
268
- ("motorcycle", "on a road"),
269
- ("helicopter", "landing on a pad"),
270
- ("scooter", "on a sidewalk"),
271
- ("train", "at a station"),
272
- ("bicycle", "parked by a post"),
273
- ("boat", "in a harbor"),
274
- ("tractor", "on a farm"),
275
- ("airplane", "taking off from a runway"),
276
- ("submarine", "below sea level"),
277
- ])
278
-
279
- # Medieval theme
280
- objects_settings.extend([
281
- ("castle", "on a hilltop"),
282
- ("knight", "riding a horse"),
283
- ("bow and arrow", "in an archery range"),
284
- ("crown", "in a treasure chest"),
285
- ("dragon", "flying over mountains"),
286
- ("shield", "next to a warrior"),
287
- ("dagger", "on a wooden table"),
288
- ("torch", "lighting a dark corridor"),
289
- ("scroll", "sealed with wax"),
290
- ("cauldron", "with bubbling potion"),
291
- ])
292
-
293
- # Modern technology
294
- objects_settings.extend([
295
- ("smartphone", "on a charger"),
296
- ("laptop", "in a cafe"),
297
- ("headphones", "around a neck"),
298
- ("camera", "on a tripod"),
299
- ("drone", "flying over a park"),
300
- ("USB stick", "plugged into a computer"),
301
- ("watch", "on a wrist"),
302
- ("microphone", "on a podcast desk"),
303
- ("tablet", "with a digital pen"),
304
- ("VR headset", "ready for gaming"),
305
- ])
306
-
307
- # Nature
308
- objects_settings.extend([
309
- ("tree", "in a forest"),
310
- ("flower", "in a garden"),
311
- ("mountain", "on a horizon"),
312
- ("cloud", "in a blue sky"),
313
- ("waterfall", "in a scenic location"),
314
- ("beach", "next to an ocean"),
315
- ("cactus", "in a desert"),
316
- ("volcano", "erupting with lava"),
317
- ("coral", "under the sea"),
318
- ("moon", "in a night sky"),
319
- ])
320
-
321
- prompts = [f"A {obj} {setting}" for obj, setting in objects_settings]
322
-
323
- return objects_settings
324
-
325
-
326
- if __name__ == "__main__":
327
- SAVE_DIR = "/scr/image/"
328
- save_path = "img43-att_mask"
329
- torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
330
- random_latents = torch.randn(
331
- [1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device)
332
-
333
- try:
334
- pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
335
- "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device)
336
- except:
337
- pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
338
- "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device)
339
-
340
- fg_object = "apple" # fg object stores the object to be masked
341
- # overall prompt stores the prompt
342
- overall_prompt = f"An {fg_object} on plate"
343
- os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True)
344
-
345
- masks_list = create_boxes()
346
-
347
- # torch.save(f"{overall_prompt}+masked", "prompt.pt")
348
- obj_settings = objects_list() # 166
349
- for obj_setting in obj_settings[120:]:
350
- fg_object = obj_setting[0]
351
- overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}"
352
- print(overall_prompt)
353
-
354
- # randomly select 10 numbers from range len of masks_list
355
- selected_mask_ids = random.sample(range(len(masks_list)), 3)
356
- for mask_id in selected_mask_ids:
357
- os.makedirs(
358
- f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True)
359
- torchvision.utils.save_image(
360
- masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png")
361
- for frozen_steps in range(0, 5):
362
- img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to(
363
- torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object)
364
-
365
- img.save(
366
- f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png")