import streamlit as st import time import cv2 import numpy as np # model part import json import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms as tr from torchvision.transforms import v2 from sklearn.preprocessing import minmax_scale from collections import OrderedDict st.session_state.image = None st.session_state.calls = 0 def get_transforms(mean, std): val_transform = tr.Compose([ tr.ToPILImage(), v2.Resize(size=256), tr.ToTensor(), #..., tr.Normalize(mean=mean, std=std) ]) def de_normalize(img): if isinstance(img, torch.Tensor): image = img.cpu() else: image = img return minmax_scale( (image.reshape(3, -1) + mean[:, None]) * std[:, None], feature_range=(0., 1.), axis=1, ).reshape(*img.shape).transpose(1, 2, 0) return val_transform, de_normalize class Conv7Stride1(nn.Module): def __init__(self, in_channels, out_channels, use_norm=True): super(Conv7Stride1, self).__init__() if use_norm: self.model = nn.Sequential(OrderedDict([ ('pad', nn.ReflectionPad2d(3)), ('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)), ('norm', nn.InstanceNorm2d(out_channels)), ('relu', nn.ReLU()) ])) else: self.model = nn.Sequential(OrderedDict([ ('pad', nn.ReflectionPad2d(3)), ('conv', torch.nn.Conv2d(in_channels, out_channels, kernel_size=7)), ('tanh', nn.Tanh()) ])) def forward(self, x): return self.model(x) class Down(nn.Module): def __init__(self, k): super(Down, self).__init__() self.model = nn.Sequential(OrderedDict([ ('conv', torch.nn.Conv2d(k//2, k, kernel_size=3, stride=2, padding=1)), ('norm', nn.InstanceNorm2d(k)), ('relu', nn.ReLU()) ])) def forward(self, x): return self.model(x) class ResBlock(nn.Module): def __init__(self, k, use_dropout=False): super(ResBlock, self).__init__() self.blocks = [] for _ in range(2): self.blocks += [nn.Sequential(OrderedDict([ ('pad', nn.ReflectionPad2d(1)), ('conv', torch.nn.Conv2d(k, k, kernel_size=3)), ('dropout', nn.BatchNorm2d(k)), ('relu', nn.ReLU()) ]))] if use_dropout: self.model = nn.Sequential(OrderedDict([ ('block1', self.blocks[0]), ('dropout', nn.Dropout(0.5)), ('block2', self.blocks[1]) ])) else: self.model = nn.Sequential(OrderedDict([ ('block1', self.blocks[0]), ('block2', self.blocks[1]) ])) def forward(self, x): return (x + self.model(x)) class Up(nn.Module): def __init__(self, k): super(Up, self).__init__() self.model = nn.Sequential(OrderedDict([ ('conv_transpose', nn.ConvTranspose2d(2*k, k, kernel_size=3, padding=1, output_padding=1, stride=2)), ('norm', nn.InstanceNorm2d(k)), ('relu', nn.ReLU()) ])) def forward(self, x): return self.model(x) class ResGenerator(nn.Module): def __init__(self, res_blocks=9, use_dropout=False): super(ResGenerator, self).__init__() self.residual_blocks = nn.Sequential(OrderedDict([ (f'R256_{i+1}', ResBlock(256, use_dropout=use_dropout)) for i in range(res_blocks) ])) self.model = nn.Sequential(OrderedDict([ ('c7s1-64', Conv7Stride1(3, 64)), ('d128', Down(128)), ('d256', Down(256)), ('res_blocks', self.residual_blocks), ('u128', Up(128)), ('u64', Up(64)), ('c7s1-3', Conv7Stride1(64, 3, use_norm=False)) ])) def forward(self, x): return self.model(x) class ConvForDisc(nn.Module): def __init__(self, *channels, stride=2, use_norm=True): super(ConvForDisc, self).__init__() if len(channels) == 1: channels = (channels[0] // 2, channels[0]) if use_norm: self.model = nn.Sequential(OrderedDict([ ('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)), ('norm', nn.InstanceNorm2d(channels[1])), ('relu', nn.LeakyReLU(0.2, True)) ])) else: self.model = nn.Sequential(OrderedDict([ ('conv', nn.Conv2d(channels[0], channels[1], kernel_size=4, stride=stride, padding=1)), ('relu', nn.LeakyReLU(0.2, True)) ])) def forward(self, x): return self.model(x) class ConvDiscriminator(nn.Module): def __init__(self): super(ConvDiscriminator, self).__init__() self.model = nn.Sequential(OrderedDict([ ('C64', ConvForDisc(3, 64, use_norm=False)), ('C128', ConvForDisc(128)), ('C256', ConvForDisc(256)), ('C512', ConvForDisc(512, stride=1)), ('conv1channel', nn.Conv2d(512, 1, kernel_size=4, padding=1)) ])) def forward(self, x): # predicts logits return torch.flatten(self.model(x), start_dim=1) class CycleGAN(nn.Module): def __init__(self, res_blocks=9, use_dropout=False): super(CycleGAN, self).__init__() self.a2b_generator = ResGenerator(res_blocks=9, use_dropout=False) self.a_discriminator = ConvDiscriminator() self.b2a_generator = ResGenerator(res_blocks=9, use_dropout=False) self.b_discriminator = ConvDiscriminator() @st.cache_resource def load_model(): checkpoint = torch.load('cycle_gan#21.pt', weights_only=False, map_location=torch.device('cpu')) model = CycleGAN() model.load_state_dict(checkpoint['model_state_dict']) return model mean_night = np.array([0.46207718, 0.52259593, 0.54372674]) mean_day = np.array([0.18620284, 0.18614635, 0.20172116]) std_night = np.array([0.21945059, 0.20839803, 0.2328357 ]) std_day = np.array([0.16982935, 0.14963816, 0.14965146]) # front part st.markdown("