Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
from typing import Union | |
import numpy as np | |
import torch | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
from asset3d_gen.data.utils import get_images_from_grid | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
__all__ = [ | |
"ImageStableSR", | |
"ImageRealESRGAN", | |
] | |
class ImageStableSR: | |
def __init__( | |
self, | |
model_path: str = "stabilityai/stable-diffusion-x4-upscaler", | |
device="cuda", | |
) -> None: | |
from diffusers import StableDiffusionUpscalePipeline | |
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, | |
).to(device) | |
self.up_pipeline_x4.set_progress_bar_config(disable=True) | |
# self.up_pipeline_x4.enable_model_cpu_offload() | |
def __call__( | |
self, | |
image: Union[Image.Image, np.ndarray], | |
prompt: str = "", | |
infer_step: int = 20, | |
) -> Image.Image: | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image = image.convert("RGB") | |
with torch.no_grad(): | |
upscaled_image = self.up_pipeline_x4( | |
image=image, | |
prompt=[prompt], | |
num_inference_steps=infer_step, | |
).images[0] | |
return upscaled_image | |
class ImageRealESRGAN: | |
def __init__(self, outscale: int, model_path: str = None) -> None: | |
# monkey_patch | |
import torchvision | |
from packaging import version | |
if version.parse(torchvision.__version__) > version.parse("0.16"): | |
import sys | |
import types | |
import torchvision.transforms.functional as TF | |
functional_tensor = types.ModuleType( | |
"torchvision.transforms.functional_tensor" | |
) | |
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale | |
sys.modules["torchvision.transforms.functional_tensor"] = ( | |
functional_tensor | |
) | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from realesrgan import RealESRGANer | |
self.outscale = outscale | |
model = RRDBNet( | |
num_in_ch=3, | |
num_out_ch=3, | |
num_feat=64, | |
num_block=23, | |
num_grow_ch=32, | |
scale=4, | |
) | |
if model_path is None: | |
suffix = "super_resolution" | |
model_path = snapshot_download( | |
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
) | |
model_path = os.path.join( | |
model_path, suffix, "RealESRGAN_x4plus.pth" | |
) | |
self.upsampler = RealESRGANer( | |
scale=4, | |
model_path=model_path, | |
model=model, | |
pre_pad=0, | |
half=True, | |
) | |
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
with torch.no_grad(): | |
output, _ = self.upsampler.enhance(image, outscale=self.outscale) | |
return Image.fromarray(output) | |
if __name__ == "__main__": | |
color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png" | |
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution. | |
# model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth" # noqa | |
super_model = ImageRealESRGAN(outscale=4) | |
multiviews = get_images_from_grid(color_path, img_size=512) | |
multiviews = [super_model(img.convert("RGB")) for img in multiviews] | |
for idx, img in enumerate(multiviews): | |
img.save(f"sr{idx}.png") | |
# # Use stable diffusion for x4 (512->2048) image super resolution. | |
# super_model = ImageStableSR() | |
# multiviews = get_images_from_grid(color_path, img_size=512) | |
# multiviews = [super_model(img) for img in multiviews] | |
# for idx, img in enumerate(multiviews): | |
# img.save(f"sr_stable{idx}.png") | |