#!/usr/bin/env python3 import argparse import logging from dataclasses import dataclass from os import PathLike from pathlib import Path from typing import Generator, Optional, Tuple import numpy as np import onnxruntime as rt from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from pandas import DataFrame, read_csv from PIL import Image from torch.utils.data import DataLoader, Dataset from tqdm import tqdm # allowed extensions IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] # image input shape IMAGE_SIZE = 448 MODEL_VARIANTS: dict[str, str] = { "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", "convnext": "SmilingWolf/wd-convnext-tagger-v3", "vit": "SmilingWolf/wd-vit-tagger-v3", } @dataclass class LabelData: names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] @dataclass class ImageLabels: caption: str booru: str rating: str general: dict[str, float] character: dict[str, float] ratings: dict[str, float] logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger() logger.setLevel(logging.INFO) ## Model loading functions def download_onnx( repo_id: str, filename: str = "model.onnx", revision: Optional[str] = None, token: Optional[str] = None, ) -> Path: if not filename.endswith(".onnx"): filename += ".onnx" model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) return Path(model_path).resolve() def create_session( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> rt.InferenceSession: model_path = download_onnx(repo_id, revision=revision, token=token) if not model_path.is_file(): model_path = model_path.joinpath("model.onnx") if not model_path.is_file(): raise FileNotFoundError(f"Model not found: {model_path}") model = rt.InferenceSession( str(model_path), providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"], ) return model ## Label loading function def load_labels_hf( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> LabelData: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token ) csv_path = Path(csv_path).resolve() except HfHubHTTPError as e: raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e df: DataFrame = read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data ## Image preprocessing functions def pil_ensure_rgb(image: Image.Image) -> Image.Image: # convert to RGB/RGBA if not already (deals with palette images etc.) if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") # convert RGBA to RGB with white background if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square( image: Image.Image, fill: tuple[int, int, int] = (255, 255, 255), ) -> Image.Image: w, h = image.size # get the largest dimension so we can pad to a square px = max(image.size) # pad to square with white background canvas = Image.new("RGB", (px, px), fill) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas def preprocess_image( image: Image.Image, size_px: int | tuple[int, int], upscale: bool = True, ) -> Image.Image: """ Preprocess an image to be square and centered on a white background. """ if isinstance(size_px, int): size_px = (size_px, size_px) # ensure RGB and pad to square image = pil_ensure_rgb(image) image = pil_pad_square(image) # resize to target size if image.size[0] < size_px[0] or image.size[1] < size_px[1]: if upscale is False: raise ValueError("Image is smaller than target size, and upscaling is disabled") image = image.resize(size_px, Image.LANCZOS) if image.size[0] > size_px[0] or image.size[1] > size_px[1]: image.thumbnail(size_px, Image.BICUBIC) return image ## Dataset for DataLoader class ImageDataset(Dataset): def __init__(self, image_paths: list[Path], size_px: int = IMAGE_SIZE, upscale: bool = True): self.size_px = size_px self.upscale = upscale self.images = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS] def __len__(self): return len(self.images) def __getitem__(self, idx): image_path: Path = self.images[idx] try: image = Image.open(image_path) image = preprocess_image(image, self.size_px, self.upscale) # turn into BGR24 numpy array of N,H,W,C since thats what these want image = image.convert("RGB").convert("BGR;24") image = np.array(image).astype(np.float32) except Exception as e: logging.exception(f"Could not load image from {image_path}, error: {e}") return None return {"image": image, "path": np.array(str(image_path).encode("utf-8"), dtype=np.bytes_)} def collate_fn_remove_corrupted(batch): """Collate function that allows to remove corrupted examples in the dataloader. It expects that the dataloader returns 'None' when that occurs. The 'None's in the batch are removed. """ # Filter out all the Nones (corrupted examples) batch = [x for x in batch if x is not None] if len(batch) == 0: return None return {k: np.array([x[k] for x in batch if x is not None]) for k in batch[0]} ## Main function class ImageLabeler: def __init__( self, repo_id: Optional[PathLike] = None, general_threshold: float = 0.35, character_threshold: float = 0.35, banned_tags: list[str] = [], ): self.repo_id = repo_id # create some object attributes for convenience self.general_threshold = general_threshold self.character_threshold = character_threshold self.banned_tags = banned_tags if banned_tags is not None else [] # actually load the model logging.info(f"Loading model from path: {self.repo_id}") self.model = create_session(self.repo_id) # Get input dimensions _, self.height, self.width, _ = self.model.get_inputs()[0].shape logging.info(f"Model loaded, input dimensions {self.height}x{self.width}") # load labels self.labels = load_labels_hf(self.repo_id) self.labels.general = [i for i in self.labels.general if i not in banned_tags] self.labels.character = [i for i in self.labels.character if i not in banned_tags] logging.info(f"Loaded labels from {self.repo_id}") @property def input_size(self) -> Tuple[int, int]: return (self.height, self.width) @property def input_name(self) -> str: return self.model.get_inputs()[0].name if self.model is not None else None @property def output_name(self) -> str: return self.model.get_outputs()[0].name if self.model is not None else None def label_images(self, images: np.ndarray) -> ImageLabels: # Run the ONNX model probs: np.ndarray = self.model.run([self.output_name], {self.input_name: images})[0] # Convert to labels results = [] for sample in list(probs): labels = list(zip(self.labels.names, sample.astype(float))) # First 4 labels are actually ratings: pick one with argmax rating_labels = dict([labels[i] for i in self.labels.rating]) rating = max(rating_labels, key=rating_labels.get) # General labels, pick any where prediction confidence > threshold gen_labels = [labels[i] for i in self.labels.general] gen_labels = dict([x for x in gen_labels if x[1] > self.general_threshold]) gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) # Character labels, pick any where prediction confidence > threshold char_labels = [labels[i] for i in self.labels.character] char_labels = dict([x for x in char_labels if x[1] > self.character_threshold]) char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) # Combine general and character labels, sort by confidence combined_names = [x for x in gen_labels] combined_names.extend([x for x in char_labels]) # Convert to a string suitable for use as a training caption caption = ", ".join(combined_names) booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") # return output results.append( ImageLabels( caption=caption, booru=booru, rating=rating, general=gen_labels, character=char_labels, ratings=rating_labels, ) ) return results def __call__(self, images: list[Image.Image]) -> Generator[ImageLabels, None, None]: for x in images: yield self.label_images(x) def main(args): images_dir: Path = Path(args.images_dir).resolve() if not images_dir.is_dir(): raise FileNotFoundError(f"Directory not found: {images_dir}") variant: str = args.variant recursive: bool = args.recursive or False banned_tags: set[str] = set(args.banned_tags.split(",")) caption_extension: str = str(args.caption_extension).lower() print_freqs: bool = args.print_freqs or False num_workers: int = args.num_workers batch_size: int = args.batch_size remove_underscore: bool = args.remove_underscore or False general_threshold: float = args.general_threshold or args.thresh character_threshold: float = args.character_threshold or args.thresh debug: bool = args.debug or False # turn base model into a repo id and model path repo_id: str = MODEL_VARIANTS.get(variant, None) if repo_id is None: raise ValueError(f"Unknown base model '{variant}'") # instantiate the dataset print(f"Loading images from {images_dir}...", end=" ") if recursive is True: image_paths = [p for p in images_dir.rglob("**/*") if p.suffix.lower() in IMAGE_EXTENSIONS] else: image_paths = [p for p in images_dir.glob("*") if p.suffix.lower() in IMAGE_EXTENSIONS] n_images = len(image_paths) print(f"found {n_images} images to process, creating DataLoader...") # sort by filename if we have a small number of images if n_images < 10000: image_paths = sorted(image_paths, key=lambda x: x.stem) dataset = ImageDataset(image_paths) # Create the data loader dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, prefetch_factor=3, ) # Create the image labeler labeler: ImageLabeler = ImageLabeler( repo_id=repo_id, character_threshold=character_threshold, general_threshold=general_threshold, banned_tags=banned_tags, ) # object to save tag frequencies tag_freqs = {} # iterate for batch in tqdm(dataloader, ncols=100, unit="image", unit_scale=batch_size): images = batch["image"] paths = batch["path"] # label the images batch_labels = labeler.label_images(images) # save the labels for image_labels, image_path in zip(batch_labels, paths): if isinstance(image_path, (np.bytes_, bytes)): image_path = Path(image_path.decode("utf-8")) # save the labels caption = image_labels.caption if remove_underscore is True: caption = caption.replace("_", " ") Path(image_path).with_suffix(caption_extension).write_text(caption + "\n", encoding="utf-8") # save the tag frequencies if print_freqs is True: for tag in caption.split(", "): if tag in banned_tags: continue if tag not in tag_freqs: tag_freqs[tag] = 0 tag_freqs[tag] += 1 # debug if debug is True: print( f"{image_path}:" + f"\n Character tags: {image_labels.character}" + f"\n General tags: {image_labels.general}" ) if print_freqs: sorted_tags = sorted(tag_freqs.items(), key=lambda x: x[1], reverse=True) print("\nTag frequencies:") for tag, freq in sorted_tags: print(f"{tag}: {freq}") print("done!") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "images_dir", type=str, help="directory to tag image files in", ) parser.add_argument( "--variant", type=str, default="swinv2", help="name of base model to use (one of 'swinv2', 'convnext', 'vit')", ) parser.add_argument( "--num_workers", type=int, default=4, help="number of threads to use in Torch DataLoader (4 should be plenty)", ) parser.add_argument( "--batch_size", type=int, default=1, help="batch size for Torch DataLoader (use 1 for cpu, 4-32 for gpu)", ) parser.add_argument( "--caption_extension", type=str, default=".txt", help="extension of caption files to write (e.g. '.txt', '.caption')", ) parser.add_argument( "--thresh", type=float, default=0.35, help="confidence threshold for adding tags", ) parser.add_argument( "--general_threshold", type=float, default=None, help="confidence threshold for general tags - defaults to --thresh", ) parser.add_argument( "--character_threshold", type=float, default=None, help="confidence threshold for character tags - defaults to --thresh", ) parser.add_argument( "--recursive", action="store_true", help="whether to recurse into subdirectories of images_dir", ) parser.add_argument( "--remove_underscore", action="store_true", help="whether to remove underscores from tags (e.g. 'long_hair' -> 'long hair')", ) parser.add_argument( "--debug", action="store_true", help="enable debug logging mode", ) parser.add_argument( "--banned_tags", type=str, default="", help="tags to filter out (comma-separated)", ) parser.add_argument( "--print_freqs", action="store_true", help="Print overall tag frequencies at the end", ) args = parser.parse_args() if args.images_dir is None: args.images_dir = Path.cwd().joinpath("temp/test") main(args)