from __future__ import annotations import os import pathlib import shlex import shutil import subprocess import gradio as gr import PIL.Image import torch def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: w, h = image.size if w == h: return image elif w > h: new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) new_image.paste(image, (0, (w - h) // 2)) return new_image else: new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) new_image.paste(image, ((h - w) // 2, 0)) return new_image class Trainer: def __init__(self): self.is_running = False self.is_running_message = "Another training is in progress." self.output_dir = pathlib.Path("results") self.data_dir = pathlib.Path("data") self.ref_data_dir = self.data_dir / "ref" self.target_data_dir = self.data_dir / "target" def check_if_running(self) -> dict: if self.is_running: return gr.update(value=self.is_running_message) else: return gr.update(value="No training is running.") def cleanup_dirs(self) -> None: shutil.rmtree(self.output_dir, ignore_errors=True) def prepare_dataset( self, ref_images: list, target_image: PIL.Image, target_mask: PIL.Image, resolution: int ) -> None: self.ref_data_dir.mkdir(parents=True) self.target_data_dir.mkdir(parents=True) for i, temp_path in enumerate(ref_images): image = PIL.Image.open(temp_path.name) image = pad_image(image) image = image.resize((resolution, resolution)) image = image.convert("RGB") out_path = self.ref_data_dir / f"{i:03d}.jpg" image.save(out_path, format="JPEG", quality=100) target_image.save(self.target_data_dir / "target.jpg", format="JPEG", quality=100) target_mask.save(self.target_data_dir / "mask.jpg", format="JPEG", quality=100) def run( self, base_model: str, resolution_s: str, n_steps: int, ref_images: list | None, target_image: PIL.Image, target_mask: PIL.Image, unet_learning_rate: float, text_encoder_learning_rate: float, gradient_accumulation: int, fp16: bool, use_8bit_adam: bool, gradient_checkpointing: bool, lora_rank: int, lora_alpha: int, lora_bias: str, lora_dropout: float, ) -> tuple[dict, list[pathlib.Path]]: if not torch.cuda.is_available(): raise gr.Error("CUDA is not available.") if self.is_running: return gr.update(value=self.is_running_message), [] if ref_images is None: raise gr.Error("You need to upload reference images.") if target_image is None: raise gr.Error("You need to upload target image.") if target_mask is None: raise gr.Error("You need to upload target mask.") resolution = int(resolution_s) self.cleanup_dirs() self.prepare_dataset(ref_images, target_image, target_mask, resolution) command = f""" accelerate launch train_dreambooth.py \ --pretrained_model_name_or_path={base_model} \ --train_data_dir={self.data_dir} \ --output_dir={self.output_dir} \ --resolution={resolution} \ --gradient_accumulation_steps={gradient_accumulation} \ --unet_learning_rate={unet_learning_rate} \ --text_encoder_learning_rate={text_encoder_learning_rate} \ --max_train_steps={n_steps} \ --train_batch_size=16 \ --lr_scheduler=constant \ --lr_warmup_steps=100 \ --lora_r={lora_rank} \ --lora_alpha={lora_alpha} \ --lora_bias={lora_bias} \ --lora_dropout={lora_dropout} \ """ if fp16: command += " --mixed_precision fp16" if use_8bit_adam: command += " --use_8bit_adam" if gradient_checkpointing: command += " --gradient_checkpointing" with open(self.output_dir / "train.sh", "w") as f: command_s = " ".join(command.split()) f.write(command_s) self.is_running = True res = subprocess.run(shlex.split(command)) self.is_running = False if res.returncode == 0: result_message = "Training Completed!" else: result_message = "Training Failed!" model_paths = sorted(self.output_dir.glob("*")) return gr.update(value=result_message), model_paths