Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
import cv2 | |
import numpy as np | |
# model part | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import datasets, transforms as tr | |
from torchvision.transforms import v2 | |
from sklearn.preprocessing import minmax_scale | |
from collections import OrderedDict | |
st.session_state.image = None | |
st.session_state.calls = 0 | |
def get_transforms(mean, std): | |
val_transform = tr.Compose([ | |
tr.ToPILImage(), | |
v2.Resize(size=256), | |
tr.ToTensor(), | |
#..., | |
tr.Normalize(mean=mean, std=std) | |
]) | |
def de_normalize(img): | |
if isinstance(img, torch.Tensor): | |
image = img.cpu() | |
else: | |
image = img | |
return minmax_scale( | |
(image.reshape(3, -1) + mean[:, None]) * std[:, None], | |
feature_range=(0., 1.), | |
axis=1, | |
).reshape(*img.shape).transpose(1, 2, 0) | |
return val_transform, de_normalize | |
class Conv7Stride1(nn.Module): | |
def __init__(self, in_channels, out_channels, use_norm=True): | |
super(Conv7Stride1, self).__init__() | |
if use_norm: | |
self.model = nn.Sequential(OrderedDict([ | |
('pad', nn.ReflectionPad2d(3)), | |
('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)), | |
('norm', nn.InstanceNorm2d(out_channels)), | |
('relu', nn.ReLU()) | |
])) | |
else: | |
self.model = nn.Sequential(OrderedDict([ | |
('pad', nn.ReflectionPad2d(3)), | |
('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)), | |
('tanh', nn.Tanh()) | |
])) | |
def forward(self, x): | |
return self.model(x) | |
class Down(nn.Module): | |
def __init__(self, k): | |
super(Down, self).__init__() | |
self.model = nn.Sequential(OrderedDict([ | |
('conv', torch.nn.Conv2d(k//2, k, kernel_size=3, stride=2, padding=1)), | |
('norm', nn.InstanceNorm2d(k)), | |
('relu', nn.ReLU()) | |
])) | |
def forward(self, x): | |
return self.model(x) | |
class ResBlock(nn.Module): | |
def __init__(self, k, use_dropout=False): | |
super(ResBlock, self).__init__() | |
self.blocks = [] | |
for _ in range(2): | |
self.blocks += [nn.Sequential(OrderedDict([ | |
('pad', nn.ReflectionPad2d(1)), | |
('conv', torch.nn.Conv2d(k, k, kernel_size=3)), | |
('dropout', nn.BatchNorm2d(k)), | |
('relu', nn.ReLU()) | |
]))] | |
if use_dropout: | |
self.model = nn.Sequential(OrderedDict([ | |
('block1', self.blocks[0]), | |
('dropout', nn.Dropout(0.5)), | |
('block2', self.blocks[1]) | |
])) | |
else: | |
self.model = nn.Sequential(OrderedDict([ | |
('block1', self.blocks[0]), | |
('block2', self.blocks[1]) | |
])) | |
def forward(self, x): | |
return (x + self.model(x)) | |
class Up(nn.Module): | |
def __init__(self, k): | |
super(Up, self).__init__() | |
self.model = nn.Sequential(OrderedDict([ | |
('conv_transpose', nn.ConvTranspose2d(2*k, k, kernel_size=3, padding=1, output_padding=1, stride=2)), | |
('norm', nn.InstanceNorm2d(k)), | |
('relu', nn.ReLU()) | |
])) | |
def forward(self, x): | |
return self.model(x) | |
class ResGenerator(nn.Module): | |
def __init__(self, res_blocks=9, use_dropout=False): | |
super(ResGenerator, self).__init__() | |
self.residual_blocks = nn.Sequential(OrderedDict([ | |
(f'R256_{i+1}', ResBlock(256, use_dropout=use_dropout)) for i in range(res_blocks) | |
])) | |
self.model = nn.Sequential(OrderedDict([ | |
('c7s1-64', Conv7Stride1(3, 64)), | |
('d128', Down(128)), | |
('d256', Down(256)), | |
('res_blocks', self.residual_blocks), | |
('u128', Up(128)), | |
('u64', Up(64)), | |
('c7s1-3', Conv7Stride1(64, 3, use_norm=False)) | |
])) | |
def forward(self, x): | |
return self.model(x) | |
class ConvForDisc(nn.Module): | |
def __init__(self, *channels, stride=2, use_norm=True): | |
super(ConvForDisc, self).__init__() | |
if len(channels) == 1: | |
channels = (channels[0] // 2, channels[0]) | |
if use_norm: | |
self.model = nn.Sequential(OrderedDict([ | |
('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)), | |
('norm', nn.InstanceNorm2d(channels[1])), | |
('relu', nn.LeakyReLU(0.2, True)) | |
])) | |
else: | |
self.model = nn.Sequential(OrderedDict([ | |
('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)), | |
('relu', nn.LeakyReLU(0.2, True)) | |
])) | |
def forward(self, x): | |
return self.model(x) | |
class ConvDiscriminator(nn.Module): | |
def __init__(self): | |
super(ConvDiscriminator, self).__init__() | |
self.model = nn.Sequential(OrderedDict([ | |
('C64', ConvForDisc(3, 64, use_norm=False)), | |
('C128', ConvForDisc(128)), | |
('C256', ConvForDisc(256)), | |
('C512', ConvForDisc(512, stride=1)), | |
('conv1channel', nn.Conv2d(512, 1, kernel_size=4, padding=1)) | |
])) | |
def forward(self, x): | |
# predicts logits | |
return torch.flatten(self.model(x), start_dim=1) | |
class CycleGAN(nn.Module): | |
def __init__(self, res_blocks=9, use_dropout=False): | |
super(CycleGAN, self).__init__() | |
self.a2b_generator = ResGenerator(res_blocks=9, use_dropout=False) | |
self.a_discriminator = ConvDiscriminator() | |
self.b2a_generator = ResGenerator(res_blocks=9, use_dropout=False) | |
self.b_discriminator = ConvDiscriminator() | |
def load_model(): | |
checkpoint = torch.load('cycle_gan#21.pt', weights_only=False, | |
map_location=torch.device('cpu')) | |
model = CycleGAN() | |
model.load_state_dict(checkpoint['model_state_dict']) | |
return model | |
mean_night = np.array([0.46207718, 0.52259593, 0.54372674]) | |
mean_day = np.array([0.18620284, 0.18614635, 0.20172116]) | |
std_night = np.array([0.21945059, 0.20839803, 0.2328357 ]) | |
std_day = np.array([0.16982935, 0.14963816, 0.14965146]) | |
# front part | |
st.markdown("<h1 style='text-align: center;'>Change daytime!</h1>", unsafe_allow_html=True) | |
def add_calls(): | |
st.session_state.calls += 1 | |
st.write(f'{st.session_state.calls=}') | |
def convert_day2night(): | |
image = st.session_state.image | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("Left Column") | |
st.image(opencv_image, channels="BGR", use_container_width=True) | |
with col2: | |
st.write("Center Column") | |
model = load_model() | |
with torch.no_grad(): | |
channel_mean = (image / 255.).mean() | |
transform, de_norm = get_transforms(mean_day, std_day) | |
batch = transform(image)[None, :, :, :] | |
batch_tr = model.a2b_generator(batch) | |
img_tr = de_norm(batch_tr[0, :, :, :]) | |
st.write(img_tr.shape) | |
st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True) | |
def convert_night2day(): | |
image = st.session_state.image | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("Left Column") | |
st.image(opencv_image, channels="BGR", use_container_width=True) | |
with col2: | |
st.write("Center Column") | |
model = load_model() | |
with torch.no_grad(): | |
transform, de_norm = get_transforms(mean_night, std_night) | |
batch = transform(image)[None, :, :, :] | |
batch_tr = model.b2a_generator(batch) | |
img_tr = de_norm(batch_tr[0, :, :, :]) | |
st.write(img_tr.shape) | |
st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True) | |
def zero_calls(): | |
st.session_state.calls = 0 | |
st.session_state.option = st.selectbox('day2night OR night2day', ['day2night', 'night2day']) | |
uploaded_file = st.file_uploader("Choose a image file", type="jpg") | |
if uploaded_file is not None: | |
# Convert the file to an opencv image. | |
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
opencv_image = cv2.imdecode(file_bytes, 1) | |
st.session_state.image = np.asarray(opencv_image) | |
image = st.session_state.image | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("Original") | |
st.image(opencv_image, channels="BGR", use_container_width=True) | |
with col2: | |
st.write("Transformed") | |
model = load_model() | |
with torch.no_grad(): | |
if st.session_state.option == 'day2night': | |
channel_mean = (image / 255.).mean() | |
transform, de_norm = get_transforms(mean_day, std_day) | |
batch = transform(image)[None, :, :, :] | |
batch_tr = model.a2b_generator(batch) | |
img_tr = de_norm(batch_tr[0, :, :, :]) | |
st.image(img_tr, channels="BGR", use_container_width=True, clamp=True) | |
else: | |
transform, de_norm = get_transforms(mean_night, std_night) | |
batch = transform(image)[None, :, :, :] | |
batch_tr = model.b2a_generator(batch) | |
img_tr = de_norm(batch_tr[0, :, :, :]) | |
st.image(img_tr, channels="BGR", use_container_width=True, clamp=True) |