import torch from PIL import Image as PilImage from deoldify.filters import IFilter, BaseFilter from deoldify.visualize import ModelImageVisualizer from fastai.basic_train import Learner from fastai.vision import normalize_funcs stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) class ImageFilter(BaseFilter): def __init__(self, learn: Learner): super().__init__(learn) self.render_base = 16 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.norm, self.denorm = normalize_funcs(*stats) def filter(self, filtered_image: PilImage, render_factor=35) -> PilImage: orig_image = filtered_image.copy() render_sz = render_factor * self.render_base model_image = self._model_process(orig=filtered_image, sz=render_sz) raw_color = self._unsquare(model_image, orig_image) return raw_color class ModelImageColorizer(ModelImageVisualizer): def __init__(self, filter: IFilter): self.filter = filter def get_colored_image(self, image, render_factor: int = None) -> PilImage: self._clean_mem() return self.filter.filter(image, render_factor=render_factor)