File size: 3,005 Bytes
2b6048b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
from dataclasses import asdict, dataclass
from functools import lru_cache
from os import PathLike
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
from PIL import Image


class DictJsonMixin:
    def asdict(self, *args, **kwargs) -> dict[str, Any]:
        return asdict(self, *args, **kwargs)

    def asjson(self, *args, **kwargs):
        return json.dumps(asdict(self, *args, **kwargs))


@dataclass
class LabelData(DictJsonMixin):
    names: list[str]
    rating: list[np.int64]
    general: list[np.int64]
    character: list[np.int64]


@dataclass
class ImageLabels(DictJsonMixin):
    caption: str
    booru: str
    rating: dict[str, float]
    general: dict[str, float]
    character: dict[str, float]


@lru_cache(maxsize=5)
def load_labels(csv_path: PathLike = "data/selected_tags.csv") -> LabelData:
    csv_path = Path(csv_path).resolve()
    if not csv_path.is_file():
        raise FileNotFoundError("No selected_tags.csv found")

    df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
    tag_data = LabelData(
        names=df["name"].tolist(),
        rating=list(np.where(df["category"] == 9)[0]),
        general=list(np.where(df["category"] == 0)[0]),
        character=list(np.where(df["category"] == 4)[0]),
    )

    return tag_data


def pil_ensure_rgb(image: Image.Image) -> Image.Image:
    # convert to RGB/RGBA if not already (deals with palette images etc.)
    if image.mode not in ["RGB", "RGBA"]:
        image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
    # convert RGBA to RGB with white background
    if image.mode == "RGBA":
        canvas = Image.new("RGBA", image.size, (255, 255, 255))
        canvas.alpha_composite(image)
        image = canvas.convert("RGB")
    return image


def pil_pad_square(
    image: Image.Image,
    fill: tuple[int, int, int] = (255, 255, 255),
) -> Image.Image:
    w, h = image.size
    # get the largest dimension so we can pad to a square
    px = max(image.size)
    # pad to square with white background
    canvas = Image.new("RGB", (px, px), fill)
    canvas.paste(image, ((px - w) // 2, (px - h) // 2))
    return canvas


def preprocess_image(
    image: Image.Image,
    size_px: int | tuple[int, int],
    upscale: bool = True,
) -> Image.Image:
    """
    Preprocess an image to be square and centered on a white background.
    """
    if isinstance(size_px, int):
        size_px = (size_px, size_px)

    # ensure RGB and pad to square
    image = pil_ensure_rgb(image)
    image = pil_pad_square(image)

    # resize to target size
    if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
        if upscale is False:
            raise ValueError("Image is smaller than target size, and upscaling is disabled")
        image = image.resize(size_px, Image.LANCZOS)
    if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
        image.thumbnail(size_px, Image.BICUBIC)

    return image