Spaces:
mikitona
/
Running on Zero

test3 / SUPIR /util.py
mikitona's picture
Update SUPIR/util.py
595ed80 verified
raw
history blame
7.9 kB
import os
import torch
import numpy as np
import cv2
from PIL import Image
from torch.nn.functional import interpolate
from omegaconf import OmegaConf
from sgm.util import instantiate_from_config
from huggingface_hub import hf_hub_download
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
# Hugging Faceのリポジトリからロードするかチェック
if '/' in ckpt_path: # リポジトリ形式のパスと判断
parts = ckpt_path.split('/')
if len(parts) == 3:
repo_id = f"{parts[0]}/{parts[1]}"
filename = parts[2]
# ダウンロードの前にログを出力
print(f"Attempting to download from Hugging Face Hub with repo_id: {repo_id} and filename: {filename}")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
# ダウンロード後に実際のファイルパスを確認するためのログを出力
print(f"Downloaded file path for {filename}: {ckpt_path}")
else:
raise ValueError("Invalid format for Hugging Face path. Expected format 'username/repo/filename'.")
# safetensors形式でのロード
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
print(f'Loaded state_dict from [{ckpt_path}]')
return state_dict
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
return model
def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
if config.SDXL_CKPT is not None:
model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False)
if config.SUPIR_CKPT is not None:
model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False)
if SUPIR_sign is not None:
assert SUPIR_sign in ['F', 'Q']
if SUPIR_sign == 'F':
model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False)
elif SUPIR_sign == 'Q':
model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False)
if load_default_setting:
default_setting = config.default_setting
return model, default_setting
return model
def load_QF_ckpt(config_path):
config = OmegaConf.load(config_path)
# SUPIR_CKPT_F のダウンロード
if '/' in config.SUPIR_CKPT_F:
parts = config.SUPIR_CKPT_F.split('/')
if len(parts) == 3:
repo_id = f"{parts[0]}/{parts[1]}"
filename = parts[2]
print(f"Attempting to download SUPIR_CKPT_F from repo_id: {repo_id} and filename: {filename}")
ckpt_F_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded SUPIR_CKPT_F to: {ckpt_F_path}")
else:
raise ValueError("Invalid format for SUPIR_CKPT_F. Expected format 'username/repo/filename'.")
else:
ckpt_F_path = config.SUPIR_CKPT_F # ローカルパスの場合
# SUPIR_CKPT_Q のダウンロード
if '/' in config.SUPIR_CKPT_Q:
parts = config.SUPIR_CKPT_Q.split('/')
if len(parts) == 3:
repo_id = f"{parts[0]}/{parts[1]}"
filename = parts[2]
print(f"Attempting to download SUPIR_CKPT_Q from repo_id: {repo_id} and filename: {filename}")
ckpt_Q_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded SUPIR_CKPT_Q to: {ckpt_Q_path}")
else:
raise ValueError("Invalid format for SUPIR_CKPT_Q. Expected format 'username/repo/filename'.")
else:
ckpt_Q_path = config.SUPIR_CKPT_Q # ローカルパスの場合
# ダウンロードしたパスからロード
ckpt_F = torch.load(ckpt_F_path, map_location='cpu')
ckpt_Q = torch.load(ckpt_Q_path, map_location='cpu')
return ckpt_Q, ckpt_F
def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
'''
PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
'''
# size
w, h = img.size
w *= upsacle
h *= upsacle
w0, h0 = round(w), round(h)
if min(w, h) < min_size:
_upsacle = min_size / min(w, h)
w *= _upsacle
h *= _upsacle
if fix_resize is not None:
_upsacle = fix_resize / min(w, h)
w *= _upsacle
h *= _upsacle
w0, h0 = round(w), round(h)
w = int(np.round(w / 64.0)) * 64
h = int(np.round(h / 64.0)) * 64
x = img.resize((w, h), Image.BICUBIC)
x = np.array(x).round().clip(0, 255).astype(np.uint8)
x = x / 255 * 2 - 1
x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
return x, h0, w0
def Tensor2PIL(x, h0, w0):
'''
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
'''
x = x.unsqueeze(0)
x = interpolate(x, size=(h0, w0), mode='bicubic')
x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
return Image.fromarray(x)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
H, W, C = input_image.shape
H = float(H)
W = float(W)
H *= upscale
W *= upscale
if min_size is not None:
if min(H, W) < min_size:
_upsacle = min_size / min(W, H)
W *= _upsacle
H *= _upsacle
H = int(np.round(H / unit_resolution)) * unit_resolution
W = int(np.round(W / unit_resolution)) * unit_resolution
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
img = img.round().clip(0, 255).astype(np.uint8)
return img
def fix_resize(input_image, size=512, unit_resolution=64):
H, W, C = input_image.shape
H = float(H)
W = float(W)
upscale = size / min(H, W)
H *= upscale
W *= upscale
H = int(np.round(H / unit_resolution)) * unit_resolution
W = int(np.round(W / unit_resolution)) * unit_resolution
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
img = img.round().clip(0, 255).astype(np.uint8)
return img
def Numpy2Tensor(img):
'''
np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
'''
# size
img = np.array(img) / 255 * 2 - 1
img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
return img
def Tensor2Numpy(x, h0=None, w0=None):
'''
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
'''
if h0 is not None and w0 is not None:
x = x.unsqueeze(0)
x = interpolate(x, size=(h0, w0), mode='bicubic')
x = x.squeeze(0)
x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
return x
def convert_dtype(dtype_str):
if dtype_str == 'fp32':
return torch.float32
elif dtype_str == 'fp16':
return torch.float16
elif dtype_str == 'bf16':
return torch.bfloat16
else:
raise NotImplementedError