Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,118 Bytes
55ed985 4811e40 55ed985 10c708b 2a08301 10c708b 2a08301 10c708b 2a08301 10c708b 2a08301 55ed985 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import logging
import os
from typing import Union
import numpy as np
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from asset3d_gen.data.utils import get_images_from_grid
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"ImageStableSR",
"ImageRealESRGAN",
]
class ImageStableSR:
def __init__(
self,
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
device="cuda",
) -> None:
from diffusers import StableDiffusionUpscalePipeline
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
).to(device)
self.up_pipeline_x4.set_progress_bar_config(disable=True)
# self.up_pipeline_x4.enable_model_cpu_offload()
def __call__(
self,
image: Union[Image.Image, np.ndarray],
prompt: str = "",
infer_step: int = 20,
) -> Image.Image:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
with torch.no_grad():
upscaled_image = self.up_pipeline_x4(
image=image,
prompt=[prompt],
num_inference_steps=infer_step,
).images[0]
return upscaled_image
class ImageRealESRGAN:
def __init__(self, outscale: int, model_path: str = None) -> None:
# monkey_patch
import torchvision
from packaging import version
if version.parse(torchvision.__version__) > version.parse("0.16"):
import sys
import types
import torchvision.transforms.functional as TF
functional_tensor = types.ModuleType(
"torchvision.transforms.functional_tensor"
)
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = (
functional_tensor
)
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
self.outscale = outscale
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if model_path is None:
suffix = "super_resolution"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
model_path = os.path.join(
model_path, suffix, "RealESRGAN_x4plus.pth"
)
self.upsampler = RealESRGANer(
scale=4,
model_path=model_path,
model=model,
pre_pad=0,
half=True,
)
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
if isinstance(image, Image.Image):
image = np.array(image)
with torch.no_grad():
output, _ = self.upsampler.enhance(image, outscale=self.outscale)
return Image.fromarray(output)
if __name__ == "__main__":
color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
# model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth" # noqa
super_model = ImageRealESRGAN(outscale=4)
multiviews = get_images_from_grid(color_path, img_size=512)
multiviews = [super_model(img.convert("RGB")) for img in multiviews]
for idx, img in enumerate(multiviews):
img.save(f"sr{idx}.png")
# # Use stable diffusion for x4 (512->2048) image super resolution.
# super_model = ImageStableSR()
# multiviews = get_images_from_grid(color_path, img_size=512)
# multiviews = [super_model(img) for img in multiviews]
# for idx, img in enumerate(multiviews):
# img.save(f"sr_stable{idx}.png")
|