import logging as logger import torch from PIL import Image from modules.Device import Device from modules.UltimateSDUpscale import RDRB from modules.UltimateSDUpscale import image_util from modules.Utilities import util def load_state_dict(state_dict: dict) -> RDRB.PyTorchModel: """#### Load a state dictionary into a PyTorch model. #### Args: - `state_dict` (dict): The state dictionary. #### Returns: - `RDRB.PyTorchModel`: The loaded PyTorch model. """ logger.debug("Loading state dict into pytorch model arch") state_dict_keys = list(state_dict.keys()) if "params_ema" in state_dict_keys: state_dict = state_dict["params_ema"] model = RDRB.RRDBNet(state_dict) return model class UpscaleModelLoader: """#### Class for loading upscale models.""" def load_model(self, model_name: str) -> tuple: """#### Load an upscale model. #### Args: - `model_name` (str): The name of the model. #### Returns: - `tuple`: The loaded model. """ model_path = f"./_internal/ESRGAN/{model_name}" sd = util.load_torch_file(model_path, safe_load=True) if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: sd = util.state_dict_prefix_replace(sd, {"module.": ""}) out = load_state_dict(sd).eval() return (out,) class ImageUpscaleWithModel: """#### Class for upscaling images with a model.""" def upscale(self, upscale_model: torch.nn.Module, image: torch.Tensor) -> tuple: """#### Upscale an image using a model. #### Args: - `upscale_model` (torch.nn.Module): The upscale model. - `image` (torch.Tensor): The input image tensor. #### Returns: - `tuple`: The upscaled image tensor. """ if torch.cuda.is_available(): device = torch.device(torch.cuda.current_device()) else: device = torch.device("cpu") upscale_model.to(device) in_img = image.movedim(-1, -3).to(device) Device.get_free_memory(device) tile = 512 overlap = 32 oom = True while oom: steps = in_img.shape[0] * image_util.get_tiled_scale_steps( in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap, ) pbar = util.ProgressBar(steps) s = image_util.tiled_scale( in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, ) oom = False upscale_model.cpu() s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) return (s,) def torch_gc() -> None: """#### Perform garbage collection for PyTorch.""" pass class Script: """#### Class representing a script.""" pass class Options: """#### Class representing options.""" img2img_background_color: str = "#ffffff" # Set to white for now class State: """#### Class representing the state.""" interrupted: bool = False def begin(self) -> None: """#### Begin the state.""" pass def end(self) -> None: """#### End the state.""" pass opts = Options() state = State() # Will only ever hold 1 upscaler sd_upscalers = [None] actual_upscaler = None # Batch of images to upscale batch = None if not hasattr(Image, "Resampling"): # For older versions of Pillow Image.Resampling = Image class Upscaler: """#### Class for upscaling images.""" def _upscale(self, img: Image.Image, scale: float) -> Image.Image: """#### Upscale an image. #### Args: - `img` (Image.Image): The input image. - `scale` (float): The scale factor. #### Returns: - `Image.Image`: The upscaled image. """ global actual_upscaler tensor = image_util.pil_to_tensor(img) image_upscale_node = ImageUpscaleWithModel() (upscaled,) = image_upscale_node.upscale(actual_upscaler, tensor) return image_util.tensor_to_pil(upscaled) def upscale(self, img: Image.Image, scale: float, selected_model: str = None) -> Image.Image: """#### Upscale an image with a selected model. #### Args: - `img` (Image.Image): The input image. - `scale` (float): The scale factor. - `selected_model` (str, optional): The selected model. Defaults to None. #### Returns: - `Image.Image`: The upscaled image. """ global batch batch = [self._upscale(img, scale) for img in batch] return batch[0] class UpscalerData: """#### Class for storing upscaler data.""" name: str = "" data_path: str = "" def __init__(self): self.scaler = Upscaler()