Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,239 Bytes
d9a2e19 1d117d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import logging as logger
import torch
from PIL import Image
from modules.Device import Device
from modules.UltimateSDUpscale import RDRB
from modules.UltimateSDUpscale import image_util
from modules.Utilities import util
def load_state_dict(state_dict: dict) -> RDRB.PyTorchModel:
"""#### Load a state dictionary into a PyTorch model.
#### Args:
- `state_dict` (dict): The state dictionary.
#### Returns:
- `RDRB.PyTorchModel`: The loaded PyTorch model.
"""
logger.debug("Loading state dict into pytorch model arch")
state_dict_keys = list(state_dict.keys())
if "params_ema" in state_dict_keys:
state_dict = state_dict["params_ema"]
model = RDRB.RRDBNet(state_dict)
return model
class UpscaleModelLoader:
"""#### Class for loading upscale models."""
def load_model(self, model_name: str) -> tuple:
"""#### Load an upscale model.
#### Args:
- `model_name` (str): The name of the model.
#### Returns:
- `tuple`: The loaded model.
"""
model_path = f"./_internal/ESRGAN/{model_name}"
sd = util.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = util.state_dict_prefix_replace(sd, {"module.": ""})
out = load_state_dict(sd).eval()
return (out,)
class ImageUpscaleWithModel:
"""#### Class for upscaling images with a model."""
def upscale(self, upscale_model: torch.nn.Module, image: torch.Tensor) -> tuple:
"""#### Upscale an image using a model.
#### Args:
- `upscale_model` (torch.nn.Module): The upscale model.
- `image` (torch.Tensor): The input image tensor.
#### Returns:
- `tuple`: The upscaled image tensor.
"""
if torch.cuda.is_available():
device = torch.device(torch.cuda.current_device())
else:
device = torch.device("cpu")
upscale_model.to(device)
in_img = image.movedim(-1, -3).to(device)
Device.get_free_memory(device)
tile = 512
overlap = 32
oom = True
while oom:
steps = in_img.shape[0] * image_util.get_tiled_scale_steps(
in_img.shape[3],
in_img.shape[2],
tile_x=tile,
tile_y=tile,
overlap=overlap,
)
pbar = util.ProgressBar(steps)
s = image_util.tiled_scale(
in_img,
lambda a: upscale_model(a),
tile_x=tile,
tile_y=tile,
overlap=overlap,
upscale_amount=upscale_model.scale,
pbar=pbar,
)
oom = False
upscale_model.cpu()
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
return (s,)
def torch_gc() -> None:
"""#### Perform garbage collection for PyTorch."""
pass
class Script:
"""#### Class representing a script."""
pass
class Options:
"""#### Class representing options."""
img2img_background_color: str = "#ffffff" # Set to white for now
class State:
"""#### Class representing the state."""
interrupted: bool = False
def begin(self) -> None:
"""#### Begin the state."""
pass
def end(self) -> None:
"""#### End the state."""
pass
opts = Options()
state = State()
# Will only ever hold 1 upscaler
sd_upscalers = [None]
actual_upscaler = None
# Batch of images to upscale
batch = None
if not hasattr(Image, "Resampling"): # For older versions of Pillow
Image.Resampling = Image
class Upscaler:
"""#### Class for upscaling images."""
def _upscale(self, img: Image.Image, scale: float) -> Image.Image:
"""#### Upscale an image.
#### Args:
- `img` (Image.Image): The input image.
- `scale` (float): The scale factor.
#### Returns:
- `Image.Image`: The upscaled image.
"""
global actual_upscaler
tensor = image_util.pil_to_tensor(img)
image_upscale_node = ImageUpscaleWithModel()
(upscaled,) = image_upscale_node.upscale(actual_upscaler, tensor)
return image_util.tensor_to_pil(upscaled)
def upscale(self, img: Image.Image, scale: float, selected_model: str = None) -> Image.Image:
"""#### Upscale an image with a selected model.
#### Args:
- `img` (Image.Image): The input image.
- `scale` (float): The scale factor.
- `selected_model` (str, optional): The selected model. Defaults to None.
#### Returns:
- `Image.Image`: The upscaled image.
"""
global batch
batch = [self._upscale(img, scale) for img in batch]
return batch[0]
class UpscalerData:
"""#### Class for storing upscaler data."""
name: str = ""
data_path: str = ""
def __init__(self):
self.scaler = Upscaler() |