Kiwinicki commited on
Commit
d288725
·
verified ·
1 Parent(s): 93cad03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -127
app.py CHANGED
@@ -6,130 +6,10 @@ from huggingface_hub import hf_hub_download
6
  import torch
7
  import json
8
  from omegaconf import OmegaConf
9
- from model import Generator
10
-
11
-
12
- class BaseGenerator(ABC, nn.Module):
13
- def __init__(self, channels: int = 3):
14
- super().__init__()
15
- self.channels = channels
16
-
17
- @abstractmethod
18
- def forward(self, x: Tensor) -> Tensor:
19
- pass
20
-
21
-
22
- class Generator(BaseGenerator):
23
- def __init__(self, cfg: DictConfig):
24
- super().__init__(cfg.channels)
25
- self.cfg = cfg
26
- self.model = self._construct_model()
27
-
28
- def _construct_model(self):
29
- initial_layer = nn.Sequential(
30
- nn.Conv2d(
31
- self.cfg.channels,
32
- self.cfg.num_features,
33
- kernel_size=7,
34
- stride=1,
35
- padding=3,
36
- padding_mode="reflect",
37
- ),
38
- nn.ReLU(inplace=True),
39
- )
40
-
41
- down_blocks = nn.Sequential(
42
- ConvBlock(
43
- self.cfg.num_features,
44
- self.cfg.num_features * 2,
45
- kernel_size=3,
46
- stride=2,
47
- padding=1,
48
- ),
49
- ConvBlock(
50
- self.cfg.num_features * 2,
51
- self.cfg.num_features * 4,
52
- kernel_size=3,
53
- stride=2,
54
- padding=1,
55
- ),
56
- )
57
-
58
- residual_blocks = nn.Sequential(
59
- *[
60
- ResidualBlock(self.cfg.num_features * 4)
61
- for _ in range(self.cfg.num_residuals)
62
- ]
63
- )
64
-
65
- up_blocks = nn.Sequential(
66
- ConvBlock(
67
- self.cfg.num_features * 4,
68
- self.cfg.num_features * 2,
69
- down=False,
70
- kernel_size=3,
71
- stride=2,
72
- padding=1,
73
- output_padding=1,
74
- ),
75
- ConvBlock(
76
- self.cfg.num_features * 2,
77
- self.cfg.num_features,
78
- down=False,
79
- kernel_size=3,
80
- stride=2,
81
- padding=1,
82
- output_padding=1,
83
- ),
84
- )
85
-
86
- last_layer = nn.Conv2d(
87
- self.cfg.num_features,
88
- self.cfg.channels,
89
- kernel_size=7,
90
- stride=1,
91
- padding=3,
92
- padding_mode="reflect",
93
- )
94
-
95
- return nn.Sequential(
96
- initial_layer, down_blocks, residual_blocks, up_blocks, last_layer
97
- )
98
-
99
- def forward(self, x: Tensor) -> Tensor:
100
- return tanh(self.model(x))
101
-
102
-
103
- class ConvBlock(nn.Module):
104
- def __init__(
105
- self, in_channels, out_channels, down=True, use_activation=True, **kwargs
106
- ):
107
- super().__init__()
108
- self.conv = nn.Sequential(
109
- nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
110
- if down
111
- else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
112
- nn.InstanceNorm2d(out_channels),
113
- nn.ReLU(inplace=True) if use_activation else nn.Identity(),
114
- )
115
-
116
- def forward(self, x: Tensor) -> Tensor:
117
- return self.conv(x)
118
-
119
-
120
- class ResidualBlock(nn.Module):
121
- def __init__(self, channels: int):
122
- super().__init__()
123
- self.block = nn.Sequential(
124
- ConvBlock(channels, channels, kernel_size=3, padding=1),
125
- ConvBlock(
126
- channels, channels, use_activation=False, kernel_size=3, padding=1
127
- ),
128
- )
129
-
130
- def forward(self, x: Tensor) -> Tensor:
131
- return x + self.block(x)
132
 
 
 
 
133
 
134
  repo_id = "Kiwinicki/sat2map-generator"
135
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
@@ -145,10 +25,23 @@ generator = Generator(cfg)
145
  generator.load_state_dict(torch.load(generator_path))
146
  generator.eval()
147
 
 
 
 
148
 
 
 
 
 
 
149
 
150
- def greet(iamge):
151
- return image
 
 
 
 
 
152
 
153
- iface = gr.Interface(fn=greet, inputs="image", outputs="image")
154
- iface.launch()
 
6
  import torch
7
  import json
8
  from omegaconf import OmegaConf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ import sys
11
+ sys.path.append(os.path.dirname(model_path))
12
+ from model import Generator
13
 
14
  repo_id = "Kiwinicki/sat2map-generator"
15
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
 
25
  generator.load_state_dict(torch.load(generator_path))
26
  generator.eval()
27
 
28
+ from PIL import Image
29
+ import torchvision.transforms as transforms
30
+
31
 
32
+ transform = transforms.Compose([
33
+ transforms.Resize((256, 256)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
36
+ ])
37
 
38
+ def process_image(image):
39
+ image_tensor = transform(image).unsqueeze(0)
40
+ with torch.no_grad():
41
+ output_tensor = generator(image_tensor)
42
+ output_image = output_tensor.squeeze(0)
43
+ output_image = transforms.ToPILImage()(output_image)
44
+ return output_image
45
 
46
+ iface = gr.Interface(fn=process_image, inputs="image", outputs="image", title="Image Generator")
47
+ iface.launch()