HW5 / app.py
MariaUDmitrieva's picture
Upload app.py
2ac4dc3 verified
import streamlit as st
import time
import cv2
import numpy as np
# model part
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms as tr
from torchvision.transforms import v2
from sklearn.preprocessing import minmax_scale
from collections import OrderedDict
st.session_state.image = None
st.session_state.calls = 0
def get_transforms(mean, std):
val_transform = tr.Compose([
tr.ToPILImage(),
v2.Resize(size=256),
tr.ToTensor(),
#...,
tr.Normalize(mean=mean, std=std)
])
def de_normalize(img):
if isinstance(img, torch.Tensor):
image = img.cpu()
else:
image = img
return minmax_scale(
(image.reshape(3, -1) + mean[:, None]) * std[:, None],
feature_range=(0., 1.),
axis=1,
).reshape(*img.shape).transpose(1, 2, 0)
return val_transform, de_normalize
class Conv7Stride1(nn.Module):
def __init__(self, in_channels, out_channels, use_norm=True):
super(Conv7Stride1, self).__init__()
if use_norm:
self.model = nn.Sequential(OrderedDict([
('pad', nn.ReflectionPad2d(3)),
('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)),
('norm', nn.InstanceNorm2d(out_channels)),
('relu', nn.ReLU())
]))
else:
self.model = nn.Sequential(OrderedDict([
('pad', nn.ReflectionPad2d(3)),
('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)),
('tanh', nn.Tanh())
]))
def forward(self, x):
return self.model(x)
class Down(nn.Module):
def __init__(self, k):
super(Down, self).__init__()
self.model = nn.Sequential(OrderedDict([
('conv', torch.nn.Conv2d(k//2, k, kernel_size=3, stride=2, padding=1)),
('norm', nn.InstanceNorm2d(k)),
('relu', nn.ReLU())
]))
def forward(self, x):
return self.model(x)
class ResBlock(nn.Module):
def __init__(self, k, use_dropout=False):
super(ResBlock, self).__init__()
self.blocks = []
for _ in range(2):
self.blocks += [nn.Sequential(OrderedDict([
('pad', nn.ReflectionPad2d(1)),
('conv', torch.nn.Conv2d(k, k, kernel_size=3)),
('dropout', nn.BatchNorm2d(k)),
('relu', nn.ReLU())
]))]
if use_dropout:
self.model = nn.Sequential(OrderedDict([
('block1', self.blocks[0]),
('dropout', nn.Dropout(0.5)),
('block2', self.blocks[1])
]))
else:
self.model = nn.Sequential(OrderedDict([
('block1', self.blocks[0]),
('block2', self.blocks[1])
]))
def forward(self, x):
return (x + self.model(x))
class Up(nn.Module):
def __init__(self, k):
super(Up, self).__init__()
self.model = nn.Sequential(OrderedDict([
('conv_transpose', nn.ConvTranspose2d(2*k, k, kernel_size=3, padding=1, output_padding=1, stride=2)),
('norm', nn.InstanceNorm2d(k)),
('relu', nn.ReLU())
]))
def forward(self, x):
return self.model(x)
class ResGenerator(nn.Module):
def __init__(self, res_blocks=9, use_dropout=False):
super(ResGenerator, self).__init__()
self.residual_blocks = nn.Sequential(OrderedDict([
(f'R256_{i+1}', ResBlock(256, use_dropout=use_dropout)) for i in range(res_blocks)
]))
self.model = nn.Sequential(OrderedDict([
('c7s1-64', Conv7Stride1(3, 64)),
('d128', Down(128)),
('d256', Down(256)),
('res_blocks', self.residual_blocks),
('u128', Up(128)),
('u64', Up(64)),
('c7s1-3', Conv7Stride1(64, 3, use_norm=False))
]))
def forward(self, x):
return self.model(x)
class ConvForDisc(nn.Module):
def __init__(self, *channels, stride=2, use_norm=True):
super(ConvForDisc, self).__init__()
if len(channels) == 1:
channels = (channels[0] // 2, channels[0])
if use_norm:
self.model = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)),
('norm', nn.InstanceNorm2d(channels[1])),
('relu', nn.LeakyReLU(0.2, True))
]))
else:
self.model = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)),
('relu', nn.LeakyReLU(0.2, True))
]))
def forward(self, x):
return self.model(x)
class ConvDiscriminator(nn.Module):
def __init__(self):
super(ConvDiscriminator, self).__init__()
self.model = nn.Sequential(OrderedDict([
('C64', ConvForDisc(3, 64, use_norm=False)),
('C128', ConvForDisc(128)),
('C256', ConvForDisc(256)),
('C512', ConvForDisc(512, stride=1)),
('conv1channel', nn.Conv2d(512, 1, kernel_size=4, padding=1))
]))
def forward(self, x):
# predicts logits
return torch.flatten(self.model(x), start_dim=1)
class CycleGAN(nn.Module):
def __init__(self, res_blocks=9, use_dropout=False):
super(CycleGAN, self).__init__()
self.a2b_generator = ResGenerator(res_blocks=9, use_dropout=False)
self.a_discriminator = ConvDiscriminator()
self.b2a_generator = ResGenerator(res_blocks=9, use_dropout=False)
self.b_discriminator = ConvDiscriminator()
@st.cache_resource
def load_model():
checkpoint = torch.load('cycle_gan#21.pt', weights_only=False,
map_location=torch.device('cpu'))
model = CycleGAN()
model.load_state_dict(checkpoint['model_state_dict'])
return model
mean_night = np.array([0.46207718, 0.52259593, 0.54372674])
mean_day = np.array([0.18620284, 0.18614635, 0.20172116])
std_night = np.array([0.21945059, 0.20839803, 0.2328357 ])
std_day = np.array([0.16982935, 0.14963816, 0.14965146])
# front part
st.markdown("<h1 style='text-align: center;'>Change daytime!</h1>", unsafe_allow_html=True)
def add_calls():
st.session_state.calls += 1
st.write(f'{st.session_state.calls=}')
def convert_day2night():
image = st.session_state.image
col1, col2 = st.columns(2)
with col1:
st.write("Left Column")
st.image(opencv_image, channels="BGR", use_container_width=True)
with col2:
st.write("Center Column")
model = load_model()
with torch.no_grad():
channel_mean = (image / 255.).mean()
transform, de_norm = get_transforms(mean_day, std_day)
batch = transform(image)[None, :, :, :]
batch_tr = model.a2b_generator(batch)
img_tr = de_norm(batch_tr[0, :, :, :])
st.write(img_tr.shape)
st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True)
def convert_night2day():
image = st.session_state.image
col1, col2 = st.columns(2)
with col1:
st.write("Left Column")
st.image(opencv_image, channels="BGR", use_container_width=True)
with col2:
st.write("Center Column")
model = load_model()
with torch.no_grad():
transform, de_norm = get_transforms(mean_night, std_night)
batch = transform(image)[None, :, :, :]
batch_tr = model.b2a_generator(batch)
img_tr = de_norm(batch_tr[0, :, :, :])
st.write(img_tr.shape)
st.image([image, img_tr], channels="BGR", use_container_width=True, clamp=True)
def zero_calls():
st.session_state.calls = 0
st.session_state.option = st.selectbox('day2night OR night2day', ['day2night', 'night2day'])
uploaded_file = st.file_uploader("Choose a image file", type="jpg")
if uploaded_file is not None:
# Convert the file to an opencv image.
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
st.session_state.image = np.asarray(opencv_image)
image = st.session_state.image
col1, col2 = st.columns(2)
with col1:
st.write("Original")
st.image(opencv_image, channels="BGR", use_container_width=True)
with col2:
st.write("Transformed")
model = load_model()
with torch.no_grad():
if st.session_state.option == 'day2night':
channel_mean = (image / 255.).mean()
transform, de_norm = get_transforms(mean_day, std_day)
batch = transform(image)[None, :, :, :]
batch_tr = model.a2b_generator(batch)
img_tr = de_norm(batch_tr[0, :, :, :])
st.image(img_tr, channels="BGR", use_container_width=True, clamp=True)
else:
transform, de_norm = get_transforms(mean_night, std_night)
batch = transform(image)[None, :, :, :]
batch_tr = model.b2a_generator(batch)
img_tr = de_norm(batch_tr[0, :, :, :])
st.image(img_tr, channels="BGR", use_container_width=True, clamp=True)