miittnnss commited on
Commit
8d701f7
·
1 Parent(s): 7c252de

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +27 -0
pipeline.py CHANGED
@@ -3,6 +3,33 @@ import torch.nn as nn
3
  from torchvision import transforms
4
  from PIL import Image
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class PretrainedPipeline():
7
  def __init__(self):
8
  self.device = torch.device("cpu")
 
3
  from torchvision import transforms
4
  from PIL import Image
5
 
6
+ class Generator(nn.Module):
7
+ def __init__(self):
8
+ super(Generator, self).__init__()
9
+ self.main = nn.Sequential(
10
+ nn.ConvTranspose2d(128, 64 * 8, 4, 1, 0, bias=False),
11
+ nn.BatchNorm2d(64 * 8),
12
+ nn.LeakyReLU(0.2, inplace=True),
13
+
14
+ nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
15
+ nn.BatchNorm2d(64 * 4),
16
+ nn.LeakyReLU(0.2, inplace=True),
17
+
18
+ nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
19
+ nn.BatchNorm2d(64 * 2),
20
+ nn.LeakyReLU(0.2, inplace=True),
21
+
22
+ nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
23
+ nn.BatchNorm2d(64),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+
26
+ nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
27
+ nn.Tanh()
28
+ )
29
+
30
+ def forward(self, input):
31
+ return self.main(input)
32
+
33
  class PretrainedPipeline():
34
  def __init__(self):
35
  self.device = torch.device("cpu")