Spaces:
Runtime error
Runtime error
import torch | |
from torchvision.utils import make_grid | |
from torchvision import transforms | |
import torchvision.transforms.functional as TF | |
from torch import nn, optim | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
from torch.utils.data import DataLoader, Dataset | |
from huggingface_hub import hf_hub_download | |
import requests | |
import gradio as gr | |
class Upsample(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dropout=True): | |
super(Upsample, self).__init__() | |
self.dropout = dropout | |
self.block = nn.Sequential( | |
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d), | |
nn.InstanceNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
self.dropout_layer = nn.Dropout2d(0.5) | |
def forward(self, x, shortcut=None): | |
x = self.block(x) | |
if self.dropout: | |
x = self.dropout_layer(x) | |
if shortcut is not None: | |
x = torch.cat([x, shortcut], dim=1) | |
return x | |
class Downsample(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, apply_instancenorm=True): | |
super(Downsample, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d) | |
self.norm = nn.InstanceNorm2d(out_channels) | |
self.relu = nn.LeakyReLU(0.2, inplace=True) | |
self.apply_norm = apply_instancenorm | |
def forward(self, x): | |
x = self.conv(x) | |
if self.apply_norm: | |
x = self.norm(x) | |
x = self.relu(x) | |
return x | |
class CycleGAN_Unet_Generator(nn.Module): | |
def __init__(self, filter=64): | |
super(CycleGAN_Unet_Generator, self).__init__() | |
self.downsamples = nn.ModuleList([ | |
Downsample(3, filter, kernel_size=4, apply_instancenorm=False), # (b, filter, 128, 128) | |
Downsample(filter, filter * 2), # (b, filter * 2, 64, 64) | |
Downsample(filter * 2, filter * 4), # (b, filter * 4, 32, 32) | |
Downsample(filter * 4, filter * 8), # (b, filter * 8, 16, 16) | |
Downsample(filter * 8, filter * 8), # (b, filter * 8, 8, 8) | |
Downsample(filter * 8, filter * 8), # (b, filter * 8, 4, 4) | |
Downsample(filter * 8, filter * 8), # (b, filter * 8, 2, 2) | |
]) | |
self.upsamples = nn.ModuleList([ | |
Upsample(filter * 8, filter * 8), | |
Upsample(filter * 16, filter * 8), | |
Upsample(filter * 16, filter * 8), | |
Upsample(filter * 16, filter * 4, dropout=False), | |
Upsample(filter * 8, filter * 2, dropout=False), | |
Upsample(filter * 4, filter, dropout=False) | |
]) | |
self.last = nn.Sequential( | |
nn.ConvTranspose2d(filter * 2, 3, kernel_size=4, stride=2, padding=1), | |
nn.Tanh() | |
) | |
def forward(self, x): | |
skips = [] | |
for l in self.downsamples: | |
x = l(x) | |
skips.append(x) | |
skips = reversed(skips[:-1]) | |
for l, s in zip(self.upsamples, skips): | |
x = l(x, s) | |
out = self.last(x) | |
return out | |
class ImageTransform: | |
def __init__(self, img_size=256): | |
self.transform = { | |
'train': transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomVerticalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]), | |
'test': transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
})} | |
| |
def __call__(self, img, phase='train'): | |
img = self.transform[phase](img) | |
| |
return img | |
path = hf_hub_download('huggan/NeonGAN', 'model.bin') | |
model_gen_n = torch.load(path, map_location=torch.device('cpu')) | |