Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
import random | |
from typing import List, Tuple | |
import fire | |
import numpy as np | |
import torch | |
from diffusers.utils import make_image_grid | |
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import ( | |
StableDiffusionXLControlNetImg2ImgPipeline, | |
) | |
from PIL import Image, ImageEnhance, ImageFilter | |
from torchvision import transforms | |
from asset3d_gen.data.datasets import Asset3dGenDataset | |
from asset3d_gen.models.texture_model import build_texture_gen_pipe | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def get_init_noise_image(image: Image.Image) -> Image.Image: | |
blurred_image = image.convert("L").filter( | |
ImageFilter.GaussianBlur(radius=3) | |
) | |
enhancer = ImageEnhance.Contrast(blurred_image) | |
image_decreased_contrast = enhancer.enhance(factor=0.5) | |
return image_decreased_contrast | |
def infer_pipe( | |
index_file: str, | |
controlnet_ckpt: str = None, | |
uid: str = None, | |
prompt: str = None, | |
controlnet_cond_scale: float = 0.4, | |
control_guidance_end: float = 0.9, | |
strength: float = 1.0, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 10, | |
ip_adapt_scale: float = 0, | |
ip_img_path: str = None, | |
sub_idxs: List[List[int]] = None, | |
num_images_per_prompt: int = 3, # increase if want similar images. | |
device: str = "cuda", | |
save_dir: str = "infer_vis", | |
seed: int = None, | |
target_hw: tuple[int, int] = (512, 512), | |
pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None, | |
) -> str: | |
# sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image. | |
if sub_idxs is None: | |
sub_idxs = [[random.randint(0, 5)]] # 6 views. | |
target_hw = [2 * size for size in target_hw] | |
transform_list = [ | |
transforms.Resize( | |
target_hw, interpolation=transforms.InterpolationMode.BILINEAR | |
), | |
transforms.CenterCrop(target_hw), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
image_transform = transforms.Compose(transform_list) | |
control_transform = transforms.Compose(transform_list[:-1]) | |
grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0])) | |
dataset = Asset3dGenDataset( | |
index_file, target_hw=grid_hw, sub_idxs=sub_idxs | |
) | |
if uid is None: | |
uid = random.choice(list(dataset.meta_info.keys())) | |
if prompt is None: | |
prompt = dataset.meta_info[uid]["capture"] | |
if isinstance(prompt, List) or isinstance(prompt, Tuple): | |
prompt = ", ".join(map(str, prompt)) | |
# prompt += "high quality, ultra-clear, high resolution, best quality, 4k" | |
# prompt += "高品质,清晰,细节" | |
prompt += ", high quality, high resolution, best quality" | |
# prompt += ", with diffuse lighting, showing no reflections." | |
logger.info(f"Inference with prompt: {prompt}") | |
negative_prompt = ( | |
"nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯,高光,镜面反射" | |
) | |
control_image = dataset.fetch_sample_grid_images( | |
uid, | |
attrs=["image_view_normal", "image_position", "image_mask"], | |
sub_idxs=sub_idxs, | |
transform=control_transform, | |
) | |
color_image = dataset.fetch_sample_grid_images( | |
uid, | |
attrs=["image_color"], | |
sub_idxs=sub_idxs, | |
transform=image_transform, | |
) | |
normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item( | |
control_image, | |
color_image, | |
save_dir=save_dir, | |
) | |
if pipeline is None: | |
pipeline = build_texture_gen_pipe( | |
base_ckpt_dir="./weights", | |
controlnet_ckpt=controlnet_ckpt, | |
ip_adapt_scale=ip_adapt_scale, | |
device=device, | |
) | |
if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0: | |
ip_image = Image.open(ip_img_path).convert("RGB") | |
ip_image = ip_image.resize(target_hw[::-1]) | |
ip_image = [ip_image] | |
pipeline.set_ip_adapter_scale([ip_adapt_scale]) | |
else: | |
ip_image = None | |
generator = None | |
if seed is not None: | |
generator = torch.Generator(device).manual_seed(seed) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
init_image = get_init_noise_image(normal_pil) | |
# init_image = get_init_noise_image(color_pil) | |
images = [] | |
row_num, col_num = 2, 3 | |
img_save_paths = [] | |
while len(images) < col_num: | |
image = pipeline( | |
prompt=prompt, | |
image=init_image, | |
controlnet_conditioning_scale=controlnet_cond_scale, | |
control_guidance_end=control_guidance_end, | |
strength=strength, | |
control_image=control_image[None, ...], | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_images_per_prompt, | |
ip_adapter_image=ip_image, | |
generator=generator, | |
).images | |
images.extend(image) | |
grid_image = [normal_pil, position_pil, color_pil] + images[:col_num] | |
# save_dir = os.path.join(save_dir, uid) | |
os.makedirs(save_dir, exist_ok=True) | |
for idx in range(col_num): | |
rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil)) | |
img_save_path = os.path.join(save_dir, f"color_sample{idx}.png") | |
rgba_image.save(img_save_path) | |
img_save_paths.append(img_save_path) | |
sub_idxs = "_".join( | |
[str(item) for sublist in sub_idxs for item in sublist] | |
) | |
save_path = os.path.join( | |
save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg" | |
) | |
make_image_grid(grid_image, row_num, col_num).save(save_path) | |
logger.info(f"Visualize in {save_path}") | |
return img_save_paths | |
def entrypoint() -> None: | |
fire.Fire(infer_pipe) | |
if __name__ == "__main__": | |
entrypoint() | |