zestyoreo commited on
Commit
2e19df8
·
1 Parent(s): 3c97505
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # model_init
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (152 Bytes). View file
 
models/__pycache__/afwm.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
models/__pycache__/afwm.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
models/__pycache__/networks.cpython-310.pyc ADDED
Binary file (6.31 kB). View file
 
models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (6.36 kB). View file
 
models/afwm.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from math import sqrt
6
+
7
+ def apply_offset(offset):
8
+ sizes = list(offset.size()[2:])
9
+ grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes])
10
+ grid_list = reversed(grid_list)
11
+ # apply offset
12
+ grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...]
13
+ for dim, grid in enumerate(grid_list)]
14
+ # normalize
15
+ grid_list = [grid / ((size - 1.0) / 2.0) - 1.0
16
+ for grid, size in zip(grid_list, reversed(sizes))]
17
+
18
+ return torch.stack(grid_list, dim=-1)
19
+
20
+
21
+ def TVLoss(x):
22
+ tv_h = x[:, :, 1:, :] - x[:, :, :-1, :]
23
+ tv_w = x[:, :, :, 1:] - x[:, :, :, :-1]
24
+
25
+ return torch.mean(torch.abs(tv_h)) + torch.mean(torch.abs(tv_w))
26
+
27
+
28
+ # backbone
29
+ class EqualLR:
30
+ def __init__(self, name):
31
+ self.name = name
32
+
33
+ def compute_weight(self, module):
34
+ weight = getattr(module, self.name + '_orig')
35
+ fan_in = weight.data.size(1) * weight.data[0][0].numel()
36
+
37
+ return weight * sqrt(2 / fan_in)
38
+
39
+ @staticmethod
40
+ def apply(module, name):
41
+ fn = EqualLR(name)
42
+
43
+ weight = getattr(module, name)
44
+ del module._parameters[name]
45
+ module.register_parameter(name + '_orig', nn.Parameter(weight.data))
46
+ module.register_forward_pre_hook(fn)
47
+
48
+ return fn
49
+
50
+ def __call__(self, module, input):
51
+ weight = self.compute_weight(module)
52
+ setattr(module, self.name, weight)
53
+
54
+
55
+ def equal_lr(module, name='weight'):
56
+ EqualLR.apply(module, name)
57
+
58
+ return module
59
+
60
+ class EqualLinear(nn.Module):
61
+ def __init__(self, in_dim, out_dim):
62
+ super().__init__()
63
+
64
+ linear = nn.Linear(in_dim, out_dim)
65
+ linear.weight.data.normal_()
66
+ linear.bias.data.zero_()
67
+
68
+ self.linear = equal_lr(linear)
69
+
70
+ def forward(self, input):
71
+ return self.linear(input)
72
+
73
+ class ModulatedConv2d(nn.Module):
74
+ def __init__(self, fin, fout, kernel_size, padding_type='zero', upsample=False, downsample=False, latent_dim=512, normalize_mlp=False):
75
+ super(ModulatedConv2d, self).__init__()
76
+ self.in_channels = fin
77
+ self.out_channels = fout
78
+ self.kernel_size = kernel_size
79
+ padding_size = kernel_size // 2
80
+
81
+ if kernel_size == 1:
82
+ self.demudulate = False
83
+ else:
84
+ self.demudulate = True
85
+
86
+ self.weight = nn.Parameter(torch.Tensor(fout, fin, kernel_size, kernel_size))
87
+ self.bias = nn.Parameter(torch.Tensor(1, fout, 1, 1))
88
+ #self.conv = F.conv2d
89
+
90
+ if normalize_mlp:
91
+ self.mlp_class_std = nn.Sequential(EqualLinear(latent_dim, fin), PixelNorm())
92
+ else:
93
+ self.mlp_class_std = EqualLinear(latent_dim, fin)
94
+
95
+ #self.blur = Blur(fout)
96
+
97
+ if padding_type == 'reflect':
98
+ self.padding = nn.ReflectionPad2d(padding_size)
99
+ else:
100
+ self.padding = nn.ZeroPad2d(padding_size)
101
+
102
+
103
+ self.weight.data.normal_()
104
+ self.bias.data.zero_()
105
+
106
+ def forward(self, input, latent):
107
+ fan_in = self.weight.data.size(1) * self.weight.data[0][0].numel()
108
+ weight = self.weight * sqrt(2 / fan_in)
109
+ weight = weight.view(1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
110
+
111
+ s = self.mlp_class_std(latent).view(-1, 1, self.in_channels, 1, 1)
112
+ weight = s * weight
113
+ if self.demudulate:
114
+ d = torch.rsqrt((weight ** 2).sum(4).sum(3).sum(2) + 1e-5).view(-1, self.out_channels, 1, 1, 1)
115
+ weight = (d * weight).view(-1, self.in_channels, self.kernel_size, self.kernel_size)
116
+ else:
117
+ weight = weight.view(-1, self.in_channels, self.kernel_size, self.kernel_size)
118
+
119
+
120
+
121
+ batch,_,height,width = input.shape
122
+ #input = input.view(1,-1,h,w)
123
+ #input = self.padding(input)
124
+ #out = self.conv(input, weight, groups=b).view(b, self.out_channels, h, w) + self.bias
125
+
126
+
127
+
128
+ input = input.view(1,-1,height,width)
129
+ input = self.padding(input)
130
+ out = F.conv2d(input, weight, groups=batch).view(batch, self.out_channels, height, width) + self.bias
131
+
132
+ return out
133
+
134
+
135
+ class StyledConvBlock(nn.Module):
136
+ def __init__(self, fin, fout, latent_dim=256, padding='zero',
137
+ actvn='lrelu', normalize_affine_output=False, modulated_conv=False):
138
+ super(StyledConvBlock, self).__init__()
139
+ if not modulated_conv:
140
+ if padding == 'reflect':
141
+ padding_layer = nn.ReflectionPad2d
142
+ else:
143
+ padding_layer = nn.ZeroPad2d
144
+
145
+ if modulated_conv:
146
+ conv2d = ModulatedConv2d
147
+ else:
148
+ conv2d = EqualConv2d
149
+
150
+ if modulated_conv:
151
+ self.actvn_gain = sqrt(2)
152
+ else:
153
+ self.actvn_gain = 1.0
154
+
155
+
156
+ self.modulated_conv = modulated_conv
157
+
158
+ if actvn == 'relu':
159
+ activation = nn.ReLU(True)
160
+ else:
161
+ activation = nn.LeakyReLU(0.2,True)
162
+
163
+
164
+ if self.modulated_conv:
165
+ self.conv0 = conv2d(fin, fout, kernel_size=3, padding_type=padding, upsample=False,
166
+ latent_dim=latent_dim, normalize_mlp=normalize_affine_output)
167
+ else:
168
+ conv0 = conv2d(fin, fout, kernel_size=3)
169
+
170
+ seq0 = [padding_layer(1), conv0]
171
+ self.conv0 = nn.Sequential(*seq0)
172
+
173
+ self.actvn0 = activation
174
+
175
+ if self.modulated_conv:
176
+ self.conv1 = conv2d(fout, fout, kernel_size=3, padding_type=padding, downsample=False,
177
+ latent_dim=latent_dim, normalize_mlp=normalize_affine_output)
178
+ else:
179
+ conv1 = conv2d(fout, fout, kernel_size=3)
180
+ seq1 = [padding_layer(1), conv1]
181
+ self.conv1 = nn.Sequential(*seq1)
182
+
183
+ self.actvn1 = activation
184
+
185
+ def forward(self, input, latent=None):
186
+ if self.modulated_conv:
187
+ out = self.conv0(input,latent)
188
+ else:
189
+ out = self.conv0(input)
190
+
191
+ out = self.actvn0(out) * self.actvn_gain
192
+
193
+ if self.modulated_conv:
194
+ out = self.conv1(out,latent)
195
+ else:
196
+ out = self.conv1(out)
197
+
198
+ out = self.actvn1(out) * self.actvn_gain
199
+
200
+ return out
201
+
202
+
203
+ class Styled_F_ConvBlock(nn.Module):
204
+ def __init__(self, fin, fout, latent_dim=256, padding='zero',
205
+ actvn='lrelu', normalize_affine_output=False, modulated_conv=False):
206
+ super(Styled_F_ConvBlock, self).__init__()
207
+ if not modulated_conv:
208
+ if padding == 'reflect':
209
+ padding_layer = nn.ReflectionPad2d
210
+ else:
211
+ padding_layer = nn.ZeroPad2d
212
+
213
+ if modulated_conv:
214
+ conv2d = ModulatedConv2d
215
+ else:
216
+ conv2d = EqualConv2d
217
+
218
+ if modulated_conv:
219
+ self.actvn_gain = sqrt(2)
220
+ else:
221
+ self.actvn_gain = 1.0
222
+
223
+
224
+ self.modulated_conv = modulated_conv
225
+
226
+ if actvn == 'relu':
227
+ activation = nn.ReLU(True)
228
+ else:
229
+ activation = nn.LeakyReLU(0.2,True)
230
+
231
+
232
+ if self.modulated_conv:
233
+ self.conv0 = conv2d(fin, 128, kernel_size=3, padding_type=padding, upsample=False,
234
+ latent_dim=latent_dim, normalize_mlp=normalize_affine_output)
235
+ else:
236
+ conv0 = conv2d(fin, 128, kernel_size=3)
237
+
238
+ seq0 = [padding_layer(1), conv0]
239
+ self.conv0 = nn.Sequential(*seq0)
240
+
241
+ self.actvn0 = activation
242
+
243
+ if self.modulated_conv:
244
+ self.conv1 = conv2d(128, fout, kernel_size=3, padding_type=padding, downsample=False,
245
+ latent_dim=latent_dim, normalize_mlp=normalize_affine_output)
246
+ else:
247
+ conv1 = conv2d(128, fout, kernel_size=3)
248
+ seq1 = [padding_layer(1), conv1]
249
+ self.conv1 = nn.Sequential(*seq1)
250
+
251
+ #self.actvn1 = activation
252
+
253
+ def forward(self, input, latent=None):
254
+ if self.modulated_conv:
255
+ out = self.conv0(input,latent)
256
+ else:
257
+ out = self.conv0(input)
258
+
259
+ out = self.actvn0(out) * self.actvn_gain
260
+
261
+ if self.modulated_conv:
262
+ out = self.conv1(out,latent)
263
+ else:
264
+ out = self.conv1(out)
265
+
266
+ #out = self.actvn1(out) * self.actvn_gain
267
+
268
+ return out
269
+
270
+
271
+ class ResBlock(nn.Module):
272
+ def __init__(self, in_channels):
273
+ super(ResBlock, self).__init__()
274
+ self.block = nn.Sequential(
275
+ nn.BatchNorm2d(in_channels),
276
+ nn.ReLU(inplace=True),
277
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
278
+ nn.BatchNorm2d(in_channels),
279
+ nn.ReLU(inplace=True),
280
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
281
+ )
282
+
283
+ def forward(self, x):
284
+ return self.block(x) + x
285
+
286
+
287
+ class DownSample(nn.Module):
288
+ def __init__(self, in_channels, out_channels):
289
+ super(DownSample, self).__init__()
290
+ self.block= nn.Sequential(
291
+ nn.BatchNorm2d(in_channels),
292
+ nn.ReLU(inplace=True),
293
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
294
+ )
295
+
296
+ def forward(self, x):
297
+ return self.block(x)
298
+
299
+
300
+
301
+ class FeatureEncoder(nn.Module):
302
+ def __init__(self, in_channels, chns=[64,128,256,256,256]):
303
+ # in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnositc representation
304
+ super(FeatureEncoder, self).__init__()
305
+ self.encoders = []
306
+ for i, out_chns in enumerate(chns):
307
+ if i == 0:
308
+ encoder = nn.Sequential(DownSample(in_channels, out_chns),
309
+ ResBlock(out_chns),
310
+ ResBlock(out_chns))
311
+ else:
312
+ encoder = nn.Sequential(DownSample(chns[i-1], out_chns),
313
+ ResBlock(out_chns),
314
+ ResBlock(out_chns))
315
+
316
+ self.encoders.append(encoder)
317
+
318
+ self.encoders = nn.ModuleList(self.encoders)
319
+
320
+
321
+ def forward(self, x):
322
+ encoder_features = []
323
+ for encoder in self.encoders:
324
+ x = encoder(x)
325
+ encoder_features.append(x)
326
+ return encoder_features
327
+
328
+ class RefinePyramid(nn.Module):
329
+ def __init__(self, chns=[64,128,256,256,256], fpn_dim=256):
330
+ super(RefinePyramid, self).__init__()
331
+ self.chns = chns
332
+
333
+ # adaptive
334
+ self.adaptive = []
335
+ for in_chns in list(reversed(chns)):
336
+ adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
337
+ self.adaptive.append(adaptive_layer)
338
+ self.adaptive = nn.ModuleList(self.adaptive)
339
+ # output conv
340
+ self.smooth = []
341
+ for i in range(len(chns)):
342
+ smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1)
343
+ self.smooth.append(smooth_layer)
344
+ self.smooth = nn.ModuleList(self.smooth)
345
+
346
+ def forward(self, x):
347
+ conv_ftr_list = x
348
+
349
+ feature_list = []
350
+ last_feature = None
351
+ for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
352
+ # adaptive
353
+ feature = self.adaptive[i](conv_ftr)
354
+ # fuse
355
+ if last_feature is not None:
356
+ feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest')
357
+ # smooth
358
+ feature = self.smooth[i](feature)
359
+ last_feature = feature
360
+ feature_list.append(feature)
361
+
362
+ return tuple(reversed(feature_list))
363
+
364
+
365
+ class AFlowNet(nn.Module):
366
+ def __init__(self, num_pyramid, fpn_dim=256):
367
+ super(AFlowNet, self).__init__()
368
+
369
+ padding_type='zero'
370
+ actvn = 'lrelu'
371
+ normalize_mlp = False
372
+ modulated_conv = True
373
+
374
+
375
+ self.netRefine = []
376
+
377
+ self.netStyle = []
378
+
379
+ self.netF = []
380
+
381
+ for i in range(num_pyramid):
382
+
383
+ netRefine_layer = torch.nn.Sequential(
384
+ torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1),
385
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
386
+ torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
387
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
388
+ torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
389
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
390
+ torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
391
+ )
392
+
393
+ style_block = StyledConvBlock(256, 49, latent_dim=256,
394
+ padding=padding_type, actvn=actvn,
395
+ normalize_affine_output=normalize_mlp,
396
+ modulated_conv=modulated_conv)
397
+
398
+ style_F_block = Styled_F_ConvBlock(49, 2, latent_dim=256,
399
+ padding=padding_type, actvn=actvn,
400
+ normalize_affine_output=normalize_mlp,
401
+ modulated_conv=modulated_conv)
402
+
403
+
404
+ self.netRefine.append(netRefine_layer)
405
+ self.netStyle.append(style_block)
406
+ self.netF.append(style_F_block)
407
+
408
+
409
+ self.netRefine = nn.ModuleList(self.netRefine)
410
+ self.netStyle = nn.ModuleList(self.netStyle)
411
+ self.netF = nn.ModuleList(self.netF)
412
+
413
+ self.cond_style = torch.nn.Sequential(torch.nn.Conv2d(256, 128, kernel_size=(8,6), stride=1, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1))
414
+
415
+ self.image_style = torch.nn.Sequential(torch.nn.Conv2d(256, 128, kernel_size=(8,6), stride=1, padding=0), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1))
416
+
417
+
418
+ def forward(self, x, x_warps, x_conds, warp_feature=True):
419
+ last_flow = None
420
+
421
+ B = x_conds[len(x_warps)-1].shape[0]
422
+
423
+ cond_style = self.cond_style(x_conds[len(x_warps) - 1]).view(B,-1)
424
+ image_style = self.image_style(x_warps[len(x_warps) - 1]).view(B,-1)
425
+ style = torch.cat([cond_style, image_style], 1)
426
+
427
+ for i in range(len(x_warps)):
428
+ x_warp = x_warps[len(x_warps) - 1 - i]
429
+ x_cond = x_conds[len(x_warps) - 1 - i]
430
+
431
+ if last_flow is not None and warp_feature:
432
+ x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1),
433
+ mode='bilinear', padding_mode='border')
434
+ else:
435
+ x_warp_after = x_warp
436
+
437
+
438
+ stylemap = self.netStyle[i](x_warp_after, style)
439
+
440
+ flow = self.netF[i](stylemap, style)
441
+ flow = apply_offset(flow)
442
+ if last_flow is not None:
443
+ flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
444
+ else:
445
+ flow = flow.permute(0, 3, 1, 2)
446
+
447
+ last_flow = flow
448
+ x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border')
449
+ concat = torch.cat([x_warp,x_cond],1)
450
+ flow = self.netRefine[i](concat)
451
+ flow = apply_offset(flow)
452
+ flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
453
+
454
+ last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear')
455
+
456
+
457
+ x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1),
458
+ mode='bilinear', padding_mode='border')
459
+ return x_warp, last_flow
460
+
461
+
462
+ class AFWM(nn.Module):
463
+
464
+ def __init__(self, opt, input_nc):
465
+ super(AFWM, self).__init__()
466
+ num_filters = [64,128,256,256,256]
467
+ self.image_features = FeatureEncoder(3, num_filters)
468
+ self.cond_features = FeatureEncoder(input_nc, num_filters)
469
+ self.image_FPN = RefinePyramid(num_filters)
470
+ self.cond_FPN = RefinePyramid(num_filters)
471
+ self.aflow_net = AFlowNet(len(num_filters))
472
+
473
+
474
+ def forward(self, cond_input, image_input):
475
+
476
+ #import ipdb; ipdb.set_trace()
477
+ cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential
478
+ image_pyramids = self.image_FPN(self.image_features(image_input))
479
+
480
+ x_warp, last_flow = self.aflow_net(image_input, image_pyramids, cond_pyramids)
481
+
482
+ return x_warp, last_flow
483
+
484
+
485
+ def update_learning_rate(self,optimizer):
486
+ lrd = opt.lr / opt.niter_decay
487
+ lr = self.old_lr - lrd
488
+ for param_group in optimizer.param_groups:
489
+ param_group['lr'] = lr
490
+ if opt.verbose:
491
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
492
+ self.old_lr = lr
493
+
494
+ def update_learning_rate_warp(self,optimizer):
495
+ lrd = 0.2 * opt.lr / opt.niter_decay
496
+ lr = self.old_lr_warp - lrd
497
+ for param_group in optimizer.param_groups:
498
+ param_group['lr'] = lr
499
+ if opt.verbose:
500
+ print('update learning rate: %f -> %f' % (self.old_lr_warp, lr))
501
+ self.old_lr_warp = lr
502
+
models/networks.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.parallel
4
+ #from torchvision import models
5
+ #from options.train_options import TrainOptions
6
+ import os
7
+
8
+ #opt = TrainOptions().parse()
9
+
10
+ class ResidualBlock(nn.Module):
11
+ def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d):
12
+ super(ResidualBlock, self).__init__()
13
+ self.relu = nn.ReLU(True)
14
+ if norm_layer == None:
15
+ self.block = nn.Sequential(
16
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
19
+ )
20
+ else:
21
+ self.block = nn.Sequential(
22
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
23
+ norm_layer(in_features),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
26
+ norm_layer(in_features)
27
+ )
28
+
29
+ def forward(self, x):
30
+ residual = x
31
+ out = self.block(x)
32
+ out += residual
33
+ out = self.relu(out)
34
+ return out
35
+
36
+
37
+ class ResUnetGenerator(nn.Module):
38
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
39
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
40
+ super(ResUnetGenerator, self).__init__()
41
+ # construct unet structure
42
+ unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
43
+
44
+ for i in range(num_downs - 5):
45
+ unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
46
+ unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
47
+ unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
48
+ unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
49
+ unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
50
+
51
+ self.model = unet_block
52
+
53
+ def forward(self, input):
54
+ return self.model(input)
55
+
56
+
57
+ # Defines the submodule with skip connection.
58
+ # X -------------------identity---------------------- X
59
+ # |-- downsampling -- |submodule| -- upsampling --|
60
+ class ResUnetSkipConnectionBlock(nn.Module):
61
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
62
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
63
+ super(ResUnetSkipConnectionBlock, self).__init__()
64
+ self.outermost = outermost
65
+ use_bias = norm_layer == nn.InstanceNorm2d
66
+
67
+ if input_nc is None:
68
+ input_nc = outer_nc
69
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
70
+ stride=2, padding=1, bias=use_bias)
71
+ # add two resblock
72
+ res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)]
73
+ res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)]
74
+
75
+ downrelu = nn.ReLU(True)
76
+ uprelu = nn.ReLU(True)
77
+ if norm_layer != None:
78
+ downnorm = norm_layer(inner_nc)
79
+ upnorm = norm_layer(outer_nc)
80
+
81
+ if outermost:
82
+ upsample = nn.Upsample(scale_factor=2, mode='nearest')
83
+ upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
84
+ down = [downconv, downrelu] + res_downconv
85
+ up = [upsample, upconv]
86
+ model = down + [submodule] + up
87
+ elif innermost:
88
+ upsample = nn.Upsample(scale_factor=2, mode='nearest')
89
+ upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
90
+ down = [downconv, downrelu] + res_downconv
91
+ if norm_layer == None:
92
+ up = [upsample, upconv, uprelu] + res_upconv
93
+ else:
94
+ up = [upsample, upconv, upnorm, uprelu] + res_upconv
95
+ model = down + up
96
+ else:
97
+ upsample = nn.Upsample(scale_factor=2, mode='nearest')
98
+ upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
99
+ if norm_layer == None:
100
+ down = [downconv, downrelu] + res_downconv
101
+ up = [upsample, upconv, uprelu] + res_upconv
102
+ else:
103
+ down = [downconv, downnorm, downrelu] + res_downconv
104
+ up = [upsample, upconv, upnorm, uprelu] + res_upconv
105
+
106
+ if use_dropout:
107
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
108
+ else:
109
+ model = down + [submodule] + up
110
+
111
+ self.model = nn.Sequential(*model)
112
+
113
+ def forward(self, x):
114
+ if self.outermost:
115
+ return self.model(x)
116
+ else:
117
+ return torch.cat([x, self.model(x)], 1)
118
+
119
+
120
+ class Vgg19(nn.Module):
121
+ def __init__(self, requires_grad=False):
122
+ super(Vgg19, self).__init__()
123
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
124
+ self.slice1 = nn.Sequential()
125
+ self.slice2 = nn.Sequential()
126
+ self.slice3 = nn.Sequential()
127
+ self.slice4 = nn.Sequential()
128
+ self.slice5 = nn.Sequential()
129
+ for x in range(2):
130
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
131
+ for x in range(2, 7):
132
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
133
+ for x in range(7, 12):
134
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
135
+ for x in range(12, 21):
136
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
137
+ for x in range(21, 30):
138
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
139
+ if not requires_grad:
140
+ for param in self.parameters():
141
+ param.requires_grad = False
142
+
143
+ def forward(self, X):
144
+ h_relu1 = self.slice1(X)
145
+ h_relu2 = self.slice2(h_relu1)
146
+ h_relu3 = self.slice3(h_relu2)
147
+ h_relu4 = self.slice4(h_relu3)
148
+ h_relu5 = self.slice5(h_relu4)
149
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
150
+ return out
151
+
152
+ class VGGLoss(nn.Module):
153
+ def __init__(self, layids = None):
154
+ super(VGGLoss, self).__init__()
155
+ self.vgg = Vgg19()
156
+ self.vgg.cuda()
157
+ self.criterion = nn.L1Loss()
158
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
159
+ self.layids = layids
160
+
161
+ def forward(self, x, y):
162
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
163
+ loss = 0
164
+ if self.layids is None:
165
+ self.layids = list(range(len(x_vgg)))
166
+ for i in self.layids:
167
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
168
+ return loss
169
+
170
+ def save_checkpoint(model, save_path):
171
+ if not os.path.exists(os.path.dirname(save_path)):
172
+ os.makedirs(os.path.dirname(save_path))
173
+ torch.save(model.state_dict(), save_path)
174
+
175
+
176
+ def load_checkpoint_parallel(model, checkpoint_path):
177
+
178
+ if not os.path.exists(checkpoint_path):
179
+ print('No checkpoint!')
180
+ return
181
+
182
+ checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(opt.local_rank))
183
+ checkpoint_new = model.state_dict()
184
+ for param in checkpoint_new:
185
+ checkpoint_new[param] = checkpoint[param]
186
+ model.load_state_dict(checkpoint_new)
187
+
188
+ def load_checkpoint_part_parallel(model, checkpoint_path):
189
+
190
+ if not os.path.exists(checkpoint_path):
191
+ print('No checkpoint!')
192
+ return
193
+ checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(opt.local_rank))
194
+ checkpoint_new = model.state_dict()
195
+ for param in checkpoint_new:
196
+ if 'cond_' not in param and 'aflow_net.netRefine' not in param or 'aflow_net.cond_style' in param:
197
+ checkpoint_new[param] = checkpoint[param]
198
+ model.load_state_dict(checkpoint_new)
199
+
200
+ def load_checkpoint(model, checkpoint_path):
201
+
202
+ if not os.path.exists(checkpoint_path):
203
+ print('No checkpoint!')
204
+ return
205
+
206
+ checkpoint = torch.load(checkpoint_path)
207
+ checkpoint_new = model.state_dict()
208
+ for param in checkpoint_new:
209
+ checkpoint_new[param] = checkpoint[param]
210
+
211
+ model.load_state_dict(checkpoint_new)
212
+
213
+
options/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # options_init
options/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
options/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (138 Bytes). View file
 
options/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (153 Bytes). View file
 
options/__pycache__/base_options.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
options/__pycache__/base_options.cpython-36.pyc ADDED
Binary file (3.06 kB). View file
 
options/__pycache__/base_options.cpython-38.pyc ADDED
Binary file (3.08 kB). View file
 
options/__pycache__/test_options.cpython-310.pyc ADDED
Binary file (970 Bytes). View file
 
options/__pycache__/test_options.cpython-36.pyc ADDED
Binary file (943 Bytes). View file
 
options/__pycache__/test_options.cpython-38.pyc ADDED
Binary file (966 Bytes). View file
 
options/base_options.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ class BaseOptions():
5
+ def __init__(self):
6
+ self.parser = argparse.ArgumentParser()
7
+ self.initialized = False
8
+
9
+ def initialize(self):
10
+ self.parser.add_argument('--name', type=str, default='demo', help='name of the experiment. It decides where to store samples and models')
11
+ self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
12
+ self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
13
+ self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
14
+ self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
15
+ self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
16
+
17
+ self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
18
+ self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
19
+ self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
20
+ self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
21
+ self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
22
+
23
+ self.parser.add_argument('--dataroot', type=str,
24
+ default='/home/sh0089/sen/fashion/')
25
+ self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
26
+ self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
27
+ self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
28
+ self.parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data')
29
+ self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
30
+
31
+ self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
32
+ self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
33
+
34
+ self.initialized = True
35
+
36
+ def parse(self, save=True):
37
+ if not self.initialized:
38
+ self.initialize()
39
+ self.opt = self.parser.parse_args()
40
+ self.opt.isTrain = self.isTrain # train or test
41
+
42
+ str_ids = self.opt.gpu_ids.split(',')
43
+ self.opt.gpu_ids = []
44
+ for str_id in str_ids:
45
+ id = int(str_id)
46
+ if id >= 0:
47
+ self.opt.gpu_ids.append(id)
48
+
49
+ if len(self.opt.gpu_ids) > 0:
50
+ torch.cuda.set_device(self.opt.gpu_ids[0])
51
+
52
+ args = vars(self.opt)
53
+
54
+ print('------------ Options -------------')
55
+ for k, v in sorted(args.items()):
56
+ print('%s: %s' % (str(k), str(v)))
57
+ print('-------------- End ----------------')
58
+
59
+ return self.opt
options/test_options.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+ class TestOptions(BaseOptions):
4
+ def initialize(self):
5
+ BaseOptions.initialize(self)
6
+
7
+ self.parser.add_argument('--warp_checkpoint', type=str, default='/home/sh0089/sen/PF-AFN/PF-AFN_train/checkpoints_ours_fc/PFAFN_e2e_ours/PFAFN_warp_epoch_101.pth', help='load the pretrained model from the specified location')
8
+ self.parser.add_argument('--gen_checkpoint', type=str, default='/home/sh0089/sen/PF-AFN/PF-AFN_train/checkpoints_ours_fc/PFAFN_e2e_ours/PFAFN_gen_epoch_101.pth', help='load the pretrained model from the specified location')
9
+ self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
10
+
11
+ self.isTrain = False
our_t_results/000001_0.jpg ADDED