deoldify / model_image_colorizer.py
leonelhs's picture
init app
56ecbb4
raw
history blame
1.2 kB
import torch
from PIL import Image as PilImage
from deoldify.filters import IFilter, BaseFilter
from deoldify.visualize import ModelImageVisualizer
from fastai.basic_train import Learner
from fastai.vision import normalize_funcs
stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
class ImageFilter(BaseFilter):
def __init__(self, learn: Learner):
super().__init__(learn)
self.render_base = 16
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.norm, self.denorm = normalize_funcs(*stats)
def filter(self, filtered_image: PilImage, render_factor=35) -> PilImage:
orig_image = filtered_image.copy()
render_sz = render_factor * self.render_base
model_image = self._model_process(orig=filtered_image, sz=render_sz)
raw_color = self._unsquare(model_image, orig_image)
return raw_color
class ModelImageColorizer(ModelImageVisualizer):
def __init__(self, filter: IFilter):
self.filter = filter
def get_colored_image(self, image, render_factor: int = None) -> PilImage:
self._clean_mem()
return self.filter.filter(image, render_factor=render_factor)