Spaces:
Runtime error
Runtime error
Delete src/image_generation.py
Browse files- src/image_generation.py +0 -366
src/image_generation.py
DELETED
@@ -1,366 +0,0 @@
|
|
1 |
-
|
2 |
-
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
-
# Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
|
4 |
-
from transformers import pipeline
|
5 |
-
import torchvision
|
6 |
-
from PIL import Image
|
7 |
-
from models.t2i_pipeline import StableDiffusionPipelineSpatialAware
|
8 |
-
import torchvision.io as vision_io
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import torch
|
11 |
-
import tqdm
|
12 |
-
import numpy as np
|
13 |
-
import cv2
|
14 |
-
import warnings
|
15 |
-
import time
|
16 |
-
import tempfile
|
17 |
-
import argparse
|
18 |
-
import glob
|
19 |
-
import multiprocessing as mp
|
20 |
-
import os
|
21 |
-
import random
|
22 |
-
|
23 |
-
# fmt: off
|
24 |
-
import sys
|
25 |
-
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
26 |
-
# fmt: on
|
27 |
-
|
28 |
-
|
29 |
-
warnings.filterwarnings("ignore")
|
30 |
-
|
31 |
-
# constants
|
32 |
-
WINDOW_NAME = "demo"
|
33 |
-
|
34 |
-
|
35 |
-
def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None,
|
36 |
-
fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
|
37 |
-
'''
|
38 |
-
Main function that calls the image diffusion model
|
39 |
-
latent: input_noise from where it starts the generation
|
40 |
-
get_latents: if True, returns the latents for each frame
|
41 |
-
'''
|
42 |
-
|
43 |
-
image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
|
44 |
-
frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil',
|
45 |
-
fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0]
|
46 |
-
torch.save(image, "img.pt")
|
47 |
-
|
48 |
-
if get_latents:
|
49 |
-
video_latents = pipe(overall_prompt, latents=latents,
|
50 |
-
num_inference_steps=num_inference_steps, output_type="latent").images
|
51 |
-
torch.save(video_latents, "img_latents.pt")
|
52 |
-
return image, video_latents
|
53 |
-
|
54 |
-
return image
|
55 |
-
|
56 |
-
|
57 |
-
def save_frames(path):
|
58 |
-
video, audio, video_info = vision_io.read_video(
|
59 |
-
f"demo3/{path}.mp4", pts_unit='sec')
|
60 |
-
|
61 |
-
# Number of frames
|
62 |
-
num_frames = video.size(0)
|
63 |
-
|
64 |
-
# Save each frame
|
65 |
-
os.makedirs(f"demo3/{path}", exist_ok=True)
|
66 |
-
for i in range(num_frames):
|
67 |
-
frame = video[i, :, :, :].numpy()
|
68 |
-
# Convert from C x H x W to H x W x C and from torch tensor to PIL Image
|
69 |
-
# frame = frame.permute(1, 2, 0).numpy()
|
70 |
-
img = Image.fromarray(frame.astype('uint8'))
|
71 |
-
img.save(f"demo3/{path}/frame_{i:04d}.png")
|
72 |
-
|
73 |
-
|
74 |
-
def create_boxes():
|
75 |
-
img_width = 96
|
76 |
-
img_height = 96
|
77 |
-
|
78 |
-
# initialize bboxes list
|
79 |
-
sbboxes = []
|
80 |
-
|
81 |
-
# object dimensions
|
82 |
-
for object_size in [20, 30, 40, 50, 60]:
|
83 |
-
obj_width, obj_height = object_size, object_size
|
84 |
-
|
85 |
-
# starting position
|
86 |
-
start_x = 3
|
87 |
-
start_y = 4
|
88 |
-
|
89 |
-
# calculate total size occupied by the objects in the grid
|
90 |
-
total_obj_width = 3 * obj_width
|
91 |
-
total_obj_height = 3 * obj_height
|
92 |
-
|
93 |
-
# determine horizontal and vertical spacings
|
94 |
-
spacing_horizontal = (img_width - total_obj_width - start_x) // 2
|
95 |
-
spacing_vertical = (img_height - total_obj_height - start_y) // 2
|
96 |
-
|
97 |
-
for i in range(3):
|
98 |
-
for j in range(3):
|
99 |
-
x_start = start_x + i * (obj_width + spacing_horizontal)
|
100 |
-
y_start = start_y + j * (obj_height + spacing_vertical)
|
101 |
-
# Corrected to img_width to include the last pixel
|
102 |
-
x_end = min(x_start + obj_width, img_width)
|
103 |
-
# Corrected to img_height to include the last pixel
|
104 |
-
y_end = min(y_start + obj_height, img_height)
|
105 |
-
sbboxes.append([x_start, y_start, x_end, y_end])
|
106 |
-
|
107 |
-
mask_id = 0
|
108 |
-
masks_list = []
|
109 |
-
|
110 |
-
for sbbox in sbboxes:
|
111 |
-
smask = torch.zeros(1, 1, 96, 96)
|
112 |
-
smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0
|
113 |
-
masks_list.append(smask)
|
114 |
-
# torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images
|
115 |
-
mask_id += 1
|
116 |
-
|
117 |
-
return masks_list
|
118 |
-
|
119 |
-
|
120 |
-
def objects_list():
|
121 |
-
objects_settings = [
|
122 |
-
("apple", "on a table"),
|
123 |
-
("ball", "in a park"),
|
124 |
-
("cat", "on a couch"),
|
125 |
-
("dog", "in a backyard"),
|
126 |
-
("elephant", "in a jungle"),
|
127 |
-
("fountain pen", "on a desk"),
|
128 |
-
("guitar", "on a stage"),
|
129 |
-
("helicopter", "in the sky"),
|
130 |
-
("island", "in the sea"),
|
131 |
-
("jar", "on a shelf"),
|
132 |
-
("kite", "in the sky"),
|
133 |
-
("lamp", "in a room"),
|
134 |
-
("motorbike", "on a road"),
|
135 |
-
("notebook", "on a table"),
|
136 |
-
("owl", "on a tree"),
|
137 |
-
("piano", "in a hall"),
|
138 |
-
("queen", "in a castle"),
|
139 |
-
("robot", "in a lab"),
|
140 |
-
("snake", "in a forest"),
|
141 |
-
("tent", "in the mountains"),
|
142 |
-
("umbrella", "on a beach"),
|
143 |
-
("violin", "in an orchestra"),
|
144 |
-
("wheel", "in a garage"),
|
145 |
-
("xylophone", "in a music class"),
|
146 |
-
("yacht", "in a marina"),
|
147 |
-
("zebra", "in a savannah"),
|
148 |
-
("aeroplane", "in the clouds"),
|
149 |
-
("bridge", "over a river"),
|
150 |
-
("computer", "in an office"),
|
151 |
-
("dragon", "in a cave"),
|
152 |
-
("egg", "in a nest"),
|
153 |
-
("flower", "in a garden"),
|
154 |
-
("globe", "in a library"),
|
155 |
-
("hat", "on a rack"),
|
156 |
-
("ice cube", "in a glass"),
|
157 |
-
("jewelry", "in a box"),
|
158 |
-
("kangaroo", "in a desert"),
|
159 |
-
("lion", "in a den"),
|
160 |
-
("mug", "on a counter"),
|
161 |
-
("nest", "on a branch"),
|
162 |
-
("octopus", "in the ocean"),
|
163 |
-
("parrot", "in a rainforest"),
|
164 |
-
("quilt", "on a bed"),
|
165 |
-
("rose", "in a vase"),
|
166 |
-
("ship", "in a dock"),
|
167 |
-
("train", "on the tracks"),
|
168 |
-
("utensils", "in a kitchen"),
|
169 |
-
("vase", "on a window sill"),
|
170 |
-
("watch", "in a store"),
|
171 |
-
("x-ray", "in a hospital"),
|
172 |
-
("yarn", "in a basket"),
|
173 |
-
("zeppelin", "above a city"),
|
174 |
-
]
|
175 |
-
objects_settings.extend([
|
176 |
-
("muffin", "on a bakery shelf"),
|
177 |
-
("notebook", "on a student's desk"),
|
178 |
-
("owl", "in a tree"),
|
179 |
-
("piano", "in a concert hall"),
|
180 |
-
("quill", "on parchment"),
|
181 |
-
("robot", "in a factory"),
|
182 |
-
("snake", "in the grass"),
|
183 |
-
("telescope", "in an observatory"),
|
184 |
-
("umbrella", "at the beach"),
|
185 |
-
("violin", "in an orchestra"),
|
186 |
-
("whale", "in the ocean"),
|
187 |
-
("xylophone", "in a music store"),
|
188 |
-
("yacht", "in a marina"),
|
189 |
-
("zebra", "on a savanna"),
|
190 |
-
|
191 |
-
# Kitchen items
|
192 |
-
("spoon", "in a drawer"),
|
193 |
-
("plate", "in a cupboard"),
|
194 |
-
("cup", "on a shelf"),
|
195 |
-
("frying pan", "on a stove"),
|
196 |
-
("jar", "in the refrigerator"),
|
197 |
-
|
198 |
-
# Office items
|
199 |
-
("computer", "in an office"),
|
200 |
-
("printer", "by a desk"),
|
201 |
-
("chair", "around a conference table"),
|
202 |
-
("lamp", "on a workbench"),
|
203 |
-
("calendar", "on a wall"),
|
204 |
-
|
205 |
-
# Outdoor items
|
206 |
-
("bicycle", "on a street"),
|
207 |
-
("tent", "in a campsite"),
|
208 |
-
("fire", "in a fireplace"),
|
209 |
-
("mountain", "in the distance"),
|
210 |
-
("river", "through the woods"),
|
211 |
-
|
212 |
-
|
213 |
-
# and so on ...
|
214 |
-
])
|
215 |
-
|
216 |
-
# To expedite the generation, you can combine themes and objects:
|
217 |
-
|
218 |
-
themes = [
|
219 |
-
("wild animals", ["tiger", "lion", "cheetah",
|
220 |
-
"giraffe", "hippopotamus"], "in the wild"),
|
221 |
-
("household items", ["sofa", "tv", "clock",
|
222 |
-
"vase", "photo frame"], "in a living room"),
|
223 |
-
("clothes", ["shirt", "pants", "shoes",
|
224 |
-
"hat", "jacket"], "in a wardrobe"),
|
225 |
-
("musical instruments", ["drum", "trumpet",
|
226 |
-
"harp", "saxophone", "tuba"], "in a band"),
|
227 |
-
("cosmic entities", ["planet", "star",
|
228 |
-
"comet", "nebula", "asteroid"], "in space"),
|
229 |
-
# ... add more themes
|
230 |
-
]
|
231 |
-
|
232 |
-
# Using the themes to extend our list
|
233 |
-
for theme_name, theme_objects, theme_location in themes:
|
234 |
-
for theme_object in theme_objects:
|
235 |
-
objects_settings.append((theme_object, theme_location))
|
236 |
-
|
237 |
-
# Sports equipment
|
238 |
-
objects_settings.extend([
|
239 |
-
("basketball", "on a court"),
|
240 |
-
("golf ball", "on a golf course"),
|
241 |
-
("tennis racket", "on a tennis court"),
|
242 |
-
("baseball bat", "in a stadium"),
|
243 |
-
("hockey stick", "on an ice rink"),
|
244 |
-
("football", "on a field"),
|
245 |
-
("skateboard", "in a skatepark"),
|
246 |
-
("boxing gloves", "in a boxing ring"),
|
247 |
-
("ski", "on a snowy slope"),
|
248 |
-
("surfboard", "on a beach shore"),
|
249 |
-
])
|
250 |
-
|
251 |
-
# Toys and games
|
252 |
-
objects_settings.extend([
|
253 |
-
("teddy bear", "on a child's bed"),
|
254 |
-
("doll", "in a toy store"),
|
255 |
-
("toy car", "on a carpet"),
|
256 |
-
("board game", "on a table"),
|
257 |
-
("yo-yo", "in a child's hand"),
|
258 |
-
("kite", "in the sky on a windy day"),
|
259 |
-
("Lego bricks", "on a construction table"),
|
260 |
-
("jigsaw puzzle", "partially completed"),
|
261 |
-
("rubik's cube", "on a shelf"),
|
262 |
-
("action figure", "on display"),
|
263 |
-
])
|
264 |
-
|
265 |
-
# Transportation
|
266 |
-
objects_settings.extend([
|
267 |
-
("bus", "at a bus stop"),
|
268 |
-
("motorcycle", "on a road"),
|
269 |
-
("helicopter", "landing on a pad"),
|
270 |
-
("scooter", "on a sidewalk"),
|
271 |
-
("train", "at a station"),
|
272 |
-
("bicycle", "parked by a post"),
|
273 |
-
("boat", "in a harbor"),
|
274 |
-
("tractor", "on a farm"),
|
275 |
-
("airplane", "taking off from a runway"),
|
276 |
-
("submarine", "below sea level"),
|
277 |
-
])
|
278 |
-
|
279 |
-
# Medieval theme
|
280 |
-
objects_settings.extend([
|
281 |
-
("castle", "on a hilltop"),
|
282 |
-
("knight", "riding a horse"),
|
283 |
-
("bow and arrow", "in an archery range"),
|
284 |
-
("crown", "in a treasure chest"),
|
285 |
-
("dragon", "flying over mountains"),
|
286 |
-
("shield", "next to a warrior"),
|
287 |
-
("dagger", "on a wooden table"),
|
288 |
-
("torch", "lighting a dark corridor"),
|
289 |
-
("scroll", "sealed with wax"),
|
290 |
-
("cauldron", "with bubbling potion"),
|
291 |
-
])
|
292 |
-
|
293 |
-
# Modern technology
|
294 |
-
objects_settings.extend([
|
295 |
-
("smartphone", "on a charger"),
|
296 |
-
("laptop", "in a cafe"),
|
297 |
-
("headphones", "around a neck"),
|
298 |
-
("camera", "on a tripod"),
|
299 |
-
("drone", "flying over a park"),
|
300 |
-
("USB stick", "plugged into a computer"),
|
301 |
-
("watch", "on a wrist"),
|
302 |
-
("microphone", "on a podcast desk"),
|
303 |
-
("tablet", "with a digital pen"),
|
304 |
-
("VR headset", "ready for gaming"),
|
305 |
-
])
|
306 |
-
|
307 |
-
# Nature
|
308 |
-
objects_settings.extend([
|
309 |
-
("tree", "in a forest"),
|
310 |
-
("flower", "in a garden"),
|
311 |
-
("mountain", "on a horizon"),
|
312 |
-
("cloud", "in a blue sky"),
|
313 |
-
("waterfall", "in a scenic location"),
|
314 |
-
("beach", "next to an ocean"),
|
315 |
-
("cactus", "in a desert"),
|
316 |
-
("volcano", "erupting with lava"),
|
317 |
-
("coral", "under the sea"),
|
318 |
-
("moon", "in a night sky"),
|
319 |
-
])
|
320 |
-
|
321 |
-
prompts = [f"A {obj} {setting}" for obj, setting in objects_settings]
|
322 |
-
|
323 |
-
return objects_settings
|
324 |
-
|
325 |
-
|
326 |
-
if __name__ == "__main__":
|
327 |
-
SAVE_DIR = "/scr/image/"
|
328 |
-
save_path = "img43-att_mask"
|
329 |
-
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
330 |
-
random_latents = torch.randn(
|
331 |
-
[1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device)
|
332 |
-
|
333 |
-
try:
|
334 |
-
pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
|
335 |
-
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device)
|
336 |
-
except:
|
337 |
-
pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
|
338 |
-
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device)
|
339 |
-
|
340 |
-
fg_object = "apple" # fg object stores the object to be masked
|
341 |
-
# overall prompt stores the prompt
|
342 |
-
overall_prompt = f"An {fg_object} on plate"
|
343 |
-
os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True)
|
344 |
-
|
345 |
-
masks_list = create_boxes()
|
346 |
-
|
347 |
-
# torch.save(f"{overall_prompt}+masked", "prompt.pt")
|
348 |
-
obj_settings = objects_list() # 166
|
349 |
-
for obj_setting in obj_settings[120:]:
|
350 |
-
fg_object = obj_setting[0]
|
351 |
-
overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}"
|
352 |
-
print(overall_prompt)
|
353 |
-
|
354 |
-
# randomly select 10 numbers from range len of masks_list
|
355 |
-
selected_mask_ids = random.sample(range(len(masks_list)), 3)
|
356 |
-
for mask_id in selected_mask_ids:
|
357 |
-
os.makedirs(
|
358 |
-
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True)
|
359 |
-
torchvision.utils.save_image(
|
360 |
-
masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png")
|
361 |
-
for frozen_steps in range(0, 5):
|
362 |
-
img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to(
|
363 |
-
torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object)
|
364 |
-
|
365 |
-
img.save(
|
366 |
-
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|