|
import gc |
|
|
|
import PIL.Image |
|
import numpy as np |
|
import torch |
|
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, LineartAnimeDetector, LineartDetector, |
|
MidasDetector, MLSDdetector, NormalBaeDetector, OpenposeDetector, PidiNetDetector) |
|
from controlnet_aux.util import HWC3 |
|
|
|
from controlnet.cv_utils import resize_image |
|
from controlnet.depth_estimator import DepthEstimator |
|
from controlnet.image_segmentor import ImageSegmentor |
|
|
|
|
|
class ControlNet_Preprocessor: |
|
MODEL_ID = 'lllyasviel/Annotators' |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.name = '' |
|
|
|
def load(self, name: str) -> None: |
|
if name == self.name: |
|
return |
|
if name == 'HED': |
|
self.model = HEDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Midas': |
|
self.model = MidasDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'MLSD': |
|
self.model = MLSDdetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Openpose': |
|
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'PidiNet': |
|
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'NormalBae': |
|
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Lineart': |
|
self.model = LineartDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'LineartAnime': |
|
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID) |
|
elif name == 'Canny': |
|
self.model = CannyDetector() |
|
elif name == 'ContentShuffle': |
|
self.model = ContentShuffleDetector() |
|
elif name == 'DPT': |
|
self.model = DepthEstimator() |
|
elif name == 'UPerNet': |
|
self.model = ImageSegmentor() |
|
else: |
|
raise ValueError |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
self.name = name |
|
|
|
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: |
|
if self.name == 'Canny': |
|
if 'detect_resolution' in kwargs: |
|
detect_resolution = kwargs.pop('detect_resolution') |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
return PIL.Image.fromarray(image) |
|
elif self.name == 'Midas': |
|
detect_resolution = kwargs.pop('detect_resolution', 512) |
|
image_resolution = kwargs.pop('image_resolution', 512) |
|
image = np.array(image) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=detect_resolution) |
|
image = self.model(image, **kwargs) |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
return PIL.Image.fromarray(image) |
|
else: |
|
image = np.array(image) |
|
return self.model(image, **kwargs) |
|
|
|
@torch.inference_mode() |
|
def preprocess_canny(self, image, image_resolution, low_threshold, high_threshold): |
|
self.load('Canny') |
|
control_image = self( |
|
image=image, |
|
low_threshold=low_threshold, |
|
high_threshold=high_threshold, |
|
detect_resolution=image_resolution |
|
) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_mlsd(self, image, image_resolution, preprocess_resolution, value_threshold, distance_threshold): |
|
self.load('MLSD') |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
thr_v=value_threshold, |
|
thr_d=distance_threshold, |
|
) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_scribble(self, image, image_resolution, preprocess_resolution, preprocessor_name): |
|
if preprocessor_name == 'None': |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
elif preprocessor_name == 'HED': |
|
self.load(preprocessor_name) |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
scribble=False, |
|
) |
|
elif preprocessor_name == 'PidiNet': |
|
self.load(preprocessor_name) |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
safe=False, |
|
) |
|
else: |
|
raise ValueError |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_scribble_interactive(self, image_and_mask, image_resolution): |
|
image = image_and_mask['mask'] |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_softedge(self, image, image_resolution, preprocess_resolution, preprocessor_name): |
|
if preprocessor_name == 'None': |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
elif preprocessor_name in ['HED', 'HED safe']: |
|
safe = 'safe' in preprocessor_name |
|
self.load('HED') |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
scribble=safe, |
|
) |
|
elif preprocessor_name in ['PidiNet', 'PidiNet safe']: |
|
safe = 'safe' in preprocessor_name |
|
self.load('PidiNet') |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
safe=safe, |
|
) |
|
else: |
|
raise ValueError |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_openpose(self, image, image_resolution, preprocess_resolution, preprocessor_name): |
|
if preprocessor_name == 'None': |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
else: |
|
self.load('Openpose') |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
hand_and_face=True, |
|
) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_segmentation(self, image, image_resolution, preprocess_resolution, preprocessor_name): |
|
if preprocessor_name == 'None': |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
else: |
|
self.load(preprocessor_name) |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_depth(self, image, image_resolution, preprocess_resolution, preprocessor_name): |
|
if preprocessor_name == 'None': |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
else: |
|
self.load(preprocessor_name) |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_normal(self, image, image_resolution, preprocess_resolution, preprocessor_name): |
|
if preprocessor_name == 'None': |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
else: |
|
self.load('NormalBae') |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_lineart(self, image, image_resolution, preprocess_resolution, preprocessor_name): |
|
if preprocessor_name in ['None', 'None (anime)']: |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
elif preprocessor_name in ['Lineart', 'Lineart coarse']: |
|
coarse = 'coarse' in preprocessor_name |
|
self.load('Lineart') |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
coarse=coarse, |
|
) |
|
elif preprocessor_name == 'Lineart (anime)': |
|
self.load('LineartAnime') |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
) |
|
else: |
|
raise ValueError |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_shuffle(self, image, image_resolution, preprocessor_name): |
|
if preprocessor_name == 'None': |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
else: |
|
self.load(preprocessor_name) |
|
control_image = self( |
|
image=image, |
|
image_resolution=image_resolution, |
|
) |
|
return control_image |
|
|
|
@torch.inference_mode() |
|
def preprocess_ip2p(self, image, image_resolution): |
|
image = HWC3(image) |
|
image = resize_image(image, resolution=image_resolution) |
|
control_image = PIL.Image.fromarray(image) |
|
return control_image |
|
|