xinjie.wang
update
4811e40
import logging
import os
from typing import Union
import numpy as np
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from asset3d_gen.data.utils import get_images_from_grid
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"ImageStableSR",
"ImageRealESRGAN",
]
class ImageStableSR:
def __init__(
self,
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
device="cuda",
) -> None:
from diffusers import StableDiffusionUpscalePipeline
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
).to(device)
self.up_pipeline_x4.set_progress_bar_config(disable=True)
# self.up_pipeline_x4.enable_model_cpu_offload()
def __call__(
self,
image: Union[Image.Image, np.ndarray],
prompt: str = "",
infer_step: int = 20,
) -> Image.Image:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
with torch.no_grad():
upscaled_image = self.up_pipeline_x4(
image=image,
prompt=[prompt],
num_inference_steps=infer_step,
).images[0]
return upscaled_image
class ImageRealESRGAN:
def __init__(self, outscale: int, model_path: str = None) -> None:
# monkey_patch
import torchvision
from packaging import version
if version.parse(torchvision.__version__) > version.parse("0.16"):
import sys
import types
import torchvision.transforms.functional as TF
functional_tensor = types.ModuleType(
"torchvision.transforms.functional_tensor"
)
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = (
functional_tensor
)
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
self.outscale = outscale
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if model_path is None:
suffix = "super_resolution"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
model_path = os.path.join(
model_path, suffix, "RealESRGAN_x4plus.pth"
)
self.upsampler = RealESRGANer(
scale=4,
model_path=model_path,
model=model,
pre_pad=0,
half=True,
)
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
if isinstance(image, Image.Image):
image = np.array(image)
with torch.no_grad():
output, _ = self.upsampler.enhance(image, outscale=self.outscale)
return Image.fromarray(output)
if __name__ == "__main__":
color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
# model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth" # noqa
super_model = ImageRealESRGAN(outscale=4)
multiviews = get_images_from_grid(color_path, img_size=512)
multiviews = [super_model(img.convert("RGB")) for img in multiviews]
for idx, img in enumerate(multiviews):
img.save(f"sr{idx}.png")
# # Use stable diffusion for x4 (512->2048) image super resolution.
# super_model = ImageStableSR()
# multiviews = get_images_from_grid(color_path, img_size=512)
# multiviews = [super_model(img) for img in multiviews]
# for idx, img in enumerate(multiviews):
# img.save(f"sr_stable{idx}.png")