|
|
|
|
|
|
|
import weakref |
|
from diffusers import CogView4Pipeline |
|
import torch |
|
import yaml |
|
|
|
from toolkit.basic import flush |
|
from toolkit.config_modules import GenerateImageConfig, ModelConfig |
|
from toolkit.dequantize import patch_dequantization_on_save |
|
from toolkit.models.base_model import BaseModel |
|
from toolkit.prompt_utils import PromptEmbeds |
|
|
|
import os |
|
import copy |
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch |
|
import torch |
|
import diffusers |
|
from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline |
|
from optimum.quanto import freeze, qfloat8, QTensor, qint4 |
|
from toolkit.util.quantize import quantize, get_qtype |
|
from transformers import GlmModel, AutoTokenizer |
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
from typing import TYPE_CHECKING |
|
from toolkit.accelerator import unwrap_model |
|
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler |
|
|
|
if TYPE_CHECKING: |
|
from toolkit.lora_special import LoRASpecialNetwork |
|
|
|
|
|
|
|
|
|
class FakeModel: |
|
def __init__(self, model): |
|
self.model_ref = weakref.ref(model) |
|
pass |
|
|
|
@property |
|
def device(self): |
|
return self.model_ref().device |
|
|
|
|
|
scheduler_config = { |
|
"base_image_seq_len": 256, |
|
"base_shift": 0.25, |
|
"invert_sigmas": False, |
|
"max_image_seq_len": 4096, |
|
"max_shift": 0.75, |
|
"num_train_timesteps": 1000, |
|
"shift": 1.0, |
|
"shift_terminal": None, |
|
"time_shift_type": "linear", |
|
"use_beta_sigmas": False, |
|
"use_dynamic_shifting": True, |
|
"use_exponential_sigmas": False, |
|
"use_karras_sigmas": False |
|
} |
|
|
|
|
|
class CogView4(BaseModel): |
|
arch = 'cogview4' |
|
def __init__( |
|
self, |
|
device, |
|
model_config: ModelConfig, |
|
dtype='bf16', |
|
custom_pipeline=None, |
|
noise_scheduler=None, |
|
**kwargs |
|
): |
|
super().__init__(device, model_config, dtype, |
|
custom_pipeline, noise_scheduler, **kwargs) |
|
self.is_flow_matching = True |
|
self.is_transformer = True |
|
self.target_lora_modules = ['CogView4Transformer2DModel'] |
|
|
|
|
|
self.effective_noise = None |
|
|
|
|
|
@staticmethod |
|
def get_train_scheduler(): |
|
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) |
|
return scheduler |
|
|
|
def load_model(self): |
|
dtype = self.torch_dtype |
|
base_model_path = "THUDM/CogView4-6B" |
|
model_path = self.model_config.name_or_path |
|
|
|
self.print_and_status_update("Loading CogView4 model") |
|
|
|
base_model_path = self.model_config.name_or_path_original |
|
subfolder = 'transformer' |
|
transformer_path = model_path |
|
if os.path.exists(transformer_path): |
|
subfolder = None |
|
transformer_path = os.path.join(transformer_path, 'transformer') |
|
|
|
te_folder_path = os.path.join(model_path, 'text_encoder') |
|
|
|
if os.path.exists(te_folder_path): |
|
base_model_path = model_path |
|
|
|
self.print_and_status_update("Loading GlmModel") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
base_model_path, subfolder="tokenizer", torch_dtype=dtype) |
|
text_encoder = GlmModel.from_pretrained( |
|
base_model_path, subfolder="text_encoder", torch_dtype=dtype) |
|
|
|
text_encoder.to(self.device_torch, dtype=dtype) |
|
flush() |
|
|
|
if self.model_config.quantize_te: |
|
self.print_and_status_update("Quantizing GlmModel") |
|
quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) |
|
freeze(text_encoder) |
|
flush() |
|
|
|
|
|
text_encoder.model = FakeModel(text_encoder) |
|
|
|
self.print_and_status_update("Loading transformer") |
|
transformer = CogView4Transformer2DModel.from_pretrained( |
|
transformer_path, |
|
subfolder=subfolder, |
|
torch_dtype=dtype, |
|
) |
|
|
|
if self.model_config.split_model_over_gpus: |
|
raise ValueError( |
|
"Splitting model over gpus is not supported for CogViewModels models") |
|
|
|
transformer.to(self.quantize_device, dtype=dtype) |
|
flush() |
|
|
|
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: |
|
raise ValueError( |
|
"Assistant LoRA is not supported for CogViewModels models currently") |
|
|
|
if self.model_config.lora_path is not None: |
|
raise ValueError( |
|
"Loading LoRA is not supported for CogViewModels models currently") |
|
|
|
flush() |
|
|
|
if self.model_config.quantize: |
|
quantization_args = self.model_config.quantize_kwargs |
|
if 'exclude' not in quantization_args: |
|
quantization_args['exclude'] = [] |
|
if 'include' not in quantization_args: |
|
quantization_args['include'] = [] |
|
|
|
|
|
quantization_args['include'] += ["transformer_blocks.*"] |
|
|
|
|
|
quantization_args['exclude'] += [ |
|
"transformer_blocks.*.norm1", |
|
"transformer_blocks.*.norm2", |
|
"transformer_blocks.*.norm2_context", |
|
"transformer_blocks.*.attn1.norm_q", |
|
"transformer_blocks.*.attn1.norm_k" |
|
] |
|
|
|
|
|
patch_dequantization_on_save(transformer) |
|
quantization_type = get_qtype(self.model_config.qtype) |
|
self.print_and_status_update("Quantizing transformer") |
|
quantize(transformer, weights=quantization_type, **quantization_args) |
|
freeze(transformer) |
|
transformer.to(self.device_torch) |
|
else: |
|
transformer.to(self.device_torch, dtype=dtype) |
|
|
|
flush() |
|
|
|
scheduler = CogView4.get_train_scheduler() |
|
self.print_and_status_update("Loading VAE") |
|
vae = AutoencoderKL.from_pretrained( |
|
base_model_path, subfolder="vae", torch_dtype=dtype) |
|
flush() |
|
|
|
self.print_and_status_update("Making pipe") |
|
pipe: CogView4Pipeline = CogView4Pipeline( |
|
scheduler=scheduler, |
|
text_encoder=None, |
|
tokenizer=tokenizer, |
|
vae=vae, |
|
transformer=None, |
|
) |
|
pipe.text_encoder = text_encoder |
|
pipe.transformer = transformer |
|
|
|
self.print_and_status_update("Preparing Model") |
|
|
|
text_encoder = pipe.text_encoder |
|
tokenizer = pipe.tokenizer |
|
|
|
pipe.transformer = pipe.transformer.to(self.device_torch) |
|
|
|
flush() |
|
text_encoder.to(self.device_torch) |
|
text_encoder.requires_grad_(False) |
|
text_encoder.eval() |
|
pipe.transformer = pipe.transformer.to(self.device_torch) |
|
flush() |
|
self.pipeline = pipe |
|
self.model = transformer |
|
self.vae = vae |
|
self.text_encoder = text_encoder |
|
self.tokenizer = tokenizer |
|
|
|
def get_generation_pipeline(self): |
|
scheduler = CogView4.get_train_scheduler() |
|
pipeline = CogView4Pipeline( |
|
vae=self.vae, |
|
transformer=self.unet, |
|
text_encoder=self.text_encoder, |
|
tokenizer=self.tokenizer, |
|
scheduler=scheduler, |
|
) |
|
return pipeline |
|
|
|
def generate_single_image( |
|
self, |
|
pipeline: CogView4Pipeline, |
|
gen_config: GenerateImageConfig, |
|
conditional_embeds: PromptEmbeds, |
|
unconditional_embeds: PromptEmbeds, |
|
generator: torch.Generator, |
|
extra: dict, |
|
): |
|
img = pipeline( |
|
prompt_embeds=conditional_embeds.text_embeds.to( |
|
self.device_torch, dtype=self.torch_dtype), |
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to( |
|
self.device_torch, dtype=self.torch_dtype), |
|
height=gen_config.height, |
|
width=gen_config.width, |
|
num_inference_steps=gen_config.num_inference_steps, |
|
guidance_scale=gen_config.guidance_scale, |
|
latents=gen_config.latents, |
|
generator=generator, |
|
**extra |
|
).images[0] |
|
return img |
|
|
|
def get_noise_prediction( |
|
self, |
|
latent_model_input: torch.Tensor, |
|
timestep: torch.Tensor, |
|
text_embeddings: PromptEmbeds, |
|
**kwargs |
|
): |
|
|
|
target_size = latent_model_input.shape[-2:] |
|
|
|
target_size = (target_size[0] * 8, target_size[1] * 8) |
|
crops_coords_top_left = torch.tensor( |
|
[(0, 0)], dtype=self.torch_dtype, device=self.device_torch) |
|
|
|
original_size = torch.tensor( |
|
[target_size], dtype=self.torch_dtype, device=self.device_torch) |
|
target_size = original_size.clone() |
|
noise_pred_cond = self.model( |
|
hidden_states=latent_model_input, |
|
encoder_hidden_states=text_embeddings.text_embeds, |
|
timestep=timestep, |
|
original_size=original_size, |
|
target_size=target_size, |
|
crop_coords=crops_coords_top_left, |
|
return_dict=False, |
|
)[0] |
|
return noise_pred_cond |
|
|
|
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: |
|
prompt_embeds, _ = self.pipeline.encode_prompt( |
|
prompt, |
|
do_classifier_free_guidance=False, |
|
device=self.device_torch, |
|
dtype=self.torch_dtype, |
|
) |
|
return PromptEmbeds(prompt_embeds) |
|
|
|
def get_model_has_grad(self): |
|
return self.model.proj_out.weight.requires_grad |
|
|
|
def get_te_has_grad(self): |
|
return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad |
|
|
|
def save_model(self, output_path, meta, save_dtype): |
|
|
|
transformer: CogView4Transformer2DModel = unwrap_model(self.model) |
|
transformer.save_pretrained( |
|
save_directory=os.path.join(output_path, 'transformer'), |
|
safe_serialization=True, |
|
) |
|
|
|
meta_path = os.path.join(output_path, 'aitk_meta.yaml') |
|
with open(meta_path, 'w') as f: |
|
yaml.dump(meta, f) |
|
|
|
def get_loss_target(self, *args, **kwargs): |
|
noise = kwargs.get('noise') |
|
effective_noise = self.effective_noise |
|
batch = kwargs.get('batch') |
|
if batch is None: |
|
raise ValueError("Batch is not provided") |
|
if noise is None: |
|
raise ValueError("Noise is not provided") |
|
|
|
|
|
return (noise - batch.latents).detach() |
|
|
|
|
|
|
|
def _get_low_res_latents(self, latents): |
|
|
|
with torch.no_grad(): |
|
|
|
images = self.decode_latents( |
|
latents, device=latents.device, dtype=latents.dtype) |
|
|
|
|
|
B, C, H, W = images.shape |
|
low_res_images = torch.nn.functional.interpolate( |
|
images, |
|
size=(H // 2, W // 2), |
|
mode="bilinear", |
|
align_corners=False |
|
) |
|
|
|
|
|
upsampled_low_res_images = torch.nn.functional.interpolate( |
|
low_res_images, |
|
size=(H, W), |
|
mode="bilinear", |
|
align_corners=False |
|
) |
|
|
|
|
|
low_res_latents = self.encode_images( |
|
upsampled_low_res_images, device=latents.device, dtype=latents.dtype) |
|
return low_res_latents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|