import gradio as gr import torch.nn as nn from torch import tanh, Tensor from abc import ABC, abstractmethod from huggingface_hub import hf_hub_download import torch import json from omegaconf import OmegaConf from model import Generator class BaseGenerator(ABC, nn.Module): def __init__(self, channels: int = 3): super().__init__() self.channels = channels @abstractmethod def forward(self, x: Tensor) -> Tensor: pass class Generator(BaseGenerator): def __init__(self, cfg: DictConfig): super().__init__(cfg.channels) self.cfg = cfg self.model = self._construct_model() def _construct_model(self): initial_layer = nn.Sequential( nn.Conv2d( self.cfg.channels, self.cfg.num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect", ), nn.ReLU(inplace=True), ) down_blocks = nn.Sequential( ConvBlock( self.cfg.num_features, self.cfg.num_features * 2, kernel_size=3, stride=2, padding=1, ), ConvBlock( self.cfg.num_features * 2, self.cfg.num_features * 4, kernel_size=3, stride=2, padding=1, ), ) residual_blocks = nn.Sequential( *[ ResidualBlock(self.cfg.num_features * 4) for _ in range(self.cfg.num_residuals) ] ) up_blocks = nn.Sequential( ConvBlock( self.cfg.num_features * 4, self.cfg.num_features * 2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1, ), ConvBlock( self.cfg.num_features * 2, self.cfg.num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1, ), ) last_layer = nn.Conv2d( self.cfg.num_features, self.cfg.channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect", ) return nn.Sequential( initial_layer, down_blocks, residual_blocks, up_blocks, last_layer ) def forward(self, x: Tensor) -> Tensor: return tanh(self.model(x)) class ConvBlock(nn.Module): def __init__( self, in_channels, out_channels, down=True, use_activation=True, **kwargs ): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs) if down else nn.ConvTranspose2d(in_channels, out_channels, **kwargs), nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True) if use_activation else nn.Identity(), ) def forward(self, x: Tensor) -> Tensor: return self.conv(x) class ResidualBlock(nn.Module): def __init__(self, channels: int): super().__init__() self.block = nn.Sequential( ConvBlock(channels, channels, kernel_size=3, padding=1), ConvBlock( channels, channels, use_activation=False, kernel_size=3, padding=1 ), ) def forward(self, x: Tensor) -> Tensor: return x + self.block(x) repo_id = "Kiwinicki/sat2map-generator" generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") model_path = hf_hub_download(repo_id=repo_id, filename="model.py") with open(config_path, "r") as f: config_dict = json.load(f) cfg = OmegaConf.create(config_dict) generator = Generator(cfg) generator.load_state_dict(torch.load(generator_path)) generator.eval() def greet(iamge): return image iface = gr.Interface(fn=greet, inputs="image", outputs="image") iface.launch()