|
import gc |
|
import math |
|
import os |
|
import torch |
|
from typing import Literal |
|
from PIL import Image, ImageFilter, ImageOps |
|
from PIL.ImageOps import exif_transpose |
|
from tqdm import tqdm |
|
|
|
from torchvision import transforms |
|
|
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
def flush(garbage_collect=True): |
|
torch.cuda.empty_cache() |
|
if garbage_collect: |
|
gc.collect() |
|
|
|
|
|
ControlTypes = Literal['depth', 'pose', 'line', 'inpaint', 'mask'] |
|
|
|
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] |
|
|
|
|
|
class ControlGenerator: |
|
def __init__(self, device, sd=None): |
|
self.device = device |
|
self.sd = sd |
|
self.has_unloaded = False |
|
self.control_depth_model = None |
|
self.control_pose_model = None |
|
self.control_line_model = None |
|
self.control_bg_remover = None |
|
self.debug = False |
|
self.regen = False |
|
|
|
def get_control_path(self, img_path, control_type: ControlTypes): |
|
if self.regen: |
|
return self._generate_control(img_path, control_type) |
|
coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls') |
|
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] |
|
file_name_no_ext_control = f"{file_name_no_ext}.{control_type}" |
|
for ext in img_ext_list: |
|
possible_path = os.path.join( |
|
coltrols_folder, file_name_no_ext_control + ext) |
|
if os.path.exists(possible_path): |
|
return possible_path |
|
|
|
return self._generate_control(img_path, control_type) |
|
|
|
def debug_print(self, *args, **kwargs): |
|
if self.debug: |
|
print(*args, **kwargs) |
|
|
|
def _generate_control(self, img_path, control_type): |
|
device = self.device |
|
image: Image = None |
|
|
|
coltrols_folder = os.path.join(os.path.dirname(img_path), '_controls') |
|
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] |
|
|
|
|
|
if not self.has_unloaded: |
|
if self.sd is not None: |
|
print("Unloading model to generate controls") |
|
self.sd.set_device_state_preset('unload') |
|
self.has_unloaded = True |
|
|
|
if image is None: |
|
|
|
image = Image.open(img_path).convert('RGB') |
|
image = exif_transpose(image) |
|
|
|
|
|
max_size = 1024 * 1024 |
|
|
|
w, h = image.size |
|
if w * h > max_size: |
|
scale = math.sqrt(max_size / (w * h)) |
|
w = int(w * scale) |
|
h = int(h * scale) |
|
image = image.resize((w, h), Image.BICUBIC) |
|
|
|
save_path = os.path.join( |
|
coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg") |
|
os.makedirs(coltrols_folder, exist_ok=True) |
|
if control_type == 'depth': |
|
self.debug_print("Generating depth control") |
|
if self.control_depth_model is None: |
|
from transformers import pipeline |
|
self.control_depth_model = pipeline( |
|
task="depth-estimation", |
|
model="depth-anything/Depth-Anything-V2-Large-hf", |
|
device=device, |
|
torch_dtype=torch.float16 |
|
) |
|
img = image.copy() |
|
in_size = img.size |
|
output = self.control_depth_model(img) |
|
out_tensor = output["predicted_depth"] |
|
out_tensor = out_tensor.clamp(0, 255) |
|
out_tensor = out_tensor.squeeze(0).cpu().numpy() |
|
img = Image.fromarray(out_tensor.astype('uint8')) |
|
img = img.resize(in_size, Image.LANCZOS) |
|
img.save(save_path) |
|
return save_path |
|
elif control_type == 'pose': |
|
self.debug_print("Generating pose control") |
|
if self.control_pose_model is None: |
|
try: |
|
import onnxruntime |
|
onnxruntime.set_default_logger_severity(3) |
|
except ImportError: |
|
raise ImportError( |
|
"onnxruntime is not installed. Please install it with pip install onnxruntime or onnxruntime-gpu") |
|
try: |
|
from easy_dwpose import DWposeDetector |
|
self.control_pose_model = DWposeDetector( |
|
device=str(device)) |
|
except ImportError: |
|
raise ImportError( |
|
"easy-dwpose is not installed. Please install it with pip install easy-dwpose") |
|
img = image.copy() |
|
|
|
detect_res = int(math.sqrt(img.size[0] * img.size[1])) |
|
img = self.control_pose_model( |
|
img, output_type="pil", include_hands=True, include_face=True, detect_resolution=detect_res) |
|
img = img.convert('RGB') |
|
img.save(save_path) |
|
return save_path |
|
|
|
elif control_type == 'line': |
|
self.debug_print("Generating line control") |
|
if self.control_line_model is None: |
|
from controlnet_aux import TEEDdetector |
|
self.control_line_model = TEEDdetector.from_pretrained( |
|
"fal-ai/teed", filename="5_model.pth").to(device) |
|
img = image.copy() |
|
img = self.control_line_model(img, detect_resolution=1024) |
|
|
|
|
|
img = img.point(lambda p: p > 128 and 255) |
|
img = img.convert('RGB') |
|
img.save(save_path) |
|
return save_path |
|
elif control_type == 'inpaint' or control_type == 'mask': |
|
self.debug_print("Generating inpaint/mask control") |
|
img = image.copy() |
|
if self.control_bg_remover is None: |
|
from transformers import AutoModelForImageSegmentation |
|
self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained( |
|
'ZhengPeng7/BiRefNet_HR', |
|
trust_remote_code=True, |
|
revision="595e212b3eaa6a1beaad56cee49749b1e00b1596", |
|
torch_dtype=torch.float16 |
|
).to(device) |
|
self.control_bg_remover.eval() |
|
|
|
image_size = (1024, 1024) |
|
transform_image = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [ |
|
0.229, 0.224, 0.225]) |
|
]) |
|
|
|
input_images = transform_image(img).unsqueeze( |
|
0).to('cuda').to(torch.float16) |
|
|
|
|
|
preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(img.size) |
|
if control_type == 'inpaint': |
|
|
|
mask = ImageOps.invert(mask) |
|
img.putalpha(mask) |
|
save_path = os.path.join( |
|
coltrols_folder, f"{file_name_no_ext}.{control_type}.webp") |
|
else: |
|
img = mask |
|
img = img.convert('RGB') |
|
img.save(save_path) |
|
return save_path |
|
else: |
|
raise Exception(f"Error: unknown control type {control_type}") |
|
|
|
def cleanup(self): |
|
if self.control_depth_model is not None: |
|
self.control_depth_model = None |
|
if self.control_pose_model is not None: |
|
self.control_pose_model = None |
|
if self.control_line_model is not None: |
|
self.control_line_model = None |
|
if self.control_bg_remover is not None: |
|
self.control_bg_remover = None |
|
if self.sd is not None and self.has_unloaded: |
|
self.sd.restore_device_state() |
|
self.has_unloaded = False |
|
|
|
flush() |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
import argparse |
|
import time |
|
import transformers |
|
transformers.logging.set_verbosity_error() |
|
|
|
control_times = { |
|
'depth': 0, |
|
'pose': 0, |
|
'line': 0, |
|
'inpaint': 0, |
|
'mask': 0 |
|
} |
|
|
|
controls = control_times.keys() |
|
|
|
parser = argparse.ArgumentParser(description="Generate control images") |
|
parser.add_argument("img_dir", type=str, help="Path to image directory") |
|
parser.add_argument('--debug', action='store_true', |
|
help="Enable debug mode") |
|
parser.add_argument('--regen', action='store_true', |
|
help="Regenerate all controls") |
|
|
|
args = parser.parse_args() |
|
img_dir = args.img_dir |
|
if not os.path.exists(img_dir): |
|
print(f"Error: {img_dir} does not exist") |
|
exit() |
|
if not os.path.isdir(img_dir): |
|
print(f"Error: {img_dir} is not a directory") |
|
exit() |
|
|
|
|
|
img_list = [] |
|
for root, dirs, files in os.walk(img_dir): |
|
for file in files: |
|
if "_controls" in root: |
|
continue |
|
if file.startswith('.'): |
|
continue |
|
if file.lower().endswith(tuple(img_ext_list)): |
|
img_list.append(os.path.join(root, file)) |
|
if len(img_list) == 0: |
|
print(f"Error: no images found in {img_dir}") |
|
exit() |
|
|
|
|
|
idx = 0 |
|
for img_path in tqdm(img_list): |
|
for control in controls: |
|
start = time.time() |
|
control_gen = ControlGenerator(torch.device('cuda')) |
|
control_gen.debug = args.debug |
|
control_gen.regen = args.regen |
|
control_path = control_gen.get_control_path(img_path, control) |
|
end = time.time() |
|
|
|
if idx < 2: |
|
continue |
|
control_times[control] += end - start |
|
idx += 1 |
|
|
|
|
|
for control in controls: |
|
control_times[control] /= (idx - 2) |
|
print( |
|
f"Avg time for {control} control: {control_times[control]:.2f} seconds") |
|
|
|
print("Done") |
|
|