Spaces:
Running
on
Zero
Running
on
Zero
import logging as logger | |
import torch | |
from PIL import Image | |
from modules.Device import Device | |
from modules.UltimateSDUpscale import RDRB | |
from modules.UltimateSDUpscale import image_util | |
from modules.Utilities import util | |
def load_state_dict(state_dict: dict) -> RDRB.PyTorchModel: | |
"""#### Load a state dictionary into a PyTorch model. | |
#### Args: | |
- `state_dict` (dict): The state dictionary. | |
#### Returns: | |
- `RDRB.PyTorchModel`: The loaded PyTorch model. | |
""" | |
logger.debug("Loading state dict into pytorch model arch") | |
state_dict_keys = list(state_dict.keys()) | |
if "params_ema" in state_dict_keys: | |
state_dict = state_dict["params_ema"] | |
model = RDRB.RRDBNet(state_dict) | |
return model | |
class UpscaleModelLoader: | |
"""#### Class for loading upscale models.""" | |
def load_model(self, model_name: str) -> tuple: | |
"""#### Load an upscale model. | |
#### Args: | |
- `model_name` (str): The name of the model. | |
#### Returns: | |
- `tuple`: The loaded model. | |
""" | |
model_path = f"./_internal/ESRGAN/{model_name}" | |
sd = util.load_torch_file(model_path, safe_load=True) | |
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: | |
sd = util.state_dict_prefix_replace(sd, {"module.": ""}) | |
out = load_state_dict(sd).eval() | |
return (out,) | |
class ImageUpscaleWithModel: | |
"""#### Class for upscaling images with a model.""" | |
def upscale(self, upscale_model: torch.nn.Module, image: torch.Tensor) -> tuple: | |
"""#### Upscale an image using a model. | |
#### Args: | |
- `upscale_model` (torch.nn.Module): The upscale model. | |
- `image` (torch.Tensor): The input image tensor. | |
#### Returns: | |
- `tuple`: The upscaled image tensor. | |
""" | |
if torch.cuda.is_available(): | |
device = torch.device(torch.cuda.current_device()) | |
else: | |
device = torch.device("cpu") | |
upscale_model.to(device) | |
in_img = image.movedim(-1, -3).to(device) | |
Device.get_free_memory(device) | |
tile = 512 | |
overlap = 32 | |
oom = True | |
while oom: | |
steps = in_img.shape[0] * image_util.get_tiled_scale_steps( | |
in_img.shape[3], | |
in_img.shape[2], | |
tile_x=tile, | |
tile_y=tile, | |
overlap=overlap, | |
) | |
pbar = util.ProgressBar(steps) | |
s = image_util.tiled_scale( | |
in_img, | |
lambda a: upscale_model(a), | |
tile_x=tile, | |
tile_y=tile, | |
overlap=overlap, | |
upscale_amount=upscale_model.scale, | |
pbar=pbar, | |
) | |
oom = False | |
upscale_model.cpu() | |
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) | |
return (s,) | |
def torch_gc() -> None: | |
"""#### Perform garbage collection for PyTorch.""" | |
pass | |
class Script: | |
"""#### Class representing a script.""" | |
pass | |
class Options: | |
"""#### Class representing options.""" | |
img2img_background_color: str = "#ffffff" # Set to white for now | |
class State: | |
"""#### Class representing the state.""" | |
interrupted: bool = False | |
def begin(self) -> None: | |
"""#### Begin the state.""" | |
pass | |
def end(self) -> None: | |
"""#### End the state.""" | |
pass | |
opts = Options() | |
state = State() | |
# Will only ever hold 1 upscaler | |
sd_upscalers = [None] | |
actual_upscaler = None | |
# Batch of images to upscale | |
batch = None | |
if not hasattr(Image, "Resampling"): # For older versions of Pillow | |
Image.Resampling = Image | |
class Upscaler: | |
"""#### Class for upscaling images.""" | |
def _upscale(self, img: Image.Image, scale: float) -> Image.Image: | |
"""#### Upscale an image. | |
#### Args: | |
- `img` (Image.Image): The input image. | |
- `scale` (float): The scale factor. | |
#### Returns: | |
- `Image.Image`: The upscaled image. | |
""" | |
global actual_upscaler | |
tensor = image_util.pil_to_tensor(img) | |
image_upscale_node = ImageUpscaleWithModel() | |
(upscaled,) = image_upscale_node.upscale(actual_upscaler, tensor) | |
return image_util.tensor_to_pil(upscaled) | |
def upscale(self, img: Image.Image, scale: float, selected_model: str = None) -> Image.Image: | |
"""#### Upscale an image with a selected model. | |
#### Args: | |
- `img` (Image.Image): The input image. | |
- `scale` (float): The scale factor. | |
- `selected_model` (str, optional): The selected model. Defaults to None. | |
#### Returns: | |
- `Image.Image`: The upscaled image. | |
""" | |
global batch | |
batch = [self._upscale(img, scale) for img in batch] | |
return batch[0] | |
class UpscalerData: | |
"""#### Class for storing upscaler data.""" | |
name: str = "" | |
data_path: str = "" | |
def __init__(self): | |
self.scaler = Upscaler() |