mridulk commited on
Commit
ae6263a
·
1 Parent(s): 4d5ae96

added files

Browse files
Files changed (1) hide show
  1. sample_level_encoding.py +274 -0
sample_level_encoding.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import torch
3
+ import pickle
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from PIL import Image
7
+ from tqdm import tqdm, trange
8
+ from einops import rearrange
9
+ from torchvision.utils import make_grid
10
+
11
+ from ldm.util import instantiate_from_config
12
+ from ldm.models.diffusion.ddim import DDIMSampler
13
+ from ldm.models.diffusion.plms import PLMSSampler
14
+
15
+
16
+
17
+ def load_model_from_config(config, ckpt, verbose=False):
18
+ print(f"Loading model from {ckpt}")
19
+ # pl_sd = torch.load(ckpt, map_location="cpu")
20
+ pl_sd = torch.load(ckpt)#, map_location="cpu")
21
+ sd = pl_sd["state_dict"]
22
+ model = instantiate_from_config(config.model)
23
+ m, u = model.load_state_dict(sd, strict=False)
24
+ if len(m) > 0 and verbose:
25
+ print("missing keys:")
26
+ print(m)
27
+ if len(u) > 0 and verbose:
28
+ print("unexpected keys:")
29
+ print(u)
30
+
31
+ model.cuda()
32
+ model.eval()
33
+ return model
34
+
35
+
36
+ if __name__ == "__main__":
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument(
40
+ "--prompt",
41
+ type=str,
42
+ nargs="?",
43
+ default="a painting of a virus monster playing guitar",
44
+ help="the prompt to render"
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--outdir",
49
+ type=str,
50
+ nargs="?",
51
+ help="dir to write results to",
52
+ default="outputs/txt2img-samples"
53
+ )
54
+ parser.add_argument(
55
+ "--ddim_steps",
56
+ type=int,
57
+ default=200,
58
+ help="number of ddim sampling steps",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--plms",
63
+ action='store_true',
64
+ help="use plms sampling",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--ddim_eta",
69
+ type=float,
70
+ default=1.0,
71
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
72
+ )
73
+ parser.add_argument(
74
+ "--n_iter",
75
+ type=int,
76
+ default=1,
77
+ help="sample this often",
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--H",
82
+ type=int,
83
+ default=256,
84
+ help="image height, in pixel space",
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--W",
89
+ type=int,
90
+ default=256,
91
+ help="image width, in pixel space",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--n_samples",
96
+ type=int,
97
+ default=4,
98
+ help="how many samples to produce for the given prompt",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--output_dir_name",
103
+ type=str,
104
+ default='default_file',
105
+ help="name of folder",
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--postfix",
110
+ type=str,
111
+ default='',
112
+ help="name of folder",
113
+ )
114
+
115
+ parser.add_argument(
116
+ "--scale",
117
+ type=float,
118
+ # default=5.0,
119
+ default=1.0,
120
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
121
+ )
122
+ opt = parser.parse_args()
123
+
124
+ # --scale 1.0 --n_samples 3 --ddim_steps 20
125
+
126
+ # # #### CLIP f4
127
+ # config_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-34-23_CLIP_f4_maxlen77_classname/configs/2023-11-09T15-34-23-project.yaml'
128
+ # ckpt_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-34-23_CLIP_f4_maxlen77_classname/checkpoints/epoch=000158.ckpt'
129
+
130
+ # # #### CLIP f8
131
+ # config_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-30-05_CLIP_f8_maxlen77_classname/configs/2023-11-09T15-30-05-project.yaml'
132
+ # ckpt_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-30-05_CLIP_f8_maxlen77_classname/checkpoints/epoch=000119.ckpt'
133
+
134
+ #### Label Encoding
135
+ # config_path = '/globalscratch/mridul/ldm/test/test_bert/2023-11-13T23-08-55_TEST_f4_ancestral_label_encoding/configs/2023-11-13T23-08-55-project.yaml'
136
+ # ckpt_path = '/globalscratch/mridul/ldm/test/test_bert/2023-11-13T23-08-55_TEST_f4_ancestral_label_encoding/checkpoints/epoch=000119.ckpt'
137
+
138
+ #### Label Encoding Leave one out
139
+ # config_path = '/globalscratch/mridul/ldm/level_encoding/leave_out/2023-12-01T01-49-15_HLE_f4_label_encoding_leave_out/configs/2023-12-01T01-49-15-project.yaml'
140
+ # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/leave_out/2023-12-01T01-49-15_HLE_f4_label_encoding_leave_out/checkpoints/epoch=000131.ckpt'
141
+
142
+ # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2023-12-03T09-33-45_HLE_f4_level_encoding_371/checkpoints/epoch=000119.ckpt'
143
+ # config_path = '/globalscratch/mridul/ldm/level_encoding/2023-12-03T09-33-45_HLE_f4_level_encoding_371/configs/2023-12-03T09-33-45-project.yaml'
144
+
145
+
146
+ # ### scale 1.25 - 137 epoch
147
+ # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale1.25/checkpoints/epoch=000119.ckpt'
148
+ # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale1.25/configs/2024-01-29T21-52-36-project.yaml'
149
+
150
+ ### scale 1.5 - 137 epoch
151
+ # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-33-03_HLE_f4_scale1.5/checkpoints/epoch=000119.ckpt'
152
+ # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-33-03_HLE_f4_scale1.5/configs/2024-01-29T20-33-03-project.yaml'
153
+
154
+
155
+ # ### scale 2 - 137 epoch
156
+ # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale2/checkpoints/epoch=000095.ckpt'
157
+ # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale2/configs/2024-01-29T21-52-36-project.yaml'
158
+
159
+ # ### scale 5 - 137 epoch
160
+ # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-32_HLE_f4_scale5/checkpoints/epoch=000095.ckpt'
161
+ # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-32_HLE_f4_scale5/configs/2024-01-29T20-26-32-project.yaml'
162
+
163
+ # ### scale 10 - 137 epoch
164
+ # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-02_HLE_f4_scale10/checkpoints/epoch=000101.ckpt'
165
+ # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-02_HLE_f4_scale10/configs/2024-01-29T20-26-02-project.yaml'
166
+
167
+ ###### hle 371,
168
+ ckpt_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/checkpoints/epoch=000119.ckpt'
169
+ config_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/configs/2024-03-01T23-15-36-project.yaml'
170
+
171
+
172
+ label_to_class_mapping = {0: 'Alosa-chrysochloris', 1: 'Carassius-auratus', 2: 'Cyprinus-carpio', 3: 'Esox-americanus',
173
+ 4: 'Gambusia-affinis', 5: 'Lepisosteus-osseus', 6: 'Lepisosteus-platostomus', 7: 'Lepomis-auritus', 8: 'Lepomis-cyanellus',
174
+ 9: 'Lepomis-gibbosus', 10: 'Lepomis-gulosus', 11: 'Lepomis-humilis', 12: 'Lepomis-macrochirus', 13: 'Lepomis-megalotis',
175
+ 14: 'Lepomis-microlophus', 15: 'Morone-chrysops', 16: 'Morone-mississippiensis', 17: 'Notropis-atherinoides',
176
+ 18: 'Notropis-blennius', 19: 'Notropis-boops', 20: 'Notropis-buccatus', 21: 'Notropis-buchanani', 22: 'Notropis-dorsalis',
177
+ 23: 'Notropis-hudsonius', 24: 'Notropis-leuciodus', 25: 'Notropis-nubilus', 26: 'Notropis-percobromus',
178
+ 27: 'Notropis-stramineus', 28: 'Notropis-telescopus', 29: 'Notropis-texanus', 30: 'Notropis-volucellus',
179
+ 31: 'Notropis-wickliffi', 32: 'Noturus-exilis', 33: 'Noturus-flavus', 34: 'Noturus-gyrinus', 35: 'Noturus-miurus',
180
+ 36: 'Noturus-nocturnus', 37: 'Phenacobius-mirabilis'}
181
+
182
+ def get_label_from_class(class_name):
183
+ for key, value in label_to_class_mapping.items():
184
+ if value == class_name:
185
+ return key
186
+
187
+ config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
188
+ model = load_model_from_config(config, ckpt_path) # TODO: check path
189
+
190
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
191
+ model = model.to(device)
192
+
193
+ if opt.plms:
194
+ sampler = PLMSSampler(model)
195
+ else:
196
+ sampler = DDIMSampler(model)
197
+
198
+ os.makedirs(opt.outdir, exist_ok=True)
199
+ outpath = opt.outdir
200
+
201
+ prompt = opt.prompt
202
+ all_images = []
203
+ labels = []
204
+
205
+ class_to_node = '/fastscratch/mridul/fishes/class_to_ancestral_label.pkl'
206
+ with open(class_to_node, 'rb') as pickle_file:
207
+ class_to_node_dict = pickle.load(pickle_file)
208
+
209
+ sample_path = os.path.join(outpath, opt.output_dir_name)
210
+ os.makedirs(sample_path, exist_ok=True)
211
+ base_count = len(os.listdir(sample_path))
212
+
213
+ for class_name, node_representation in tqdm(class_to_node_dict.items()):
214
+ prompt = node_representation
215
+ all_samples=list()
216
+ with torch.no_grad():
217
+ with model.ema_scope():
218
+ uc = None
219
+ # if opt.scale != 1.0:
220
+ # uc = model.get_learned_conditioning(opt.n_samples * [""])
221
+ for n in trange(opt.n_iter, desc="Sampling"):
222
+
223
+ all_prompts = opt.n_samples * (prompt)
224
+ all_prompts = [tuple(all_prompts)]
225
+ print(class_name, prompt)
226
+ c = model.get_learned_conditioning({'class_to_node': all_prompts})
227
+ shape = [3, 64, 64]
228
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
229
+ conditioning=c,
230
+ batch_size=opt.n_samples,
231
+ shape=shape,
232
+ verbose=False,
233
+ unconditional_guidance_scale=opt.scale,
234
+ unconditional_conditioning=uc,
235
+ eta=opt.ddim_eta)
236
+
237
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
238
+ x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
239
+
240
+ all_samples.append(x_samples_ddim)
241
+
242
+ ###### to make grid
243
+ # additionally, save as grid
244
+ grid = torch.stack(all_samples, 0)
245
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
246
+ grid = make_grid(grid, nrow=opt.n_samples)
247
+
248
+ # to image
249
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
250
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(sample_path, f'{class_name.replace(" ", "-")}.png'))
251
+
252
+ # # individual images
253
+ # grid = torch.stack(all_samples, 0)
254
+ # grid = rearrange(grid, 'n b c h w -> (n b) c h w')
255
+
256
+ # for i in range(opt.n_samples):
257
+ # sample = grid[i]
258
+ # img = 255. * rearrange(sample, 'c h w -> h w c').cpu().numpy()
259
+ # img_arr = img.astype(np.uint8)
260
+ # class_name = class_name.replace(" ", "-")
261
+ # all_images.append(img_arr)
262
+ # labels.append(get_label_from_class(class_name))
263
+ # Image.fromarray(img_arr).save(f'{sample_path}/{class_name}_{i}.png')
264
+
265
+ # all_images = np.array(all_images)
266
+ # labels = np.array(labels)
267
+
268
+ # np.savez(sample_path + '.npz', all_images, labels)
269
+
270
+
271
+ print(f"Your samples are ready and waiting four you here: \n{sample_path} \nEnjoy.")
272
+
273
+
274
+ # python sample_text.py --outdir /home/mridul/sample_images_text --scale 1.0 --n_samples 3 --ddim_steps 200 --ddim_eta 1.0