Spaces:
Running
Running
| import os | |
| import weakref | |
| from _weakref import ReferenceType | |
| from typing import TYPE_CHECKING, List, Union | |
| import cv2 | |
| import torch | |
| import random | |
| from PIL import Image | |
| from PIL.ImageOps import exif_transpose | |
| from toolkit import image_utils | |
| from toolkit.basic import get_quick_signature_string | |
| from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ | |
| ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ | |
| UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin | |
| if TYPE_CHECKING: | |
| from toolkit.config_modules import DatasetConfig | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| printed_messages = [] | |
| def print_once(msg): | |
| global printed_messages | |
| if msg not in printed_messages: | |
| print(msg) | |
| printed_messages.append(msg) | |
| class FileItemDTO( | |
| LatentCachingFileItemDTOMixin, | |
| CaptionProcessingDTOMixin, | |
| ImageProcessingDTOMixin, | |
| ControlFileItemDTOMixin, | |
| InpaintControlFileItemDTOMixin, | |
| ClipImageFileItemDTOMixin, | |
| MaskFileItemDTOMixin, | |
| AugmentationFileItemDTOMixin, | |
| UnconditionalFileItemDTOMixin, | |
| PoiFileItemDTOMixin, | |
| ArgBreakMixin, | |
| ): | |
| def __init__(self, *args, **kwargs): | |
| self.path = kwargs.get('path', '') | |
| self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) | |
| self.is_video = self.dataset_config.num_frames > 1 | |
| size_database = kwargs.get('size_database', {}) | |
| dataset_root = kwargs.get('dataset_root', None) | |
| if dataset_root is not None: | |
| # remove dataset root from path | |
| file_key = self.path.replace(dataset_root, '') | |
| else: | |
| file_key = os.path.basename(self.path) | |
| file_signature = get_quick_signature_string(self.path) | |
| if file_signature is None: | |
| raise Exception("Error: Could not get file signature for {self.path}") | |
| use_db_entry = False | |
| if file_key in size_database: | |
| db_entry = size_database[file_key] | |
| if db_entry is not None and len(db_entry) >= 3 and db_entry[2] == file_signature: | |
| use_db_entry = True | |
| if use_db_entry: | |
| w, h, _ = size_database[file_key] | |
| elif self.is_video: | |
| # Open the video file | |
| video = cv2.VideoCapture(self.path) | |
| # Check if video opened successfully | |
| if not video.isOpened(): | |
| raise Exception(f"Error: Could not open video file {self.path}") | |
| # Get width and height | |
| width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| w, h = width, height | |
| # Release the video capture object immediately | |
| video.release() | |
| size_database[file_key] = (width, height, file_signature) | |
| else: | |
| # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now. | |
| # process width and height | |
| # try: | |
| # w, h = image_utils.get_image_size(self.path) | |
| # except image_utils.UnknownImageFormat: | |
| # print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ | |
| # f'This process is faster for png, jpeg') | |
| img = exif_transpose(Image.open(self.path)) | |
| w, h = img.size | |
| size_database[file_key] = (w, h, file_signature) | |
| self.width: int = w | |
| self.height: int = h | |
| self.dataloader_transforms = kwargs.get('dataloader_transforms', None) | |
| super().__init__(*args, **kwargs) | |
| # self.caption_path: str = kwargs.get('caption_path', None) | |
| self.raw_caption: str = kwargs.get('raw_caption', None) | |
| # we scale first, then crop | |
| self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale)) | |
| self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale)) | |
| # crop values are from scaled size | |
| self.crop_x: int = kwargs.get('crop_x', 0) | |
| self.crop_y: int = kwargs.get('crop_y', 0) | |
| self.crop_width: int = kwargs.get('crop_width', self.scale_to_width) | |
| self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) | |
| self.flip_x: bool = kwargs.get('flip_x', False) | |
| self.flip_y: bool = kwargs.get('flip_x', False) | |
| self.augments: List[str] = self.dataset_config.augments | |
| self.loss_multiplier: float = self.dataset_config.loss_multiplier | |
| self.network_weight: float = self.dataset_config.network_weight | |
| self.is_reg = self.dataset_config.is_reg | |
| self.tensor: Union[torch.Tensor, None] = None | |
| def cleanup(self): | |
| self.tensor = None | |
| self.cleanup_latent() | |
| self.cleanup_control() | |
| self.cleanup_inpaint() | |
| self.cleanup_clip_image() | |
| self.cleanup_mask() | |
| self.cleanup_unconditional() | |
| class DataLoaderBatchDTO: | |
| def __init__(self, **kwargs): | |
| try: | |
| self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) | |
| is_latents_cached = self.file_items[0].is_latent_cached | |
| self.tensor: Union[torch.Tensor, None] = None | |
| self.latents: Union[torch.Tensor, None] = None | |
| self.control_tensor: Union[torch.Tensor, None] = None | |
| self.clip_image_tensor: Union[torch.Tensor, None] = None | |
| self.mask_tensor: Union[torch.Tensor, None] = None | |
| self.unaugmented_tensor: Union[torch.Tensor, None] = None | |
| self.unconditional_tensor: Union[torch.Tensor, None] = None | |
| self.unconditional_latents: Union[torch.Tensor, None] = None | |
| self.clip_image_embeds: Union[List[dict], None] = None | |
| self.clip_image_embeds_unconditional: Union[List[dict], None] = None | |
| self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code | |
| self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None | |
| if not is_latents_cached: | |
| # only return a tensor if latents are not cached | |
| self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) | |
| # if we have encoded latents, we concatenate them | |
| self.latents: Union[torch.Tensor, None] = None | |
| if is_latents_cached: | |
| self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) | |
| self.control_tensor: Union[torch.Tensor, None] = None | |
| # if self.file_items[0].control_tensor is not None: | |
| # if any have a control tensor, we concatenate them | |
| if any([x.control_tensor is not None for x in self.file_items]): | |
| # find one to use as a base | |
| base_control_tensor = None | |
| for x in self.file_items: | |
| if x.control_tensor is not None: | |
| base_control_tensor = x.control_tensor | |
| break | |
| control_tensors = [] | |
| for x in self.file_items: | |
| if x.control_tensor is None: | |
| control_tensors.append(torch.zeros_like(base_control_tensor)) | |
| else: | |
| control_tensors.append(x.control_tensor) | |
| self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) | |
| self.inpaint_tensor: Union[torch.Tensor, None] = None | |
| if any([x.inpaint_tensor is not None for x in self.file_items]): | |
| # find one to use as a base | |
| base_inpaint_tensor = None | |
| for x in self.file_items: | |
| if x.inpaint_tensor is not None: | |
| base_inpaint_tensor = x.inpaint_tensor | |
| break | |
| inpaint_tensors = [] | |
| for x in self.file_items: | |
| if x.inpaint_tensor is None: | |
| inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor)) | |
| else: | |
| inpaint_tensors.append(x.inpaint_tensor) | |
| self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors]) | |
| self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items] | |
| if any([x.clip_image_tensor is not None for x in self.file_items]): | |
| # find one to use as a base | |
| base_clip_image_tensor = None | |
| for x in self.file_items: | |
| if x.clip_image_tensor is not None: | |
| base_clip_image_tensor = x.clip_image_tensor | |
| break | |
| clip_image_tensors = [] | |
| for x in self.file_items: | |
| if x.clip_image_tensor is None: | |
| clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor)) | |
| else: | |
| clip_image_tensors.append(x.clip_image_tensor) | |
| self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors]) | |
| if any([x.mask_tensor is not None for x in self.file_items]): | |
| # find one to use as a base | |
| base_mask_tensor = None | |
| for x in self.file_items: | |
| if x.mask_tensor is not None: | |
| base_mask_tensor = x.mask_tensor | |
| break | |
| mask_tensors = [] | |
| for x in self.file_items: | |
| if x.mask_tensor is None: | |
| mask_tensors.append(torch.zeros_like(base_mask_tensor)) | |
| else: | |
| mask_tensors.append(x.mask_tensor) | |
| self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors]) | |
| # add unaugmented tensors for ones with augments | |
| if any([x.unaugmented_tensor is not None for x in self.file_items]): | |
| # find one to use as a base | |
| base_unaugmented_tensor = None | |
| for x in self.file_items: | |
| if x.unaugmented_tensor is not None: | |
| base_unaugmented_tensor = x.unaugmented_tensor | |
| break | |
| unaugmented_tensor = [] | |
| for x in self.file_items: | |
| if x.unaugmented_tensor is None: | |
| unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor)) | |
| else: | |
| unaugmented_tensor.append(x.unaugmented_tensor) | |
| self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor]) | |
| # add unconditional tensors | |
| if any([x.unconditional_tensor is not None for x in self.file_items]): | |
| # find one to use as a base | |
| base_unconditional_tensor = None | |
| for x in self.file_items: | |
| if x.unaugmented_tensor is not None: | |
| base_unconditional_tensor = x.unconditional_tensor | |
| break | |
| unconditional_tensor = [] | |
| for x in self.file_items: | |
| if x.unconditional_tensor is None: | |
| unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor)) | |
| else: | |
| unconditional_tensor.append(x.unconditional_tensor) | |
| self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor]) | |
| if any([x.clip_image_embeds is not None for x in self.file_items]): | |
| self.clip_image_embeds = [] | |
| for x in self.file_items: | |
| if x.clip_image_embeds is not None: | |
| self.clip_image_embeds.append(x.clip_image_embeds) | |
| else: | |
| raise Exception("clip_image_embeds is None for some file items") | |
| if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]): | |
| self.clip_image_embeds_unconditional = [] | |
| for x in self.file_items: | |
| if x.clip_image_embeds_unconditional is not None: | |
| self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional) | |
| else: | |
| raise Exception("clip_image_embeds_unconditional is None for some file items") | |
| except Exception as e: | |
| print(e) | |
| raise e | |
| def get_is_reg_list(self): | |
| return [x.is_reg for x in self.file_items] | |
| def get_network_weight_list(self): | |
| return [x.network_weight for x in self.file_items] | |
| def get_caption_list( | |
| self, | |
| trigger=None, | |
| to_replace_list=None, | |
| add_if_not_present=True | |
| ): | |
| return [x.caption for x in self.file_items] | |
| def get_caption_short_list( | |
| self, | |
| trigger=None, | |
| to_replace_list=None, | |
| add_if_not_present=True | |
| ): | |
| return [x.caption_short for x in self.file_items] | |
| def cleanup(self): | |
| del self.latents | |
| del self.tensor | |
| del self.control_tensor | |
| for file_item in self.file_items: | |
| file_item.cleanup() | |