MariaUDmitrieva commited on
Commit
2ac4dc3
·
verified ·
1 Parent(s): 36ce20c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ import cv2
4
+ import numpy as np
5
+
6
+ # model part
7
+
8
+ import json
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from torchvision import datasets, transforms as tr
14
+ from torchvision.transforms import v2
15
+ from sklearn.preprocessing import minmax_scale
16
+ from collections import OrderedDict
17
+
18
+ st.session_state.image = None
19
+ st.session_state.calls = 0
20
+
21
+ def get_transforms(mean, std):
22
+
23
+ val_transform = tr.Compose([
24
+ tr.ToPILImage(),
25
+ v2.Resize(size=256),
26
+ tr.ToTensor(),
27
+ #...,
28
+ tr.Normalize(mean=mean, std=std)
29
+ ])
30
+
31
+ def de_normalize(img):
32
+ if isinstance(img, torch.Tensor):
33
+ image = img.cpu()
34
+ else:
35
+ image = img
36
+
37
+ return minmax_scale(
38
+ (image.reshape(3, -1) + mean[:, None]) * std[:, None],
39
+ feature_range=(0., 1.),
40
+ axis=1,
41
+ ).reshape(*img.shape).transpose(1, 2, 0)
42
+
43
+ return val_transform, de_normalize
44
+
45
+ class Conv7Stride1(nn.Module):
46
+ def __init__(self, in_channels, out_channels, use_norm=True):
47
+ super(Conv7Stride1, self).__init__()
48
+ if use_norm:
49
+ self.model = nn.Sequential(OrderedDict([
50
+ ('pad', nn.ReflectionPad2d(3)),
51
+ ('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)),
52
+ ('norm', nn.InstanceNorm2d(out_channels)),
53
+ ('relu', nn.ReLU())
54
+ ]))
55
+ else:
56
+ self.model = nn.Sequential(OrderedDict([
57
+ ('pad', nn.ReflectionPad2d(3)),
58
+ ('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)),
59
+ ('tanh', nn.Tanh())
60
+ ]))
61
+ def forward(self, x):
62
+ return self.model(x)
63
+
64
+ class Down(nn.Module):
65
+ def __init__(self, k):
66
+ super(Down, self).__init__()
67
+ self.model = nn.Sequential(OrderedDict([
68
+ ('conv', torch.nn.Conv2d(k//2, k, kernel_size=3, stride=2, padding=1)),
69
+ ('norm', nn.InstanceNorm2d(k)),
70
+ ('relu', nn.ReLU())
71
+ ]))
72
+ def forward(self, x):
73
+ return self.model(x)
74
+
75
+ class ResBlock(nn.Module):
76
+ def __init__(self, k, use_dropout=False):
77
+ super(ResBlock, self).__init__()
78
+ self.blocks = []
79
+ for _ in range(2):
80
+ self.blocks += [nn.Sequential(OrderedDict([
81
+ ('pad', nn.ReflectionPad2d(1)),
82
+ ('conv', torch.nn.Conv2d(k, k, kernel_size=3)),
83
+ ('dropout', nn.BatchNorm2d(k)),
84
+ ('relu', nn.ReLU())
85
+ ]))]
86
+
87
+ if use_dropout:
88
+ self.model = nn.Sequential(OrderedDict([
89
+ ('block1', self.blocks[0]),
90
+ ('dropout', nn.Dropout(0.5)),
91
+ ('block2', self.blocks[1])
92
+ ]))
93
+ else:
94
+ self.model = nn.Sequential(OrderedDict([
95
+ ('block1', self.blocks[0]),
96
+ ('block2', self.blocks[1])
97
+ ]))
98
+
99
+
100
+ def forward(self, x):
101
+ return (x + self.model(x))
102
+
103
+ class Up(nn.Module):
104
+ def __init__(self, k):
105
+ super(Up, self).__init__()
106
+ self.model = nn.Sequential(OrderedDict([
107
+ ('conv_transpose', nn.ConvTranspose2d(2*k, k, kernel_size=3, padding=1, output_padding=1, stride=2)),
108
+ ('norm', nn.InstanceNorm2d(k)),
109
+ ('relu', nn.ReLU())
110
+ ]))
111
+ def forward(self, x):
112
+ return self.model(x)
113
+
114
+ class ResGenerator(nn.Module):
115
+ def __init__(self, res_blocks=9, use_dropout=False):
116
+ super(ResGenerator, self).__init__()
117
+ self.residual_blocks = nn.Sequential(OrderedDict([
118
+ (f'R256_{i+1}', ResBlock(256, use_dropout=use_dropout)) for i in range(res_blocks)
119
+ ]))
120
+ self.model = nn.Sequential(OrderedDict([
121
+ ('c7s1-64', Conv7Stride1(3, 64)),
122
+ ('d128', Down(128)),
123
+ ('d256', Down(256)),
124
+ ('res_blocks', self.residual_blocks),
125
+ ('u128', Up(128)),
126
+ ('u64', Up(64)),
127
+ ('c7s1-3', Conv7Stride1(64, 3, use_norm=False))
128
+ ]))
129
+ def forward(self, x):
130
+ return self.model(x)
131
+
132
+ class ConvForDisc(nn.Module):
133
+ def __init__(self, *channels, stride=2, use_norm=True):
134
+ super(ConvForDisc, self).__init__()
135
+ if len(channels) == 1:
136
+ channels = (channels[0] // 2, channels[0])
137
+ if use_norm:
138
+ self.model = nn.Sequential(OrderedDict([
139
+ ('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)),
140
+ ('norm', nn.InstanceNorm2d(channels[1])),
141
+ ('relu', nn.LeakyReLU(0.2, True))
142
+ ]))
143
+ else:
144
+ self.model = nn.Sequential(OrderedDict([
145
+ ('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)),
146
+ ('relu', nn.LeakyReLU(0.2, True))
147
+ ]))
148
+
149
+ def forward(self, x):
150
+ return self.model(x)
151
+
152
+ class ConvDiscriminator(nn.Module):
153
+ def __init__(self):
154
+ super(ConvDiscriminator, self).__init__()
155
+ self.model = nn.Sequential(OrderedDict([
156
+ ('C64', ConvForDisc(3, 64, use_norm=False)),
157
+ ('C128', ConvForDisc(128)),
158
+ ('C256', ConvForDisc(256)),
159
+ ('C512', ConvForDisc(512, stride=1)),
160
+ ('conv1channel', nn.Conv2d(512, 1, kernel_size=4, padding=1))
161
+ ]))
162
+
163
+ def forward(self, x):
164
+ # predicts logits
165
+ return torch.flatten(self.model(x), start_dim=1)
166
+
167
+ class CycleGAN(nn.Module):
168
+ def __init__(self, res_blocks=9, use_dropout=False):
169
+ super(CycleGAN, self).__init__()
170
+ self.a2b_generator = ResGenerator(res_blocks=9, use_dropout=False)
171
+ self.a_discriminator = ConvDiscriminator()
172
+ self.b2a_generator = ResGenerator(res_blocks=9, use_dropout=False)
173
+ self.b_discriminator = ConvDiscriminator()
174
+
175
+ @st.cache_resource
176
+ def load_model():
177
+ checkpoint = torch.load('cycle_gan#21.pt', weights_only=False,
178
+ map_location=torch.device('cpu'))
179
+ model = CycleGAN()
180
+ model.load_state_dict(checkpoint['model_state_dict'])
181
+ return model
182
+
183
+ mean_night = np.array([0.46207718, 0.52259593, 0.54372674])
184
+
185
+ mean_day = np.array([0.18620284, 0.18614635, 0.20172116])
186
+
187
+ std_night = np.array([0.21945059, 0.20839803, 0.2328357 ])
188
+
189
+ std_day = np.array([0.16982935, 0.14963816, 0.14965146])
190
+
191
+
192
+
193
+ # front part
194
+
195
+ st.markdown("<h1 style='text-align: center;'>Change daytime!</h1>", unsafe_allow_html=True)
196
+
197
+ def add_calls():
198
+ st.session_state.calls += 1
199
+ st.write(f'{st.session_state.calls=}')
200
+
201
+
202
+ def convert_day2night():
203
+ image = st.session_state.image
204
+ col1, col2 = st.columns(2)
205
+ with col1:
206
+ st.write("Left Column")
207
+ st.image(opencv_image, channels="BGR", use_container_width=True)
208
+ with col2:
209
+ st.write("Center Column")
210
+
211
+ model = load_model()
212
+ with torch.no_grad():
213
+ channel_mean = (image / 255.).mean()
214
+ transform, de_norm = get_transforms(mean_day, std_day)
215
+ batch = transform(image)[None, :, :, :]
216
+ batch_tr = model.a2b_generator(batch)
217
+ img_tr = de_norm(batch_tr[0, :, :, :])
218
+ st.write(img_tr.shape)
219
+ st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True)
220
+
221
+ def convert_night2day():
222
+ image = st.session_state.image
223
+ col1, col2 = st.columns(2)
224
+ with col1:
225
+ st.write("Left Column")
226
+ st.image(opencv_image, channels="BGR", use_container_width=True)
227
+ with col2:
228
+ st.write("Center Column")
229
+ model = load_model()
230
+ with torch.no_grad():
231
+ transform, de_norm = get_transforms(mean_night, std_night)
232
+ batch = transform(image)[None, :, :, :]
233
+ batch_tr = model.b2a_generator(batch)
234
+ img_tr = de_norm(batch_tr[0, :, :, :])
235
+ st.write(img_tr.shape)
236
+ st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True)
237
+
238
+ def zero_calls():
239
+ st.session_state.calls = 0
240
+
241
+ st.session_state.option = st.selectbox('day2night OR night2day', ['day2night', 'night2day'])
242
+
243
+ uploaded_file = st.file_uploader("Choose a image file", type="jpg")
244
+
245
+ if uploaded_file is not None:
246
+ # Convert the file to an opencv image.
247
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
248
+ opencv_image = cv2.imdecode(file_bytes, 1)
249
+
250
+ st.session_state.image = np.asarray(opencv_image)
251
+
252
+ image = st.session_state.image
253
+ col1, col2 = st.columns(2)
254
+ with col1:
255
+ st.write("Original")
256
+ st.image(opencv_image, channels="BGR", use_container_width=True)
257
+ with col2:
258
+ st.write("Transformed")
259
+
260
+ model = load_model()
261
+ with torch.no_grad():
262
+ if st.session_state.option == 'day2night':
263
+ channel_mean = (image / 255.).mean()
264
+ transform, de_norm = get_transforms(mean_day, std_day)
265
+ batch = transform(image)[None, :, :, :]
266
+ batch_tr = model.a2b_generator(batch)
267
+ img_tr = de_norm(batch_tr[0, :, :, :])
268
+ st.image(img_tr, channels="BGR", use_container_width=True, clamp=True)
269
+ else:
270
+ transform, de_norm = get_transforms(mean_night, std_night)
271
+ batch = transform(image)[None, :, :, :]
272
+ batch_tr = model.b2a_generator(batch)
273
+ img_tr = de_norm(batch_tr[0, :, :, :])
274
+ st.image(img_tr, channels="BGR", use_container_width=True, clamp=True)