File size: 5,239 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d117d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import logging as logger
import torch
from PIL import Image

from modules.Device import Device
from modules.UltimateSDUpscale import RDRB
from modules.UltimateSDUpscale import image_util
from modules.Utilities import util


def load_state_dict(state_dict: dict) -> RDRB.PyTorchModel:
    """#### Load a state dictionary into a PyTorch model.



    #### Args:

        - `state_dict` (dict): The state dictionary.



    #### Returns:

        - `RDRB.PyTorchModel`: The loaded PyTorch model.

    """
    logger.debug("Loading state dict into pytorch model arch")
    state_dict_keys = list(state_dict.keys())
    if "params_ema" in state_dict_keys:
        state_dict = state_dict["params_ema"]
    model = RDRB.RRDBNet(state_dict)
    return model


class UpscaleModelLoader:
    """#### Class for loading upscale models."""

    def load_model(self, model_name: str) -> tuple:
        """#### Load an upscale model.



        #### Args:

            - `model_name` (str): The name of the model.



        #### Returns:

            - `tuple`: The loaded model.

        """
        model_path = f"./_internal/ESRGAN/{model_name}"
        sd = util.load_torch_file(model_path, safe_load=True)
        if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
            sd = util.state_dict_prefix_replace(sd, {"module.": ""})
        out = load_state_dict(sd).eval()
        return (out,)


class ImageUpscaleWithModel:
    """#### Class for upscaling images with a model."""

    def upscale(self, upscale_model: torch.nn.Module, image: torch.Tensor) -> tuple:
        """#### Upscale an image using a model.



        #### Args:

            - `upscale_model` (torch.nn.Module): The upscale model.

            - `image` (torch.Tensor): The input image tensor.



        #### Returns:

            - `tuple`: The upscaled image tensor.

        """
        if torch.cuda.is_available():
            device = torch.device(torch.cuda.current_device())
        else:
            device = torch.device("cpu")
        upscale_model.to(device)
        in_img = image.movedim(-1, -3).to(device)
        Device.get_free_memory(device)

        tile = 512
        overlap = 32

        oom = True
        while oom:
            steps = in_img.shape[0] * image_util.get_tiled_scale_steps(
                in_img.shape[3],
                in_img.shape[2],
                tile_x=tile,
                tile_y=tile,
                overlap=overlap,
            )
            pbar = util.ProgressBar(steps)
            s = image_util.tiled_scale(
                in_img,
                lambda a: upscale_model(a),
                tile_x=tile,
                tile_y=tile,
                overlap=overlap,
                upscale_amount=upscale_model.scale,
                pbar=pbar,
            )
            oom = False

        upscale_model.cpu()
        s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
        return (s,)


def torch_gc() -> None:
    """#### Perform garbage collection for PyTorch."""
    pass


class Script:
    """#### Class representing a script."""
    pass


class Options:
    """#### Class representing options."""

    img2img_background_color: str = "#ffffff"  # Set to white for now


class State:
    """#### Class representing the state."""

    interrupted: bool = False

    def begin(self) -> None:
        """#### Begin the state."""
        pass

    def end(self) -> None:
        """#### End the state."""
        pass


opts = Options()
state = State()

# Will only ever hold 1 upscaler
sd_upscalers = [None]
actual_upscaler = None

# Batch of images to upscale
batch = None


if not hasattr(Image, "Resampling"):  # For older versions of Pillow
    Image.Resampling = Image


class Upscaler:
    """#### Class for upscaling images."""

    def _upscale(self, img: Image.Image, scale: float) -> Image.Image:
        """#### Upscale an image.



        #### Args:

            - `img` (Image.Image): The input image.

            - `scale` (float): The scale factor.



        #### Returns:

            - `Image.Image`: The upscaled image.

        """
        global actual_upscaler
        tensor = image_util.pil_to_tensor(img)
        image_upscale_node = ImageUpscaleWithModel()
        (upscaled,) = image_upscale_node.upscale(actual_upscaler, tensor)
        return image_util.tensor_to_pil(upscaled)

    def upscale(self, img: Image.Image, scale: float, selected_model: str = None) -> Image.Image:
        """#### Upscale an image with a selected model.



        #### Args:

            - `img` (Image.Image): The input image.

            - `scale` (float): The scale factor.

            - `selected_model` (str, optional): The selected model. Defaults to None.



        #### Returns:

            - `Image.Image`: The upscaled image.

        """
        global batch
        batch = [self._upscale(img, scale) for img in batch]
        return batch[0]


class UpscalerData:
    """#### Class for storing upscaler data."""

    name: str = ""
    data_path: str = ""

    def __init__(self):
        self.scaler = Upscaler()