|
import torch |
|
from PIL import Image as PilImage |
|
|
|
from deoldify.filters import IFilter, BaseFilter |
|
from deoldify.visualize import ModelImageVisualizer |
|
from fastai.basic_train import Learner |
|
from fastai.vision import normalize_funcs |
|
|
|
stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
|
|
|
class ImageFilter(BaseFilter): |
|
def __init__(self, learn: Learner): |
|
super().__init__(learn) |
|
self.render_base = 16 |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.norm, self.denorm = normalize_funcs(*stats) |
|
|
|
def filter(self, filtered_image: PilImage, render_factor=35) -> PilImage: |
|
orig_image = filtered_image.copy() |
|
render_sz = render_factor * self.render_base |
|
model_image = self._model_process(orig=filtered_image, sz=render_sz) |
|
raw_color = self._unsquare(model_image, orig_image) |
|
return raw_color |
|
|
|
|
|
class ModelImageColorizer(ModelImageVisualizer): |
|
def __init__(self, filter: IFilter): |
|
self.filter = filter |
|
|
|
def get_colored_image(self, image, render_factor: int = None) -> PilImage: |
|
self._clean_mem() |
|
return self.filter.filter(image, render_factor=render_factor) |
|
|