File size: 5,455 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import numpy as np
from PIL import Image

output_directory = "./_internal/output"


def get_output_directory() -> str:
    """#### Get the output directory.



    #### Returns:

        - `str`: The output directory.

    """
    global output_directory
    return output_directory


def get_save_image_path(

    filename_prefix: str, output_dir: str, image_width: int = 0, image_height: int = 0

) -> tuple:
    """#### Get the save image path.



    #### Args:

        - `filename_prefix` (str): The filename prefix.

        - `output_dir` (str): The output directory.

        - `image_width` (int, optional): The image width. Defaults to 0.

        - `image_height` (int, optional): The image height. Defaults to 0.



    #### Returns:

        - `tuple`: The full output folder, filename, counter, subfolder, and filename prefix.

    """

    def map_filename(filename: str) -> tuple:
        prefix_len = len(os.path.basename(filename_prefix))
        prefix = filename[: prefix_len + 1]
        try:
            digits = int(filename[prefix_len + 1 :].split("_")[0])
        except:
            digits = 0
        return (digits, prefix)

    def compute_vars(input: str, image_width: int, image_height: int) -> str:
        input = input.replace("%width%", str(image_width))
        input = input.replace("%height%", str(image_height))
        return input

    filename_prefix = compute_vars(filename_prefix, image_width, image_height)

    subfolder = os.path.dirname(os.path.normpath(filename_prefix))
    filename = os.path.basename(os.path.normpath(filename_prefix))

    full_output_folder = os.path.join(output_dir, subfolder)
    subfolder_paths = [
        os.path.join(full_output_folder, x) 
        for x in ["Classic", "HiresFix", "Img2Img", "Flux", "Adetailer"]
    ]
    for path in subfolder_paths:
        os.makedirs(path, exist_ok=True)
    # Find highest counter across all subfolders
    counter = 1
    for path in subfolder_paths:
        if os.path.exists(path):
            files = os.listdir(path)
            if files:
                numbers = [
                    map_filename(f)[0] 
                    for f in files 
                    if f.startswith(filename) and f.endswith(".png")
                ]
                if numbers:
                    counter = max(max(numbers) + 1, counter)

    return full_output_folder, filename, counter, subfolder, filename_prefix


MAX_RESOLUTION = 16384


class SaveImage:
    """#### Class for saving images."""

    def __init__(self):
        """#### Initialize the SaveImage class."""
        self.output_dir = get_output_directory()
        self.type = "output"
        self.prefix_append = ""
        self.compress_level = 4

    def save_images(

        self,

        images: list,

        filename_prefix: str = "LD",

        prompt: str = None,

        extra_pnginfo: dict = None,

    ) -> dict:
        """#### Save images to the output directory.



        #### Args:

            - `images` (list): The list of images.

            - `filename_prefix` (str, optional): The filename prefix. Defaults to "LD".

            - `prompt` (str, optional): The prompt. Defaults to None.

            - `extra_pnginfo` (dict, optional): Additional PNG info. Defaults to None.



        #### Returns:

            - `dict`: The saved images information.

        """
        filename_prefix += self.prefix_append
        full_output_folder, filename, counter, subfolder, filename_prefix = (
            get_save_image_path(
                filename_prefix, self.output_dir, images[0].shape[-2], images[0].shape[-1]
            )
        )
        results = list()
        for batch_number, image in enumerate(images):
            # Ensure correct shape by squeezing extra dimensions
            i = 255.0 * image.cpu().numpy()
            i = np.squeeze(i)  # Remove extra dimensions
            
            # Ensure we have a valid 3D array (height, width, channels)
            if i.ndim == 4:
                i = i.reshape(-1, i.shape[-2], i.shape[-1])
            
            img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
            metadata = None

            filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
            file = f"{filename_with_batch_num}_{counter:05}_.png"
            if filename_prefix == "LD-HF":
                full_output_folder = os.path.join(full_output_folder, "HiresFix")
            elif filename_prefix == "LD-I2I":
                full_output_folder = os.path.join(full_output_folder, "Img2Img")
            elif filename_prefix == "LD-Flux":
                full_output_folder = os.path.join(full_output_folder, "Flux")
            elif filename_prefix == "LD-head" or filename_prefix == "LD-body":
                full_output_folder = os.path.join(full_output_folder, "Adetailer")
            else:
                full_output_folder = os.path.join(full_output_folder, "Classic")
            img.save(
                os.path.join(full_output_folder, file),
                pnginfo=metadata,
                compress_level=self.compress_level,
            )
            results.append(
                {"filename": file, "subfolder": subfolder, "type": self.type}
            )
            counter += 1

        return {"ui": {"images": results}}