Spaces:
Running
Running
| import os | |
| from typing import TYPE_CHECKING, List, Optional | |
| import einops | |
| import torch | |
| import torchvision | |
| import yaml | |
| from toolkit import train_tools | |
| from toolkit.config_modules import GenerateImageConfig, ModelConfig | |
| from PIL import Image | |
| from toolkit.models.base_model import BaseModel | |
| from diffusers import AutoencoderKL, TorchAoConfig | |
| from toolkit.basic import flush | |
| from toolkit.prompt_utils import PromptEmbeds | |
| from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler | |
| from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance | |
| from toolkit.dequantize import patch_dequantization_on_save | |
| from toolkit.accelerator import get_accelerator, unwrap_model | |
| from optimum.quanto import freeze, QTensor | |
| from toolkit.util.mask import generate_random_mask, random_dialate_mask | |
| from toolkit.util.quantize import quantize, get_qtype | |
| from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer, TorchAoConfig as TorchAoConfigTransformers | |
| from .src.pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline | |
| from .src.models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel | |
| from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
| from transformers import LlamaForCausalLM, PreTrainedTokenizerFast | |
| from einops import rearrange, repeat | |
| import random | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from transformers import ( | |
| CLIPTextModelWithProjection, | |
| CLIPTokenizer, | |
| T5EncoderModel, | |
| T5Tokenizer, | |
| LlamaForCausalLM, | |
| PreTrainedTokenizerFast | |
| ) | |
| if TYPE_CHECKING: | |
| from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
| scheduler_config = { | |
| "num_train_timesteps": 1000, | |
| "shift": 3.0 | |
| } | |
| # LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| LLAMA_MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-Instruct" | |
| BASE_MODEL_PATH = "HiDream-ai/HiDream-I1-Full" | |
| class HidreamModel(BaseModel): | |
| arch = "hidream" | |
| def __init__( | |
| self, | |
| device, | |
| model_config: ModelConfig, | |
| dtype='bf16', | |
| custom_pipeline=None, | |
| noise_scheduler=None, | |
| **kwargs | |
| ): | |
| super().__init__( | |
| device, | |
| model_config, | |
| dtype, | |
| custom_pipeline, | |
| noise_scheduler, | |
| **kwargs | |
| ) | |
| self.is_flow_matching = True | |
| self.is_transformer = True | |
| self.target_lora_modules = ['HiDreamImageTransformer2DModel'] | |
| # static method to get the noise scheduler | |
| def get_train_scheduler(): | |
| return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) | |
| def get_bucket_divisibility(self): | |
| return 16 | |
| def load_model(self): | |
| dtype = self.torch_dtype | |
| # HiDream-ai/HiDream-I1-Full | |
| self.print_and_status_update("Loading HiDream model") | |
| # will be updated if we detect a existing checkpoint in training folder | |
| model_path = self.model_config.name_or_path | |
| extras_path = self.model_config.extras_name_or_path | |
| llama_model_path = self.model_config.model_kwargs.get('llama_model_path', LLAMA_MODEL_PATH) | |
| scheduler = HidreamModel.get_train_scheduler() | |
| self.print_and_status_update("Loading llama 8b model") | |
| tokenizer_4 = PreTrainedTokenizerFast.from_pretrained( | |
| llama_model_path, | |
| use_fast=False | |
| ) | |
| text_encoder_4 = LlamaForCausalLM.from_pretrained( | |
| llama_model_path, | |
| output_hidden_states=True, | |
| output_attentions=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| text_encoder_4.to(self.device_torch, dtype=dtype) | |
| if self.model_config.quantize_te: | |
| self.print_and_status_update("Quantizing llama 8b model") | |
| quantization_type = get_qtype(self.model_config.qtype_te) | |
| quantize(text_encoder_4, weights=quantization_type) | |
| freeze(text_encoder_4) | |
| if self.low_vram: | |
| # unload it for now | |
| text_encoder_4.to('cpu') | |
| flush() | |
| self.print_and_status_update("Loading transformer") | |
| transformer = HiDreamImageTransformer2DModel.from_pretrained( | |
| model_path, | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| if not self.low_vram: | |
| transformer.to(self.device_torch, dtype=dtype) | |
| if self.model_config.quantize: | |
| self.print_and_status_update("Quantizing transformer") | |
| quantization_type = get_qtype(self.model_config.qtype) | |
| if self.low_vram: | |
| # move and quantize only certain pieces at a time. | |
| all_blocks = list(transformer.double_stream_blocks) + list(transformer.single_stream_blocks) | |
| self.print_and_status_update(" - quantizing transformer blocks") | |
| for block in tqdm(all_blocks): | |
| block.to(self.device_torch, dtype=dtype) | |
| quantize(block, weights=quantization_type) | |
| freeze(block) | |
| block.to('cpu') | |
| # flush() | |
| self.print_and_status_update(" - quantizing extras") | |
| transformer.to(self.device_torch, dtype=dtype) | |
| quantize(transformer, weights=quantization_type) | |
| freeze(transformer) | |
| else: | |
| quantize(transformer, weights=quantization_type) | |
| freeze(transformer) | |
| if self.low_vram: | |
| # unload it for now | |
| transformer.to('cpu') | |
| flush() | |
| self.print_and_status_update("Loading vae") | |
| vae = AutoencoderKL.from_pretrained( | |
| extras_path, | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16 | |
| ).to(self.device_torch, dtype=dtype) | |
| self.print_and_status_update("Loading clip encoders") | |
| text_encoder = CLIPTextModelWithProjection.from_pretrained( | |
| extras_path, | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16 | |
| ).to(self.device_torch, dtype=dtype) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| extras_path, | |
| subfolder="tokenizer" | |
| ) | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| extras_path, | |
| subfolder="text_encoder_2", | |
| torch_dtype=torch.bfloat16 | |
| ).to(self.device_torch, dtype=dtype) | |
| tokenizer_2 = CLIPTokenizer.from_pretrained( | |
| extras_path, | |
| subfolder="tokenizer_2" | |
| ) | |
| flush() | |
| self.print_and_status_update("Loading T5 encoders") | |
| text_encoder_3 = T5EncoderModel.from_pretrained( | |
| extras_path, | |
| subfolder="text_encoder_3", | |
| torch_dtype=torch.bfloat16 | |
| ).to(self.device_torch, dtype=dtype) | |
| if self.model_config.quantize_te: | |
| self.print_and_status_update("Quantizing T5") | |
| quantization_type = get_qtype(self.model_config.qtype_te) | |
| quantize(text_encoder_3, weights=quantization_type) | |
| freeze(text_encoder_3) | |
| flush() | |
| tokenizer_3 = T5Tokenizer.from_pretrained( | |
| extras_path, | |
| subfolder="tokenizer_3" | |
| ) | |
| flush() | |
| if self.low_vram: | |
| self.print_and_status_update("Moving ecerything to device") | |
| # move it all back | |
| transformer.to(self.device_torch, dtype=dtype) | |
| vae.to(self.device_torch, dtype=dtype) | |
| text_encoder.to(self.device_torch, dtype=dtype) | |
| text_encoder_2.to(self.device_torch, dtype=dtype) | |
| text_encoder_4.to(self.device_torch, dtype=dtype) | |
| text_encoder_3.to(self.device_torch, dtype=dtype) | |
| # set to eval mode | |
| # transformer.eval() | |
| vae.eval() | |
| text_encoder.eval() | |
| text_encoder_2.eval() | |
| text_encoder_4.eval() | |
| text_encoder_3.eval() | |
| pipe = HiDreamImagePipeline( | |
| scheduler=scheduler, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| text_encoder_2=text_encoder_2, | |
| tokenizer_2=tokenizer_2, | |
| text_encoder_3=text_encoder_3, | |
| tokenizer_3=tokenizer_3, | |
| text_encoder_4=text_encoder_4, | |
| tokenizer_4=tokenizer_4, | |
| transformer=transformer, | |
| ) | |
| flush() | |
| text_encoder_list = [text_encoder, text_encoder_2, text_encoder_3, text_encoder_4] | |
| tokenizer_list = [tokenizer, tokenizer_2, tokenizer_3, tokenizer_4] | |
| for te in text_encoder_list: | |
| # set the dtype | |
| te.to(self.device_torch, dtype=dtype) | |
| # freeze the model | |
| freeze(te) | |
| # set to eval mode | |
| te.eval() | |
| # set the requires grad to false | |
| te.requires_grad_(False) | |
| flush() | |
| # save it to the model class | |
| self.vae = vae | |
| self.text_encoder = text_encoder_list # list of text encoders | |
| self.tokenizer = tokenizer_list # list of tokenizers | |
| self.model = pipe.transformer | |
| self.pipeline = pipe | |
| self.print_and_status_update("Model Loaded") | |
| def get_generation_pipeline(self): | |
| scheduler = FlowUniPCMultistepScheduler( | |
| num_train_timesteps=1000, | |
| shift=3.0, | |
| use_dynamic_shifting=False | |
| ) | |
| pipeline: HiDreamImagePipeline = HiDreamImagePipeline( | |
| scheduler=scheduler, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder[0], | |
| tokenizer=self.tokenizer[0], | |
| text_encoder_2=self.text_encoder[1], | |
| tokenizer_2=self.tokenizer[1], | |
| text_encoder_3=self.text_encoder[2], | |
| tokenizer_3=self.tokenizer[2], | |
| text_encoder_4=self.text_encoder[3], | |
| tokenizer_4=self.tokenizer[3], | |
| transformer=unwrap_model(self.model), | |
| aggressive_unloading=self.low_vram | |
| ) | |
| pipeline = pipeline.to(self.device_torch) | |
| return pipeline | |
| def generate_single_image( | |
| self, | |
| pipeline: HiDreamImagePipeline, | |
| gen_config: GenerateImageConfig, | |
| conditional_embeds: PromptEmbeds, | |
| unconditional_embeds: PromptEmbeds, | |
| generator: torch.Generator, | |
| extra: dict, | |
| ): | |
| img = pipeline( | |
| prompt_embeds=conditional_embeds.text_embeds, | |
| pooled_prompt_embeds=conditional_embeds.pooled_embeds, | |
| negative_prompt_embeds=unconditional_embeds.text_embeds, | |
| negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| generator=generator, | |
| **extra | |
| ).images[0] | |
| return img | |
| def get_noise_prediction( | |
| self, | |
| latent_model_input: torch.Tensor, | |
| timestep: torch.Tensor, # 0 to 1000 scale | |
| text_embeddings: PromptEmbeds, | |
| **kwargs | |
| ): | |
| batch_size = latent_model_input.shape[0] | |
| with torch.no_grad(): | |
| if latent_model_input.shape[-2] != latent_model_input.shape[-1]: | |
| B, C, H, W = latent_model_input.shape | |
| pH, pW = H // self.model.config.patch_size, W // self.model.config.patch_size | |
| img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) | |
| img_ids = torch.zeros(pH, pW, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] | |
| img_ids = img_ids.reshape(pH * pW, -1) | |
| img_ids_pad = torch.zeros(self.transformer.max_seq, 3) | |
| img_ids_pad[:pH*pW, :] = img_ids | |
| img_sizes = img_sizes.unsqueeze(0).to(latent_model_input.device) | |
| img_sizes = torch.cat([img_sizes] * batch_size, dim=0) | |
| img_ids = img_ids_pad.unsqueeze(0).to(latent_model_input.device) | |
| img_ids = torch.cat([img_ids] * batch_size, dim=0) | |
| else: | |
| img_sizes = img_ids = None | |
| dtype = self.model.dtype | |
| device = self.device_torch | |
| # Pack the latent | |
| if latent_model_input.shape[-2] != latent_model_input.shape[-1]: | |
| B, C, H, W = latent_model_input.shape | |
| patch_size = self.transformer.config.patch_size | |
| pH, pW = H // patch_size, W // patch_size | |
| out = torch.zeros( | |
| (B, C, self.transformer.max_seq, patch_size * patch_size), | |
| dtype=latent_model_input.dtype, | |
| device=latent_model_input.device | |
| ) | |
| latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) | |
| out[:, :, 0:pH*pW] = latent_model_input | |
| latent_model_input = out | |
| text_embeds = text_embeddings.text_embeds | |
| # run the to for the list | |
| text_embeds = [te.to(device, dtype=dtype) for te in text_embeds] | |
| noise_pred = self.transformer( | |
| hidden_states = latent_model_input, | |
| timesteps = timestep, | |
| encoder_hidden_states = text_embeds, | |
| pooled_embeds = text_embeddings.pooled_embeds.to(device, dtype=dtype), | |
| img_sizes = img_sizes, | |
| img_ids = img_ids, | |
| return_dict = False, | |
| )[0] | |
| noise_pred = -noise_pred | |
| return noise_pred | |
| def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: | |
| self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) | |
| max_sequence_length = 128 | |
| prompt_embeds, pooled_prompt_embeds = self.pipeline._encode_prompt( | |
| prompt = prompt, | |
| prompt_2 = prompt, | |
| prompt_3 = prompt, | |
| prompt_4 = prompt, | |
| device = self.device_torch, | |
| dtype = self.torch_dtype, | |
| num_images_per_prompt = 1, | |
| max_sequence_length = max_sequence_length, | |
| ) | |
| pe = PromptEmbeds( | |
| [prompt_embeds, pooled_prompt_embeds] | |
| ) | |
| return pe | |
| def get_model_has_grad(self): | |
| # return from a weight if it has grad | |
| return self.model.double_stream_blocks[0].block.attn1.to_q.weight.requires_grad | |
| def get_te_has_grad(self): | |
| # assume no one wants to finetune 4 text encoders. | |
| return False | |
| def save_model(self, output_path, meta, save_dtype): | |
| # only save the unet | |
| transformer: HiDreamImageTransformer2DModel = unwrap_model(self.model) | |
| transformer.save_pretrained( | |
| save_directory=os.path.join(output_path, 'transformer'), | |
| safe_serialization=True, | |
| ) | |
| meta_path = os.path.join(output_path, 'aitk_meta.yaml') | |
| with open(meta_path, 'w') as f: | |
| yaml.dump(meta, f) | |
| def get_loss_target(self, *args, **kwargs): | |
| noise = kwargs.get('noise') | |
| batch = kwargs.get('batch') | |
| return (noise - batch.latents).detach() | |
| def get_transformer_block_names(self) -> Optional[List[str]]: | |
| return ['double_stream_blocks', 'single_stream_blocks'] | |
| def convert_lora_weights_before_save(self, state_dict): | |
| # currently starte with transformer. but needs to start with diffusion_model. for comfyui | |
| new_sd = {} | |
| for key, value in state_dict.items(): | |
| new_key = key.replace("transformer.", "diffusion_model.") | |
| new_sd[new_key] = value | |
| return new_sd | |
| def convert_lora_weights_before_load(self, state_dict): | |
| # saved as diffusion_model. but needs to be transformer. for ai-toolkit | |
| new_sd = {} | |
| for key, value in state_dict.items(): | |
| new_key = key.replace("diffusion_model.", "transformer.") | |
| new_sd[new_key] = value | |
| return new_sd | |
| def get_base_model_version(self): | |
| return "hidream_i1" | |