Spaces:
Runtime error
Runtime error
import torch | |
import torchvision.transforms as T | |
import PIL.Image | |
from typing import List | |
size = (224, 224) | |
class ResizeWithPadding: | |
def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None: | |
self.target_size = target_size | |
self.fill = fill | |
self.mode = mode | |
def __call__(self, image: PIL.Image) -> PIL.Image: | |
original_width, original_height = image.size | |
aspect_ratio = original_width / original_height | |
if aspect_ratio > 1: | |
new_width = self.target_size | |
new_height = int(self.target_size / aspect_ratio) | |
else: | |
new_height = self.target_size | |
new_width = int(self.target_size * aspect_ratio) | |
resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST) | |
delta_w = self.target_size - new_width | |
delta_h = self.target_size - new_height | |
padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2) | |
padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill) | |
padded_image.paste(resized_image, (padding[0], padding[1])) | |
return padded_image | |
def get_transform(mean: List[float], std: List[float]) -> T.Compose: | |
return T.Compose([ | |
ResizeWithPadding(), | |
T.ToTensor(), | |
T.Normalize(mean=mean, std=std), | |
]) | |
mask_transform = T.Compose([ | |
ResizeWithPadding(mode="L"), | |
T.ToTensor(), | |
T.Lambda(lambda x: (x * 255).long()), | |
]) | |
class EMA: | |
def __init__(self, alpha: float = 0.9) -> None: | |
self.value = None | |
self.alpha = alpha | |
def __call__(self, value: float) -> float: | |
if self.value is None: | |
self.value = value | |
else: | |
self.value = self.alpha * self.value + (1 - self.alpha) * value | |
return self.value |