Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Upload 25 files
Browse files- library/__init__.py +0 -0
- library/adafactor_fused.py +106 -0
- library/attention_processors.py +227 -0
- library/config_util.py +721 -0
- library/custom_train_functions.py +559 -0
- library/deepspeed_utils.py +139 -0
- library/device_utils.py +84 -0
- library/huggingface_util.py +84 -0
- library/hypernetwork.py +223 -0
- library/ipex/__init__.py +180 -0
- library/ipex/attention.py +177 -0
- library/ipex/diffusers.py +312 -0
- library/ipex/gradscaler.py +183 -0
- library/ipex/hijacks.py +313 -0
- library/lpw_stable_diffusion.py +1233 -0
- library/model_util.py +1356 -0
- library/original_unet.py +1919 -0
- library/sai_model_spec.py +309 -0
- library/sdxl_lpw_stable_diffusion.py +1347 -0
- library/sdxl_model_util.py +583 -0
- library/sdxl_original_unet.py +1286 -0
- library/sdxl_train_util.py +381 -0
- library/slicing_vae.py +682 -0
- library/train_util.py +0 -0
- library/utils.py +287 -0
    	
        library/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        library/adafactor_fused.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from transformers import Adafactor
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            @torch.no_grad()
         | 
| 6 | 
            +
            def adafactor_step_param(self, p, group):
         | 
| 7 | 
            +
                if p.grad is None:
         | 
| 8 | 
            +
                    return
         | 
| 9 | 
            +
                grad = p.grad
         | 
| 10 | 
            +
                if grad.dtype in {torch.float16, torch.bfloat16}:
         | 
| 11 | 
            +
                    grad = grad.float()
         | 
| 12 | 
            +
                if grad.is_sparse:
         | 
| 13 | 
            +
                    raise RuntimeError("Adafactor does not support sparse gradients.")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                state = self.state[p]
         | 
| 16 | 
            +
                grad_shape = grad.shape
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                factored, use_first_moment = Adafactor._get_options(group, grad_shape)
         | 
| 19 | 
            +
                # State Initialization
         | 
| 20 | 
            +
                if len(state) == 0:
         | 
| 21 | 
            +
                    state["step"] = 0
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    if use_first_moment:
         | 
| 24 | 
            +
                        # Exponential moving average of gradient values
         | 
| 25 | 
            +
                        state["exp_avg"] = torch.zeros_like(grad)
         | 
| 26 | 
            +
                    if factored:
         | 
| 27 | 
            +
                        state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
         | 
| 28 | 
            +
                        state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        state["exp_avg_sq"] = torch.zeros_like(grad)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    state["RMS"] = 0
         | 
| 33 | 
            +
                else:
         | 
| 34 | 
            +
                    if use_first_moment:
         | 
| 35 | 
            +
                        state["exp_avg"] = state["exp_avg"].to(grad)
         | 
| 36 | 
            +
                    if factored:
         | 
| 37 | 
            +
                        state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
         | 
| 38 | 
            +
                        state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
         | 
| 39 | 
            +
                    else:
         | 
| 40 | 
            +
                        state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                p_data_fp32 = p
         | 
| 43 | 
            +
                if p.dtype in {torch.float16, torch.bfloat16}:
         | 
| 44 | 
            +
                    p_data_fp32 = p_data_fp32.float()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                state["step"] += 1
         | 
| 47 | 
            +
                state["RMS"] = Adafactor._rms(p_data_fp32)
         | 
| 48 | 
            +
                lr = Adafactor._get_lr(group, state)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
         | 
| 51 | 
            +
                update = (grad ** 2) + group["eps"][0]
         | 
| 52 | 
            +
                if factored:
         | 
| 53 | 
            +
                    exp_avg_sq_row = state["exp_avg_sq_row"]
         | 
| 54 | 
            +
                    exp_avg_sq_col = state["exp_avg_sq_col"]
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
         | 
| 57 | 
            +
                    exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Approximation of exponential moving average of square of gradient
         | 
| 60 | 
            +
                    update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
         | 
| 61 | 
            +
                    update.mul_(grad)
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                    exp_avg_sq = state["exp_avg_sq"]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
         | 
| 66 | 
            +
                    update = exp_avg_sq.rsqrt().mul_(grad)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
         | 
| 69 | 
            +
                update.mul_(lr)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                if use_first_moment:
         | 
| 72 | 
            +
                    exp_avg = state["exp_avg"]
         | 
| 73 | 
            +
                    exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
         | 
| 74 | 
            +
                    update = exp_avg
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                if group["weight_decay"] != 0:
         | 
| 77 | 
            +
                    p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                p_data_fp32.add_(-update)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                if p.dtype in {torch.float16, torch.bfloat16}:
         | 
| 82 | 
            +
                    p.copy_(p_data_fp32)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            @torch.no_grad()
         | 
| 86 | 
            +
            def adafactor_step(self, closure=None):
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                Performs a single optimization step
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                Arguments:
         | 
| 91 | 
            +
                    closure (callable, optional): A closure that reevaluates the model
         | 
| 92 | 
            +
                        and returns the loss.
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                loss = None
         | 
| 95 | 
            +
                if closure is not None:
         | 
| 96 | 
            +
                    loss = closure()
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                for group in self.param_groups:
         | 
| 99 | 
            +
                    for p in group["params"]:
         | 
| 100 | 
            +
                        adafactor_step_param(self, p, group)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                return loss
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            def patch_adafactor_fused(optimizer: Adafactor):
         | 
| 105 | 
            +
                optimizer.step_param = adafactor_step_param.__get__(optimizer)
         | 
| 106 | 
            +
                optimizer.step = adafactor_step.__get__(optimizer)
         | 
    	
        library/attention_processors.py
    ADDED
    
    | @@ -0,0 +1,227 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from typing import Any
         | 
| 3 | 
            +
            from einops import rearrange
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from diffusers.models.attention_processor import Attention
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            # flash attention forwards and backwards
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # https://arxiv.org/abs/2205.14135
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            EPSILON = 1e-6
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class FlashAttentionFunction(torch.autograd.function.Function):
         | 
| 16 | 
            +
                @staticmethod
         | 
| 17 | 
            +
                @torch.no_grad()
         | 
| 18 | 
            +
                def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
         | 
| 19 | 
            +
                    """Algorithm 2 in the paper"""
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    device = q.device
         | 
| 22 | 
            +
                    dtype = q.dtype
         | 
| 23 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 24 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    o = torch.zeros_like(q)
         | 
| 27 | 
            +
                    all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
         | 
| 28 | 
            +
                    all_row_maxes = torch.full(
         | 
| 29 | 
            +
                        (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
         | 
| 30 | 
            +
                    )
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    scale = q.shape[-1] ** -0.5
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if mask is None:
         | 
| 35 | 
            +
                        mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
         | 
| 36 | 
            +
                    else:
         | 
| 37 | 
            +
                        mask = rearrange(mask, "b n -> b 1 1 n")
         | 
| 38 | 
            +
                        mask = mask.split(q_bucket_size, dim=-1)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    row_splits = zip(
         | 
| 41 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 42 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 43 | 
            +
                        mask,
         | 
| 44 | 
            +
                        all_row_sums.split(q_bucket_size, dim=-2),
         | 
| 45 | 
            +
                        all_row_maxes.split(q_bucket_size, dim=-2),
         | 
| 46 | 
            +
                    )
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
         | 
| 49 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        col_splits = zip(
         | 
| 52 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 53 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                        for k_ind, (kc, vc) in enumerate(col_splits):
         | 
| 57 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                            attn_weights = (
         | 
| 60 | 
            +
                                torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 61 | 
            +
                            )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                            if row_mask is not None:
         | 
| 64 | 
            +
                                attn_weights.masked_fill_(~row_mask, max_neg_value)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 67 | 
            +
                                causal_mask = torch.ones(
         | 
| 68 | 
            +
                                    (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
         | 
| 69 | 
            +
                                ).triu(q_start_index - k_start_index + 1)
         | 
| 70 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                            block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
         | 
| 73 | 
            +
                            attn_weights -= block_row_maxes
         | 
| 74 | 
            +
                            exp_weights = torch.exp(attn_weights)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                            if row_mask is not None:
         | 
| 77 | 
            +
                                exp_weights.masked_fill_(~row_mask, 0.0)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                            block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
         | 
| 80 | 
            +
                                min=EPSILON
         | 
| 81 | 
            +
                            )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                            new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                            exp_values = torch.einsum(
         | 
| 86 | 
            +
                                "... i j, ... j d -> ... i d", exp_weights, vc
         | 
| 87 | 
            +
                            )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                            exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
         | 
| 90 | 
            +
                            exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                            new_row_sums = (
         | 
| 93 | 
            +
                                exp_row_max_diff * row_sums
         | 
| 94 | 
            +
                                + exp_block_row_max_diff * block_row_sums
         | 
| 95 | 
            +
                            )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                            oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
         | 
| 98 | 
            +
                                (exp_block_row_max_diff / new_row_sums) * exp_values
         | 
| 99 | 
            +
                            )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                            row_maxes.copy_(new_row_maxes)
         | 
| 102 | 
            +
                            row_sums.copy_(new_row_sums)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
         | 
| 105 | 
            +
                    ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    return o
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                @staticmethod
         | 
| 110 | 
            +
                @torch.no_grad()
         | 
| 111 | 
            +
                def backward(ctx, do):
         | 
| 112 | 
            +
                    """Algorithm 4 in the paper"""
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
         | 
| 115 | 
            +
                    q, k, v, o, l, m = ctx.saved_tensors
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    device = q.device
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 120 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    dq = torch.zeros_like(q)
         | 
| 123 | 
            +
                    dk = torch.zeros_like(k)
         | 
| 124 | 
            +
                    dv = torch.zeros_like(v)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    row_splits = zip(
         | 
| 127 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 128 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 129 | 
            +
                        do.split(q_bucket_size, dim=-2),
         | 
| 130 | 
            +
                        mask,
         | 
| 131 | 
            +
                        l.split(q_bucket_size, dim=-2),
         | 
| 132 | 
            +
                        m.split(q_bucket_size, dim=-2),
         | 
| 133 | 
            +
                        dq.split(q_bucket_size, dim=-2),
         | 
| 134 | 
            +
                    )
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
         | 
| 137 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        col_splits = zip(
         | 
| 140 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 141 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 142 | 
            +
                            dk.split(k_bucket_size, dim=-2),
         | 
| 143 | 
            +
                            dv.split(k_bucket_size, dim=-2),
         | 
| 144 | 
            +
                        )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                        for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
         | 
| 147 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                            attn_weights = (
         | 
| 150 | 
            +
                                torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 151 | 
            +
                            )
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 154 | 
            +
                                causal_mask = torch.ones(
         | 
| 155 | 
            +
                                    (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
         | 
| 156 | 
            +
                                ).triu(q_start_index - k_start_index + 1)
         | 
| 157 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                            exp_attn_weights = torch.exp(attn_weights - mc)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                            if row_mask is not None:
         | 
| 162 | 
            +
                                exp_attn_weights.masked_fill_(~row_mask, 0.0)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                            p = exp_attn_weights / lc
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                            dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
         | 
| 167 | 
            +
                            dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                            D = (doc * oc).sum(dim=-1, keepdims=True)
         | 
| 170 | 
            +
                            ds = p * scale * (dp - D)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                            dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
         | 
| 173 | 
            +
                            dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                            dqc.add_(dq_chunk)
         | 
| 176 | 
            +
                            dkc.add_(dk_chunk)
         | 
| 177 | 
            +
                            dvc.add_(dv_chunk)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    return dq, dk, dv, None, None, None, None
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            class FlashAttnProcessor:
         | 
| 183 | 
            +
                def __call__(
         | 
| 184 | 
            +
                    self,
         | 
| 185 | 
            +
                    attn: Attention,
         | 
| 186 | 
            +
                    hidden_states,
         | 
| 187 | 
            +
                    encoder_hidden_states=None,
         | 
| 188 | 
            +
                    attention_mask=None,
         | 
| 189 | 
            +
                ) -> Any:
         | 
| 190 | 
            +
                    q_bucket_size = 512
         | 
| 191 | 
            +
                    k_bucket_size = 1024
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    h = attn.heads
         | 
| 194 | 
            +
                    q = attn.to_q(hidden_states)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    encoder_hidden_states = (
         | 
| 197 | 
            +
                        encoder_hidden_states
         | 
| 198 | 
            +
                        if encoder_hidden_states is not None
         | 
| 199 | 
            +
                        else hidden_states
         | 
| 200 | 
            +
                    )
         | 
| 201 | 
            +
                    encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
         | 
| 204 | 
            +
                        context_k, context_v = attn.hypernetwork.forward(
         | 
| 205 | 
            +
                            hidden_states, encoder_hidden_states
         | 
| 206 | 
            +
                        )
         | 
| 207 | 
            +
                        context_k = context_k.to(hidden_states.dtype)
         | 
| 208 | 
            +
                        context_v = context_v.to(hidden_states.dtype)
         | 
| 209 | 
            +
                    else:
         | 
| 210 | 
            +
                        context_k = encoder_hidden_states
         | 
| 211 | 
            +
                        context_v = encoder_hidden_states
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    k = attn.to_k(context_k)
         | 
| 214 | 
            +
                    v = attn.to_v(context_v)
         | 
| 215 | 
            +
                    del encoder_hidden_states, hidden_states
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    out = FlashAttentionFunction.apply(
         | 
| 220 | 
            +
                        q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
         | 
| 221 | 
            +
                    )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    out = rearrange(out, "b h n d -> b n (h d)")
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    out = attn.to_out[0](out)
         | 
| 226 | 
            +
                    out = attn.to_out[1](out)
         | 
| 227 | 
            +
                    return out
         | 
    	
        library/config_util.py
    ADDED
    
    | @@ -0,0 +1,721 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            from dataclasses import (
         | 
| 3 | 
            +
                asdict,
         | 
| 4 | 
            +
                dataclass,
         | 
| 5 | 
            +
            )
         | 
| 6 | 
            +
            import functools
         | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            from textwrap import dedent, indent
         | 
| 9 | 
            +
            import json
         | 
| 10 | 
            +
            from pathlib import Path
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # from toolz import curry
         | 
| 13 | 
            +
            from typing import (
         | 
| 14 | 
            +
                List,
         | 
| 15 | 
            +
                Optional,
         | 
| 16 | 
            +
                Sequence,
         | 
| 17 | 
            +
                Tuple,
         | 
| 18 | 
            +
                Union,
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import toml
         | 
| 22 | 
            +
            import voluptuous
         | 
| 23 | 
            +
            from voluptuous import (
         | 
| 24 | 
            +
                Any,
         | 
| 25 | 
            +
                ExactSequence,
         | 
| 26 | 
            +
                MultipleInvalid,
         | 
| 27 | 
            +
                Object,
         | 
| 28 | 
            +
                Required,
         | 
| 29 | 
            +
                Schema,
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
            from transformers import CLIPTokenizer
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            from . import train_util
         | 
| 34 | 
            +
            from .train_util import (
         | 
| 35 | 
            +
                DreamBoothSubset,
         | 
| 36 | 
            +
                FineTuningSubset,
         | 
| 37 | 
            +
                ControlNetSubset,
         | 
| 38 | 
            +
                DreamBoothDataset,
         | 
| 39 | 
            +
                FineTuningDataset,
         | 
| 40 | 
            +
                ControlNetDataset,
         | 
| 41 | 
            +
                DatasetGroup,
         | 
| 42 | 
            +
            )
         | 
| 43 | 
            +
            from .utils import setup_logging
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            setup_logging()
         | 
| 46 | 
            +
            import logging
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def add_config_arguments(parser: argparse.ArgumentParser):
         | 
| 52 | 
            +
                parser.add_argument(
         | 
| 53 | 
            +
                    "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            # TODO: inherit Params class in Subset, Dataset
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            @dataclass
         | 
| 61 | 
            +
            class BaseSubsetParams:
         | 
| 62 | 
            +
                image_dir: Optional[str] = None
         | 
| 63 | 
            +
                num_repeats: int = 1
         | 
| 64 | 
            +
                shuffle_caption: bool = False
         | 
| 65 | 
            +
                caption_separator: str = (",",)
         | 
| 66 | 
            +
                keep_tokens: int = 0
         | 
| 67 | 
            +
                keep_tokens_separator: str = (None,)
         | 
| 68 | 
            +
                secondary_separator: Optional[str] = None
         | 
| 69 | 
            +
                enable_wildcard: bool = False
         | 
| 70 | 
            +
                color_aug: bool = False
         | 
| 71 | 
            +
                flip_aug: bool = False
         | 
| 72 | 
            +
                face_crop_aug_range: Optional[Tuple[float, float]] = None
         | 
| 73 | 
            +
                random_crop: bool = False
         | 
| 74 | 
            +
                caption_prefix: Optional[str] = None
         | 
| 75 | 
            +
                caption_suffix: Optional[str] = None
         | 
| 76 | 
            +
                caption_dropout_rate: float = 0.0
         | 
| 77 | 
            +
                caption_dropout_every_n_epochs: int = 0
         | 
| 78 | 
            +
                caption_tag_dropout_rate: float = 0.0
         | 
| 79 | 
            +
                token_warmup_min: int = 1
         | 
| 80 | 
            +
                token_warmup_step: float = 0
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            @dataclass
         | 
| 84 | 
            +
            class DreamBoothSubsetParams(BaseSubsetParams):
         | 
| 85 | 
            +
                is_reg: bool = False
         | 
| 86 | 
            +
                class_tokens: Optional[str] = None
         | 
| 87 | 
            +
                caption_extension: str = ".caption"
         | 
| 88 | 
            +
                cache_info: bool = False
         | 
| 89 | 
            +
                alpha_mask: bool = False
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            @dataclass
         | 
| 93 | 
            +
            class FineTuningSubsetParams(BaseSubsetParams):
         | 
| 94 | 
            +
                metadata_file: Optional[str] = None
         | 
| 95 | 
            +
                alpha_mask: bool = False
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            @dataclass
         | 
| 99 | 
            +
            class ControlNetSubsetParams(BaseSubsetParams):
         | 
| 100 | 
            +
                conditioning_data_dir: str = None
         | 
| 101 | 
            +
                caption_extension: str = ".caption"
         | 
| 102 | 
            +
                cache_info: bool = False
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            @dataclass
         | 
| 106 | 
            +
            class BaseDatasetParams:
         | 
| 107 | 
            +
                tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
         | 
| 108 | 
            +
                max_token_length: int = None
         | 
| 109 | 
            +
                resolution: Optional[Tuple[int, int]] = None
         | 
| 110 | 
            +
                network_multiplier: float = 1.0
         | 
| 111 | 
            +
                debug_dataset: bool = False
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            @dataclass
         | 
| 115 | 
            +
            class DreamBoothDatasetParams(BaseDatasetParams):
         | 
| 116 | 
            +
                batch_size: int = 1
         | 
| 117 | 
            +
                enable_bucket: bool = False
         | 
| 118 | 
            +
                min_bucket_reso: int = 256
         | 
| 119 | 
            +
                max_bucket_reso: int = 1024
         | 
| 120 | 
            +
                bucket_reso_steps: int = 64
         | 
| 121 | 
            +
                bucket_no_upscale: bool = False
         | 
| 122 | 
            +
                prior_loss_weight: float = 1.0
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            @dataclass
         | 
| 126 | 
            +
            class FineTuningDatasetParams(BaseDatasetParams):
         | 
| 127 | 
            +
                batch_size: int = 1
         | 
| 128 | 
            +
                enable_bucket: bool = False
         | 
| 129 | 
            +
                min_bucket_reso: int = 256
         | 
| 130 | 
            +
                max_bucket_reso: int = 1024
         | 
| 131 | 
            +
                bucket_reso_steps: int = 64
         | 
| 132 | 
            +
                bucket_no_upscale: bool = False
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            @dataclass
         | 
| 136 | 
            +
            class ControlNetDatasetParams(BaseDatasetParams):
         | 
| 137 | 
            +
                batch_size: int = 1
         | 
| 138 | 
            +
                enable_bucket: bool = False
         | 
| 139 | 
            +
                min_bucket_reso: int = 256
         | 
| 140 | 
            +
                max_bucket_reso: int = 1024
         | 
| 141 | 
            +
                bucket_reso_steps: int = 64
         | 
| 142 | 
            +
                bucket_no_upscale: bool = False
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            @dataclass
         | 
| 146 | 
            +
            class SubsetBlueprint:
         | 
| 147 | 
            +
                params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            @dataclass
         | 
| 151 | 
            +
            class DatasetBlueprint:
         | 
| 152 | 
            +
                is_dreambooth: bool
         | 
| 153 | 
            +
                is_controlnet: bool
         | 
| 154 | 
            +
                params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
         | 
| 155 | 
            +
                subsets: Sequence[SubsetBlueprint]
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            @dataclass
         | 
| 159 | 
            +
            class DatasetGroupBlueprint:
         | 
| 160 | 
            +
                datasets: Sequence[DatasetBlueprint]
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            @dataclass
         | 
| 164 | 
            +
            class Blueprint:
         | 
| 165 | 
            +
                dataset_group: DatasetGroupBlueprint
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            class ConfigSanitizer:
         | 
| 169 | 
            +
                # @curry
         | 
| 170 | 
            +
                @staticmethod
         | 
| 171 | 
            +
                def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
         | 
| 172 | 
            +
                    Schema(ExactSequence([klass, klass]))(value)
         | 
| 173 | 
            +
                    return tuple(value)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                # @curry
         | 
| 176 | 
            +
                @staticmethod
         | 
| 177 | 
            +
                def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
         | 
| 178 | 
            +
                    Schema(Any(klass, ExactSequence([klass, klass])))(value)
         | 
| 179 | 
            +
                    try:
         | 
| 180 | 
            +
                        Schema(klass)(value)
         | 
| 181 | 
            +
                        return (value, value)
         | 
| 182 | 
            +
                    except:
         | 
| 183 | 
            +
                        return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                # subset schema
         | 
| 186 | 
            +
                SUBSET_ASCENDABLE_SCHEMA = {
         | 
| 187 | 
            +
                    "color_aug": bool,
         | 
| 188 | 
            +
                    "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
         | 
| 189 | 
            +
                    "flip_aug": bool,
         | 
| 190 | 
            +
                    "num_repeats": int,
         | 
| 191 | 
            +
                    "random_crop": bool,
         | 
| 192 | 
            +
                    "shuffle_caption": bool,
         | 
| 193 | 
            +
                    "keep_tokens": int,
         | 
| 194 | 
            +
                    "keep_tokens_separator": str,
         | 
| 195 | 
            +
                    "secondary_separator": str,
         | 
| 196 | 
            +
                    "caption_separator": str,
         | 
| 197 | 
            +
                    "enable_wildcard": bool,
         | 
| 198 | 
            +
                    "token_warmup_min": int,
         | 
| 199 | 
            +
                    "token_warmup_step": Any(float, int),
         | 
| 200 | 
            +
                    "caption_prefix": str,
         | 
| 201 | 
            +
                    "caption_suffix": str,
         | 
| 202 | 
            +
                }
         | 
| 203 | 
            +
                # DO means DropOut
         | 
| 204 | 
            +
                DO_SUBSET_ASCENDABLE_SCHEMA = {
         | 
| 205 | 
            +
                    "caption_dropout_every_n_epochs": int,
         | 
| 206 | 
            +
                    "caption_dropout_rate": Any(float, int),
         | 
| 207 | 
            +
                    "caption_tag_dropout_rate": Any(float, int),
         | 
| 208 | 
            +
                }
         | 
| 209 | 
            +
                # DB means DreamBooth
         | 
| 210 | 
            +
                DB_SUBSET_ASCENDABLE_SCHEMA = {
         | 
| 211 | 
            +
                    "caption_extension": str,
         | 
| 212 | 
            +
                    "class_tokens": str,
         | 
| 213 | 
            +
                    "cache_info": bool,
         | 
| 214 | 
            +
                }
         | 
| 215 | 
            +
                DB_SUBSET_DISTINCT_SCHEMA = {
         | 
| 216 | 
            +
                    Required("image_dir"): str,
         | 
| 217 | 
            +
                    "is_reg": bool,
         | 
| 218 | 
            +
                    "alpha_mask": bool,
         | 
| 219 | 
            +
                }
         | 
| 220 | 
            +
                # FT means FineTuning
         | 
| 221 | 
            +
                FT_SUBSET_DISTINCT_SCHEMA = {
         | 
| 222 | 
            +
                    Required("metadata_file"): str,
         | 
| 223 | 
            +
                    "image_dir": str,
         | 
| 224 | 
            +
                    "alpha_mask": bool,
         | 
| 225 | 
            +
                }
         | 
| 226 | 
            +
                CN_SUBSET_ASCENDABLE_SCHEMA = {
         | 
| 227 | 
            +
                    "caption_extension": str,
         | 
| 228 | 
            +
                    "cache_info": bool,
         | 
| 229 | 
            +
                }
         | 
| 230 | 
            +
                CN_SUBSET_DISTINCT_SCHEMA = {
         | 
| 231 | 
            +
                    Required("image_dir"): str,
         | 
| 232 | 
            +
                    Required("conditioning_data_dir"): str,
         | 
| 233 | 
            +
                }
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                # datasets schema
         | 
| 236 | 
            +
                DATASET_ASCENDABLE_SCHEMA = {
         | 
| 237 | 
            +
                    "batch_size": int,
         | 
| 238 | 
            +
                    "bucket_no_upscale": bool,
         | 
| 239 | 
            +
                    "bucket_reso_steps": int,
         | 
| 240 | 
            +
                    "enable_bucket": bool,
         | 
| 241 | 
            +
                    "max_bucket_reso": int,
         | 
| 242 | 
            +
                    "min_bucket_reso": int,
         | 
| 243 | 
            +
                    "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
         | 
| 244 | 
            +
                    "network_multiplier": float,
         | 
| 245 | 
            +
                }
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                # options handled by argparse but not handled by user config
         | 
| 248 | 
            +
                ARGPARSE_SPECIFIC_SCHEMA = {
         | 
| 249 | 
            +
                    "debug_dataset": bool,
         | 
| 250 | 
            +
                    "max_token_length": Any(None, int),
         | 
| 251 | 
            +
                    "prior_loss_weight": Any(float, int),
         | 
| 252 | 
            +
                }
         | 
| 253 | 
            +
                # for handling default None value of argparse
         | 
| 254 | 
            +
                ARGPARSE_NULLABLE_OPTNAMES = [
         | 
| 255 | 
            +
                    "face_crop_aug_range",
         | 
| 256 | 
            +
                    "resolution",
         | 
| 257 | 
            +
                ]
         | 
| 258 | 
            +
                # prepare map because option name may differ among argparse and user config
         | 
| 259 | 
            +
                ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
         | 
| 260 | 
            +
                    "train_batch_size": "batch_size",
         | 
| 261 | 
            +
                    "dataset_repeats": "num_repeats",
         | 
| 262 | 
            +
                }
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
         | 
| 265 | 
            +
                    assert support_dreambooth or support_finetuning or support_controlnet, (
         | 
| 266 | 
            +
                        "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
         | 
| 267 | 
            +
                        + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
         | 
| 268 | 
            +
                    )
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    self.db_subset_schema = self.__merge_dict(
         | 
| 271 | 
            +
                        self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 272 | 
            +
                        self.DB_SUBSET_DISTINCT_SCHEMA,
         | 
| 273 | 
            +
                        self.DB_SUBSET_ASCENDABLE_SCHEMA,
         | 
| 274 | 
            +
                        self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 275 | 
            +
                    )
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    self.ft_subset_schema = self.__merge_dict(
         | 
| 278 | 
            +
                        self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 279 | 
            +
                        self.FT_SUBSET_DISTINCT_SCHEMA,
         | 
| 280 | 
            +
                        self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 281 | 
            +
                    )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    self.cn_subset_schema = self.__merge_dict(
         | 
| 284 | 
            +
                        self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 285 | 
            +
                        self.CN_SUBSET_DISTINCT_SCHEMA,
         | 
| 286 | 
            +
                        self.CN_SUBSET_ASCENDABLE_SCHEMA,
         | 
| 287 | 
            +
                        self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 288 | 
            +
                    )
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    self.db_dataset_schema = self.__merge_dict(
         | 
| 291 | 
            +
                        self.DATASET_ASCENDABLE_SCHEMA,
         | 
| 292 | 
            +
                        self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 293 | 
            +
                        self.DB_SUBSET_ASCENDABLE_SCHEMA,
         | 
| 294 | 
            +
                        self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 295 | 
            +
                        {"subsets": [self.db_subset_schema]},
         | 
| 296 | 
            +
                    )
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    self.ft_dataset_schema = self.__merge_dict(
         | 
| 299 | 
            +
                        self.DATASET_ASCENDABLE_SCHEMA,
         | 
| 300 | 
            +
                        self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 301 | 
            +
                        self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 302 | 
            +
                        {"subsets": [self.ft_subset_schema]},
         | 
| 303 | 
            +
                    )
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    self.cn_dataset_schema = self.__merge_dict(
         | 
| 306 | 
            +
                        self.DATASET_ASCENDABLE_SCHEMA,
         | 
| 307 | 
            +
                        self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 308 | 
            +
                        self.CN_SUBSET_ASCENDABLE_SCHEMA,
         | 
| 309 | 
            +
                        self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 310 | 
            +
                        {"subsets": [self.cn_subset_schema]},
         | 
| 311 | 
            +
                    )
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    if support_dreambooth and support_finetuning:
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                        def validate_flex_dataset(dataset_config: dict):
         | 
| 316 | 
            +
                            subsets_config = dataset_config.get("subsets", [])
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                            if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
         | 
| 319 | 
            +
                                return Schema(self.cn_dataset_schema)(dataset_config)
         | 
| 320 | 
            +
                            # check dataset meets FT style
         | 
| 321 | 
            +
                            # NOTE: all FT subsets should have "metadata_file"
         | 
| 322 | 
            +
                            elif all(["metadata_file" in subset for subset in subsets_config]):
         | 
| 323 | 
            +
                                return Schema(self.ft_dataset_schema)(dataset_config)
         | 
| 324 | 
            +
                            # check dataset meets DB style
         | 
| 325 | 
            +
                            # NOTE: all DB subsets should have no "metadata_file"
         | 
| 326 | 
            +
                            elif all(["metadata_file" not in subset for subset in subsets_config]):
         | 
| 327 | 
            +
                                return Schema(self.db_dataset_schema)(dataset_config)
         | 
| 328 | 
            +
                            else:
         | 
| 329 | 
            +
                                raise voluptuous.Invalid(
         | 
| 330 | 
            +
                                    "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
         | 
| 331 | 
            +
                                )
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                        self.dataset_schema = validate_flex_dataset
         | 
| 334 | 
            +
                    elif support_dreambooth:
         | 
| 335 | 
            +
                        if support_controlnet:
         | 
| 336 | 
            +
                            self.dataset_schema = self.cn_dataset_schema
         | 
| 337 | 
            +
                        else:
         | 
| 338 | 
            +
                            self.dataset_schema = self.db_dataset_schema
         | 
| 339 | 
            +
                    elif support_finetuning:
         | 
| 340 | 
            +
                        self.dataset_schema = self.ft_dataset_schema
         | 
| 341 | 
            +
                    elif support_controlnet:
         | 
| 342 | 
            +
                        self.dataset_schema = self.cn_dataset_schema
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    self.general_schema = self.__merge_dict(
         | 
| 345 | 
            +
                        self.DATASET_ASCENDABLE_SCHEMA,
         | 
| 346 | 
            +
                        self.SUBSET_ASCENDABLE_SCHEMA,
         | 
| 347 | 
            +
                        self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
         | 
| 348 | 
            +
                        self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
         | 
| 349 | 
            +
                        self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
         | 
| 350 | 
            +
                    )
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    self.user_config_validator = Schema(
         | 
| 353 | 
            +
                        {
         | 
| 354 | 
            +
                            "general": self.general_schema,
         | 
| 355 | 
            +
                            "datasets": [self.dataset_schema],
         | 
| 356 | 
            +
                        }
         | 
| 357 | 
            +
                    )
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    self.argparse_schema = self.__merge_dict(
         | 
| 360 | 
            +
                        self.general_schema,
         | 
| 361 | 
            +
                        self.ARGPARSE_SPECIFIC_SCHEMA,
         | 
| 362 | 
            +
                        {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
         | 
| 363 | 
            +
                        {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
         | 
| 364 | 
            +
                    )
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                def sanitize_user_config(self, user_config: dict) -> dict:
         | 
| 369 | 
            +
                    try:
         | 
| 370 | 
            +
                        return self.user_config_validator(user_config)
         | 
| 371 | 
            +
                    except MultipleInvalid:
         | 
| 372 | 
            +
                        # TODO: エラー発生時のメッセージをわかりやすくする
         | 
| 373 | 
            +
                        logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
         | 
| 374 | 
            +
                        raise
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                # NOTE: In nature, argument parser result is not needed to be sanitize
         | 
| 377 | 
            +
                #   However this will help us to detect program bug
         | 
| 378 | 
            +
                def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
         | 
| 379 | 
            +
                    try:
         | 
| 380 | 
            +
                        return self.argparse_config_validator(argparse_namespace)
         | 
| 381 | 
            +
                    except MultipleInvalid:
         | 
| 382 | 
            +
                        # XXX: this should be a bug
         | 
| 383 | 
            +
                        logger.error(
         | 
| 384 | 
            +
                            "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
         | 
| 385 | 
            +
                        )
         | 
| 386 | 
            +
                        raise
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                # NOTE: value would be overwritten by latter dict if there is already the same key
         | 
| 389 | 
            +
                @staticmethod
         | 
| 390 | 
            +
                def __merge_dict(*dict_list: dict) -> dict:
         | 
| 391 | 
            +
                    merged = {}
         | 
| 392 | 
            +
                    for schema in dict_list:
         | 
| 393 | 
            +
                        # merged |= schema
         | 
| 394 | 
            +
                        for k, v in schema.items():
         | 
| 395 | 
            +
                            merged[k] = v
         | 
| 396 | 
            +
                    return merged
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            class BlueprintGenerator:
         | 
| 400 | 
            +
                BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                def __init__(self, sanitizer: ConfigSanitizer):
         | 
| 403 | 
            +
                    self.sanitizer = sanitizer
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
         | 
| 406 | 
            +
                def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
         | 
| 407 | 
            +
                    sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
         | 
| 408 | 
            +
                    sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    # convert argparse namespace to dict like config
         | 
| 411 | 
            +
                    # NOTE: it is ok to have extra entries in dict
         | 
| 412 | 
            +
                    optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
         | 
| 413 | 
            +
                    argparse_config = {
         | 
| 414 | 
            +
                        optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
         | 
| 415 | 
            +
                    }
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                    general_config = sanitized_user_config.get("general", {})
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    dataset_blueprints = []
         | 
| 420 | 
            +
                    for dataset_config in sanitized_user_config.get("datasets", []):
         | 
| 421 | 
            +
                        # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
         | 
| 422 | 
            +
                        subsets = dataset_config.get("subsets", [])
         | 
| 423 | 
            +
                        is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
         | 
| 424 | 
            +
                        is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
         | 
| 425 | 
            +
                        if is_controlnet:
         | 
| 426 | 
            +
                            subset_params_klass = ControlNetSubsetParams
         | 
| 427 | 
            +
                            dataset_params_klass = ControlNetDatasetParams
         | 
| 428 | 
            +
                        elif is_dreambooth:
         | 
| 429 | 
            +
                            subset_params_klass = DreamBoothSubsetParams
         | 
| 430 | 
            +
                            dataset_params_klass = DreamBoothDatasetParams
         | 
| 431 | 
            +
                        else:
         | 
| 432 | 
            +
                            subset_params_klass = FineTuningSubsetParams
         | 
| 433 | 
            +
                            dataset_params_klass = FineTuningDatasetParams
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                        subset_blueprints = []
         | 
| 436 | 
            +
                        for subset_config in subsets:
         | 
| 437 | 
            +
                            params = self.generate_params_by_fallbacks(
         | 
| 438 | 
            +
                                subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
         | 
| 439 | 
            +
                            )
         | 
| 440 | 
            +
                            subset_blueprints.append(SubsetBlueprint(params))
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                        params = self.generate_params_by_fallbacks(
         | 
| 443 | 
            +
                            dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
         | 
| 444 | 
            +
                        )
         | 
| 445 | 
            +
                        dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    return Blueprint(dataset_group_blueprint)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                @staticmethod
         | 
| 452 | 
            +
                def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
         | 
| 453 | 
            +
                    name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
         | 
| 454 | 
            +
                    search_value = BlueprintGenerator.search_value
         | 
| 455 | 
            +
                    default_params = asdict(param_klass())
         | 
| 456 | 
            +
                    param_names = default_params.keys()
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    return param_klass(**params)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                @staticmethod
         | 
| 463 | 
            +
                def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
         | 
| 464 | 
            +
                    for cand in fallbacks:
         | 
| 465 | 
            +
                        value = cand.get(key)
         | 
| 466 | 
            +
                        if value is not None:
         | 
| 467 | 
            +
                            return value
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                    return default_value
         | 
| 470 | 
            +
             | 
| 471 | 
            +
             | 
| 472 | 
            +
            def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
         | 
| 473 | 
            +
                datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                for dataset_blueprint in dataset_group_blueprint.datasets:
         | 
| 476 | 
            +
                    if dataset_blueprint.is_controlnet:
         | 
| 477 | 
            +
                        subset_klass = ControlNetSubset
         | 
| 478 | 
            +
                        dataset_klass = ControlNetDataset
         | 
| 479 | 
            +
                    elif dataset_blueprint.is_dreambooth:
         | 
| 480 | 
            +
                        subset_klass = DreamBoothSubset
         | 
| 481 | 
            +
                        dataset_klass = DreamBoothDataset
         | 
| 482 | 
            +
                    else:
         | 
| 483 | 
            +
                        subset_klass = FineTuningSubset
         | 
| 484 | 
            +
                        dataset_klass = FineTuningDataset
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
         | 
| 487 | 
            +
                    dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
         | 
| 488 | 
            +
                    datasets.append(dataset)
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                # print info
         | 
| 491 | 
            +
                info = ""
         | 
| 492 | 
            +
                for i, dataset in enumerate(datasets):
         | 
| 493 | 
            +
                    is_dreambooth = isinstance(dataset, DreamBoothDataset)
         | 
| 494 | 
            +
                    is_controlnet = isinstance(dataset, ControlNetDataset)
         | 
| 495 | 
            +
                    info += dedent(
         | 
| 496 | 
            +
                        f"""\
         | 
| 497 | 
            +
                  [Dataset {i}]
         | 
| 498 | 
            +
                    batch_size: {dataset.batch_size}
         | 
| 499 | 
            +
                    resolution: {(dataset.width, dataset.height)}
         | 
| 500 | 
            +
                    enable_bucket: {dataset.enable_bucket}
         | 
| 501 | 
            +
                    network_multiplier: {dataset.network_multiplier}
         | 
| 502 | 
            +
                """
         | 
| 503 | 
            +
                    )
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    if dataset.enable_bucket:
         | 
| 506 | 
            +
                        info += indent(
         | 
| 507 | 
            +
                            dedent(
         | 
| 508 | 
            +
                                f"""\
         | 
| 509 | 
            +
                    min_bucket_reso: {dataset.min_bucket_reso}
         | 
| 510 | 
            +
                    max_bucket_reso: {dataset.max_bucket_reso}
         | 
| 511 | 
            +
                    bucket_reso_steps: {dataset.bucket_reso_steps}
         | 
| 512 | 
            +
                    bucket_no_upscale: {dataset.bucket_no_upscale}
         | 
| 513 | 
            +
                  \n"""
         | 
| 514 | 
            +
                            ),
         | 
| 515 | 
            +
                            "  ",
         | 
| 516 | 
            +
                        )
         | 
| 517 | 
            +
                    else:
         | 
| 518 | 
            +
                        info += "\n"
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    for j, subset in enumerate(dataset.subsets):
         | 
| 521 | 
            +
                        info += indent(
         | 
| 522 | 
            +
                            dedent(
         | 
| 523 | 
            +
                                f"""\
         | 
| 524 | 
            +
                    [Subset {j} of Dataset {i}]
         | 
| 525 | 
            +
                      image_dir: "{subset.image_dir}"
         | 
| 526 | 
            +
                      image_count: {subset.img_count}
         | 
| 527 | 
            +
                      num_repeats: {subset.num_repeats}
         | 
| 528 | 
            +
                      shuffle_caption: {subset.shuffle_caption}
         | 
| 529 | 
            +
                      keep_tokens: {subset.keep_tokens}
         | 
| 530 | 
            +
                      keep_tokens_separator: {subset.keep_tokens_separator}
         | 
| 531 | 
            +
                      caption_separator: {subset.caption_separator}
         | 
| 532 | 
            +
                      secondary_separator: {subset.secondary_separator}
         | 
| 533 | 
            +
                      enable_wildcard: {subset.enable_wildcard}
         | 
| 534 | 
            +
                      caption_dropout_rate: {subset.caption_dropout_rate}
         | 
| 535 | 
            +
                      caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
         | 
| 536 | 
            +
                      caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
         | 
| 537 | 
            +
                      caption_prefix: {subset.caption_prefix}
         | 
| 538 | 
            +
                      caption_suffix: {subset.caption_suffix}
         | 
| 539 | 
            +
                      color_aug: {subset.color_aug}
         | 
| 540 | 
            +
                      flip_aug: {subset.flip_aug}
         | 
| 541 | 
            +
                      face_crop_aug_range: {subset.face_crop_aug_range}
         | 
| 542 | 
            +
                      random_crop: {subset.random_crop}
         | 
| 543 | 
            +
                      token_warmup_min: {subset.token_warmup_min},
         | 
| 544 | 
            +
                      token_warmup_step: {subset.token_warmup_step},
         | 
| 545 | 
            +
                      alpha_mask: {subset.alpha_mask},
         | 
| 546 | 
            +
                  """
         | 
| 547 | 
            +
                            ),
         | 
| 548 | 
            +
                            "  ",
         | 
| 549 | 
            +
                        )
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                        if is_dreambooth:
         | 
| 552 | 
            +
                            info += indent(
         | 
| 553 | 
            +
                                dedent(
         | 
| 554 | 
            +
                                    f"""\
         | 
| 555 | 
            +
                      is_reg: {subset.is_reg}
         | 
| 556 | 
            +
                      class_tokens: {subset.class_tokens}
         | 
| 557 | 
            +
                      caption_extension: {subset.caption_extension}
         | 
| 558 | 
            +
                    \n"""
         | 
| 559 | 
            +
                                ),
         | 
| 560 | 
            +
                                "    ",
         | 
| 561 | 
            +
                            )
         | 
| 562 | 
            +
                        elif not is_controlnet:
         | 
| 563 | 
            +
                            info += indent(
         | 
| 564 | 
            +
                                dedent(
         | 
| 565 | 
            +
                                    f"""\
         | 
| 566 | 
            +
                      metadata_file: {subset.metadata_file}
         | 
| 567 | 
            +
                    \n"""
         | 
| 568 | 
            +
                                ),
         | 
| 569 | 
            +
                                "    ",
         | 
| 570 | 
            +
                            )
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                logger.info(f"{info}")
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                # make buckets first because it determines the length of dataset
         | 
| 575 | 
            +
                # and set the same seed for all datasets
         | 
| 576 | 
            +
                seed = random.randint(0, 2**31)  # actual seed is seed + epoch_no
         | 
| 577 | 
            +
                for i, dataset in enumerate(datasets):
         | 
| 578 | 
            +
                    logger.info(f"[Dataset {i}]")
         | 
| 579 | 
            +
                    dataset.make_buckets()
         | 
| 580 | 
            +
                    dataset.set_seed(seed)
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                return DatasetGroup(datasets)
         | 
| 583 | 
            +
             | 
| 584 | 
            +
             | 
| 585 | 
            +
            def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
         | 
| 586 | 
            +
                def extract_dreambooth_params(name: str) -> Tuple[int, str]:
         | 
| 587 | 
            +
                    tokens = name.split("_")
         | 
| 588 | 
            +
                    try:
         | 
| 589 | 
            +
                        n_repeats = int(tokens[0])
         | 
| 590 | 
            +
                    except ValueError as e:
         | 
| 591 | 
            +
                        logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
         | 
| 592 | 
            +
                        return 0, ""
         | 
| 593 | 
            +
                    caption_by_folder = "_".join(tokens[1:])
         | 
| 594 | 
            +
                    return n_repeats, caption_by_folder
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                def generate(base_dir: Optional[str], is_reg: bool):
         | 
| 597 | 
            +
                    if base_dir is None:
         | 
| 598 | 
            +
                        return []
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    base_dir: Path = Path(base_dir)
         | 
| 601 | 
            +
                    if not base_dir.is_dir():
         | 
| 602 | 
            +
                        return []
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                    subsets_config = []
         | 
| 605 | 
            +
                    for subdir in base_dir.iterdir():
         | 
| 606 | 
            +
                        if not subdir.is_dir():
         | 
| 607 | 
            +
                            continue
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                        num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
         | 
| 610 | 
            +
                        if num_repeats < 1:
         | 
| 611 | 
            +
                            continue
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                        subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
         | 
| 614 | 
            +
                        subsets_config.append(subset_config)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                    return subsets_config
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                subsets_config = []
         | 
| 619 | 
            +
                subsets_config += generate(train_data_dir, False)
         | 
| 620 | 
            +
                subsets_config += generate(reg_data_dir, True)
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                return subsets_config
         | 
| 623 | 
            +
             | 
| 624 | 
            +
             | 
| 625 | 
            +
            def generate_controlnet_subsets_config_by_subdirs(
         | 
| 626 | 
            +
                train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
         | 
| 627 | 
            +
            ):
         | 
| 628 | 
            +
                def generate(base_dir: Optional[str]):
         | 
| 629 | 
            +
                    if base_dir is None:
         | 
| 630 | 
            +
                        return []
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    base_dir: Path = Path(base_dir)
         | 
| 633 | 
            +
                    if not base_dir.is_dir():
         | 
| 634 | 
            +
                        return []
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                    subsets_config = []
         | 
| 637 | 
            +
                    subset_config = {
         | 
| 638 | 
            +
                        "image_dir": train_data_dir,
         | 
| 639 | 
            +
                        "conditioning_data_dir": conditioning_data_dir,
         | 
| 640 | 
            +
                        "caption_extension": caption_extension,
         | 
| 641 | 
            +
                        "num_repeats": 1,
         | 
| 642 | 
            +
                    }
         | 
| 643 | 
            +
                    subsets_config.append(subset_config)
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                    return subsets_config
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                subsets_config = []
         | 
| 648 | 
            +
                subsets_config += generate(train_data_dir)
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                return subsets_config
         | 
| 651 | 
            +
             | 
| 652 | 
            +
             | 
| 653 | 
            +
            def load_user_config(file: str) -> dict:
         | 
| 654 | 
            +
                file: Path = Path(file)
         | 
| 655 | 
            +
                if not file.is_file():
         | 
| 656 | 
            +
                    raise ValueError(f"file not found / ファイルが見つかりません: {file}")
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                if file.name.lower().endswith(".json"):
         | 
| 659 | 
            +
                    try:
         | 
| 660 | 
            +
                        with open(file, "r") as f:
         | 
| 661 | 
            +
                            config = json.load(f)
         | 
| 662 | 
            +
                    except Exception:
         | 
| 663 | 
            +
                        logger.error(
         | 
| 664 | 
            +
                            f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
         | 
| 665 | 
            +
                        )
         | 
| 666 | 
            +
                        raise
         | 
| 667 | 
            +
                elif file.name.lower().endswith(".toml"):
         | 
| 668 | 
            +
                    try:
         | 
| 669 | 
            +
                        config = toml.load(file)
         | 
| 670 | 
            +
                    except Exception:
         | 
| 671 | 
            +
                        logger.error(
         | 
| 672 | 
            +
                            f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
         | 
| 673 | 
            +
                        )
         | 
| 674 | 
            +
                        raise
         | 
| 675 | 
            +
                else:
         | 
| 676 | 
            +
                    raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                return config
         | 
| 679 | 
            +
             | 
| 680 | 
            +
             | 
| 681 | 
            +
            # for config test
         | 
| 682 | 
            +
            if __name__ == "__main__":
         | 
| 683 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 684 | 
            +
                parser.add_argument("--support_dreambooth", action="store_true")
         | 
| 685 | 
            +
                parser.add_argument("--support_finetuning", action="store_true")
         | 
| 686 | 
            +
                parser.add_argument("--support_controlnet", action="store_true")
         | 
| 687 | 
            +
                parser.add_argument("--support_dropout", action="store_true")
         | 
| 688 | 
            +
                parser.add_argument("dataset_config")
         | 
| 689 | 
            +
                config_args, remain = parser.parse_known_args()
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 692 | 
            +
                train_util.add_dataset_arguments(
         | 
| 693 | 
            +
                    parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
         | 
| 694 | 
            +
                )
         | 
| 695 | 
            +
                train_util.add_training_arguments(parser, config_args.support_dreambooth)
         | 
| 696 | 
            +
                argparse_namespace = parser.parse_args(remain)
         | 
| 697 | 
            +
                train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
         | 
| 698 | 
            +
             | 
| 699 | 
            +
                logger.info("[argparse_namespace]")
         | 
| 700 | 
            +
                logger.info(f"{vars(argparse_namespace)}")
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                user_config = load_user_config(config_args.dataset_config)
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                logger.info("")
         | 
| 705 | 
            +
                logger.info("[user_config]")
         | 
| 706 | 
            +
                logger.info(f"{user_config}")
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                sanitizer = ConfigSanitizer(
         | 
| 709 | 
            +
                    config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
         | 
| 710 | 
            +
                )
         | 
| 711 | 
            +
                sanitized_user_config = sanitizer.sanitize_user_config(user_config)
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                logger.info("")
         | 
| 714 | 
            +
                logger.info("[sanitized_user_config]")
         | 
| 715 | 
            +
                logger.info(f"{sanitized_user_config}")
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                logger.info("")
         | 
| 720 | 
            +
                logger.info("[blueprint]")
         | 
| 721 | 
            +
                logger.info(f"{blueprint}")
         | 
    	
        library/custom_train_functions.py
    ADDED
    
    | @@ -0,0 +1,559 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
            from typing import List, Optional, Union
         | 
| 6 | 
            +
            from .utils import setup_logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            setup_logging()
         | 
| 9 | 
            +
            import logging
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def prepare_scheduler_for_custom_training(noise_scheduler, device):
         | 
| 15 | 
            +
                if hasattr(noise_scheduler, "all_snr"):
         | 
| 16 | 
            +
                    return
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                alphas_cumprod = noise_scheduler.alphas_cumprod
         | 
| 19 | 
            +
                sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
         | 
| 20 | 
            +
                sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
         | 
| 21 | 
            +
                alpha = sqrt_alphas_cumprod
         | 
| 22 | 
            +
                sigma = sqrt_one_minus_alphas_cumprod
         | 
| 23 | 
            +
                all_snr = (alpha / sigma) ** 2
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                noise_scheduler.all_snr = all_snr.to(device)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
         | 
| 29 | 
            +
                # fix beta: zero terminal SNR
         | 
| 30 | 
            +
                logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def enforce_zero_terminal_snr(betas):
         | 
| 33 | 
            +
                    # Convert betas to alphas_bar_sqrt
         | 
| 34 | 
            +
                    alphas = 1 - betas
         | 
| 35 | 
            +
                    alphas_bar = alphas.cumprod(0)
         | 
| 36 | 
            +
                    alphas_bar_sqrt = alphas_bar.sqrt()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    # Store old values.
         | 
| 39 | 
            +
                    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
         | 
| 40 | 
            +
                    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
         | 
| 41 | 
            +
                    # Shift so last timestep is zero.
         | 
| 42 | 
            +
                    alphas_bar_sqrt -= alphas_bar_sqrt_T
         | 
| 43 | 
            +
                    # Scale so first timestep is back to old value.
         | 
| 44 | 
            +
                    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # Convert alphas_bar_sqrt to betas
         | 
| 47 | 
            +
                    alphas_bar = alphas_bar_sqrt**2
         | 
| 48 | 
            +
                    alphas = alphas_bar[1:] / alphas_bar[:-1]
         | 
| 49 | 
            +
                    alphas = torch.cat([alphas_bar[0:1], alphas])
         | 
| 50 | 
            +
                    betas = 1 - alphas
         | 
| 51 | 
            +
                    return betas
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                betas = noise_scheduler.betas
         | 
| 54 | 
            +
                betas = enforce_zero_terminal_snr(betas)
         | 
| 55 | 
            +
                alphas = 1.0 - betas
         | 
| 56 | 
            +
                alphas_cumprod = torch.cumprod(alphas, dim=0)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # logger.info(f"original: {noise_scheduler.betas}")
         | 
| 59 | 
            +
                # logger.info(f"fixed: {betas}")
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                noise_scheduler.betas = betas
         | 
| 62 | 
            +
                noise_scheduler.alphas = alphas
         | 
| 63 | 
            +
                noise_scheduler.alphas_cumprod = alphas_cumprod
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
         | 
| 67 | 
            +
                snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
         | 
| 68 | 
            +
                min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
         | 
| 69 | 
            +
                if v_prediction:
         | 
| 70 | 
            +
                    snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
         | 
| 71 | 
            +
                else:
         | 
| 72 | 
            +
                    snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
         | 
| 73 | 
            +
                loss = loss * snr_weight
         | 
| 74 | 
            +
                return loss
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
         | 
| 78 | 
            +
                scale = get_snr_scale(timesteps, noise_scheduler)
         | 
| 79 | 
            +
                loss = loss * scale
         | 
| 80 | 
            +
                return loss
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def get_snr_scale(timesteps, noise_scheduler):
         | 
| 84 | 
            +
                snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])  # batch_size
         | 
| 85 | 
            +
                snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)  # if timestep is 0, snr_t is inf, so limit it to 1000
         | 
| 86 | 
            +
                scale = snr_t / (snr_t + 1)
         | 
| 87 | 
            +
                # # show debug info
         | 
| 88 | 
            +
                # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
         | 
| 89 | 
            +
                return scale
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
         | 
| 93 | 
            +
                scale = get_snr_scale(timesteps, noise_scheduler)
         | 
| 94 | 
            +
                # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
         | 
| 95 | 
            +
                loss = loss + loss / scale * v_pred_like_loss
         | 
| 96 | 
            +
                return loss
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
         | 
| 100 | 
            +
                snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])  # batch_size
         | 
| 101 | 
            +
                snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)  # if timestep is 0, snr_t is inf, so limit it to 1000
         | 
| 102 | 
            +
                if v_prediction:
         | 
| 103 | 
            +
                    weight = 1 / (snr_t + 1)
         | 
| 104 | 
            +
                else:
         | 
| 105 | 
            +
                    weight = 1 / torch.sqrt(snr_t)
         | 
| 106 | 
            +
                loss = weight * loss
         | 
| 107 | 
            +
                return loss
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            # TODO train_utilと分散しているのでどちらかに寄せる
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
         | 
| 114 | 
            +
                parser.add_argument(
         | 
| 115 | 
            +
                    "--min_snr_gamma",
         | 
| 116 | 
            +
                    type=float,
         | 
| 117 | 
            +
                    default=None,
         | 
| 118 | 
            +
                    help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
         | 
| 119 | 
            +
                )
         | 
| 120 | 
            +
                parser.add_argument(
         | 
| 121 | 
            +
                    "--scale_v_pred_loss_like_noise_pred",
         | 
| 122 | 
            +
                    action="store_true",
         | 
| 123 | 
            +
                    help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
         | 
| 124 | 
            +
                )
         | 
| 125 | 
            +
                parser.add_argument(
         | 
| 126 | 
            +
                    "--v_pred_like_loss",
         | 
| 127 | 
            +
                    type=float,
         | 
| 128 | 
            +
                    default=None,
         | 
| 129 | 
            +
                    help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
         | 
| 130 | 
            +
                )
         | 
| 131 | 
            +
                parser.add_argument(
         | 
| 132 | 
            +
                    "--debiased_estimation_loss",
         | 
| 133 | 
            +
                    action="store_true",
         | 
| 134 | 
            +
                    help="debiased estimation loss / debiased estimation loss",
         | 
| 135 | 
            +
                )
         | 
| 136 | 
            +
                if support_weighted_captions:
         | 
| 137 | 
            +
                    parser.add_argument(
         | 
| 138 | 
            +
                        "--weighted_captions",
         | 
| 139 | 
            +
                        action="store_true",
         | 
| 140 | 
            +
                        default=False,
         | 
| 141 | 
            +
                        help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            re_attention = re.compile(
         | 
| 146 | 
            +
                r"""
         | 
| 147 | 
            +
            \\\(|
         | 
| 148 | 
            +
            \\\)|
         | 
| 149 | 
            +
            \\\[|
         | 
| 150 | 
            +
            \\]|
         | 
| 151 | 
            +
            \\\\|
         | 
| 152 | 
            +
            \\|
         | 
| 153 | 
            +
            \(|
         | 
| 154 | 
            +
            \[|
         | 
| 155 | 
            +
            :([+-]?[.\d]+)\)|
         | 
| 156 | 
            +
            \)|
         | 
| 157 | 
            +
            ]|
         | 
| 158 | 
            +
            [^\\()\[\]:]+|
         | 
| 159 | 
            +
            :
         | 
| 160 | 
            +
            """,
         | 
| 161 | 
            +
                re.X,
         | 
| 162 | 
            +
            )
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def parse_prompt_attention(text):
         | 
| 166 | 
            +
                """
         | 
| 167 | 
            +
                Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
         | 
| 168 | 
            +
                Accepted tokens are:
         | 
| 169 | 
            +
                  (abc) - increases attention to abc by a multiplier of 1.1
         | 
| 170 | 
            +
                  (abc:3.12) - increases attention to abc by a multiplier of 3.12
         | 
| 171 | 
            +
                  [abc] - decreases attention to abc by a multiplier of 1.1
         | 
| 172 | 
            +
                  \( - literal character '('
         | 
| 173 | 
            +
                  \[ - literal character '['
         | 
| 174 | 
            +
                  \) - literal character ')'
         | 
| 175 | 
            +
                  \] - literal character ']'
         | 
| 176 | 
            +
                  \\ - literal character '\'
         | 
| 177 | 
            +
                  anything else - just text
         | 
| 178 | 
            +
                >>> parse_prompt_attention('normal text')
         | 
| 179 | 
            +
                [['normal text', 1.0]]
         | 
| 180 | 
            +
                >>> parse_prompt_attention('an (important) word')
         | 
| 181 | 
            +
                [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
         | 
| 182 | 
            +
                >>> parse_prompt_attention('(unbalanced')
         | 
| 183 | 
            +
                [['unbalanced', 1.1]]
         | 
| 184 | 
            +
                >>> parse_prompt_attention('\(literal\]')
         | 
| 185 | 
            +
                [['(literal]', 1.0]]
         | 
| 186 | 
            +
                >>> parse_prompt_attention('(unnecessary)(parens)')
         | 
| 187 | 
            +
                [['unnecessaryparens', 1.1]]
         | 
| 188 | 
            +
                >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
         | 
| 189 | 
            +
                [['a ', 1.0],
         | 
| 190 | 
            +
                 ['house', 1.5730000000000004],
         | 
| 191 | 
            +
                 [' ', 1.1],
         | 
| 192 | 
            +
                 ['on', 1.0],
         | 
| 193 | 
            +
                 [' a ', 1.1],
         | 
| 194 | 
            +
                 ['hill', 0.55],
         | 
| 195 | 
            +
                 [', sun, ', 1.1],
         | 
| 196 | 
            +
                 ['sky', 1.4641000000000006],
         | 
| 197 | 
            +
                 ['.', 1.1]]
         | 
| 198 | 
            +
                """
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                res = []
         | 
| 201 | 
            +
                round_brackets = []
         | 
| 202 | 
            +
                square_brackets = []
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                round_bracket_multiplier = 1.1
         | 
| 205 | 
            +
                square_bracket_multiplier = 1 / 1.1
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                def multiply_range(start_position, multiplier):
         | 
| 208 | 
            +
                    for p in range(start_position, len(res)):
         | 
| 209 | 
            +
                        res[p][1] *= multiplier
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                for m in re_attention.finditer(text):
         | 
| 212 | 
            +
                    text = m.group(0)
         | 
| 213 | 
            +
                    weight = m.group(1)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    if text.startswith("\\"):
         | 
| 216 | 
            +
                        res.append([text[1:], 1.0])
         | 
| 217 | 
            +
                    elif text == "(":
         | 
| 218 | 
            +
                        round_brackets.append(len(res))
         | 
| 219 | 
            +
                    elif text == "[":
         | 
| 220 | 
            +
                        square_brackets.append(len(res))
         | 
| 221 | 
            +
                    elif weight is not None and len(round_brackets) > 0:
         | 
| 222 | 
            +
                        multiply_range(round_brackets.pop(), float(weight))
         | 
| 223 | 
            +
                    elif text == ")" and len(round_brackets) > 0:
         | 
| 224 | 
            +
                        multiply_range(round_brackets.pop(), round_bracket_multiplier)
         | 
| 225 | 
            +
                    elif text == "]" and len(square_brackets) > 0:
         | 
| 226 | 
            +
                        multiply_range(square_brackets.pop(), square_bracket_multiplier)
         | 
| 227 | 
            +
                    else:
         | 
| 228 | 
            +
                        res.append([text, 1.0])
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                for pos in round_brackets:
         | 
| 231 | 
            +
                    multiply_range(pos, round_bracket_multiplier)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                for pos in square_brackets:
         | 
| 234 | 
            +
                    multiply_range(pos, square_bracket_multiplier)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                if len(res) == 0:
         | 
| 237 | 
            +
                    res = [["", 1.0]]
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                # merge runs of identical weights
         | 
| 240 | 
            +
                i = 0
         | 
| 241 | 
            +
                while i + 1 < len(res):
         | 
| 242 | 
            +
                    if res[i][1] == res[i + 1][1]:
         | 
| 243 | 
            +
                        res[i][0] += res[i + 1][0]
         | 
| 244 | 
            +
                        res.pop(i + 1)
         | 
| 245 | 
            +
                    else:
         | 
| 246 | 
            +
                        i += 1
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                return res
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
         | 
| 252 | 
            +
                r"""
         | 
| 253 | 
            +
                Tokenize a list of prompts and return its tokens with weights of each token.
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                No padding, starting or ending token is included.
         | 
| 256 | 
            +
                """
         | 
| 257 | 
            +
                tokens = []
         | 
| 258 | 
            +
                weights = []
         | 
| 259 | 
            +
                truncated = False
         | 
| 260 | 
            +
                for text in prompt:
         | 
| 261 | 
            +
                    texts_and_weights = parse_prompt_attention(text)
         | 
| 262 | 
            +
                    text_token = []
         | 
| 263 | 
            +
                    text_weight = []
         | 
| 264 | 
            +
                    for word, weight in texts_and_weights:
         | 
| 265 | 
            +
                        # tokenize and discard the starting and the ending token
         | 
| 266 | 
            +
                        token = tokenizer(word).input_ids[1:-1]
         | 
| 267 | 
            +
                        text_token += token
         | 
| 268 | 
            +
                        # copy the weight by length of token
         | 
| 269 | 
            +
                        text_weight += [weight] * len(token)
         | 
| 270 | 
            +
                        # stop if the text is too long (longer than truncation limit)
         | 
| 271 | 
            +
                        if len(text_token) > max_length:
         | 
| 272 | 
            +
                            truncated = True
         | 
| 273 | 
            +
                            break
         | 
| 274 | 
            +
                    # truncate
         | 
| 275 | 
            +
                    if len(text_token) > max_length:
         | 
| 276 | 
            +
                        truncated = True
         | 
| 277 | 
            +
                        text_token = text_token[:max_length]
         | 
| 278 | 
            +
                        text_weight = text_weight[:max_length]
         | 
| 279 | 
            +
                    tokens.append(text_token)
         | 
| 280 | 
            +
                    weights.append(text_weight)
         | 
| 281 | 
            +
                if truncated:
         | 
| 282 | 
            +
                    logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
         | 
| 283 | 
            +
                return tokens, weights
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
            def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
         | 
| 287 | 
            +
                r"""
         | 
| 288 | 
            +
                Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
         | 
| 289 | 
            +
                """
         | 
| 290 | 
            +
                max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
         | 
| 291 | 
            +
                weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
         | 
| 292 | 
            +
                for i in range(len(tokens)):
         | 
| 293 | 
            +
                    tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
         | 
| 294 | 
            +
                    if no_boseos_middle:
         | 
| 295 | 
            +
                        weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
         | 
| 296 | 
            +
                    else:
         | 
| 297 | 
            +
                        w = []
         | 
| 298 | 
            +
                        if len(weights[i]) == 0:
         | 
| 299 | 
            +
                            w = [1.0] * weights_length
         | 
| 300 | 
            +
                        else:
         | 
| 301 | 
            +
                            for j in range(max_embeddings_multiples):
         | 
| 302 | 
            +
                                w.append(1.0)  # weight for starting token in this chunk
         | 
| 303 | 
            +
                                w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
         | 
| 304 | 
            +
                                w.append(1.0)  # weight for ending token in this chunk
         | 
| 305 | 
            +
                            w += [1.0] * (weights_length - len(w))
         | 
| 306 | 
            +
                        weights[i] = w[:]
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                return tokens, weights
         | 
| 309 | 
            +
             | 
| 310 | 
            +
             | 
| 311 | 
            +
            def get_unweighted_text_embeddings(
         | 
| 312 | 
            +
                tokenizer,
         | 
| 313 | 
            +
                text_encoder,
         | 
| 314 | 
            +
                text_input: torch.Tensor,
         | 
| 315 | 
            +
                chunk_length: int,
         | 
| 316 | 
            +
                clip_skip: int,
         | 
| 317 | 
            +
                eos: int,
         | 
| 318 | 
            +
                pad: int,
         | 
| 319 | 
            +
                no_boseos_middle: Optional[bool] = True,
         | 
| 320 | 
            +
            ):
         | 
| 321 | 
            +
                """
         | 
| 322 | 
            +
                When the length of tokens is a multiple of the capacity of the text encoder,
         | 
| 323 | 
            +
                it should be split into chunks and sent to the text encoder individually.
         | 
| 324 | 
            +
                """
         | 
| 325 | 
            +
                max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
         | 
| 326 | 
            +
                if max_embeddings_multiples > 1:
         | 
| 327 | 
            +
                    text_embeddings = []
         | 
| 328 | 
            +
                    for i in range(max_embeddings_multiples):
         | 
| 329 | 
            +
                        # extract the i-th chunk
         | 
| 330 | 
            +
                        text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                        # cover the head and the tail by the starting and the ending tokens
         | 
| 333 | 
            +
                        text_input_chunk[:, 0] = text_input[0, 0]
         | 
| 334 | 
            +
                        if pad == eos:  # v1
         | 
| 335 | 
            +
                            text_input_chunk[:, -1] = text_input[0, -1]
         | 
| 336 | 
            +
                        else:  # v2
         | 
| 337 | 
            +
                            for j in range(len(text_input_chunk)):
         | 
| 338 | 
            +
                                if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad:  # 最後に普通の文字がある
         | 
| 339 | 
            +
                                    text_input_chunk[j, -1] = eos
         | 
| 340 | 
            +
                                if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
         | 
| 341 | 
            +
                                    text_input_chunk[j, 1] = eos
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                        if clip_skip is None or clip_skip == 1:
         | 
| 344 | 
            +
                            text_embedding = text_encoder(text_input_chunk)[0]
         | 
| 345 | 
            +
                        else:
         | 
| 346 | 
            +
                            enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
         | 
| 347 | 
            +
                            text_embedding = enc_out["hidden_states"][-clip_skip]
         | 
| 348 | 
            +
                            text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                        if no_boseos_middle:
         | 
| 351 | 
            +
                            if i == 0:
         | 
| 352 | 
            +
                                # discard the ending token
         | 
| 353 | 
            +
                                text_embedding = text_embedding[:, :-1]
         | 
| 354 | 
            +
                            elif i == max_embeddings_multiples - 1:
         | 
| 355 | 
            +
                                # discard the starting token
         | 
| 356 | 
            +
                                text_embedding = text_embedding[:, 1:]
         | 
| 357 | 
            +
                            else:
         | 
| 358 | 
            +
                                # discard both starting and ending tokens
         | 
| 359 | 
            +
                                text_embedding = text_embedding[:, 1:-1]
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                        text_embeddings.append(text_embedding)
         | 
| 362 | 
            +
                    text_embeddings = torch.concat(text_embeddings, axis=1)
         | 
| 363 | 
            +
                else:
         | 
| 364 | 
            +
                    if clip_skip is None or clip_skip == 1:
         | 
| 365 | 
            +
                        text_embeddings = text_encoder(text_input)[0]
         | 
| 366 | 
            +
                    else:
         | 
| 367 | 
            +
                        enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
         | 
| 368 | 
            +
                        text_embeddings = enc_out["hidden_states"][-clip_skip]
         | 
| 369 | 
            +
                        text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
         | 
| 370 | 
            +
                return text_embeddings
         | 
| 371 | 
            +
             | 
| 372 | 
            +
             | 
| 373 | 
            +
            def get_weighted_text_embeddings(
         | 
| 374 | 
            +
                tokenizer,
         | 
| 375 | 
            +
                text_encoder,
         | 
| 376 | 
            +
                prompt: Union[str, List[str]],
         | 
| 377 | 
            +
                device,
         | 
| 378 | 
            +
                max_embeddings_multiples: Optional[int] = 3,
         | 
| 379 | 
            +
                no_boseos_middle: Optional[bool] = False,
         | 
| 380 | 
            +
                clip_skip=None,
         | 
| 381 | 
            +
            ):
         | 
| 382 | 
            +
                r"""
         | 
| 383 | 
            +
                Prompts can be assigned with local weights using brackets. For example,
         | 
| 384 | 
            +
                prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
         | 
| 385 | 
            +
                and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                Args:
         | 
| 390 | 
            +
                    prompt (`str` or `List[str]`):
         | 
| 391 | 
            +
                        The prompt or prompts to guide the image generation.
         | 
| 392 | 
            +
                    max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 393 | 
            +
                        The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 394 | 
            +
                    no_boseos_middle (`bool`, *optional*, defaults to `False`):
         | 
| 395 | 
            +
                        If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
         | 
| 396 | 
            +
                        ending token in each of the chunk in the middle.
         | 
| 397 | 
            +
                    skip_parsing (`bool`, *optional*, defaults to `False`):
         | 
| 398 | 
            +
                        Skip the parsing of brackets.
         | 
| 399 | 
            +
                    skip_weighting (`bool`, *optional*, defaults to `False`):
         | 
| 400 | 
            +
                        Skip the weighting. When the parsing is skipped, it is forced True.
         | 
| 401 | 
            +
                """
         | 
| 402 | 
            +
                max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 403 | 
            +
                if isinstance(prompt, str):
         | 
| 404 | 
            +
                    prompt = [prompt]
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                # round up the longest length of tokens to a multiple of (model_max_length - 2)
         | 
| 409 | 
            +
                max_length = max([len(token) for token in prompt_tokens])
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                max_embeddings_multiples = min(
         | 
| 412 | 
            +
                    max_embeddings_multiples,
         | 
| 413 | 
            +
                    (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
         | 
| 414 | 
            +
                )
         | 
| 415 | 
            +
                max_embeddings_multiples = max(1, max_embeddings_multiples)
         | 
| 416 | 
            +
                max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                # pad the length of tokens and weights
         | 
| 419 | 
            +
                bos = tokenizer.bos_token_id
         | 
| 420 | 
            +
                eos = tokenizer.eos_token_id
         | 
| 421 | 
            +
                pad = tokenizer.pad_token_id
         | 
| 422 | 
            +
                prompt_tokens, prompt_weights = pad_tokens_and_weights(
         | 
| 423 | 
            +
                    prompt_tokens,
         | 
| 424 | 
            +
                    prompt_weights,
         | 
| 425 | 
            +
                    max_length,
         | 
| 426 | 
            +
                    bos,
         | 
| 427 | 
            +
                    eos,
         | 
| 428 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 429 | 
            +
                    chunk_length=tokenizer.model_max_length,
         | 
| 430 | 
            +
                )
         | 
| 431 | 
            +
                prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                # get the embeddings
         | 
| 434 | 
            +
                text_embeddings = get_unweighted_text_embeddings(
         | 
| 435 | 
            +
                    tokenizer,
         | 
| 436 | 
            +
                    text_encoder,
         | 
| 437 | 
            +
                    prompt_tokens,
         | 
| 438 | 
            +
                    tokenizer.model_max_length,
         | 
| 439 | 
            +
                    clip_skip,
         | 
| 440 | 
            +
                    eos,
         | 
| 441 | 
            +
                    pad,
         | 
| 442 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 443 | 
            +
                )
         | 
| 444 | 
            +
                prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                # assign weights to the prompts and normalize in the sense of mean
         | 
| 447 | 
            +
                previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 448 | 
            +
                text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
         | 
| 449 | 
            +
                current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 450 | 
            +
                text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                return text_embeddings
         | 
| 453 | 
            +
             | 
| 454 | 
            +
             | 
| 455 | 
            +
            # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
         | 
| 456 | 
            +
            def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
         | 
| 457 | 
            +
                b, c, w, h = noise.shape  # EDIT: w and h get over-written, rename for a different variant!
         | 
| 458 | 
            +
                u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
         | 
| 459 | 
            +
                for i in range(iterations):
         | 
| 460 | 
            +
                    r = random.random() * 2 + 2  # Rather than always going 2x,
         | 
| 461 | 
            +
                    wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
         | 
| 462 | 
            +
                    noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
         | 
| 463 | 
            +
                    if wn == 1 or hn == 1:
         | 
| 464 | 
            +
                        break  # Lowest resolution is 1x1
         | 
| 465 | 
            +
                return noise / noise.std()  # Scaled back to roughly unit variance
         | 
| 466 | 
            +
             | 
| 467 | 
            +
             | 
| 468 | 
            +
            # https://www.crosslabs.org//blog/diffusion-with-offset-noise
         | 
| 469 | 
            +
            def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
         | 
| 470 | 
            +
                if noise_offset is None:
         | 
| 471 | 
            +
                    return noise
         | 
| 472 | 
            +
                if adaptive_noise_scale is not None:
         | 
| 473 | 
            +
                    # latent shape: (batch_size, channels, height, width)
         | 
| 474 | 
            +
                    # abs mean value for each channel
         | 
| 475 | 
            +
                    latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    # multiply adaptive noise scale to the mean value and add it to the noise offset
         | 
| 478 | 
            +
                    noise_offset = noise_offset + adaptive_noise_scale * latent_mean
         | 
| 479 | 
            +
                    noise_offset = torch.clamp(noise_offset, 0.0, None)  # in case of adaptive noise scale is negative
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
         | 
| 482 | 
            +
                return noise
         | 
| 483 | 
            +
             | 
| 484 | 
            +
             | 
| 485 | 
            +
            def apply_masked_loss(loss, batch):
         | 
| 486 | 
            +
                if "conditioning_images" in batch:
         | 
| 487 | 
            +
                    # conditioning image is -1 to 1. we need to convert it to 0 to 1
         | 
| 488 | 
            +
                    mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1)  # use R channel
         | 
| 489 | 
            +
                    mask_image = mask_image / 2 + 0.5
         | 
| 490 | 
            +
                    # print(f"conditioning_image: {mask_image.shape}")
         | 
| 491 | 
            +
                elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
         | 
| 492 | 
            +
                    # alpha mask is 0 to 1
         | 
| 493 | 
            +
                    mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
         | 
| 494 | 
            +
                    # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
         | 
| 495 | 
            +
                else:
         | 
| 496 | 
            +
                    return loss
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                # resize to the same size as the loss
         | 
| 499 | 
            +
                mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
         | 
| 500 | 
            +
                loss = loss * mask_image
         | 
| 501 | 
            +
                return loss
         | 
| 502 | 
            +
             | 
| 503 | 
            +
             | 
| 504 | 
            +
            """
         | 
| 505 | 
            +
            ##########################################
         | 
| 506 | 
            +
            # Perlin Noise
         | 
| 507 | 
            +
            def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
         | 
| 508 | 
            +
                delta = (res[0] / shape[0], res[1] / shape[1])
         | 
| 509 | 
            +
                d = (shape[0] // res[0], shape[1] // res[1])
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                grid = (
         | 
| 512 | 
            +
                    torch.stack(
         | 
| 513 | 
            +
                        torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
         | 
| 514 | 
            +
                        dim=-1,
         | 
| 515 | 
            +
                    )
         | 
| 516 | 
            +
                    % 1
         | 
| 517 | 
            +
                )
         | 
| 518 | 
            +
                angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
         | 
| 519 | 
            +
                gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                tile_grads = (
         | 
| 522 | 
            +
                    lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
         | 
| 523 | 
            +
                    .repeat_interleave(d[0], 0)
         | 
| 524 | 
            +
                    .repeat_interleave(d[1], 1)
         | 
| 525 | 
            +
                )
         | 
| 526 | 
            +
                dot = lambda grad, shift: (
         | 
| 527 | 
            +
                    torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
         | 
| 528 | 
            +
                    * grad[: shape[0], : shape[1]]
         | 
| 529 | 
            +
                ).sum(dim=-1)
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
         | 
| 532 | 
            +
                n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
         | 
| 533 | 
            +
                n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
         | 
| 534 | 
            +
                n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
         | 
| 535 | 
            +
                t = fade(grid[: shape[0], : shape[1]])
         | 
| 536 | 
            +
                return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
         | 
| 537 | 
            +
             | 
| 538 | 
            +
             | 
| 539 | 
            +
            def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
         | 
| 540 | 
            +
                noise = torch.zeros(shape, device=device)
         | 
| 541 | 
            +
                frequency = 1
         | 
| 542 | 
            +
                amplitude = 1
         | 
| 543 | 
            +
                for _ in range(octaves):
         | 
| 544 | 
            +
                    noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
         | 
| 545 | 
            +
                    frequency *= 2
         | 
| 546 | 
            +
                    amplitude *= persistence
         | 
| 547 | 
            +
                return noise
         | 
| 548 | 
            +
             | 
| 549 | 
            +
             | 
| 550 | 
            +
            def perlin_noise(noise, device, octaves):
         | 
| 551 | 
            +
                _, c, w, h = noise.shape
         | 
| 552 | 
            +
                perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
         | 
| 553 | 
            +
                noise_perlin = []
         | 
| 554 | 
            +
                for _ in range(c):
         | 
| 555 | 
            +
                    noise_perlin.append(perlin())
         | 
| 556 | 
            +
                noise_perlin = torch.stack(noise_perlin).unsqueeze(0)   # (1, c, w, h)
         | 
| 557 | 
            +
                noise += noise_perlin # broadcast for each batch
         | 
| 558 | 
            +
                return noise / noise.std()  # Scaled back to roughly unit variance
         | 
| 559 | 
            +
            """
         | 
    	
        library/deepspeed_utils.py
    ADDED
    
    | @@ -0,0 +1,139 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from accelerate import DeepSpeedPlugin, Accelerator
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .utils import setup_logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            setup_logging()
         | 
| 9 | 
            +
            import logging
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def add_deepspeed_arguments(parser: argparse.ArgumentParser):
         | 
| 15 | 
            +
                # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
         | 
| 16 | 
            +
                parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
         | 
| 17 | 
            +
                parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
         | 
| 18 | 
            +
                parser.add_argument(
         | 
| 19 | 
            +
                    "--offload_optimizer_device",
         | 
| 20 | 
            +
                    type=str,
         | 
| 21 | 
            +
                    default=None,
         | 
| 22 | 
            +
                    choices=[None, "cpu", "nvme"],
         | 
| 23 | 
            +
                    help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
         | 
| 24 | 
            +
                )
         | 
| 25 | 
            +
                parser.add_argument(
         | 
| 26 | 
            +
                    "--offload_optimizer_nvme_path",
         | 
| 27 | 
            +
                    type=str,
         | 
| 28 | 
            +
                    default=None,
         | 
| 29 | 
            +
                    help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
         | 
| 30 | 
            +
                )
         | 
| 31 | 
            +
                parser.add_argument(
         | 
| 32 | 
            +
                    "--offload_param_device",
         | 
| 33 | 
            +
                    type=str,
         | 
| 34 | 
            +
                    default=None,
         | 
| 35 | 
            +
                    choices=[None, "cpu", "nvme"],
         | 
| 36 | 
            +
                    help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
         | 
| 37 | 
            +
                )
         | 
| 38 | 
            +
                parser.add_argument(
         | 
| 39 | 
            +
                    "--offload_param_nvme_path",
         | 
| 40 | 
            +
                    type=str,
         | 
| 41 | 
            +
                    default=None,
         | 
| 42 | 
            +
                    help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
         | 
| 43 | 
            +
                )
         | 
| 44 | 
            +
                parser.add_argument(
         | 
| 45 | 
            +
                    "--zero3_init_flag",
         | 
| 46 | 
            +
                    action="store_true",
         | 
| 47 | 
            +
                    help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
         | 
| 48 | 
            +
                    "Only applicable with ZeRO Stage-3.",
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
                parser.add_argument(
         | 
| 51 | 
            +
                    "--zero3_save_16bit_model",
         | 
| 52 | 
            +
                    action="store_true",
         | 
| 53 | 
            +
                    help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
                parser.add_argument(
         | 
| 56 | 
            +
                    "--fp16_master_weights_and_gradients",
         | 
| 57 | 
            +
                    action="store_true",
         | 
| 58 | 
            +
                    help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
         | 
| 59 | 
            +
                )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def prepare_deepspeed_args(args: argparse.Namespace):
         | 
| 63 | 
            +
                if not args.deepspeed:
         | 
| 64 | 
            +
                    return
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
         | 
| 67 | 
            +
                args.max_data_loader_n_workers = 1
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def prepare_deepspeed_plugin(args: argparse.Namespace):
         | 
| 71 | 
            +
                if not args.deepspeed:
         | 
| 72 | 
            +
                    return None
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                try:
         | 
| 75 | 
            +
                    import deepspeed
         | 
| 76 | 
            +
                except ImportError as e:
         | 
| 77 | 
            +
                    logger.error(
         | 
| 78 | 
            +
                        "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
                    exit(1)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                deepspeed_plugin = DeepSpeedPlugin(
         | 
| 83 | 
            +
                    zero_stage=args.zero_stage,
         | 
| 84 | 
            +
                    gradient_accumulation_steps=args.gradient_accumulation_steps,
         | 
| 85 | 
            +
                    gradient_clipping=args.max_grad_norm,
         | 
| 86 | 
            +
                    offload_optimizer_device=args.offload_optimizer_device,
         | 
| 87 | 
            +
                    offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
         | 
| 88 | 
            +
                    offload_param_device=args.offload_param_device,
         | 
| 89 | 
            +
                    offload_param_nvme_path=args.offload_param_nvme_path,
         | 
| 90 | 
            +
                    zero3_init_flag=args.zero3_init_flag,
         | 
| 91 | 
            +
                    zero3_save_16bit_model=args.zero3_save_16bit_model,
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
                deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
         | 
| 94 | 
            +
                deepspeed_plugin.deepspeed_config["train_batch_size"] = (
         | 
| 95 | 
            +
                    args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
         | 
| 96 | 
            +
                )
         | 
| 97 | 
            +
                deepspeed_plugin.set_mixed_precision(args.mixed_precision)
         | 
| 98 | 
            +
                if args.mixed_precision.lower() == "fp16":
         | 
| 99 | 
            +
                    deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0  # preventing overflow.
         | 
| 100 | 
            +
                if args.full_fp16 or args.fp16_master_weights_and_gradients:
         | 
| 101 | 
            +
                    if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
         | 
| 102 | 
            +
                        deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
         | 
| 103 | 
            +
                        logger.info("[DeepSpeed] full fp16 enable.")
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        logger.info(
         | 
| 106 | 
            +
                            "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
         | 
| 107 | 
            +
                        )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                if args.offload_optimizer_device is not None:
         | 
| 110 | 
            +
                    logger.info("[DeepSpeed] start to manually build cpu_adam.")
         | 
| 111 | 
            +
                    deepspeed.ops.op_builder.CPUAdamBuilder().load()
         | 
| 112 | 
            +
                    logger.info("[DeepSpeed] building cpu_adam done.")
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                return deepspeed_plugin
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
         | 
| 118 | 
            +
            def prepare_deepspeed_model(args: argparse.Namespace, **models):
         | 
| 119 | 
            +
                # remove None from models
         | 
| 120 | 
            +
                models = {k: v for k, v in models.items() if v is not None}
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                class DeepSpeedWrapper(torch.nn.Module):
         | 
| 123 | 
            +
                    def __init__(self, **kw_models) -> None:
         | 
| 124 | 
            +
                        super().__init__()
         | 
| 125 | 
            +
                        self.models = torch.nn.ModuleDict()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        for key, model in kw_models.items():
         | 
| 128 | 
            +
                            if isinstance(model, list):
         | 
| 129 | 
            +
                                model = torch.nn.ModuleList(model)
         | 
| 130 | 
            +
                            assert isinstance(
         | 
| 131 | 
            +
                                model, torch.nn.Module
         | 
| 132 | 
            +
                            ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
         | 
| 133 | 
            +
                            self.models.update(torch.nn.ModuleDict({key: model}))
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    def get_models(self):
         | 
| 136 | 
            +
                        return self.models
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                ds_model = DeepSpeedWrapper(**models)
         | 
| 139 | 
            +
                return ds_model
         | 
    	
        library/device_utils.py
    ADDED
    
    | @@ -0,0 +1,84 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import functools
         | 
| 2 | 
            +
            import gc
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            try:
         | 
| 7 | 
            +
                HAS_CUDA = torch.cuda.is_available()
         | 
| 8 | 
            +
            except Exception:
         | 
| 9 | 
            +
                HAS_CUDA = False
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            try:
         | 
| 12 | 
            +
                HAS_MPS = torch.backends.mps.is_available()
         | 
| 13 | 
            +
            except Exception:
         | 
| 14 | 
            +
                HAS_MPS = False
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            try:
         | 
| 17 | 
            +
                import intel_extension_for_pytorch as ipex  # noqa
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                HAS_XPU = torch.xpu.is_available()
         | 
| 20 | 
            +
            except Exception:
         | 
| 21 | 
            +
                HAS_XPU = False
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def clean_memory():
         | 
| 25 | 
            +
                gc.collect()
         | 
| 26 | 
            +
                if HAS_CUDA:
         | 
| 27 | 
            +
                    torch.cuda.empty_cache()
         | 
| 28 | 
            +
                if HAS_XPU:
         | 
| 29 | 
            +
                    torch.xpu.empty_cache()
         | 
| 30 | 
            +
                if HAS_MPS:
         | 
| 31 | 
            +
                    torch.mps.empty_cache()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def clean_memory_on_device(device: torch.device):
         | 
| 35 | 
            +
                r"""
         | 
| 36 | 
            +
                Clean memory on the specified device, will be called from training scripts.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                gc.collect()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # device may "cuda" or "cuda:0", so we need to check the type of device
         | 
| 41 | 
            +
                if device.type == "cuda":
         | 
| 42 | 
            +
                    torch.cuda.empty_cache()
         | 
| 43 | 
            +
                if device.type == "xpu":
         | 
| 44 | 
            +
                    torch.xpu.empty_cache()
         | 
| 45 | 
            +
                if device.type == "mps":
         | 
| 46 | 
            +
                    torch.mps.empty_cache()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            @functools.lru_cache(maxsize=None)
         | 
| 50 | 
            +
            def get_preferred_device() -> torch.device:
         | 
| 51 | 
            +
                r"""
         | 
| 52 | 
            +
                Do not call this function from training scripts. Use accelerator.device instead.
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                if HAS_CUDA:
         | 
| 55 | 
            +
                    device = torch.device("cuda")
         | 
| 56 | 
            +
                elif HAS_XPU:
         | 
| 57 | 
            +
                    device = torch.device("xpu")
         | 
| 58 | 
            +
                elif HAS_MPS:
         | 
| 59 | 
            +
                    device = torch.device("mps")
         | 
| 60 | 
            +
                else:
         | 
| 61 | 
            +
                    device = torch.device("cpu")
         | 
| 62 | 
            +
                print(f"get_preferred_device() -> {device}")
         | 
| 63 | 
            +
                return device
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def init_ipex():
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                This function should run right after importing torch and before doing anything else.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                If IPEX is not available, this function does nothing.
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                try:
         | 
| 75 | 
            +
                    if HAS_XPU:
         | 
| 76 | 
            +
                        from library.ipex import ipex_init
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                        is_initialized, error_message = ipex_init()
         | 
| 79 | 
            +
                        if not is_initialized:
         | 
| 80 | 
            +
                            print("failed to initialize ipex:", error_message)
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        return
         | 
| 83 | 
            +
                except Exception as e:
         | 
| 84 | 
            +
                    print("failed to initialize ipex:", e)
         | 
    	
        library/huggingface_util.py
    ADDED
    
    | @@ -0,0 +1,84 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Union, BinaryIO
         | 
| 2 | 
            +
            from huggingface_hub import HfApi
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from library.utils import fire_in_thread
         | 
| 7 | 
            +
            from library.utils import setup_logging
         | 
| 8 | 
            +
            setup_logging()
         | 
| 9 | 
            +
            import logging
         | 
| 10 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
         | 
| 13 | 
            +
                api = HfApi(
         | 
| 14 | 
            +
                    token=token,
         | 
| 15 | 
            +
                )
         | 
| 16 | 
            +
                try:
         | 
| 17 | 
            +
                    api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
         | 
| 18 | 
            +
                    return True
         | 
| 19 | 
            +
                except:
         | 
| 20 | 
            +
                    return False
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def upload(
         | 
| 24 | 
            +
                args: argparse.Namespace,
         | 
| 25 | 
            +
                src: Union[str, Path, bytes, BinaryIO],
         | 
| 26 | 
            +
                dest_suffix: str = "",
         | 
| 27 | 
            +
                force_sync_upload: bool = False,
         | 
| 28 | 
            +
            ):
         | 
| 29 | 
            +
                repo_id = args.huggingface_repo_id
         | 
| 30 | 
            +
                repo_type = args.huggingface_repo_type
         | 
| 31 | 
            +
                token = args.huggingface_token
         | 
| 32 | 
            +
                path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
         | 
| 33 | 
            +
                private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
         | 
| 34 | 
            +
                api = HfApi(token=token)
         | 
| 35 | 
            +
                if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
         | 
| 36 | 
            +
                    try:
         | 
| 37 | 
            +
                        api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
         | 
| 38 | 
            +
                    except Exception as e:  # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
         | 
| 39 | 
            +
                        logger.error("===========================================")
         | 
| 40 | 
            +
                        logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
         | 
| 41 | 
            +
                        logger.error("===========================================")
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def uploader():
         | 
| 46 | 
            +
                    try:
         | 
| 47 | 
            +
                        if is_folder:
         | 
| 48 | 
            +
                            api.upload_folder(
         | 
| 49 | 
            +
                                repo_id=repo_id,
         | 
| 50 | 
            +
                                repo_type=repo_type,
         | 
| 51 | 
            +
                                folder_path=src,
         | 
| 52 | 
            +
                                path_in_repo=path_in_repo,
         | 
| 53 | 
            +
                            )
         | 
| 54 | 
            +
                        else:
         | 
| 55 | 
            +
                            api.upload_file(
         | 
| 56 | 
            +
                                repo_id=repo_id,
         | 
| 57 | 
            +
                                repo_type=repo_type,
         | 
| 58 | 
            +
                                path_or_fileobj=src,
         | 
| 59 | 
            +
                                path_in_repo=path_in_repo,
         | 
| 60 | 
            +
                            )
         | 
| 61 | 
            +
                    except Exception as e:  # RuntimeErrorを確認済みだが他にあると困るので
         | 
| 62 | 
            +
                        logger.error("===========================================")
         | 
| 63 | 
            +
                        logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
         | 
| 64 | 
            +
                        logger.error("===========================================")
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                if args.async_upload and not force_sync_upload:
         | 
| 67 | 
            +
                    fire_in_thread(uploader)
         | 
| 68 | 
            +
                else:
         | 
| 69 | 
            +
                    uploader()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def list_dir(
         | 
| 73 | 
            +
                repo_id: str,
         | 
| 74 | 
            +
                subfolder: str,
         | 
| 75 | 
            +
                repo_type: str,
         | 
| 76 | 
            +
                revision: str = "main",
         | 
| 77 | 
            +
                token: str = None,
         | 
| 78 | 
            +
            ):
         | 
| 79 | 
            +
                api = HfApi(
         | 
| 80 | 
            +
                    token=token,
         | 
| 81 | 
            +
                )
         | 
| 82 | 
            +
                repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
         | 
| 83 | 
            +
                file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
         | 
| 84 | 
            +
                return file_list
         | 
    	
        library/hypernetwork.py
    ADDED
    
    | @@ -0,0 +1,223 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
            from diffusers.models.attention_processor import (
         | 
| 4 | 
            +
                Attention,
         | 
| 5 | 
            +
                AttnProcessor2_0,
         | 
| 6 | 
            +
                SlicedAttnProcessor,
         | 
| 7 | 
            +
                XFormersAttnProcessor
         | 
| 8 | 
            +
            )
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            try:
         | 
| 11 | 
            +
                import xformers.ops
         | 
| 12 | 
            +
            except:
         | 
| 13 | 
            +
                xformers = None
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            loaded_networks = []
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def apply_single_hypernetwork(
         | 
| 20 | 
            +
                hypernetwork, hidden_states, encoder_hidden_states
         | 
| 21 | 
            +
            ):
         | 
| 22 | 
            +
                context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
         | 
| 23 | 
            +
                return context_k, context_v
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def apply_hypernetworks(context_k, context_v, layer=None):
         | 
| 27 | 
            +
                if len(loaded_networks) == 0:
         | 
| 28 | 
            +
                    return context_v, context_v
         | 
| 29 | 
            +
                for hypernetwork in loaded_networks:
         | 
| 30 | 
            +
                    context_k, context_v = hypernetwork.forward(context_k, context_v)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                context_k = context_k.to(dtype=context_k.dtype)
         | 
| 33 | 
            +
                context_v = context_v.to(dtype=context_k.dtype)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                return context_k, context_v
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def xformers_forward(
         | 
| 40 | 
            +
                self: XFormersAttnProcessor,
         | 
| 41 | 
            +
                attn: Attention,
         | 
| 42 | 
            +
                hidden_states: torch.Tensor,
         | 
| 43 | 
            +
                encoder_hidden_states: torch.Tensor = None,
         | 
| 44 | 
            +
                attention_mask: torch.Tensor = None,
         | 
| 45 | 
            +
            ):
         | 
| 46 | 
            +
                batch_size, sequence_length, _ = (
         | 
| 47 | 
            +
                    hidden_states.shape
         | 
| 48 | 
            +
                    if encoder_hidden_states is None
         | 
| 49 | 
            +
                    else encoder_hidden_states.shape
         | 
| 50 | 
            +
                )
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                attention_mask = attn.prepare_attention_mask(
         | 
| 53 | 
            +
                    attention_mask, sequence_length, batch_size
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                query = attn.to_q(hidden_states)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                if encoder_hidden_states is None:
         | 
| 59 | 
            +
                    encoder_hidden_states = hidden_states
         | 
| 60 | 
            +
                elif attn.norm_cross:
         | 
| 61 | 
            +
                    encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                key = attn.to_k(context_k)
         | 
| 66 | 
            +
                value = attn.to_v(context_v)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                query = attn.head_to_batch_dim(query).contiguous()
         | 
| 69 | 
            +
                key = attn.head_to_batch_dim(key).contiguous()
         | 
| 70 | 
            +
                value = attn.head_to_batch_dim(value).contiguous()
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                hidden_states = xformers.ops.memory_efficient_attention(
         | 
| 73 | 
            +
                    query,
         | 
| 74 | 
            +
                    key,
         | 
| 75 | 
            +
                    value,
         | 
| 76 | 
            +
                    attn_bias=attention_mask,
         | 
| 77 | 
            +
                    op=self.attention_op,
         | 
| 78 | 
            +
                    scale=attn.scale,
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                hidden_states = hidden_states.to(query.dtype)
         | 
| 81 | 
            +
                hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                # linear proj
         | 
| 84 | 
            +
                hidden_states = attn.to_out[0](hidden_states)
         | 
| 85 | 
            +
                # dropout
         | 
| 86 | 
            +
                hidden_states = attn.to_out[1](hidden_states)
         | 
| 87 | 
            +
                return hidden_states
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def sliced_attn_forward(
         | 
| 91 | 
            +
                self: SlicedAttnProcessor,
         | 
| 92 | 
            +
                attn: Attention,
         | 
| 93 | 
            +
                hidden_states: torch.Tensor,
         | 
| 94 | 
            +
                encoder_hidden_states: torch.Tensor = None,
         | 
| 95 | 
            +
                attention_mask: torch.Tensor = None,
         | 
| 96 | 
            +
            ):
         | 
| 97 | 
            +
                batch_size, sequence_length, _ = (
         | 
| 98 | 
            +
                    hidden_states.shape
         | 
| 99 | 
            +
                    if encoder_hidden_states is None
         | 
| 100 | 
            +
                    else encoder_hidden_states.shape
         | 
| 101 | 
            +
                )
         | 
| 102 | 
            +
                attention_mask = attn.prepare_attention_mask(
         | 
| 103 | 
            +
                    attention_mask, sequence_length, batch_size
         | 
| 104 | 
            +
                )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                query = attn.to_q(hidden_states)
         | 
| 107 | 
            +
                dim = query.shape[-1]
         | 
| 108 | 
            +
                query = attn.head_to_batch_dim(query)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                if encoder_hidden_states is None:
         | 
| 111 | 
            +
                    encoder_hidden_states = hidden_states
         | 
| 112 | 
            +
                elif attn.norm_cross:
         | 
| 113 | 
            +
                    encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                key = attn.to_k(context_k)
         | 
| 118 | 
            +
                value = attn.to_v(context_v)
         | 
| 119 | 
            +
                key = attn.head_to_batch_dim(key)
         | 
| 120 | 
            +
                value = attn.head_to_batch_dim(value)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                batch_size_attention, query_tokens, _ = query.shape
         | 
| 123 | 
            +
                hidden_states = torch.zeros(
         | 
| 124 | 
            +
                    (batch_size_attention, query_tokens, dim // attn.heads),
         | 
| 125 | 
            +
                    device=query.device,
         | 
| 126 | 
            +
                    dtype=query.dtype,
         | 
| 127 | 
            +
                )
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                for i in range(batch_size_attention // self.slice_size):
         | 
| 130 | 
            +
                    start_idx = i * self.slice_size
         | 
| 131 | 
            +
                    end_idx = (i + 1) * self.slice_size
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    query_slice = query[start_idx:end_idx]
         | 
| 134 | 
            +
                    key_slice = key[start_idx:end_idx]
         | 
| 135 | 
            +
                    attn_mask_slice = (
         | 
| 136 | 
            +
                        attention_mask[start_idx:end_idx] if attention_mask is not None else None
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    hidden_states[start_idx:end_idx] = attn_slice
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                # linear proj
         | 
| 148 | 
            +
                hidden_states = attn.to_out[0](hidden_states)
         | 
| 149 | 
            +
                # dropout
         | 
| 150 | 
            +
                hidden_states = attn.to_out[1](hidden_states)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                return hidden_states
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def v2_0_forward(
         | 
| 156 | 
            +
                self: AttnProcessor2_0,
         | 
| 157 | 
            +
                attn: Attention,
         | 
| 158 | 
            +
                hidden_states,
         | 
| 159 | 
            +
                encoder_hidden_states=None,
         | 
| 160 | 
            +
                attention_mask=None,
         | 
| 161 | 
            +
            ):
         | 
| 162 | 
            +
                batch_size, sequence_length, _ = (
         | 
| 163 | 
            +
                    hidden_states.shape
         | 
| 164 | 
            +
                    if encoder_hidden_states is None
         | 
| 165 | 
            +
                    else encoder_hidden_states.shape
         | 
| 166 | 
            +
                )
         | 
| 167 | 
            +
                inner_dim = hidden_states.shape[-1]
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                if attention_mask is not None:
         | 
| 170 | 
            +
                    attention_mask = attn.prepare_attention_mask(
         | 
| 171 | 
            +
                        attention_mask, sequence_length, batch_size
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
                    # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 174 | 
            +
                    # (batch, heads, source_length, target_length)
         | 
| 175 | 
            +
                    attention_mask = attention_mask.view(
         | 
| 176 | 
            +
                        batch_size, attn.heads, -1, attention_mask.shape[-1]
         | 
| 177 | 
            +
                    )
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                query = attn.to_q(hidden_states)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                if encoder_hidden_states is None:
         | 
| 182 | 
            +
                    encoder_hidden_states = hidden_states
         | 
| 183 | 
            +
                elif attn.norm_cross:
         | 
| 184 | 
            +
                    encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                key = attn.to_k(context_k)
         | 
| 189 | 
            +
                value = attn.to_v(context_v)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                head_dim = inner_dim // attn.heads
         | 
| 192 | 
            +
                query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 193 | 
            +
                key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 194 | 
            +
                value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 197 | 
            +
                # TODO: add support for attn.scale when we move to Torch 2.1
         | 
| 198 | 
            +
                hidden_states = F.scaled_dot_product_attention(
         | 
| 199 | 
            +
                    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
         | 
| 200 | 
            +
                )
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                hidden_states = hidden_states.transpose(1, 2).reshape(
         | 
| 203 | 
            +
                    batch_size, -1, attn.heads * head_dim
         | 
| 204 | 
            +
                )
         | 
| 205 | 
            +
                hidden_states = hidden_states.to(query.dtype)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                # linear proj
         | 
| 208 | 
            +
                hidden_states = attn.to_out[0](hidden_states)
         | 
| 209 | 
            +
                # dropout
         | 
| 210 | 
            +
                hidden_states = attn.to_out[1](hidden_states)
         | 
| 211 | 
            +
                return hidden_states
         | 
| 212 | 
            +
             | 
| 213 | 
            +
             | 
| 214 | 
            +
            def replace_attentions_for_hypernetwork():
         | 
| 215 | 
            +
                import diffusers.models.attention_processor
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
         | 
| 218 | 
            +
                    xformers_forward
         | 
| 219 | 
            +
                )
         | 
| 220 | 
            +
                diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
         | 
| 221 | 
            +
                    sliced_attn_forward
         | 
| 222 | 
            +
                )
         | 
| 223 | 
            +
                diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
         | 
    	
        library/ipex/__init__.py
    ADDED
    
    | @@ -0,0 +1,180 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import contextlib
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
         | 
| 6 | 
            +
            from .hijacks import ipex_hijacks
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # pylint: disable=protected-access, missing-function-docstring, line-too-long
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            def ipex_init(): # pylint: disable=too-many-statements
         | 
| 11 | 
            +
                try:
         | 
| 12 | 
            +
                    if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
         | 
| 13 | 
            +
                        return True, "Skipping IPEX hijack"
         | 
| 14 | 
            +
                    else:
         | 
| 15 | 
            +
                        # Replace cuda with xpu:
         | 
| 16 | 
            +
                        torch.cuda.current_device = torch.xpu.current_device
         | 
| 17 | 
            +
                        torch.cuda.current_stream = torch.xpu.current_stream
         | 
| 18 | 
            +
                        torch.cuda.device = torch.xpu.device
         | 
| 19 | 
            +
                        torch.cuda.device_count = torch.xpu.device_count
         | 
| 20 | 
            +
                        torch.cuda.device_of = torch.xpu.device_of
         | 
| 21 | 
            +
                        torch.cuda.get_device_name = torch.xpu.get_device_name
         | 
| 22 | 
            +
                        torch.cuda.get_device_properties = torch.xpu.get_device_properties
         | 
| 23 | 
            +
                        torch.cuda.init = torch.xpu.init
         | 
| 24 | 
            +
                        torch.cuda.is_available = torch.xpu.is_available
         | 
| 25 | 
            +
                        torch.cuda.is_initialized = torch.xpu.is_initialized
         | 
| 26 | 
            +
                        torch.cuda.is_current_stream_capturing = lambda: False
         | 
| 27 | 
            +
                        torch.cuda.set_device = torch.xpu.set_device
         | 
| 28 | 
            +
                        torch.cuda.stream = torch.xpu.stream
         | 
| 29 | 
            +
                        torch.cuda.synchronize = torch.xpu.synchronize
         | 
| 30 | 
            +
                        torch.cuda.Event = torch.xpu.Event
         | 
| 31 | 
            +
                        torch.cuda.Stream = torch.xpu.Stream
         | 
| 32 | 
            +
                        torch.cuda.FloatTensor = torch.xpu.FloatTensor
         | 
| 33 | 
            +
                        torch.Tensor.cuda = torch.Tensor.xpu
         | 
| 34 | 
            +
                        torch.Tensor.is_cuda = torch.Tensor.is_xpu
         | 
| 35 | 
            +
                        torch.nn.Module.cuda = torch.nn.Module.xpu
         | 
| 36 | 
            +
                        torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
         | 
| 37 | 
            +
                        torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
         | 
| 38 | 
            +
                        torch.cuda._initialized = torch.xpu.lazy_init._initialized
         | 
| 39 | 
            +
                        torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
         | 
| 40 | 
            +
                        torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
         | 
| 41 | 
            +
                        torch.cuda._tls = torch.xpu.lazy_init._tls
         | 
| 42 | 
            +
                        torch.cuda.threading = torch.xpu.lazy_init.threading
         | 
| 43 | 
            +
                        torch.cuda.traceback = torch.xpu.lazy_init.traceback
         | 
| 44 | 
            +
                        torch.cuda.Optional = torch.xpu.Optional
         | 
| 45 | 
            +
                        torch.cuda.__cached__ = torch.xpu.__cached__
         | 
| 46 | 
            +
                        torch.cuda.__loader__ = torch.xpu.__loader__
         | 
| 47 | 
            +
                        torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
         | 
| 48 | 
            +
                        torch.cuda.Tuple = torch.xpu.Tuple
         | 
| 49 | 
            +
                        torch.cuda.streams = torch.xpu.streams
         | 
| 50 | 
            +
                        torch.cuda._lazy_new = torch.xpu._lazy_new
         | 
| 51 | 
            +
                        torch.cuda.FloatStorage = torch.xpu.FloatStorage
         | 
| 52 | 
            +
                        torch.cuda.Any = torch.xpu.Any
         | 
| 53 | 
            +
                        torch.cuda.__doc__ = torch.xpu.__doc__
         | 
| 54 | 
            +
                        torch.cuda.default_generators = torch.xpu.default_generators
         | 
| 55 | 
            +
                        torch.cuda.HalfTensor = torch.xpu.HalfTensor
         | 
| 56 | 
            +
                        torch.cuda._get_device_index = torch.xpu._get_device_index
         | 
| 57 | 
            +
                        torch.cuda.__path__ = torch.xpu.__path__
         | 
| 58 | 
            +
                        torch.cuda.Device = torch.xpu.Device
         | 
| 59 | 
            +
                        torch.cuda.IntTensor = torch.xpu.IntTensor
         | 
| 60 | 
            +
                        torch.cuda.ByteStorage = torch.xpu.ByteStorage
         | 
| 61 | 
            +
                        torch.cuda.set_stream = torch.xpu.set_stream
         | 
| 62 | 
            +
                        torch.cuda.BoolStorage = torch.xpu.BoolStorage
         | 
| 63 | 
            +
                        torch.cuda.os = torch.xpu.os
         | 
| 64 | 
            +
                        torch.cuda.torch = torch.xpu.torch
         | 
| 65 | 
            +
                        torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
         | 
| 66 | 
            +
                        torch.cuda.Union = torch.xpu.Union
         | 
| 67 | 
            +
                        torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
         | 
| 68 | 
            +
                        torch.cuda.ShortTensor = torch.xpu.ShortTensor
         | 
| 69 | 
            +
                        torch.cuda.LongTensor = torch.xpu.LongTensor
         | 
| 70 | 
            +
                        torch.cuda.IntStorage = torch.xpu.IntStorage
         | 
| 71 | 
            +
                        torch.cuda.LongStorage = torch.xpu.LongStorage
         | 
| 72 | 
            +
                        torch.cuda.__annotations__ = torch.xpu.__annotations__
         | 
| 73 | 
            +
                        torch.cuda.__package__ = torch.xpu.__package__
         | 
| 74 | 
            +
                        torch.cuda.__builtins__ = torch.xpu.__builtins__
         | 
| 75 | 
            +
                        torch.cuda.CharTensor = torch.xpu.CharTensor
         | 
| 76 | 
            +
                        torch.cuda.List = torch.xpu.List
         | 
| 77 | 
            +
                        torch.cuda._lazy_init = torch.xpu._lazy_init
         | 
| 78 | 
            +
                        torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
         | 
| 79 | 
            +
                        torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
         | 
| 80 | 
            +
                        torch.cuda.ByteTensor = torch.xpu.ByteTensor
         | 
| 81 | 
            +
                        torch.cuda.StreamContext = torch.xpu.StreamContext
         | 
| 82 | 
            +
                        torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
         | 
| 83 | 
            +
                        torch.cuda.ShortStorage = torch.xpu.ShortStorage
         | 
| 84 | 
            +
                        torch.cuda._lazy_call = torch.xpu._lazy_call
         | 
| 85 | 
            +
                        torch.cuda.HalfStorage = torch.xpu.HalfStorage
         | 
| 86 | 
            +
                        torch.cuda.random = torch.xpu.random
         | 
| 87 | 
            +
                        torch.cuda._device = torch.xpu._device
         | 
| 88 | 
            +
                        torch.cuda.classproperty = torch.xpu.classproperty
         | 
| 89 | 
            +
                        torch.cuda.__name__ = torch.xpu.__name__
         | 
| 90 | 
            +
                        torch.cuda._device_t = torch.xpu._device_t
         | 
| 91 | 
            +
                        torch.cuda.warnings = torch.xpu.warnings
         | 
| 92 | 
            +
                        torch.cuda.__spec__ = torch.xpu.__spec__
         | 
| 93 | 
            +
                        torch.cuda.BoolTensor = torch.xpu.BoolTensor
         | 
| 94 | 
            +
                        torch.cuda.CharStorage = torch.xpu.CharStorage
         | 
| 95 | 
            +
                        torch.cuda.__file__ = torch.xpu.__file__
         | 
| 96 | 
            +
                        torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
         | 
| 97 | 
            +
                        # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                        # Memory:
         | 
| 100 | 
            +
                        torch.cuda.memory = torch.xpu.memory
         | 
| 101 | 
            +
                        if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
         | 
| 102 | 
            +
                            torch.xpu.empty_cache = lambda: None
         | 
| 103 | 
            +
                        torch.cuda.empty_cache = torch.xpu.empty_cache
         | 
| 104 | 
            +
                        torch.cuda.memory_stats = torch.xpu.memory_stats
         | 
| 105 | 
            +
                        torch.cuda.memory_summary = torch.xpu.memory_summary
         | 
| 106 | 
            +
                        torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
         | 
| 107 | 
            +
                        torch.cuda.memory_allocated = torch.xpu.memory_allocated
         | 
| 108 | 
            +
                        torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
         | 
| 109 | 
            +
                        torch.cuda.memory_reserved = torch.xpu.memory_reserved
         | 
| 110 | 
            +
                        torch.cuda.memory_cached = torch.xpu.memory_reserved
         | 
| 111 | 
            +
                        torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
         | 
| 112 | 
            +
                        torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
         | 
| 113 | 
            +
                        torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
         | 
| 114 | 
            +
                        torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
         | 
| 115 | 
            +
                        torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
         | 
| 116 | 
            +
                        torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
         | 
| 117 | 
            +
                        torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                        # RNG:
         | 
| 120 | 
            +
                        torch.cuda.get_rng_state = torch.xpu.get_rng_state
         | 
| 121 | 
            +
                        torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
         | 
| 122 | 
            +
                        torch.cuda.set_rng_state = torch.xpu.set_rng_state
         | 
| 123 | 
            +
                        torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
         | 
| 124 | 
            +
                        torch.cuda.manual_seed = torch.xpu.manual_seed
         | 
| 125 | 
            +
                        torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
         | 
| 126 | 
            +
                        torch.cuda.seed = torch.xpu.seed
         | 
| 127 | 
            +
                        torch.cuda.seed_all = torch.xpu.seed_all
         | 
| 128 | 
            +
                        torch.cuda.initial_seed = torch.xpu.initial_seed
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        # AMP:
         | 
| 131 | 
            +
                        torch.cuda.amp = torch.xpu.amp
         | 
| 132 | 
            +
                        torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
         | 
| 133 | 
            +
                        torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        if not hasattr(torch.cuda.amp, "common"):
         | 
| 136 | 
            +
                            torch.cuda.amp.common = contextlib.nullcontext()
         | 
| 137 | 
            +
                        torch.cuda.amp.common.amp_definitely_not_available = lambda: False
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        try:
         | 
| 140 | 
            +
                            torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
         | 
| 141 | 
            +
                        except Exception: # pylint: disable=broad-exception-caught
         | 
| 142 | 
            +
                            try:
         | 
| 143 | 
            +
                                from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
         | 
| 144 | 
            +
                                gradscaler_init()
         | 
| 145 | 
            +
                                torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
         | 
| 146 | 
            +
                            except Exception: # pylint: disable=broad-exception-caught
         | 
| 147 | 
            +
                                torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        # C
         | 
| 150 | 
            +
                        torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
         | 
| 151 | 
            +
                        ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
         | 
| 152 | 
            +
                        ipex._C._DeviceProperties.major = 2024
         | 
| 153 | 
            +
                        ipex._C._DeviceProperties.minor = 0
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        # Fix functions with ipex:
         | 
| 156 | 
            +
                        torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
         | 
| 157 | 
            +
                        torch._utils._get_available_device_type = lambda: "xpu"
         | 
| 158 | 
            +
                        torch.has_cuda = True
         | 
| 159 | 
            +
                        torch.cuda.has_half = True
         | 
| 160 | 
            +
                        torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
         | 
| 161 | 
            +
                        torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
         | 
| 162 | 
            +
                        torch.backends.cuda.is_built = lambda *args, **kwargs: True
         | 
| 163 | 
            +
                        torch.version.cuda = "12.1"
         | 
| 164 | 
            +
                        torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
         | 
| 165 | 
            +
                        torch.cuda.get_device_properties.major = 12
         | 
| 166 | 
            +
                        torch.cuda.get_device_properties.minor = 1
         | 
| 167 | 
            +
                        torch.cuda.ipc_collect = lambda *args, **kwargs: None
         | 
| 168 | 
            +
                        torch.cuda.utilization = lambda *args, **kwargs: 0
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        ipex_hijacks()
         | 
| 171 | 
            +
                        if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
         | 
| 172 | 
            +
                            try:
         | 
| 173 | 
            +
                                from .diffusers import ipex_diffusers
         | 
| 174 | 
            +
                                ipex_diffusers()
         | 
| 175 | 
            +
                            except Exception: # pylint: disable=broad-exception-caught
         | 
| 176 | 
            +
                                pass
         | 
| 177 | 
            +
                        torch.cuda.is_xpu_hijacked = True
         | 
| 178 | 
            +
                except Exception as e:
         | 
| 179 | 
            +
                    return False, e
         | 
| 180 | 
            +
                return True, None
         | 
    	
        library/ipex/attention.py
    ADDED
    
    | @@ -0,0 +1,177 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
         | 
| 4 | 
            +
            from functools import cache
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # pylint: disable=protected-access, missing-function-docstring, line-too-long
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
         | 
| 11 | 
            +
            attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Find something divisible with the input_tokens
         | 
| 14 | 
            +
            @cache
         | 
| 15 | 
            +
            def find_slice_size(slice_size, slice_block_size):
         | 
| 16 | 
            +
                while (slice_size * slice_block_size) > attention_slice_rate:
         | 
| 17 | 
            +
                    slice_size = slice_size // 2
         | 
| 18 | 
            +
                    if slice_size <= 1:
         | 
| 19 | 
            +
                        slice_size = 1
         | 
| 20 | 
            +
                        break
         | 
| 21 | 
            +
                return slice_size
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # Find slice sizes for SDPA
         | 
| 24 | 
            +
            @cache
         | 
| 25 | 
            +
            def find_sdpa_slice_sizes(query_shape, query_element_size):
         | 
| 26 | 
            +
                if len(query_shape) == 3:
         | 
| 27 | 
            +
                    batch_size_attention, query_tokens, shape_three = query_shape
         | 
| 28 | 
            +
                    shape_four = 1
         | 
| 29 | 
            +
                else:
         | 
| 30 | 
            +
                    batch_size_attention, query_tokens, shape_three, shape_four = query_shape
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
         | 
| 33 | 
            +
                block_size = batch_size_attention * slice_block_size
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                split_slice_size = batch_size_attention
         | 
| 36 | 
            +
                split_2_slice_size = query_tokens
         | 
| 37 | 
            +
                split_3_slice_size = shape_three
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                do_split = False
         | 
| 40 | 
            +
                do_split_2 = False
         | 
| 41 | 
            +
                do_split_3 = False
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                if block_size > sdpa_slice_trigger_rate:
         | 
| 44 | 
            +
                    do_split = True
         | 
| 45 | 
            +
                    split_slice_size = find_slice_size(split_slice_size, slice_block_size)
         | 
| 46 | 
            +
                    if split_slice_size * slice_block_size > attention_slice_rate:
         | 
| 47 | 
            +
                        slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
         | 
| 48 | 
            +
                        do_split_2 = True
         | 
| 49 | 
            +
                        split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
         | 
| 50 | 
            +
                        if split_2_slice_size * slice_2_block_size > attention_slice_rate:
         | 
| 51 | 
            +
                            slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
         | 
| 52 | 
            +
                            do_split_3 = True
         | 
| 53 | 
            +
                            split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            # Find slice sizes for BMM
         | 
| 58 | 
            +
            @cache
         | 
| 59 | 
            +
            def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
         | 
| 60 | 
            +
                batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
         | 
| 61 | 
            +
                slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
         | 
| 62 | 
            +
                block_size = batch_size_attention * slice_block_size
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                split_slice_size = batch_size_attention
         | 
| 65 | 
            +
                split_2_slice_size = input_tokens
         | 
| 66 | 
            +
                split_3_slice_size = mat2_atten_shape
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                do_split = False
         | 
| 69 | 
            +
                do_split_2 = False
         | 
| 70 | 
            +
                do_split_3 = False
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                if block_size > attention_slice_rate:
         | 
| 73 | 
            +
                    do_split = True
         | 
| 74 | 
            +
                    split_slice_size = find_slice_size(split_slice_size, slice_block_size)
         | 
| 75 | 
            +
                    if split_slice_size * slice_block_size > attention_slice_rate:
         | 
| 76 | 
            +
                        slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
         | 
| 77 | 
            +
                        do_split_2 = True
         | 
| 78 | 
            +
                        split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
         | 
| 79 | 
            +
                        if split_2_slice_size * slice_2_block_size > attention_slice_rate:
         | 
| 80 | 
            +
                            slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
         | 
| 81 | 
            +
                            do_split_3 = True
         | 
| 82 | 
            +
                            split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            original_torch_bmm = torch.bmm
         | 
| 88 | 
            +
            def torch_bmm_32_bit(input, mat2, *, out=None):
         | 
| 89 | 
            +
                if input.device.type != "xpu":
         | 
| 90 | 
            +
                    return original_torch_bmm(input, mat2, out=out)
         | 
| 91 | 
            +
                do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # Slice BMM
         | 
| 94 | 
            +
                if do_split:
         | 
| 95 | 
            +
                    batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
         | 
| 96 | 
            +
                    hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
         | 
| 97 | 
            +
                    for i in range(batch_size_attention // split_slice_size):
         | 
| 98 | 
            +
                        start_idx = i * split_slice_size
         | 
| 99 | 
            +
                        end_idx = (i + 1) * split_slice_size
         | 
| 100 | 
            +
                        if do_split_2:
         | 
| 101 | 
            +
                            for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
         | 
| 102 | 
            +
                                start_idx_2 = i2 * split_2_slice_size
         | 
| 103 | 
            +
                                end_idx_2 = (i2 + 1) * split_2_slice_size
         | 
| 104 | 
            +
                                if do_split_3:
         | 
| 105 | 
            +
                                    for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
         | 
| 106 | 
            +
                                        start_idx_3 = i3 * split_3_slice_size
         | 
| 107 | 
            +
                                        end_idx_3 = (i3 + 1) * split_3_slice_size
         | 
| 108 | 
            +
                                        hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
         | 
| 109 | 
            +
                                            input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
         | 
| 110 | 
            +
                                            mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
         | 
| 111 | 
            +
                                            out=out
         | 
| 112 | 
            +
                                        )
         | 
| 113 | 
            +
                                else:
         | 
| 114 | 
            +
                                    hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
         | 
| 115 | 
            +
                                        input[start_idx:end_idx, start_idx_2:end_idx_2],
         | 
| 116 | 
            +
                                        mat2[start_idx:end_idx, start_idx_2:end_idx_2],
         | 
| 117 | 
            +
                                        out=out
         | 
| 118 | 
            +
                                    )
         | 
| 119 | 
            +
                        else:
         | 
| 120 | 
            +
                            hidden_states[start_idx:end_idx] = original_torch_bmm(
         | 
| 121 | 
            +
                                input[start_idx:end_idx],
         | 
| 122 | 
            +
                                mat2[start_idx:end_idx],
         | 
| 123 | 
            +
                                out=out
         | 
| 124 | 
            +
                            )
         | 
| 125 | 
            +
                    torch.xpu.synchronize(input.device)
         | 
| 126 | 
            +
                else:
         | 
| 127 | 
            +
                    return original_torch_bmm(input, mat2, out=out)
         | 
| 128 | 
            +
                return hidden_states
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
         | 
| 131 | 
            +
            def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
         | 
| 132 | 
            +
                if query.device.type != "xpu":
         | 
| 133 | 
            +
                    return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
         | 
| 134 | 
            +
                do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                # Slice SDPA
         | 
| 137 | 
            +
                if do_split:
         | 
| 138 | 
            +
                    batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
         | 
| 139 | 
            +
                    hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
         | 
| 140 | 
            +
                    for i in range(batch_size_attention // split_slice_size):
         | 
| 141 | 
            +
                        start_idx = i * split_slice_size
         | 
| 142 | 
            +
                        end_idx = (i + 1) * split_slice_size
         | 
| 143 | 
            +
                        if do_split_2:
         | 
| 144 | 
            +
                            for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
         | 
| 145 | 
            +
                                start_idx_2 = i2 * split_2_slice_size
         | 
| 146 | 
            +
                                end_idx_2 = (i2 + 1) * split_2_slice_size
         | 
| 147 | 
            +
                                if do_split_3:
         | 
| 148 | 
            +
                                    for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
         | 
| 149 | 
            +
                                        start_idx_3 = i3 * split_3_slice_size
         | 
| 150 | 
            +
                                        end_idx_3 = (i3 + 1) * split_3_slice_size
         | 
| 151 | 
            +
                                        hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
         | 
| 152 | 
            +
                                            query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
         | 
| 153 | 
            +
                                            key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
         | 
| 154 | 
            +
                                            value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
         | 
| 155 | 
            +
                                            attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
         | 
| 156 | 
            +
                                            dropout_p=dropout_p, is_causal=is_causal, **kwargs
         | 
| 157 | 
            +
                                        )
         | 
| 158 | 
            +
                                else:
         | 
| 159 | 
            +
                                    hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
         | 
| 160 | 
            +
                                        query[start_idx:end_idx, start_idx_2:end_idx_2],
         | 
| 161 | 
            +
                                        key[start_idx:end_idx, start_idx_2:end_idx_2],
         | 
| 162 | 
            +
                                        value[start_idx:end_idx, start_idx_2:end_idx_2],
         | 
| 163 | 
            +
                                        attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
         | 
| 164 | 
            +
                                        dropout_p=dropout_p, is_causal=is_causal, **kwargs
         | 
| 165 | 
            +
                                    )
         | 
| 166 | 
            +
                        else:
         | 
| 167 | 
            +
                            hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
         | 
| 168 | 
            +
                                query[start_idx:end_idx],
         | 
| 169 | 
            +
                                key[start_idx:end_idx],
         | 
| 170 | 
            +
                                value[start_idx:end_idx],
         | 
| 171 | 
            +
                                attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
         | 
| 172 | 
            +
                                dropout_p=dropout_p, is_causal=is_causal, **kwargs
         | 
| 173 | 
            +
                            )
         | 
| 174 | 
            +
                    torch.xpu.synchronize(query.device)
         | 
| 175 | 
            +
                else:
         | 
| 176 | 
            +
                    return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
         | 
| 177 | 
            +
                return hidden_states
         | 
    	
        library/ipex/diffusers.py
    ADDED
    
    | @@ -0,0 +1,312 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
         | 
| 4 | 
            +
            import diffusers #0.24.0 # pylint: disable=import-error
         | 
| 5 | 
            +
            from diffusers.models.attention_processor import Attention
         | 
| 6 | 
            +
            from diffusers.utils import USE_PEFT_BACKEND
         | 
| 7 | 
            +
            from functools import cache
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # pylint: disable=protected-access, missing-function-docstring, line-too-long
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            @cache
         | 
| 14 | 
            +
            def find_slice_size(slice_size, slice_block_size):
         | 
| 15 | 
            +
                while (slice_size * slice_block_size) > attention_slice_rate:
         | 
| 16 | 
            +
                    slice_size = slice_size // 2
         | 
| 17 | 
            +
                    if slice_size <= 1:
         | 
| 18 | 
            +
                        slice_size = 1
         | 
| 19 | 
            +
                        break
         | 
| 20 | 
            +
                return slice_size
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            @cache
         | 
| 23 | 
            +
            def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
         | 
| 24 | 
            +
                if len(query_shape) == 3:
         | 
| 25 | 
            +
                    batch_size_attention, query_tokens, shape_three = query_shape
         | 
| 26 | 
            +
                    shape_four = 1
         | 
| 27 | 
            +
                else:
         | 
| 28 | 
            +
                    batch_size_attention, query_tokens, shape_three, shape_four = query_shape
         | 
| 29 | 
            +
                if slice_size is not None:
         | 
| 30 | 
            +
                    batch_size_attention = slice_size
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
         | 
| 33 | 
            +
                block_size = batch_size_attention * slice_block_size
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                split_slice_size = batch_size_attention
         | 
| 36 | 
            +
                split_2_slice_size = query_tokens
         | 
| 37 | 
            +
                split_3_slice_size = shape_three
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                do_split = False
         | 
| 40 | 
            +
                do_split_2 = False
         | 
| 41 | 
            +
                do_split_3 = False
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                if query_device_type != "xpu":
         | 
| 44 | 
            +
                    return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                if block_size > attention_slice_rate:
         | 
| 47 | 
            +
                    do_split = True
         | 
| 48 | 
            +
                    split_slice_size = find_slice_size(split_slice_size, slice_block_size)
         | 
| 49 | 
            +
                    if split_slice_size * slice_block_size > attention_slice_rate:
         | 
| 50 | 
            +
                        slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
         | 
| 51 | 
            +
                        do_split_2 = True
         | 
| 52 | 
            +
                        split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
         | 
| 53 | 
            +
                        if split_2_slice_size * slice_2_block_size > attention_slice_rate:
         | 
| 54 | 
            +
                            slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
         | 
| 55 | 
            +
                            do_split_3 = True
         | 
| 56 | 
            +
                            split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
         | 
| 61 | 
            +
                r"""
         | 
| 62 | 
            +
                Processor for implementing sliced attention.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                Args:
         | 
| 65 | 
            +
                    slice_size (`int`, *optional*):
         | 
| 66 | 
            +
                        The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
         | 
| 67 | 
            +
                        `attention_head_dim` must be a multiple of the `slice_size`.
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def __init__(self, slice_size):
         | 
| 71 | 
            +
                    self.slice_size = slice_size
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
         | 
| 74 | 
            +
                encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    residual = hidden_states
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if input_ndim == 4:
         | 
| 81 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 82 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 85 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if attn.group_norm is not None:
         | 
| 90 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 93 | 
            +
                    dim = query.shape[-1]
         | 
| 94 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    if encoder_hidden_states is None:
         | 
| 97 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 98 | 
            +
                    elif attn.norm_cross:
         | 
| 99 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 102 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 103 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 104 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    batch_size_attention, query_tokens, shape_three = query.shape
         | 
| 107 | 
            +
                    hidden_states = torch.zeros(
         | 
| 108 | 
            +
                        (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    ####################################################################
         | 
| 112 | 
            +
                    # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
         | 
| 113 | 
            +
                    _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    for i in range(batch_size_attention // split_slice_size):
         | 
| 116 | 
            +
                        start_idx = i * split_slice_size
         | 
| 117 | 
            +
                        end_idx = (i + 1) * split_slice_size
         | 
| 118 | 
            +
                        if do_split_2:
         | 
| 119 | 
            +
                            for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
         | 
| 120 | 
            +
                                start_idx_2 = i2 * split_2_slice_size
         | 
| 121 | 
            +
                                end_idx_2 = (i2 + 1) * split_2_slice_size
         | 
| 122 | 
            +
                                if do_split_3:
         | 
| 123 | 
            +
                                    for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
         | 
| 124 | 
            +
                                        start_idx_3 = i3 * split_3_slice_size
         | 
| 125 | 
            +
                                        end_idx_3 = (i3 + 1) * split_3_slice_size
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                                        query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
         | 
| 128 | 
            +
                                        key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
         | 
| 129 | 
            +
                                        attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                                        attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 132 | 
            +
                                        del query_slice
         | 
| 133 | 
            +
                                        del key_slice
         | 
| 134 | 
            +
                                        del attn_mask_slice
         | 
| 135 | 
            +
                                        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                                        hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
         | 
| 138 | 
            +
                                        del attn_slice
         | 
| 139 | 
            +
                                else:
         | 
| 140 | 
            +
                                    query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
         | 
| 141 | 
            +
                                    key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
         | 
| 142 | 
            +
                                    attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                                    attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 145 | 
            +
                                    del query_slice
         | 
| 146 | 
            +
                                    del key_slice
         | 
| 147 | 
            +
                                    del attn_mask_slice
         | 
| 148 | 
            +
                                    attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                                    hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
         | 
| 151 | 
            +
                                    del attn_slice
         | 
| 152 | 
            +
                            torch.xpu.synchronize(query.device)
         | 
| 153 | 
            +
                        else:
         | 
| 154 | 
            +
                            query_slice = query[start_idx:end_idx]
         | 
| 155 | 
            +
                            key_slice = key[start_idx:end_idx]
         | 
| 156 | 
            +
                            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 159 | 
            +
                            del query_slice
         | 
| 160 | 
            +
                            del key_slice
         | 
| 161 | 
            +
                            del attn_mask_slice
         | 
| 162 | 
            +
                            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                            hidden_states[start_idx:end_idx] = attn_slice
         | 
| 165 | 
            +
                            del attn_slice
         | 
| 166 | 
            +
                    ####################################################################
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # linear proj
         | 
| 171 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 172 | 
            +
                    # dropout
         | 
| 173 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    if input_ndim == 4:
         | 
| 176 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    if attn.residual_connection:
         | 
| 179 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    return hidden_states
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            class AttnProcessor:
         | 
| 187 | 
            +
                r"""
         | 
| 188 | 
            +
                Default processor for performing attention-related computations.
         | 
| 189 | 
            +
                """
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
         | 
| 192 | 
            +
                encoder_hidden_states=None, attention_mask=None,
         | 
| 193 | 
            +
                temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    residual = hidden_states
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    args = () if USE_PEFT_BACKEND else (scale,)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 200 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    if input_ndim == 4:
         | 
| 205 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 206 | 
            +
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 209 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 210 | 
            +
                    )
         | 
| 211 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    if attn.group_norm is not None:
         | 
| 214 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    query = attn.to_q(hidden_states, *args)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    if encoder_hidden_states is None:
         | 
| 219 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 220 | 
            +
                    elif attn.norm_cross:
         | 
| 221 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    key = attn.to_k(encoder_hidden_states, *args)
         | 
| 224 | 
            +
                    value = attn.to_v(encoder_hidden_states, *args)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 227 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 228 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    ####################################################################
         | 
| 231 | 
            +
                    # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
         | 
| 232 | 
            +
                    batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
         | 
| 233 | 
            +
                    hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
         | 
| 234 | 
            +
                    do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    if do_split:
         | 
| 237 | 
            +
                        for i in range(batch_size_attention // split_slice_size):
         | 
| 238 | 
            +
                            start_idx = i * split_slice_size
         | 
| 239 | 
            +
                            end_idx = (i + 1) * split_slice_size
         | 
| 240 | 
            +
                            if do_split_2:
         | 
| 241 | 
            +
                                for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
         | 
| 242 | 
            +
                                    start_idx_2 = i2 * split_2_slice_size
         | 
| 243 | 
            +
                                    end_idx_2 = (i2 + 1) * split_2_slice_size
         | 
| 244 | 
            +
                                    if do_split_3:
         | 
| 245 | 
            +
                                        for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
         | 
| 246 | 
            +
                                            start_idx_3 = i3 * split_3_slice_size
         | 
| 247 | 
            +
                                            end_idx_3 = (i3 + 1) * split_3_slice_size
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                                            query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
         | 
| 250 | 
            +
                                            key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
         | 
| 251 | 
            +
                                            attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                                            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 254 | 
            +
                                            del query_slice
         | 
| 255 | 
            +
                                            del key_slice
         | 
| 256 | 
            +
                                            del attn_mask_slice
         | 
| 257 | 
            +
                                            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                                            hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
         | 
| 260 | 
            +
                                            del attn_slice
         | 
| 261 | 
            +
                                    else:
         | 
| 262 | 
            +
                                        query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
         | 
| 263 | 
            +
                                        key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
         | 
| 264 | 
            +
                                        attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                                        attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 267 | 
            +
                                        del query_slice
         | 
| 268 | 
            +
                                        del key_slice
         | 
| 269 | 
            +
                                        del attn_mask_slice
         | 
| 270 | 
            +
                                        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                                        hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
         | 
| 273 | 
            +
                                        del attn_slice
         | 
| 274 | 
            +
                            else:
         | 
| 275 | 
            +
                                query_slice = query[start_idx:end_idx]
         | 
| 276 | 
            +
                                key_slice = key[start_idx:end_idx]
         | 
| 277 | 
            +
                                attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                                attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
         | 
| 280 | 
            +
                                del query_slice
         | 
| 281 | 
            +
                                del key_slice
         | 
| 282 | 
            +
                                del attn_mask_slice
         | 
| 283 | 
            +
                                attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                                hidden_states[start_idx:end_idx] = attn_slice
         | 
| 286 | 
            +
                                del attn_slice
         | 
| 287 | 
            +
                        torch.xpu.synchronize(query.device)
         | 
| 288 | 
            +
                    else:
         | 
| 289 | 
            +
                        attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 290 | 
            +
                        hidden_states = torch.bmm(attention_probs, value)
         | 
| 291 | 
            +
                    ####################################################################
         | 
| 292 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    # linear proj
         | 
| 295 | 
            +
                    hidden_states = attn.to_out[0](hidden_states, *args)
         | 
| 296 | 
            +
                    # dropout
         | 
| 297 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    if input_ndim == 4:
         | 
| 300 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    if attn.residual_connection:
         | 
| 303 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    return hidden_states
         | 
| 308 | 
            +
             | 
| 309 | 
            +
            def ipex_diffusers():
         | 
| 310 | 
            +
                #ARC GPUs can't allocate more than 4GB to a single block:
         | 
| 311 | 
            +
                diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
         | 
| 312 | 
            +
                diffusers.models.attention_processor.AttnProcessor = AttnProcessor
         | 
    	
        library/ipex/gradscaler.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from collections import defaultdict
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
         | 
| 4 | 
            +
            import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # pylint: disable=protected-access, missing-function-docstring, line-too-long
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            device_supports_fp64 = torch.xpu.has_fp64_dtype()
         | 
| 9 | 
            +
            OptState = ipex.cpu.autocast._grad_scaler.OptState
         | 
| 10 | 
            +
            _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
         | 
| 11 | 
            +
            _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
         | 
| 14 | 
            +
                per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
         | 
| 15 | 
            +
                per_device_found_inf = _MultiDeviceReplicator(found_inf)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
         | 
| 18 | 
            +
                # There could be hundreds of grads, so we'd like to iterate through them just once.
         | 
| 19 | 
            +
                # However, we don't know their devices or dtypes in advance.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
         | 
| 22 | 
            +
                # Google says mypy struggles with defaultdicts type annotations.
         | 
| 23 | 
            +
                per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
         | 
| 24 | 
            +
                # sync grad to master weight
         | 
| 25 | 
            +
                if hasattr(optimizer, "sync_grad"):
         | 
| 26 | 
            +
                    optimizer.sync_grad()
         | 
| 27 | 
            +
                with torch.no_grad():
         | 
| 28 | 
            +
                    for group in optimizer.param_groups:
         | 
| 29 | 
            +
                        for param in group["params"]:
         | 
| 30 | 
            +
                            if param.grad is None:
         | 
| 31 | 
            +
                                continue
         | 
| 32 | 
            +
                            if (not allow_fp16) and param.grad.dtype == torch.float16:
         | 
| 33 | 
            +
                                raise ValueError("Attempting to unscale FP16 gradients.")
         | 
| 34 | 
            +
                            if param.grad.is_sparse:
         | 
| 35 | 
            +
                                # is_coalesced() == False means the sparse grad has values with duplicate indices.
         | 
| 36 | 
            +
                                # coalesce() deduplicates indices and adds all values that have the same index.
         | 
| 37 | 
            +
                                # For scaled fp16 values, there's a good chance coalescing will cause overflow,
         | 
| 38 | 
            +
                                # so we should check the coalesced _values().
         | 
| 39 | 
            +
                                if param.grad.dtype is torch.float16:
         | 
| 40 | 
            +
                                    param.grad = param.grad.coalesce()
         | 
| 41 | 
            +
                                to_unscale = param.grad._values()
         | 
| 42 | 
            +
                            else:
         | 
| 43 | 
            +
                                to_unscale = param.grad
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                            # -: is there a way to split by device and dtype without appending in the inner loop?
         | 
| 46 | 
            +
                            to_unscale = to_unscale.to("cpu")
         | 
| 47 | 
            +
                            per_device_and_dtype_grads[to_unscale.device][
         | 
| 48 | 
            +
                                to_unscale.dtype
         | 
| 49 | 
            +
                            ].append(to_unscale)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    for _, per_dtype_grads in per_device_and_dtype_grads.items():
         | 
| 52 | 
            +
                        for grads in per_dtype_grads.values():
         | 
| 53 | 
            +
                            core._amp_foreach_non_finite_check_and_unscale_(
         | 
| 54 | 
            +
                                grads,
         | 
| 55 | 
            +
                                per_device_found_inf.get("cpu"),
         | 
| 56 | 
            +
                                per_device_inv_scale.get("cpu"),
         | 
| 57 | 
            +
                            )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                return per_device_found_inf._per_device_tensors
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            def unscale_(self, optimizer):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                Divides ("unscales") the optimizer's gradient tensors by the scale factor.
         | 
| 64 | 
            +
                :meth:`unscale_` is optional, serving cases where you need to
         | 
| 65 | 
            +
                :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
         | 
| 66 | 
            +
                between the backward pass(es) and :meth:`step`.
         | 
| 67 | 
            +
                If :meth:`unscale_` is not called explicitly,  gradients will be unscaled  automatically during :meth:`step`.
         | 
| 68 | 
            +
                Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
         | 
| 69 | 
            +
                    ...
         | 
| 70 | 
            +
                    scaler.scale(loss).backward()
         | 
| 71 | 
            +
                    scaler.unscale_(optimizer)
         | 
| 72 | 
            +
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
         | 
| 73 | 
            +
                    scaler.step(optimizer)
         | 
| 74 | 
            +
                    scaler.update()
         | 
| 75 | 
            +
                Args:
         | 
| 76 | 
            +
                    optimizer (torch.optim.Optimizer):  Optimizer that owns the gradients to be unscaled.
         | 
| 77 | 
            +
                .. warning::
         | 
| 78 | 
            +
                    :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
         | 
| 79 | 
            +
                    and only after all gradients for that optimizer's assigned parameters have been accumulated.
         | 
| 80 | 
            +
                    Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
         | 
| 81 | 
            +
                .. warning::
         | 
| 82 | 
            +
                    :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
                if not self._enabled:
         | 
| 85 | 
            +
                    return
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                self._check_scale_growth_tracker("unscale_")
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                optimizer_state = self._per_optimizer_states[id(optimizer)]
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
         | 
| 92 | 
            +
                    raise RuntimeError(
         | 
| 93 | 
            +
                        "unscale_() has already been called on this optimizer since the last update()."
         | 
| 94 | 
            +
                    )
         | 
| 95 | 
            +
                elif optimizer_state["stage"] is OptState.STEPPED:
         | 
| 96 | 
            +
                    raise RuntimeError("unscale_() is being called after step().")
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
         | 
| 99 | 
            +
                assert self._scale is not None
         | 
| 100 | 
            +
                if device_supports_fp64:
         | 
| 101 | 
            +
                    inv_scale = self._scale.double().reciprocal().float()
         | 
| 102 | 
            +
                else:
         | 
| 103 | 
            +
                    inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
         | 
| 104 | 
            +
                found_inf = torch.full(
         | 
| 105 | 
            +
                    (1,), 0.0, dtype=torch.float32, device=self._scale.device
         | 
| 106 | 
            +
                )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                optimizer_state["found_inf_per_device"] = self._unscale_grads_(
         | 
| 109 | 
            +
                    optimizer, inv_scale, found_inf, False
         | 
| 110 | 
            +
                )
         | 
| 111 | 
            +
                optimizer_state["stage"] = OptState.UNSCALED
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            def update(self, new_scale=None):
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                Updates the scale factor.
         | 
| 116 | 
            +
                If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
         | 
| 117 | 
            +
                to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
         | 
| 118 | 
            +
                the scale is multiplied by ``growth_factor`` to increase it.
         | 
| 119 | 
            +
                Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
         | 
| 120 | 
            +
                used directly, it's used to fill GradScaler's internal scale tensor. So if
         | 
| 121 | 
            +
                ``new_scale`` was a tensor, later in-place changes to that tensor will not further
         | 
| 122 | 
            +
                affect the scale GradScaler uses internally.)
         | 
| 123 | 
            +
                Args:
         | 
| 124 | 
            +
                    new_scale (float or :class:`torch.FloatTensor`, optional, default=None):  New scale factor.
         | 
| 125 | 
            +
                .. warning::
         | 
| 126 | 
            +
                    :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
         | 
| 127 | 
            +
                    been invoked for all optimizers used this iteration.
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
                if not self._enabled:
         | 
| 130 | 
            +
                    return
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                _scale, _growth_tracker = self._check_scale_growth_tracker("update")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                if new_scale is not None:
         | 
| 135 | 
            +
                    # Accept a new user-defined scale.
         | 
| 136 | 
            +
                    if isinstance(new_scale, float):
         | 
| 137 | 
            +
                        self._scale.fill_(new_scale)  # type: ignore[union-attr]
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
         | 
| 140 | 
            +
                        assert isinstance(new_scale, torch.FloatTensor), reason  # type: ignore[attr-defined]
         | 
| 141 | 
            +
                        assert new_scale.numel() == 1, reason
         | 
| 142 | 
            +
                        assert new_scale.requires_grad is False, reason
         | 
| 143 | 
            +
                        self._scale.copy_(new_scale)  # type: ignore[union-attr]
         | 
| 144 | 
            +
                else:
         | 
| 145 | 
            +
                    # Consume shared inf/nan data collected from optimizers to update the scale.
         | 
| 146 | 
            +
                    # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
         | 
| 147 | 
            +
                    found_infs = [
         | 
| 148 | 
            +
                        found_inf.to(device="cpu", non_blocking=True)
         | 
| 149 | 
            +
                        for state in self._per_optimizer_states.values()
         | 
| 150 | 
            +
                        for found_inf in state["found_inf_per_device"].values()
         | 
| 151 | 
            +
                    ]
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    assert len(found_infs) > 0, "No inf checks were recorded prior to update."
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    found_inf_combined = found_infs[0]
         | 
| 156 | 
            +
                    if len(found_infs) > 1:
         | 
| 157 | 
            +
                        for i in range(1, len(found_infs)):
         | 
| 158 | 
            +
                            found_inf_combined += found_infs[i]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    to_device = _scale.device
         | 
| 161 | 
            +
                    _scale = _scale.to("cpu")
         | 
| 162 | 
            +
                    _growth_tracker = _growth_tracker.to("cpu")
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    core._amp_update_scale_(
         | 
| 165 | 
            +
                        _scale,
         | 
| 166 | 
            +
                        _growth_tracker,
         | 
| 167 | 
            +
                        found_inf_combined,
         | 
| 168 | 
            +
                        self._growth_factor,
         | 
| 169 | 
            +
                        self._backoff_factor,
         | 
| 170 | 
            +
                        self._growth_interval,
         | 
| 171 | 
            +
                    )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    _scale = _scale.to(to_device)
         | 
| 174 | 
            +
                    _growth_tracker = _growth_tracker.to(to_device)
         | 
| 175 | 
            +
                # To prepare for next iteration, clear the data collected from optimizers this iteration.
         | 
| 176 | 
            +
                self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
            def gradscaler_init():
         | 
| 179 | 
            +
                torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
         | 
| 180 | 
            +
                torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
         | 
| 181 | 
            +
                torch.xpu.amp.GradScaler.unscale_ = unscale_
         | 
| 182 | 
            +
                torch.xpu.amp.GradScaler.update = update
         | 
| 183 | 
            +
                return torch.xpu.amp.GradScaler
         | 
    	
        library/ipex/hijacks.py
    ADDED
    
    | @@ -0,0 +1,313 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from functools import wraps
         | 
| 3 | 
            +
            from contextlib import nullcontext
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            device_supports_fp64 = torch.xpu.has_fp64_dtype()
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
         | 
| 13 | 
            +
                def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
         | 
| 14 | 
            +
                    if isinstance(device_ids, list) and len(device_ids) > 1:
         | 
| 15 | 
            +
                        print("IPEX backend doesn't support DataParallel on multiple XPU devices")
         | 
| 16 | 
            +
                    return module.to("xpu")
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
         | 
| 19 | 
            +
                return nullcontext()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            @property
         | 
| 22 | 
            +
            def is_cuda(self):
         | 
| 23 | 
            +
                return self.device.type == 'xpu' or self.device.type == 'cuda'
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def check_device(device):
         | 
| 26 | 
            +
                return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def return_xpu(device):
         | 
| 29 | 
            +
                return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            # Autocast
         | 
| 33 | 
            +
            original_autocast_init = torch.amp.autocast_mode.autocast.__init__
         | 
| 34 | 
            +
            @wraps(torch.amp.autocast_mode.autocast.__init__)
         | 
| 35 | 
            +
            def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
         | 
| 36 | 
            +
                if device_type == "cuda":
         | 
| 37 | 
            +
                    return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
         | 
| 38 | 
            +
                else:
         | 
| 39 | 
            +
                    return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            # Latent Antialias CPU Offload:
         | 
| 42 | 
            +
            original_interpolate = torch.nn.functional.interpolate
         | 
| 43 | 
            +
            @wraps(torch.nn.functional.interpolate)
         | 
| 44 | 
            +
            def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
         | 
| 45 | 
            +
                if antialias or align_corners is not None or mode == 'bicubic':
         | 
| 46 | 
            +
                    return_device = tensor.device
         | 
| 47 | 
            +
                    return_dtype = tensor.dtype
         | 
| 48 | 
            +
                    return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
         | 
| 49 | 
            +
                    align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
         | 
| 50 | 
            +
                else:
         | 
| 51 | 
            +
                    return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
         | 
| 52 | 
            +
                    align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
         | 
| 56 | 
            +
            original_from_numpy = torch.from_numpy
         | 
| 57 | 
            +
            @wraps(torch.from_numpy)
         | 
| 58 | 
            +
            def from_numpy(ndarray):
         | 
| 59 | 
            +
                if ndarray.dtype == float:
         | 
| 60 | 
            +
                    return original_from_numpy(ndarray.astype('float32'))
         | 
| 61 | 
            +
                else:
         | 
| 62 | 
            +
                    return original_from_numpy(ndarray)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            original_as_tensor = torch.as_tensor
         | 
| 65 | 
            +
            @wraps(torch.as_tensor)
         | 
| 66 | 
            +
            def as_tensor(data, dtype=None, device=None):
         | 
| 67 | 
            +
                if check_device(device):
         | 
| 68 | 
            +
                    device = return_xpu(device)
         | 
| 69 | 
            +
                if isinstance(data, np.ndarray) and data.dtype == float and not (
         | 
| 70 | 
            +
                    (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
         | 
| 71 | 
            +
                    return original_as_tensor(data, dtype=torch.float32, device=device)
         | 
| 72 | 
            +
                else:
         | 
| 73 | 
            +
                    return original_as_tensor(data, dtype=dtype, device=device)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
         | 
| 77 | 
            +
                original_torch_bmm = torch.bmm
         | 
| 78 | 
            +
                original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
         | 
| 79 | 
            +
            else:
         | 
| 80 | 
            +
                # 32 bit attention workarounds for Alchemist:
         | 
| 81 | 
            +
                try:
         | 
| 82 | 
            +
                    from .attention import torch_bmm_32_bit as original_torch_bmm
         | 
| 83 | 
            +
                    from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
         | 
| 84 | 
            +
                except Exception: # pylint: disable=broad-exception-caught
         | 
| 85 | 
            +
                    original_torch_bmm = torch.bmm
         | 
| 86 | 
            +
                    original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            # Data Type Errors:
         | 
| 90 | 
            +
            @wraps(torch.bmm)
         | 
| 91 | 
            +
            def torch_bmm(input, mat2, *, out=None):
         | 
| 92 | 
            +
                if input.dtype != mat2.dtype:
         | 
| 93 | 
            +
                    mat2 = mat2.to(input.dtype)
         | 
| 94 | 
            +
                return original_torch_bmm(input, mat2, out=out)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            @wraps(torch.nn.functional.scaled_dot_product_attention)
         | 
| 97 | 
            +
            def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
         | 
| 98 | 
            +
                if query.dtype != key.dtype:
         | 
| 99 | 
            +
                    key = key.to(dtype=query.dtype)
         | 
| 100 | 
            +
                if query.dtype != value.dtype:
         | 
| 101 | 
            +
                    value = value.to(dtype=query.dtype)
         | 
| 102 | 
            +
                if attn_mask is not None and query.dtype != attn_mask.dtype:
         | 
| 103 | 
            +
                    attn_mask = attn_mask.to(dtype=query.dtype)
         | 
| 104 | 
            +
                return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            # A1111 FP16
         | 
| 107 | 
            +
            original_functional_group_norm = torch.nn.functional.group_norm
         | 
| 108 | 
            +
            @wraps(torch.nn.functional.group_norm)
         | 
| 109 | 
            +
            def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
         | 
| 110 | 
            +
                if weight is not None and input.dtype != weight.data.dtype:
         | 
| 111 | 
            +
                    input = input.to(dtype=weight.data.dtype)
         | 
| 112 | 
            +
                if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
         | 
| 113 | 
            +
                    bias.data = bias.data.to(dtype=weight.data.dtype)
         | 
| 114 | 
            +
                return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            # A1111 BF16
         | 
| 117 | 
            +
            original_functional_layer_norm = torch.nn.functional.layer_norm
         | 
| 118 | 
            +
            @wraps(torch.nn.functional.layer_norm)
         | 
| 119 | 
            +
            def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
         | 
| 120 | 
            +
                if weight is not None and input.dtype != weight.data.dtype:
         | 
| 121 | 
            +
                    input = input.to(dtype=weight.data.dtype)
         | 
| 122 | 
            +
                if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
         | 
| 123 | 
            +
                    bias.data = bias.data.to(dtype=weight.data.dtype)
         | 
| 124 | 
            +
                return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            # Training
         | 
| 127 | 
            +
            original_functional_linear = torch.nn.functional.linear
         | 
| 128 | 
            +
            @wraps(torch.nn.functional.linear)
         | 
| 129 | 
            +
            def functional_linear(input, weight, bias=None):
         | 
| 130 | 
            +
                if input.dtype != weight.data.dtype:
         | 
| 131 | 
            +
                    input = input.to(dtype=weight.data.dtype)
         | 
| 132 | 
            +
                if bias is not None and bias.data.dtype != weight.data.dtype:
         | 
| 133 | 
            +
                    bias.data = bias.data.to(dtype=weight.data.dtype)
         | 
| 134 | 
            +
                return original_functional_linear(input, weight, bias=bias)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            original_functional_conv2d = torch.nn.functional.conv2d
         | 
| 137 | 
            +
            @wraps(torch.nn.functional.conv2d)
         | 
| 138 | 
            +
            def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
         | 
| 139 | 
            +
                if input.dtype != weight.data.dtype:
         | 
| 140 | 
            +
                    input = input.to(dtype=weight.data.dtype)
         | 
| 141 | 
            +
                if bias is not None and bias.data.dtype != weight.data.dtype:
         | 
| 142 | 
            +
                    bias.data = bias.data.to(dtype=weight.data.dtype)
         | 
| 143 | 
            +
                return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            # A1111 Embedding BF16
         | 
| 146 | 
            +
            original_torch_cat = torch.cat
         | 
| 147 | 
            +
            @wraps(torch.cat)
         | 
| 148 | 
            +
            def torch_cat(tensor, *args, **kwargs):
         | 
| 149 | 
            +
                if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
         | 
| 150 | 
            +
                    return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
         | 
| 151 | 
            +
                else:
         | 
| 152 | 
            +
                    return original_torch_cat(tensor, *args, **kwargs)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            # SwinIR BF16:
         | 
| 155 | 
            +
            original_functional_pad = torch.nn.functional.pad
         | 
| 156 | 
            +
            @wraps(torch.nn.functional.pad)
         | 
| 157 | 
            +
            def functional_pad(input, pad, mode='constant', value=None):
         | 
| 158 | 
            +
                if mode == 'reflect' and input.dtype == torch.bfloat16:
         | 
| 159 | 
            +
                    return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
         | 
| 160 | 
            +
                else:
         | 
| 161 | 
            +
                    return original_functional_pad(input, pad, mode=mode, value=value)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            original_torch_tensor = torch.tensor
         | 
| 165 | 
            +
            @wraps(torch.tensor)
         | 
| 166 | 
            +
            def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
         | 
| 167 | 
            +
                if check_device(device):
         | 
| 168 | 
            +
                    device = return_xpu(device)
         | 
| 169 | 
            +
                if not device_supports_fp64:
         | 
| 170 | 
            +
                    if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
         | 
| 171 | 
            +
                        if dtype == torch.float64:
         | 
| 172 | 
            +
                            dtype = torch.float32
         | 
| 173 | 
            +
                        elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
         | 
| 174 | 
            +
                            dtype = torch.float32
         | 
| 175 | 
            +
                return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            original_Tensor_to = torch.Tensor.to
         | 
| 178 | 
            +
            @wraps(torch.Tensor.to)
         | 
| 179 | 
            +
            def Tensor_to(self, device=None, *args, **kwargs):
         | 
| 180 | 
            +
                if check_device(device):
         | 
| 181 | 
            +
                    return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
         | 
| 182 | 
            +
                else:
         | 
| 183 | 
            +
                    return original_Tensor_to(self, device, *args, **kwargs)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            original_Tensor_cuda = torch.Tensor.cuda
         | 
| 186 | 
            +
            @wraps(torch.Tensor.cuda)
         | 
| 187 | 
            +
            def Tensor_cuda(self, device=None, *args, **kwargs):
         | 
| 188 | 
            +
                if check_device(device):
         | 
| 189 | 
            +
                    return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
         | 
| 190 | 
            +
                else:
         | 
| 191 | 
            +
                    return original_Tensor_cuda(self, device, *args, **kwargs)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            original_Tensor_pin_memory = torch.Tensor.pin_memory
         | 
| 194 | 
            +
            @wraps(torch.Tensor.pin_memory)
         | 
| 195 | 
            +
            def Tensor_pin_memory(self, device=None, *args, **kwargs):
         | 
| 196 | 
            +
                if device is None:
         | 
| 197 | 
            +
                    device = "xpu"
         | 
| 198 | 
            +
                if check_device(device):
         | 
| 199 | 
            +
                    return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
         | 
| 200 | 
            +
                else:
         | 
| 201 | 
            +
                    return original_Tensor_pin_memory(self, device, *args, **kwargs)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            original_UntypedStorage_init = torch.UntypedStorage.__init__
         | 
| 204 | 
            +
            @wraps(torch.UntypedStorage.__init__)
         | 
| 205 | 
            +
            def UntypedStorage_init(*args, device=None, **kwargs):
         | 
| 206 | 
            +
                if check_device(device):
         | 
| 207 | 
            +
                    return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
         | 
| 208 | 
            +
                else:
         | 
| 209 | 
            +
                    return original_UntypedStorage_init(*args, device=device, **kwargs)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            original_UntypedStorage_cuda = torch.UntypedStorage.cuda
         | 
| 212 | 
            +
            @wraps(torch.UntypedStorage.cuda)
         | 
| 213 | 
            +
            def UntypedStorage_cuda(self, device=None, *args, **kwargs):
         | 
| 214 | 
            +
                if check_device(device):
         | 
| 215 | 
            +
                    return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
         | 
| 216 | 
            +
                else:
         | 
| 217 | 
            +
                    return original_UntypedStorage_cuda(self, device, *args, **kwargs)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
            original_torch_empty = torch.empty
         | 
| 220 | 
            +
            @wraps(torch.empty)
         | 
| 221 | 
            +
            def torch_empty(*args, device=None, **kwargs):
         | 
| 222 | 
            +
                if check_device(device):
         | 
| 223 | 
            +
                    return original_torch_empty(*args, device=return_xpu(device), **kwargs)
         | 
| 224 | 
            +
                else:
         | 
| 225 | 
            +
                    return original_torch_empty(*args, device=device, **kwargs)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
            original_torch_randn = torch.randn
         | 
| 228 | 
            +
            @wraps(torch.randn)
         | 
| 229 | 
            +
            def torch_randn(*args, device=None, dtype=None, **kwargs):
         | 
| 230 | 
            +
                if dtype == bytes:
         | 
| 231 | 
            +
                    dtype = None
         | 
| 232 | 
            +
                if check_device(device):
         | 
| 233 | 
            +
                    return original_torch_randn(*args, device=return_xpu(device), **kwargs)
         | 
| 234 | 
            +
                else:
         | 
| 235 | 
            +
                    return original_torch_randn(*args, device=device, **kwargs)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
            original_torch_ones = torch.ones
         | 
| 238 | 
            +
            @wraps(torch.ones)
         | 
| 239 | 
            +
            def torch_ones(*args, device=None, **kwargs):
         | 
| 240 | 
            +
                if check_device(device):
         | 
| 241 | 
            +
                    return original_torch_ones(*args, device=return_xpu(device), **kwargs)
         | 
| 242 | 
            +
                else:
         | 
| 243 | 
            +
                    return original_torch_ones(*args, device=device, **kwargs)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
            original_torch_zeros = torch.zeros
         | 
| 246 | 
            +
            @wraps(torch.zeros)
         | 
| 247 | 
            +
            def torch_zeros(*args, device=None, **kwargs):
         | 
| 248 | 
            +
                if check_device(device):
         | 
| 249 | 
            +
                    return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
         | 
| 250 | 
            +
                else:
         | 
| 251 | 
            +
                    return original_torch_zeros(*args, device=device, **kwargs)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
            original_torch_linspace = torch.linspace
         | 
| 254 | 
            +
            @wraps(torch.linspace)
         | 
| 255 | 
            +
            def torch_linspace(*args, device=None, **kwargs):
         | 
| 256 | 
            +
                if check_device(device):
         | 
| 257 | 
            +
                    return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
         | 
| 258 | 
            +
                else:
         | 
| 259 | 
            +
                    return original_torch_linspace(*args, device=device, **kwargs)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
            original_torch_Generator = torch.Generator
         | 
| 262 | 
            +
            @wraps(torch.Generator)
         | 
| 263 | 
            +
            def torch_Generator(device=None):
         | 
| 264 | 
            +
                if check_device(device):
         | 
| 265 | 
            +
                    return original_torch_Generator(return_xpu(device))
         | 
| 266 | 
            +
                else:
         | 
| 267 | 
            +
                    return original_torch_Generator(device)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
            original_torch_load = torch.load
         | 
| 270 | 
            +
            @wraps(torch.load)
         | 
| 271 | 
            +
            def torch_load(f, map_location=None, *args, **kwargs):
         | 
| 272 | 
            +
                if map_location is None:
         | 
| 273 | 
            +
                    map_location = "xpu"
         | 
| 274 | 
            +
                if check_device(map_location):
         | 
| 275 | 
            +
                    return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
         | 
| 276 | 
            +
                else:
         | 
| 277 | 
            +
                    return original_torch_load(f, *args, map_location=map_location, **kwargs)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            # Hijack Functions:
         | 
| 281 | 
            +
            def ipex_hijacks():
         | 
| 282 | 
            +
                torch.tensor = torch_tensor
         | 
| 283 | 
            +
                torch.Tensor.to = Tensor_to
         | 
| 284 | 
            +
                torch.Tensor.cuda = Tensor_cuda
         | 
| 285 | 
            +
                torch.Tensor.pin_memory = Tensor_pin_memory
         | 
| 286 | 
            +
                torch.UntypedStorage.__init__ = UntypedStorage_init
         | 
| 287 | 
            +
                torch.UntypedStorage.cuda = UntypedStorage_cuda
         | 
| 288 | 
            +
                torch.empty = torch_empty
         | 
| 289 | 
            +
                torch.randn = torch_randn
         | 
| 290 | 
            +
                torch.ones = torch_ones
         | 
| 291 | 
            +
                torch.zeros = torch_zeros
         | 
| 292 | 
            +
                torch.linspace = torch_linspace
         | 
| 293 | 
            +
                torch.Generator = torch_Generator
         | 
| 294 | 
            +
                torch.load = torch_load
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                torch.backends.cuda.sdp_kernel = return_null_context
         | 
| 297 | 
            +
                torch.nn.DataParallel = DummyDataParallel
         | 
| 298 | 
            +
                torch.UntypedStorage.is_cuda = is_cuda
         | 
| 299 | 
            +
                torch.amp.autocast_mode.autocast.__init__ = autocast_init
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
         | 
| 302 | 
            +
                torch.nn.functional.group_norm = functional_group_norm
         | 
| 303 | 
            +
                torch.nn.functional.layer_norm = functional_layer_norm
         | 
| 304 | 
            +
                torch.nn.functional.linear = functional_linear
         | 
| 305 | 
            +
                torch.nn.functional.conv2d = functional_conv2d
         | 
| 306 | 
            +
                torch.nn.functional.interpolate = interpolate
         | 
| 307 | 
            +
                torch.nn.functional.pad = functional_pad
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                torch.bmm = torch_bmm
         | 
| 310 | 
            +
                torch.cat = torch_cat
         | 
| 311 | 
            +
                if not device_supports_fp64:
         | 
| 312 | 
            +
                    torch.from_numpy = from_numpy
         | 
| 313 | 
            +
                    torch.as_tensor = as_tensor
         | 
    	
        library/lpw_stable_diffusion.py
    ADDED
    
    | @@ -0,0 +1,1233 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
         | 
| 2 | 
            +
            # and modify to support SD2.x
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import inspect
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
            from typing import Callable, List, Optional, Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import PIL.Image
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from packaging import version
         | 
| 12 | 
            +
            from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import diffusers
         | 
| 15 | 
            +
            from diffusers import SchedulerMixin, StableDiffusionPipeline
         | 
| 16 | 
            +
            from diffusers.models import AutoencoderKL, UNet2DConditionModel
         | 
| 17 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
         | 
| 18 | 
            +
            from diffusers.utils import logging
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            try:
         | 
| 21 | 
            +
                from diffusers.utils import PIL_INTERPOLATION
         | 
| 22 | 
            +
            except ImportError:
         | 
| 23 | 
            +
                if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
         | 
| 24 | 
            +
                    PIL_INTERPOLATION = {
         | 
| 25 | 
            +
                        "linear": PIL.Image.Resampling.BILINEAR,
         | 
| 26 | 
            +
                        "bilinear": PIL.Image.Resampling.BILINEAR,
         | 
| 27 | 
            +
                        "bicubic": PIL.Image.Resampling.BICUBIC,
         | 
| 28 | 
            +
                        "lanczos": PIL.Image.Resampling.LANCZOS,
         | 
| 29 | 
            +
                        "nearest": PIL.Image.Resampling.NEAREST,
         | 
| 30 | 
            +
                    }
         | 
| 31 | 
            +
                else:
         | 
| 32 | 
            +
                    PIL_INTERPOLATION = {
         | 
| 33 | 
            +
                        "linear": PIL.Image.LINEAR,
         | 
| 34 | 
            +
                        "bilinear": PIL.Image.BILINEAR,
         | 
| 35 | 
            +
                        "bicubic": PIL.Image.BICUBIC,
         | 
| 36 | 
            +
                        "lanczos": PIL.Image.LANCZOS,
         | 
| 37 | 
            +
                        "nearest": PIL.Image.NEAREST,
         | 
| 38 | 
            +
                    }
         | 
| 39 | 
            +
            # ------------------------------------------------------------------------------
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            re_attention = re.compile(
         | 
| 44 | 
            +
                r"""
         | 
| 45 | 
            +
            \\\(|
         | 
| 46 | 
            +
            \\\)|
         | 
| 47 | 
            +
            \\\[|
         | 
| 48 | 
            +
            \\]|
         | 
| 49 | 
            +
            \\\\|
         | 
| 50 | 
            +
            \\|
         | 
| 51 | 
            +
            \(|
         | 
| 52 | 
            +
            \[|
         | 
| 53 | 
            +
            :([+-]?[.\d]+)\)|
         | 
| 54 | 
            +
            \)|
         | 
| 55 | 
            +
            ]|
         | 
| 56 | 
            +
            [^\\()\[\]:]+|
         | 
| 57 | 
            +
            :
         | 
| 58 | 
            +
            """,
         | 
| 59 | 
            +
                re.X,
         | 
| 60 | 
            +
            )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def parse_prompt_attention(text):
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
         | 
| 66 | 
            +
                Accepted tokens are:
         | 
| 67 | 
            +
                  (abc) - increases attention to abc by a multiplier of 1.1
         | 
| 68 | 
            +
                  (abc:3.12) - increases attention to abc by a multiplier of 3.12
         | 
| 69 | 
            +
                  [abc] - decreases attention to abc by a multiplier of 1.1
         | 
| 70 | 
            +
                  \( - literal character '('
         | 
| 71 | 
            +
                  \[ - literal character '['
         | 
| 72 | 
            +
                  \) - literal character ')'
         | 
| 73 | 
            +
                  \] - literal character ']'
         | 
| 74 | 
            +
                  \\ - literal character '\'
         | 
| 75 | 
            +
                  anything else - just text
         | 
| 76 | 
            +
                >>> parse_prompt_attention('normal text')
         | 
| 77 | 
            +
                [['normal text', 1.0]]
         | 
| 78 | 
            +
                >>> parse_prompt_attention('an (important) word')
         | 
| 79 | 
            +
                [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
         | 
| 80 | 
            +
                >>> parse_prompt_attention('(unbalanced')
         | 
| 81 | 
            +
                [['unbalanced', 1.1]]
         | 
| 82 | 
            +
                >>> parse_prompt_attention('\(literal\]')
         | 
| 83 | 
            +
                [['(literal]', 1.0]]
         | 
| 84 | 
            +
                >>> parse_prompt_attention('(unnecessary)(parens)')
         | 
| 85 | 
            +
                [['unnecessaryparens', 1.1]]
         | 
| 86 | 
            +
                >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
         | 
| 87 | 
            +
                [['a ', 1.0],
         | 
| 88 | 
            +
                 ['house', 1.5730000000000004],
         | 
| 89 | 
            +
                 [' ', 1.1],
         | 
| 90 | 
            +
                 ['on', 1.0],
         | 
| 91 | 
            +
                 [' a ', 1.1],
         | 
| 92 | 
            +
                 ['hill', 0.55],
         | 
| 93 | 
            +
                 [', sun, ', 1.1],
         | 
| 94 | 
            +
                 ['sky', 1.4641000000000006],
         | 
| 95 | 
            +
                 ['.', 1.1]]
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                res = []
         | 
| 99 | 
            +
                round_brackets = []
         | 
| 100 | 
            +
                square_brackets = []
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                round_bracket_multiplier = 1.1
         | 
| 103 | 
            +
                square_bracket_multiplier = 1 / 1.1
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def multiply_range(start_position, multiplier):
         | 
| 106 | 
            +
                    for p in range(start_position, len(res)):
         | 
| 107 | 
            +
                        res[p][1] *= multiplier
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                for m in re_attention.finditer(text):
         | 
| 110 | 
            +
                    text = m.group(0)
         | 
| 111 | 
            +
                    weight = m.group(1)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if text.startswith("\\"):
         | 
| 114 | 
            +
                        res.append([text[1:], 1.0])
         | 
| 115 | 
            +
                    elif text == "(":
         | 
| 116 | 
            +
                        round_brackets.append(len(res))
         | 
| 117 | 
            +
                    elif text == "[":
         | 
| 118 | 
            +
                        square_brackets.append(len(res))
         | 
| 119 | 
            +
                    elif weight is not None and len(round_brackets) > 0:
         | 
| 120 | 
            +
                        multiply_range(round_brackets.pop(), float(weight))
         | 
| 121 | 
            +
                    elif text == ")" and len(round_brackets) > 0:
         | 
| 122 | 
            +
                        multiply_range(round_brackets.pop(), round_bracket_multiplier)
         | 
| 123 | 
            +
                    elif text == "]" and len(square_brackets) > 0:
         | 
| 124 | 
            +
                        multiply_range(square_brackets.pop(), square_bracket_multiplier)
         | 
| 125 | 
            +
                    else:
         | 
| 126 | 
            +
                        res.append([text, 1.0])
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                for pos in round_brackets:
         | 
| 129 | 
            +
                    multiply_range(pos, round_bracket_multiplier)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                for pos in square_brackets:
         | 
| 132 | 
            +
                    multiply_range(pos, square_bracket_multiplier)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                if len(res) == 0:
         | 
| 135 | 
            +
                    res = [["", 1.0]]
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                # merge runs of identical weights
         | 
| 138 | 
            +
                i = 0
         | 
| 139 | 
            +
                while i + 1 < len(res):
         | 
| 140 | 
            +
                    if res[i][1] == res[i + 1][1]:
         | 
| 141 | 
            +
                        res[i][0] += res[i + 1][0]
         | 
| 142 | 
            +
                        res.pop(i + 1)
         | 
| 143 | 
            +
                    else:
         | 
| 144 | 
            +
                        i += 1
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                return res
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
         | 
| 150 | 
            +
                r"""
         | 
| 151 | 
            +
                Tokenize a list of prompts and return its tokens with weights of each token.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                No padding, starting or ending token is included.
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                tokens = []
         | 
| 156 | 
            +
                weights = []
         | 
| 157 | 
            +
                truncated = False
         | 
| 158 | 
            +
                for text in prompt:
         | 
| 159 | 
            +
                    texts_and_weights = parse_prompt_attention(text)
         | 
| 160 | 
            +
                    text_token = []
         | 
| 161 | 
            +
                    text_weight = []
         | 
| 162 | 
            +
                    for word, weight in texts_and_weights:
         | 
| 163 | 
            +
                        # tokenize and discard the starting and the ending token
         | 
| 164 | 
            +
                        token = pipe.tokenizer(word).input_ids[1:-1]
         | 
| 165 | 
            +
                        text_token += token
         | 
| 166 | 
            +
                        # copy the weight by length of token
         | 
| 167 | 
            +
                        text_weight += [weight] * len(token)
         | 
| 168 | 
            +
                        # stop if the text is too long (longer than truncation limit)
         | 
| 169 | 
            +
                        if len(text_token) > max_length:
         | 
| 170 | 
            +
                            truncated = True
         | 
| 171 | 
            +
                            break
         | 
| 172 | 
            +
                    # truncate
         | 
| 173 | 
            +
                    if len(text_token) > max_length:
         | 
| 174 | 
            +
                        truncated = True
         | 
| 175 | 
            +
                        text_token = text_token[:max_length]
         | 
| 176 | 
            +
                        text_weight = text_weight[:max_length]
         | 
| 177 | 
            +
                    tokens.append(text_token)
         | 
| 178 | 
            +
                    weights.append(text_weight)
         | 
| 179 | 
            +
                if truncated:
         | 
| 180 | 
            +
                    logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
         | 
| 181 | 
            +
                return tokens, weights
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
         | 
| 185 | 
            +
                r"""
         | 
| 186 | 
            +
                Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
         | 
| 187 | 
            +
                """
         | 
| 188 | 
            +
                max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
         | 
| 189 | 
            +
                weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
         | 
| 190 | 
            +
                for i in range(len(tokens)):
         | 
| 191 | 
            +
                    tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
         | 
| 192 | 
            +
                    if no_boseos_middle:
         | 
| 193 | 
            +
                        weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
         | 
| 194 | 
            +
                    else:
         | 
| 195 | 
            +
                        w = []
         | 
| 196 | 
            +
                        if len(weights[i]) == 0:
         | 
| 197 | 
            +
                            w = [1.0] * weights_length
         | 
| 198 | 
            +
                        else:
         | 
| 199 | 
            +
                            for j in range(max_embeddings_multiples):
         | 
| 200 | 
            +
                                w.append(1.0)  # weight for starting token in this chunk
         | 
| 201 | 
            +
                                w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
         | 
| 202 | 
            +
                                w.append(1.0)  # weight for ending token in this chunk
         | 
| 203 | 
            +
                            w += [1.0] * (weights_length - len(w))
         | 
| 204 | 
            +
                        weights[i] = w[:]
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                return tokens, weights
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            def get_unweighted_text_embeddings(
         | 
| 210 | 
            +
                pipe: StableDiffusionPipeline,
         | 
| 211 | 
            +
                text_input: torch.Tensor,
         | 
| 212 | 
            +
                chunk_length: int,
         | 
| 213 | 
            +
                clip_skip: int,
         | 
| 214 | 
            +
                eos: int,
         | 
| 215 | 
            +
                pad: int,
         | 
| 216 | 
            +
                no_boseos_middle: Optional[bool] = True,
         | 
| 217 | 
            +
            ):
         | 
| 218 | 
            +
                """
         | 
| 219 | 
            +
                When the length of tokens is a multiple of the capacity of the text encoder,
         | 
| 220 | 
            +
                it should be split into chunks and sent to the text encoder individually.
         | 
| 221 | 
            +
                """
         | 
| 222 | 
            +
                max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
         | 
| 223 | 
            +
                if max_embeddings_multiples > 1:
         | 
| 224 | 
            +
                    text_embeddings = []
         | 
| 225 | 
            +
                    for i in range(max_embeddings_multiples):
         | 
| 226 | 
            +
                        # extract the i-th chunk
         | 
| 227 | 
            +
                        text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                        # cover the head and the tail by the starting and the ending tokens
         | 
| 230 | 
            +
                        text_input_chunk[:, 0] = text_input[0, 0]
         | 
| 231 | 
            +
                        if pad == eos:  # v1
         | 
| 232 | 
            +
                            text_input_chunk[:, -1] = text_input[0, -1]
         | 
| 233 | 
            +
                        else:  # v2
         | 
| 234 | 
            +
                            for j in range(len(text_input_chunk)):
         | 
| 235 | 
            +
                                if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad:  # 最後に普通の文字がある
         | 
| 236 | 
            +
                                    text_input_chunk[j, -1] = eos
         | 
| 237 | 
            +
                                if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
         | 
| 238 | 
            +
                                    text_input_chunk[j, 1] = eos
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        if clip_skip is None or clip_skip == 1:
         | 
| 241 | 
            +
                            text_embedding = pipe.text_encoder(text_input_chunk)[0]
         | 
| 242 | 
            +
                        else:
         | 
| 243 | 
            +
                            enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
         | 
| 244 | 
            +
                            text_embedding = enc_out["hidden_states"][-clip_skip]
         | 
| 245 | 
            +
                            text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                        if no_boseos_middle:
         | 
| 248 | 
            +
                            if i == 0:
         | 
| 249 | 
            +
                                # discard the ending token
         | 
| 250 | 
            +
                                text_embedding = text_embedding[:, :-1]
         | 
| 251 | 
            +
                            elif i == max_embeddings_multiples - 1:
         | 
| 252 | 
            +
                                # discard the starting token
         | 
| 253 | 
            +
                                text_embedding = text_embedding[:, 1:]
         | 
| 254 | 
            +
                            else:
         | 
| 255 | 
            +
                                # discard both starting and ending tokens
         | 
| 256 | 
            +
                                text_embedding = text_embedding[:, 1:-1]
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                        text_embeddings.append(text_embedding)
         | 
| 259 | 
            +
                    text_embeddings = torch.concat(text_embeddings, axis=1)
         | 
| 260 | 
            +
                else:
         | 
| 261 | 
            +
                    if clip_skip is None or clip_skip == 1:
         | 
| 262 | 
            +
                        text_embeddings = pipe.text_encoder(text_input)[0]
         | 
| 263 | 
            +
                    else:
         | 
| 264 | 
            +
                        enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
         | 
| 265 | 
            +
                        text_embeddings = enc_out["hidden_states"][-clip_skip]
         | 
| 266 | 
            +
                        text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
         | 
| 267 | 
            +
                return text_embeddings
         | 
| 268 | 
            +
             | 
| 269 | 
            +
             | 
| 270 | 
            +
            def get_weighted_text_embeddings(
         | 
| 271 | 
            +
                pipe: StableDiffusionPipeline,
         | 
| 272 | 
            +
                prompt: Union[str, List[str]],
         | 
| 273 | 
            +
                uncond_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 274 | 
            +
                max_embeddings_multiples: Optional[int] = 3,
         | 
| 275 | 
            +
                no_boseos_middle: Optional[bool] = False,
         | 
| 276 | 
            +
                skip_parsing: Optional[bool] = False,
         | 
| 277 | 
            +
                skip_weighting: Optional[bool] = False,
         | 
| 278 | 
            +
                clip_skip=None,
         | 
| 279 | 
            +
            ):
         | 
| 280 | 
            +
                r"""
         | 
| 281 | 
            +
                Prompts can be assigned with local weights using brackets. For example,
         | 
| 282 | 
            +
                prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
         | 
| 283 | 
            +
                and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                Args:
         | 
| 288 | 
            +
                    pipe (`StableDiffusionPipeline`):
         | 
| 289 | 
            +
                        Pipe to provide access to the tokenizer and the text encoder.
         | 
| 290 | 
            +
                    prompt (`str` or `List[str]`):
         | 
| 291 | 
            +
                        The prompt or prompts to guide the image generation.
         | 
| 292 | 
            +
                    uncond_prompt (`str` or `List[str]`):
         | 
| 293 | 
            +
                        The unconditional prompt or prompts for guide the image generation. If unconditional prompt
         | 
| 294 | 
            +
                        is provided, the embeddings of prompt and uncond_prompt are concatenated.
         | 
| 295 | 
            +
                    max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 296 | 
            +
                        The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 297 | 
            +
                    no_boseos_middle (`bool`, *optional*, defaults to `False`):
         | 
| 298 | 
            +
                        If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
         | 
| 299 | 
            +
                        ending token in each of the chunk in the middle.
         | 
| 300 | 
            +
                    skip_parsing (`bool`, *optional*, defaults to `False`):
         | 
| 301 | 
            +
                        Skip the parsing of brackets.
         | 
| 302 | 
            +
                    skip_weighting (`bool`, *optional*, defaults to `False`):
         | 
| 303 | 
            +
                        Skip the weighting. When the parsing is skipped, it is forced True.
         | 
| 304 | 
            +
                """
         | 
| 305 | 
            +
                max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 306 | 
            +
                if isinstance(prompt, str):
         | 
| 307 | 
            +
                    prompt = [prompt]
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                if not skip_parsing:
         | 
| 310 | 
            +
                    prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
         | 
| 311 | 
            +
                    if uncond_prompt is not None:
         | 
| 312 | 
            +
                        if isinstance(uncond_prompt, str):
         | 
| 313 | 
            +
                            uncond_prompt = [uncond_prompt]
         | 
| 314 | 
            +
                        uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
         | 
| 315 | 
            +
                else:
         | 
| 316 | 
            +
                    prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
         | 
| 317 | 
            +
                    prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
         | 
| 318 | 
            +
                    if uncond_prompt is not None:
         | 
| 319 | 
            +
                        if isinstance(uncond_prompt, str):
         | 
| 320 | 
            +
                            uncond_prompt = [uncond_prompt]
         | 
| 321 | 
            +
                        uncond_tokens = [
         | 
| 322 | 
            +
                            token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
         | 
| 323 | 
            +
                        ]
         | 
| 324 | 
            +
                        uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                # round up the longest length of tokens to a multiple of (model_max_length - 2)
         | 
| 327 | 
            +
                max_length = max([len(token) for token in prompt_tokens])
         | 
| 328 | 
            +
                if uncond_prompt is not None:
         | 
| 329 | 
            +
                    max_length = max(max_length, max([len(token) for token in uncond_tokens]))
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                max_embeddings_multiples = min(
         | 
| 332 | 
            +
                    max_embeddings_multiples,
         | 
| 333 | 
            +
                    (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
         | 
| 334 | 
            +
                )
         | 
| 335 | 
            +
                max_embeddings_multiples = max(1, max_embeddings_multiples)
         | 
| 336 | 
            +
                max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                # pad the length of tokens and weights
         | 
| 339 | 
            +
                bos = pipe.tokenizer.bos_token_id
         | 
| 340 | 
            +
                eos = pipe.tokenizer.eos_token_id
         | 
| 341 | 
            +
                pad = pipe.tokenizer.pad_token_id
         | 
| 342 | 
            +
                prompt_tokens, prompt_weights = pad_tokens_and_weights(
         | 
| 343 | 
            +
                    prompt_tokens,
         | 
| 344 | 
            +
                    prompt_weights,
         | 
| 345 | 
            +
                    max_length,
         | 
| 346 | 
            +
                    bos,
         | 
| 347 | 
            +
                    eos,
         | 
| 348 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 349 | 
            +
                    chunk_length=pipe.tokenizer.model_max_length,
         | 
| 350 | 
            +
                )
         | 
| 351 | 
            +
                prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
         | 
| 352 | 
            +
                if uncond_prompt is not None:
         | 
| 353 | 
            +
                    uncond_tokens, uncond_weights = pad_tokens_and_weights(
         | 
| 354 | 
            +
                        uncond_tokens,
         | 
| 355 | 
            +
                        uncond_weights,
         | 
| 356 | 
            +
                        max_length,
         | 
| 357 | 
            +
                        bos,
         | 
| 358 | 
            +
                        eos,
         | 
| 359 | 
            +
                        no_boseos_middle=no_boseos_middle,
         | 
| 360 | 
            +
                        chunk_length=pipe.tokenizer.model_max_length,
         | 
| 361 | 
            +
                    )
         | 
| 362 | 
            +
                    uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                # get the embeddings
         | 
| 365 | 
            +
                text_embeddings = get_unweighted_text_embeddings(
         | 
| 366 | 
            +
                    pipe,
         | 
| 367 | 
            +
                    prompt_tokens,
         | 
| 368 | 
            +
                    pipe.tokenizer.model_max_length,
         | 
| 369 | 
            +
                    clip_skip,
         | 
| 370 | 
            +
                    eos,
         | 
| 371 | 
            +
                    pad,
         | 
| 372 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 373 | 
            +
                )
         | 
| 374 | 
            +
                prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
         | 
| 375 | 
            +
                if uncond_prompt is not None:
         | 
| 376 | 
            +
                    uncond_embeddings = get_unweighted_text_embeddings(
         | 
| 377 | 
            +
                        pipe,
         | 
| 378 | 
            +
                        uncond_tokens,
         | 
| 379 | 
            +
                        pipe.tokenizer.model_max_length,
         | 
| 380 | 
            +
                        clip_skip,
         | 
| 381 | 
            +
                        eos,
         | 
| 382 | 
            +
                        pad,
         | 
| 383 | 
            +
                        no_boseos_middle=no_boseos_middle,
         | 
| 384 | 
            +
                    )
         | 
| 385 | 
            +
                    uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                # assign weights to the prompts and normalize in the sense of mean
         | 
| 388 | 
            +
                # TODO: should we normalize by chunk or in a whole (current implementation)?
         | 
| 389 | 
            +
                if (not skip_parsing) and (not skip_weighting):
         | 
| 390 | 
            +
                    previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 391 | 
            +
                    text_embeddings *= prompt_weights.unsqueeze(-1)
         | 
| 392 | 
            +
                    current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 393 | 
            +
                    text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
         | 
| 394 | 
            +
                    if uncond_prompt is not None:
         | 
| 395 | 
            +
                        previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
         | 
| 396 | 
            +
                        uncond_embeddings *= uncond_weights.unsqueeze(-1)
         | 
| 397 | 
            +
                        current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
         | 
| 398 | 
            +
                        uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                if uncond_prompt is not None:
         | 
| 401 | 
            +
                    return text_embeddings, uncond_embeddings
         | 
| 402 | 
            +
                return text_embeddings, None
         | 
| 403 | 
            +
             | 
| 404 | 
            +
             | 
| 405 | 
            +
            def preprocess_image(image):
         | 
| 406 | 
            +
                w, h = image.size
         | 
| 407 | 
            +
                w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
         | 
| 408 | 
            +
                image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
         | 
| 409 | 
            +
                image = np.array(image).astype(np.float32) / 255.0
         | 
| 410 | 
            +
                image = image[None].transpose(0, 3, 1, 2)
         | 
| 411 | 
            +
                image = torch.from_numpy(image)
         | 
| 412 | 
            +
                return 2.0 * image - 1.0
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 415 | 
            +
            def preprocess_mask(mask, scale_factor=8):
         | 
| 416 | 
            +
                mask = mask.convert("L")
         | 
| 417 | 
            +
                w, h = mask.size
         | 
| 418 | 
            +
                w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
         | 
| 419 | 
            +
                mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
         | 
| 420 | 
            +
                mask = np.array(mask).astype(np.float32) / 255.0
         | 
| 421 | 
            +
                mask = np.tile(mask, (4, 1, 1))
         | 
| 422 | 
            +
                mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?
         | 
| 423 | 
            +
                mask = 1 - mask  # repaint white, keep black
         | 
| 424 | 
            +
                mask = torch.from_numpy(mask)
         | 
| 425 | 
            +
                return mask
         | 
| 426 | 
            +
             | 
| 427 | 
            +
             | 
| 428 | 
            +
            def prepare_controlnet_image(
         | 
| 429 | 
            +
                image: PIL.Image.Image,
         | 
| 430 | 
            +
                width: int,
         | 
| 431 | 
            +
                height: int,
         | 
| 432 | 
            +
                batch_size: int,
         | 
| 433 | 
            +
                num_images_per_prompt: int,
         | 
| 434 | 
            +
                device: torch.device,
         | 
| 435 | 
            +
                dtype: torch.dtype,
         | 
| 436 | 
            +
                do_classifier_free_guidance: bool = False,
         | 
| 437 | 
            +
                guess_mode: bool = False,
         | 
| 438 | 
            +
            ):
         | 
| 439 | 
            +
                if not isinstance(image, torch.Tensor):
         | 
| 440 | 
            +
                    if isinstance(image, PIL.Image.Image):
         | 
| 441 | 
            +
                        image = [image]
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    if isinstance(image[0], PIL.Image.Image):
         | 
| 444 | 
            +
                        images = []
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                        for image_ in image:
         | 
| 447 | 
            +
                            image_ = image_.convert("RGB")
         | 
| 448 | 
            +
                            image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
         | 
| 449 | 
            +
                            image_ = np.array(image_)
         | 
| 450 | 
            +
                            image_ = image_[None, :]
         | 
| 451 | 
            +
                            images.append(image_)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                        image = images
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                        image = np.concatenate(image, axis=0)
         | 
| 456 | 
            +
                        image = np.array(image).astype(np.float32) / 255.0
         | 
| 457 | 
            +
                        image = image.transpose(0, 3, 1, 2)
         | 
| 458 | 
            +
                        image = torch.from_numpy(image)
         | 
| 459 | 
            +
                    elif isinstance(image[0], torch.Tensor):
         | 
| 460 | 
            +
                        image = torch.cat(image, dim=0)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                image_batch_size = image.shape[0]
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                if image_batch_size == 1:
         | 
| 465 | 
            +
                    repeat_by = batch_size
         | 
| 466 | 
            +
                else:
         | 
| 467 | 
            +
                    # image batch size is the same as prompt batch size
         | 
| 468 | 
            +
                    repeat_by = num_images_per_prompt
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                image = image.repeat_interleave(repeat_by, dim=0)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                image = image.to(device=device, dtype=dtype)
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                if do_classifier_free_guidance and not guess_mode:
         | 
| 475 | 
            +
                    image = torch.cat([image] * 2)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                return image
         | 
| 478 | 
            +
             | 
| 479 | 
            +
             | 
| 480 | 
            +
            class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
         | 
| 481 | 
            +
                r"""
         | 
| 482 | 
            +
                Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
         | 
| 483 | 
            +
                weighting in prompt.
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 486 | 
            +
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                Args:
         | 
| 489 | 
            +
                    vae ([`AutoencoderKL`]):
         | 
| 490 | 
            +
                        Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
         | 
| 491 | 
            +
                    text_encoder ([`CLIPTextModel`]):
         | 
| 492 | 
            +
                        Frozen text-encoder. Stable Diffusion uses the text portion of
         | 
| 493 | 
            +
                        [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
         | 
| 494 | 
            +
                        the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
         | 
| 495 | 
            +
                    tokenizer (`CLIPTokenizer`):
         | 
| 496 | 
            +
                        Tokenizer of class
         | 
| 497 | 
            +
                        [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
         | 
| 498 | 
            +
                    unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
         | 
| 499 | 
            +
                    scheduler ([`SchedulerMixin`]):
         | 
| 500 | 
            +
                        A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
         | 
| 501 | 
            +
                        [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
         | 
| 502 | 
            +
                    safety_checker ([`StableDiffusionSafetyChecker`]):
         | 
| 503 | 
            +
                        Classification module that estimates whether generated images could be considered offensive or harmful.
         | 
| 504 | 
            +
                        Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
         | 
| 505 | 
            +
                    feature_extractor ([`CLIPFeatureExtractor`]):
         | 
| 506 | 
            +
                        Model that extracts features from generated images to be used as inputs for the `safety_checker`.
         | 
| 507 | 
            +
                """
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                def __init__(
         | 
| 512 | 
            +
                    self,
         | 
| 513 | 
            +
                    vae: AutoencoderKL,
         | 
| 514 | 
            +
                    text_encoder: CLIPTextModel,
         | 
| 515 | 
            +
                    tokenizer: CLIPTokenizer,
         | 
| 516 | 
            +
                    unet: UNet2DConditionModel,
         | 
| 517 | 
            +
                    scheduler: SchedulerMixin,
         | 
| 518 | 
            +
                    # clip_skip: int,
         | 
| 519 | 
            +
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 520 | 
            +
                    feature_extractor: CLIPFeatureExtractor,
         | 
| 521 | 
            +
                    requires_safety_checker: bool = True,
         | 
| 522 | 
            +
                    image_encoder: CLIPVisionModelWithProjection = None,
         | 
| 523 | 
            +
                    clip_skip: int = 1,
         | 
| 524 | 
            +
                ):
         | 
| 525 | 
            +
                    super().__init__(
         | 
| 526 | 
            +
                        vae=vae,
         | 
| 527 | 
            +
                        text_encoder=text_encoder,
         | 
| 528 | 
            +
                        tokenizer=tokenizer,
         | 
| 529 | 
            +
                        unet=unet,
         | 
| 530 | 
            +
                        scheduler=scheduler,
         | 
| 531 | 
            +
                        safety_checker=safety_checker,
         | 
| 532 | 
            +
                        feature_extractor=feature_extractor,
         | 
| 533 | 
            +
                        requires_safety_checker=requires_safety_checker,
         | 
| 534 | 
            +
                        image_encoder=image_encoder,
         | 
| 535 | 
            +
                    )
         | 
| 536 | 
            +
                    self.custom_clip_skip = clip_skip
         | 
| 537 | 
            +
                    self.__init__additional__()
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                def __init__additional__(self):
         | 
| 540 | 
            +
                    if not hasattr(self, "vae_scale_factor"):
         | 
| 541 | 
            +
                        setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                @property
         | 
| 544 | 
            +
                def _execution_device(self):
         | 
| 545 | 
            +
                    r"""
         | 
| 546 | 
            +
                    Returns the device on which the pipeline's models will be executed. After calling
         | 
| 547 | 
            +
                    `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
         | 
| 548 | 
            +
                    hooks.
         | 
| 549 | 
            +
                    """
         | 
| 550 | 
            +
                    if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
         | 
| 551 | 
            +
                        return self.device
         | 
| 552 | 
            +
                    for module in self.unet.modules():
         | 
| 553 | 
            +
                        if (
         | 
| 554 | 
            +
                            hasattr(module, "_hf_hook")
         | 
| 555 | 
            +
                            and hasattr(module._hf_hook, "execution_device")
         | 
| 556 | 
            +
                            and module._hf_hook.execution_device is not None
         | 
| 557 | 
            +
                        ):
         | 
| 558 | 
            +
                            return torch.device(module._hf_hook.execution_device)
         | 
| 559 | 
            +
                    return self.device
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                def _encode_prompt(
         | 
| 562 | 
            +
                    self,
         | 
| 563 | 
            +
                    prompt,
         | 
| 564 | 
            +
                    device,
         | 
| 565 | 
            +
                    num_images_per_prompt,
         | 
| 566 | 
            +
                    do_classifier_free_guidance,
         | 
| 567 | 
            +
                    negative_prompt,
         | 
| 568 | 
            +
                    max_embeddings_multiples,
         | 
| 569 | 
            +
                ):
         | 
| 570 | 
            +
                    r"""
         | 
| 571 | 
            +
                    Encodes the prompt into text encoder hidden states.
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                    Args:
         | 
| 574 | 
            +
                        prompt (`str` or `list(int)`):
         | 
| 575 | 
            +
                            prompt to be encoded
         | 
| 576 | 
            +
                        device: (`torch.device`):
         | 
| 577 | 
            +
                            torch device
         | 
| 578 | 
            +
                        num_images_per_prompt (`int`):
         | 
| 579 | 
            +
                            number of images that should be generated per prompt
         | 
| 580 | 
            +
                        do_classifier_free_guidance (`bool`):
         | 
| 581 | 
            +
                            whether to use classifier free guidance or not
         | 
| 582 | 
            +
                        negative_prompt (`str` or `List[str]`):
         | 
| 583 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 584 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 585 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 586 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 587 | 
            +
                    """
         | 
| 588 | 
            +
                    batch_size = len(prompt) if isinstance(prompt, list) else 1
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    if negative_prompt is None:
         | 
| 591 | 
            +
                        negative_prompt = [""] * batch_size
         | 
| 592 | 
            +
                    elif isinstance(negative_prompt, str):
         | 
| 593 | 
            +
                        negative_prompt = [negative_prompt] * batch_size
         | 
| 594 | 
            +
                    if batch_size != len(negative_prompt):
         | 
| 595 | 
            +
                        raise ValueError(
         | 
| 596 | 
            +
                            f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 597 | 
            +
                            f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 598 | 
            +
                            " the batch size of `prompt`."
         | 
| 599 | 
            +
                        )
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                    text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
         | 
| 602 | 
            +
                        pipe=self,
         | 
| 603 | 
            +
                        prompt=prompt,
         | 
| 604 | 
            +
                        uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
         | 
| 605 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 606 | 
            +
                        clip_skip=self.custom_clip_skip,
         | 
| 607 | 
            +
                    )
         | 
| 608 | 
            +
                    bs_embed, seq_len, _ = text_embeddings.shape
         | 
| 609 | 
            +
                    text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
         | 
| 610 | 
            +
                    text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    if do_classifier_free_guidance:
         | 
| 613 | 
            +
                        bs_embed, seq_len, _ = uncond_embeddings.shape
         | 
| 614 | 
            +
                        uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
         | 
| 615 | 
            +
                        uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 616 | 
            +
                        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    return text_embeddings
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                def check_inputs(self, prompt, height, width, strength, callback_steps):
         | 
| 621 | 
            +
                    if not isinstance(prompt, str) and not isinstance(prompt, list):
         | 
| 622 | 
            +
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                    if strength < 0 or strength > 1:
         | 
| 625 | 
            +
                        raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 628 | 
            +
                        logger.info(f'{height} {width}')
         | 
| 629 | 
            +
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                    if (callback_steps is None) or (
         | 
| 632 | 
            +
                        callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
         | 
| 633 | 
            +
                    ):
         | 
| 634 | 
            +
                        raise ValueError(
         | 
| 635 | 
            +
                            f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
         | 
| 636 | 
            +
                        )
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
         | 
| 639 | 
            +
                    if is_text2img:
         | 
| 640 | 
            +
                        return self.scheduler.timesteps.to(device), num_inference_steps
         | 
| 641 | 
            +
                    else:
         | 
| 642 | 
            +
                        # get the original timestep using init_timestep
         | 
| 643 | 
            +
                        offset = self.scheduler.config.get("steps_offset", 0)
         | 
| 644 | 
            +
                        init_timestep = int(num_inference_steps * strength) + offset
         | 
| 645 | 
            +
                        init_timestep = min(init_timestep, num_inference_steps)
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                        t_start = max(num_inference_steps - init_timestep + offset, 0)
         | 
| 648 | 
            +
                        timesteps = self.scheduler.timesteps[t_start:].to(device)
         | 
| 649 | 
            +
                        return timesteps, num_inference_steps - t_start
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                def run_safety_checker(self, image, device, dtype):
         | 
| 652 | 
            +
                    if self.safety_checker is not None:
         | 
| 653 | 
            +
                        safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
         | 
| 654 | 
            +
                        image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
         | 
| 655 | 
            +
                    else:
         | 
| 656 | 
            +
                        has_nsfw_concept = None
         | 
| 657 | 
            +
                    return image, has_nsfw_concept
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                def decode_latents(self, latents):
         | 
| 660 | 
            +
                    latents = 1 / 0.18215 * latents
         | 
| 661 | 
            +
                    image = self.vae.decode(latents).sample
         | 
| 662 | 
            +
                    image = (image / 2 + 0.5).clamp(0, 1)
         | 
| 663 | 
            +
                    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
         | 
| 664 | 
            +
                    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
         | 
| 665 | 
            +
                    return image
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 668 | 
            +
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 669 | 
            +
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 670 | 
            +
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 671 | 
            +
                    # and should be between [0, 1]
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 674 | 
            +
                    extra_step_kwargs = {}
         | 
| 675 | 
            +
                    if accepts_eta:
         | 
| 676 | 
            +
                        extra_step_kwargs["eta"] = eta
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                    # check if the scheduler accepts generator
         | 
| 679 | 
            +
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 680 | 
            +
                    if accepts_generator:
         | 
| 681 | 
            +
                        extra_step_kwargs["generator"] = generator
         | 
| 682 | 
            +
                    return extra_step_kwargs
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
         | 
| 685 | 
            +
                    if image is None:
         | 
| 686 | 
            +
                        shape = (
         | 
| 687 | 
            +
                            batch_size,
         | 
| 688 | 
            +
                            self.unet.in_channels,
         | 
| 689 | 
            +
                            height // self.vae_scale_factor,
         | 
| 690 | 
            +
                            width // self.vae_scale_factor,
         | 
| 691 | 
            +
                        )
         | 
| 692 | 
            +
             | 
| 693 | 
            +
                        if latents is None:
         | 
| 694 | 
            +
                            if device.type == "mps":
         | 
| 695 | 
            +
                                # randn does not work reproducibly on mps
         | 
| 696 | 
            +
                                latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
         | 
| 697 | 
            +
                            else:
         | 
| 698 | 
            +
                                latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
         | 
| 699 | 
            +
                        else:
         | 
| 700 | 
            +
                            if latents.shape != shape:
         | 
| 701 | 
            +
                                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
         | 
| 702 | 
            +
                            latents = latents.to(device)
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                        # scale the initial noise by the standard deviation required by the scheduler
         | 
| 705 | 
            +
                        latents = latents * self.scheduler.init_noise_sigma
         | 
| 706 | 
            +
                        return latents, None, None
         | 
| 707 | 
            +
                    else:
         | 
| 708 | 
            +
                        init_latent_dist = self.vae.encode(image).latent_dist
         | 
| 709 | 
            +
                        init_latents = init_latent_dist.sample(generator=generator)
         | 
| 710 | 
            +
                        init_latents = 0.18215 * init_latents
         | 
| 711 | 
            +
                        init_latents = torch.cat([init_latents] * batch_size, dim=0)
         | 
| 712 | 
            +
                        init_latents_orig = init_latents
         | 
| 713 | 
            +
                        shape = init_latents.shape
         | 
| 714 | 
            +
             | 
| 715 | 
            +
                        # add noise to latents using the timesteps
         | 
| 716 | 
            +
                        if device.type == "mps":
         | 
| 717 | 
            +
                            noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
         | 
| 718 | 
            +
                        else:
         | 
| 719 | 
            +
                            noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
         | 
| 720 | 
            +
                        latents = self.scheduler.add_noise(init_latents, noise, timestep)
         | 
| 721 | 
            +
                        return latents, init_latents_orig, noise
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                @torch.no_grad()
         | 
| 724 | 
            +
                def __call__(
         | 
| 725 | 
            +
                    self,
         | 
| 726 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 727 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 728 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 729 | 
            +
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 730 | 
            +
                    height: int = 512,
         | 
| 731 | 
            +
                    width: int = 512,
         | 
| 732 | 
            +
                    num_inference_steps: int = 50,
         | 
| 733 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 734 | 
            +
                    strength: float = 0.8,
         | 
| 735 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 736 | 
            +
                    eta: float = 0.0,
         | 
| 737 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 738 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 739 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 740 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 741 | 
            +
                    return_dict: bool = True,
         | 
| 742 | 
            +
                    controlnet=None,
         | 
| 743 | 
            +
                    controlnet_image=None,
         | 
| 744 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 745 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 746 | 
            +
                    callback_steps: int = 1,
         | 
| 747 | 
            +
                ):
         | 
| 748 | 
            +
                    r"""
         | 
| 749 | 
            +
                    Function invoked when calling the pipeline for generation.
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                    Args:
         | 
| 752 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 753 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 754 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 755 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 756 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 757 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 758 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 759 | 
            +
                            process.
         | 
| 760 | 
            +
                        mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 761 | 
            +
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 762 | 
            +
                            replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
         | 
| 763 | 
            +
                            PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
         | 
| 764 | 
            +
                            contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
         | 
| 765 | 
            +
                        height (`int`, *optional*, defaults to 512):
         | 
| 766 | 
            +
                            The height in pixels of the generated image.
         | 
| 767 | 
            +
                        width (`int`, *optional*, defaults to 512):
         | 
| 768 | 
            +
                            The width in pixels of the generated image.
         | 
| 769 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 770 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 771 | 
            +
                            expense of slower inference.
         | 
| 772 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 773 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 774 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 775 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 776 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 777 | 
            +
                            usually at the expense of lower image quality.
         | 
| 778 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 779 | 
            +
                            Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
         | 
| 780 | 
            +
                            `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
         | 
| 781 | 
            +
                            number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
         | 
| 782 | 
            +
                            noise will be maximum and the denoising process will run for the full number of iterations specified in
         | 
| 783 | 
            +
                            `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
         | 
| 784 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 785 | 
            +
                            The number of images to generate per prompt.
         | 
| 786 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 787 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 788 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 789 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 790 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 791 | 
            +
                            deterministic.
         | 
| 792 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 793 | 
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 794 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 795 | 
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 796 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 797 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 798 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 799 | 
            +
                            The output format of the generate image. Choose between
         | 
| 800 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 801 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 802 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 803 | 
            +
                            plain tuple.
         | 
| 804 | 
            +
                        controlnet (`diffusers.ControlNetModel`, *optional*):
         | 
| 805 | 
            +
                            A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
         | 
| 806 | 
            +
                        controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
         | 
| 807 | 
            +
                            `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
         | 
| 808 | 
            +
                            inference.
         | 
| 809 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 810 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 811 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 812 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 813 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 814 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 815 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 816 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 817 | 
            +
                            called at every step.
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                    Returns:
         | 
| 820 | 
            +
                        `None` if cancelled by `is_cancelled_callback`,
         | 
| 821 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 822 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 823 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 824 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 825 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 826 | 
            +
                    """
         | 
| 827 | 
            +
                    if controlnet is not None and controlnet_image is None:
         | 
| 828 | 
            +
                        raise ValueError("controlnet_image must be provided if controlnet is not None.")
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                    # 0. Default height and width to unet
         | 
| 831 | 
            +
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 832 | 
            +
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                    # 1. Check inputs. Raise error if not correct
         | 
| 835 | 
            +
                    self.check_inputs(prompt, height, width, strength, callback_steps)
         | 
| 836 | 
            +
             | 
| 837 | 
            +
                    # 2. Define call parameters
         | 
| 838 | 
            +
                    batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 839 | 
            +
                    device = self._execution_device
         | 
| 840 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 841 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 842 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 843 | 
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                    # 3. Encode input prompt
         | 
| 846 | 
            +
                    text_embeddings = self._encode_prompt(
         | 
| 847 | 
            +
                        prompt,
         | 
| 848 | 
            +
                        device,
         | 
| 849 | 
            +
                        num_images_per_prompt,
         | 
| 850 | 
            +
                        do_classifier_free_guidance,
         | 
| 851 | 
            +
                        negative_prompt,
         | 
| 852 | 
            +
                        max_embeddings_multiples,
         | 
| 853 | 
            +
                    )
         | 
| 854 | 
            +
                    dtype = text_embeddings.dtype
         | 
| 855 | 
            +
             | 
| 856 | 
            +
                    # 4. Preprocess image and mask
         | 
| 857 | 
            +
                    if isinstance(image, PIL.Image.Image):
         | 
| 858 | 
            +
                        image = preprocess_image(image)
         | 
| 859 | 
            +
                    if image is not None:
         | 
| 860 | 
            +
                        image = image.to(device=self.device, dtype=dtype)
         | 
| 861 | 
            +
                    if isinstance(mask_image, PIL.Image.Image):
         | 
| 862 | 
            +
                        mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
         | 
| 863 | 
            +
                    if mask_image is not None:
         | 
| 864 | 
            +
                        mask = mask_image.to(device=self.device, dtype=dtype)
         | 
| 865 | 
            +
                        mask = torch.cat([mask] * batch_size * num_images_per_prompt)
         | 
| 866 | 
            +
                    else:
         | 
| 867 | 
            +
                        mask = None
         | 
| 868 | 
            +
             | 
| 869 | 
            +
                    if controlnet_image is not None:
         | 
| 870 | 
            +
                        controlnet_image = prepare_controlnet_image(
         | 
| 871 | 
            +
                            controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
         | 
| 872 | 
            +
                        )
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    # 5. set timesteps
         | 
| 875 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 876 | 
            +
                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
         | 
| 877 | 
            +
                    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                    # 6. Prepare latent variables
         | 
| 880 | 
            +
                    latents, init_latents_orig, noise = self.prepare_latents(
         | 
| 881 | 
            +
                        image,
         | 
| 882 | 
            +
                        latent_timestep,
         | 
| 883 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 884 | 
            +
                        height,
         | 
| 885 | 
            +
                        width,
         | 
| 886 | 
            +
                        dtype,
         | 
| 887 | 
            +
                        device,
         | 
| 888 | 
            +
                        generator,
         | 
| 889 | 
            +
                        latents,
         | 
| 890 | 
            +
                    )
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                    # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 893 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                    # 8. Denoising loop
         | 
| 896 | 
            +
                    for i, t in enumerate(self.progress_bar(timesteps)):
         | 
| 897 | 
            +
                        # expand the latents if we are doing classifier free guidance
         | 
| 898 | 
            +
                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         | 
| 899 | 
            +
                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                        unet_additional_args = {}
         | 
| 902 | 
            +
                        if controlnet is not None:
         | 
| 903 | 
            +
                            down_block_res_samples, mid_block_res_sample = controlnet(
         | 
| 904 | 
            +
                                latent_model_input,
         | 
| 905 | 
            +
                                t,
         | 
| 906 | 
            +
                                encoder_hidden_states=text_embeddings,
         | 
| 907 | 
            +
                                controlnet_cond=controlnet_image,
         | 
| 908 | 
            +
                                conditioning_scale=1.0,
         | 
| 909 | 
            +
                                guess_mode=False,
         | 
| 910 | 
            +
                                return_dict=False,
         | 
| 911 | 
            +
                            )
         | 
| 912 | 
            +
                            unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
         | 
| 913 | 
            +
                            unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
         | 
| 914 | 
            +
             | 
| 915 | 
            +
                        # predict the noise residual
         | 
| 916 | 
            +
                        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                        # perform guidance
         | 
| 919 | 
            +
                        if do_classifier_free_guidance:
         | 
| 920 | 
            +
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 921 | 
            +
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                        # compute the previous noisy sample x_t -> x_t-1
         | 
| 924 | 
            +
                        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                        if mask is not None:
         | 
| 927 | 
            +
                            # masking
         | 
| 928 | 
            +
                            init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
         | 
| 929 | 
            +
                            latents = (init_latents_proper * mask) + (latents * (1 - mask))
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                        # call the callback, if provided
         | 
| 932 | 
            +
                        if i % callback_steps == 0:
         | 
| 933 | 
            +
                            if callback is not None:
         | 
| 934 | 
            +
                                callback(i, t, latents)
         | 
| 935 | 
            +
                            if is_cancelled_callback is not None and is_cancelled_callback():
         | 
| 936 | 
            +
                                return None
         | 
| 937 | 
            +
             | 
| 938 | 
            +
                    return latents
         | 
| 939 | 
            +
             | 
| 940 | 
            +
                def latents_to_image(self, latents):
         | 
| 941 | 
            +
                    # 9. Post-processing
         | 
| 942 | 
            +
                    image = self.decode_latents(latents.to(self.vae.dtype))
         | 
| 943 | 
            +
                    image = self.numpy_to_pil(image)
         | 
| 944 | 
            +
                    return image
         | 
| 945 | 
            +
             | 
| 946 | 
            +
                def text2img(
         | 
| 947 | 
            +
                    self,
         | 
| 948 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 949 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 950 | 
            +
                    height: int = 512,
         | 
| 951 | 
            +
                    width: int = 512,
         | 
| 952 | 
            +
                    num_inference_steps: int = 50,
         | 
| 953 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 954 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 955 | 
            +
                    eta: float = 0.0,
         | 
| 956 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 957 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 958 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 959 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 960 | 
            +
                    return_dict: bool = True,
         | 
| 961 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 962 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 963 | 
            +
                    callback_steps: int = 1,
         | 
| 964 | 
            +
                ):
         | 
| 965 | 
            +
                    r"""
         | 
| 966 | 
            +
                    Function for text-to-image generation.
         | 
| 967 | 
            +
                    Args:
         | 
| 968 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 969 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 970 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 971 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 972 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 973 | 
            +
                        height (`int`, *optional*, defaults to 512):
         | 
| 974 | 
            +
                            The height in pixels of the generated image.
         | 
| 975 | 
            +
                        width (`int`, *optional*, defaults to 512):
         | 
| 976 | 
            +
                            The width in pixels of the generated image.
         | 
| 977 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 978 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 979 | 
            +
                            expense of slower inference.
         | 
| 980 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 981 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 982 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 983 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 984 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 985 | 
            +
                            usually at the expense of lower image quality.
         | 
| 986 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 987 | 
            +
                            The number of images to generate per prompt.
         | 
| 988 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 989 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 990 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 991 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 992 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 993 | 
            +
                            deterministic.
         | 
| 994 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 995 | 
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 996 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 997 | 
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 998 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 999 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1000 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1001 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1002 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1003 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1004 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1005 | 
            +
                            plain tuple.
         | 
| 1006 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1007 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1008 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1009 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1010 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1011 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1012 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1013 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1014 | 
            +
                            called at every step.
         | 
| 1015 | 
            +
                    Returns:
         | 
| 1016 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1017 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1018 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1019 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1020 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1021 | 
            +
                    """
         | 
| 1022 | 
            +
                    return self.__call__(
         | 
| 1023 | 
            +
                        prompt=prompt,
         | 
| 1024 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1025 | 
            +
                        height=height,
         | 
| 1026 | 
            +
                        width=width,
         | 
| 1027 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1028 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1029 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1030 | 
            +
                        eta=eta,
         | 
| 1031 | 
            +
                        generator=generator,
         | 
| 1032 | 
            +
                        latents=latents,
         | 
| 1033 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1034 | 
            +
                        output_type=output_type,
         | 
| 1035 | 
            +
                        return_dict=return_dict,
         | 
| 1036 | 
            +
                        callback=callback,
         | 
| 1037 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1038 | 
            +
                        callback_steps=callback_steps,
         | 
| 1039 | 
            +
                    )
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                def img2img(
         | 
| 1042 | 
            +
                    self,
         | 
| 1043 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1044 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 1045 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 1046 | 
            +
                    strength: float = 0.8,
         | 
| 1047 | 
            +
                    num_inference_steps: Optional[int] = 50,
         | 
| 1048 | 
            +
                    guidance_scale: Optional[float] = 7.5,
         | 
| 1049 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 1050 | 
            +
                    eta: Optional[float] = 0.0,
         | 
| 1051 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 1052 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 1053 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 1054 | 
            +
                    return_dict: bool = True,
         | 
| 1055 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 1056 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 1057 | 
            +
                    callback_steps: int = 1,
         | 
| 1058 | 
            +
                ):
         | 
| 1059 | 
            +
                    r"""
         | 
| 1060 | 
            +
                    Function for image-to-image generation.
         | 
| 1061 | 
            +
                    Args:
         | 
| 1062 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1063 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 1064 | 
            +
                            process.
         | 
| 1065 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 1066 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 1067 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 1068 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 1069 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 1070 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 1071 | 
            +
                            Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
         | 
| 1072 | 
            +
                            `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
         | 
| 1073 | 
            +
                            number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
         | 
| 1074 | 
            +
                            noise will be maximum and the denoising process will run for the full number of iterations specified in
         | 
| 1075 | 
            +
                            `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
         | 
| 1076 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 1077 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 1078 | 
            +
                            expense of slower inference. This parameter will be modulated by `strength`.
         | 
| 1079 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 1080 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 1081 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 1082 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 1083 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 1084 | 
            +
                            usually at the expense of lower image quality.
         | 
| 1085 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 1086 | 
            +
                            The number of images to generate per prompt.
         | 
| 1087 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 1088 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 1089 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 1090 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 1091 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 1092 | 
            +
                            deterministic.
         | 
| 1093 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 1094 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1095 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1096 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1097 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1098 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1099 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1100 | 
            +
                            plain tuple.
         | 
| 1101 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1102 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1103 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1104 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1105 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1106 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1107 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1108 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1109 | 
            +
                            called at every step.
         | 
| 1110 | 
            +
                    Returns:
         | 
| 1111 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1112 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1113 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1114 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1115 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1116 | 
            +
                    """
         | 
| 1117 | 
            +
                    return self.__call__(
         | 
| 1118 | 
            +
                        prompt=prompt,
         | 
| 1119 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1120 | 
            +
                        image=image,
         | 
| 1121 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1122 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1123 | 
            +
                        strength=strength,
         | 
| 1124 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1125 | 
            +
                        eta=eta,
         | 
| 1126 | 
            +
                        generator=generator,
         | 
| 1127 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1128 | 
            +
                        output_type=output_type,
         | 
| 1129 | 
            +
                        return_dict=return_dict,
         | 
| 1130 | 
            +
                        callback=callback,
         | 
| 1131 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1132 | 
            +
                        callback_steps=callback_steps,
         | 
| 1133 | 
            +
                    )
         | 
| 1134 | 
            +
             | 
| 1135 | 
            +
                def inpaint(
         | 
| 1136 | 
            +
                    self,
         | 
| 1137 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1138 | 
            +
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1139 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 1140 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 1141 | 
            +
                    strength: float = 0.8,
         | 
| 1142 | 
            +
                    num_inference_steps: Optional[int] = 50,
         | 
| 1143 | 
            +
                    guidance_scale: Optional[float] = 7.5,
         | 
| 1144 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 1145 | 
            +
                    eta: Optional[float] = 0.0,
         | 
| 1146 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 1147 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 1148 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 1149 | 
            +
                    return_dict: bool = True,
         | 
| 1150 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 1151 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 1152 | 
            +
                    callback_steps: int = 1,
         | 
| 1153 | 
            +
                ):
         | 
| 1154 | 
            +
                    r"""
         | 
| 1155 | 
            +
                    Function for inpaint.
         | 
| 1156 | 
            +
                    Args:
         | 
| 1157 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1158 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 1159 | 
            +
                            process. This is the image whose masked region will be inpainted.
         | 
| 1160 | 
            +
                        mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1161 | 
            +
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 1162 | 
            +
                            replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
         | 
| 1163 | 
            +
                            PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
         | 
| 1164 | 
            +
                            contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
         | 
| 1165 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 1166 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 1167 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 1168 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 1169 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 1170 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 1171 | 
            +
                            Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
         | 
| 1172 | 
            +
                            is 1, the denoising process will be run on the masked area for the full number of iterations specified
         | 
| 1173 | 
            +
                            in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
         | 
| 1174 | 
            +
                            noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
         | 
| 1175 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 1176 | 
            +
                            The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
         | 
| 1177 | 
            +
                            the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
         | 
| 1178 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 1179 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 1180 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 1181 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 1182 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 1183 | 
            +
                            usually at the expense of lower image quality.
         | 
| 1184 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 1185 | 
            +
                            The number of images to generate per prompt.
         | 
| 1186 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 1187 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 1188 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 1189 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 1190 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 1191 | 
            +
                            deterministic.
         | 
| 1192 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 1193 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1194 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1195 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1196 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1197 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1198 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1199 | 
            +
                            plain tuple.
         | 
| 1200 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1201 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1202 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1203 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1204 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1205 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1206 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1207 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1208 | 
            +
                            called at every step.
         | 
| 1209 | 
            +
                    Returns:
         | 
| 1210 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1211 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1212 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1213 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1214 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1215 | 
            +
                    """
         | 
| 1216 | 
            +
                    return self.__call__(
         | 
| 1217 | 
            +
                        prompt=prompt,
         | 
| 1218 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1219 | 
            +
                        image=image,
         | 
| 1220 | 
            +
                        mask_image=mask_image,
         | 
| 1221 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1222 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1223 | 
            +
                        strength=strength,
         | 
| 1224 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1225 | 
            +
                        eta=eta,
         | 
| 1226 | 
            +
                        generator=generator,
         | 
| 1227 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1228 | 
            +
                        output_type=output_type,
         | 
| 1229 | 
            +
                        return_dict=return_dict,
         | 
| 1230 | 
            +
                        callback=callback,
         | 
| 1231 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1232 | 
            +
                        callback_steps=callback_steps,
         | 
| 1233 | 
            +
                    )
         | 
    	
        library/model_util.py
    ADDED
    
    | @@ -0,0 +1,1356 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # v1: split from train_db_fixed.py.
         | 
| 2 | 
            +
            # v2: support safetensors
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from library.device_utils import init_ipex
         | 
| 9 | 
            +
            init_ipex()
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import diffusers
         | 
| 12 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
         | 
| 13 | 
            +
            from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline  # , UNet2DConditionModel
         | 
| 14 | 
            +
            from safetensors.torch import load_file, save_file
         | 
| 15 | 
            +
            from library.original_unet import UNet2DConditionModel
         | 
| 16 | 
            +
            from library.utils import setup_logging
         | 
| 17 | 
            +
            setup_logging()
         | 
| 18 | 
            +
            import logging
         | 
| 19 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # DiffUsers版StableDiffusionのモデルパラメータ
         | 
| 22 | 
            +
            NUM_TRAIN_TIMESTEPS = 1000
         | 
| 23 | 
            +
            BETA_START = 0.00085
         | 
| 24 | 
            +
            BETA_END = 0.0120
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            UNET_PARAMS_MODEL_CHANNELS = 320
         | 
| 27 | 
            +
            UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
         | 
| 28 | 
            +
            UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
         | 
| 29 | 
            +
            UNET_PARAMS_IMAGE_SIZE = 64  # fixed from old invalid value `32`
         | 
| 30 | 
            +
            UNET_PARAMS_IN_CHANNELS = 4
         | 
| 31 | 
            +
            UNET_PARAMS_OUT_CHANNELS = 4
         | 
| 32 | 
            +
            UNET_PARAMS_NUM_RES_BLOCKS = 2
         | 
| 33 | 
            +
            UNET_PARAMS_CONTEXT_DIM = 768
         | 
| 34 | 
            +
            UNET_PARAMS_NUM_HEADS = 8
         | 
| 35 | 
            +
            # UNET_PARAMS_USE_LINEAR_PROJECTION = False
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            VAE_PARAMS_Z_CHANNELS = 4
         | 
| 38 | 
            +
            VAE_PARAMS_RESOLUTION = 256
         | 
| 39 | 
            +
            VAE_PARAMS_IN_CHANNELS = 3
         | 
| 40 | 
            +
            VAE_PARAMS_OUT_CH = 3
         | 
| 41 | 
            +
            VAE_PARAMS_CH = 128
         | 
| 42 | 
            +
            VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
         | 
| 43 | 
            +
            VAE_PARAMS_NUM_RES_BLOCKS = 2
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            # V2
         | 
| 46 | 
            +
            V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
         | 
| 47 | 
            +
            V2_UNET_PARAMS_CONTEXT_DIM = 1024
         | 
| 48 | 
            +
            # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            # Diffusersの設定を読み込むための参照モデル
         | 
| 51 | 
            +
            DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
         | 
| 52 | 
            +
            DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            # region StableDiffusion->Diffusersの変換コード
         | 
| 56 | 
            +
            # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def shave_segments(path, n_shave_prefix_segments=1):
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                Removes segments. Positive values shave the first segments, negative shave the last segments.
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                if n_shave_prefix_segments >= 0:
         | 
| 64 | 
            +
                    return ".".join(path.split(".")[n_shave_prefix_segments:])
         | 
| 65 | 
            +
                else:
         | 
| 66 | 
            +
                    return ".".join(path.split(".")[:n_shave_prefix_segments])
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
                Updates paths inside resnets to the new naming scheme (local renaming)
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                mapping = []
         | 
| 74 | 
            +
                for old_item in old_list:
         | 
| 75 | 
            +
                    new_item = old_item.replace("in_layers.0", "norm1")
         | 
| 76 | 
            +
                    new_item = new_item.replace("in_layers.2", "conv1")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    new_item = new_item.replace("out_layers.0", "norm2")
         | 
| 79 | 
            +
                    new_item = new_item.replace("out_layers.3", "conv2")
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    new_item = new_item.replace("emb_layers.1", "time_emb_proj")
         | 
| 82 | 
            +
                    new_item = new_item.replace("skip_connection", "conv_shortcut")
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                return mapping
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
         | 
| 92 | 
            +
                """
         | 
| 93 | 
            +
                Updates paths inside resnets to the new naming scheme (local renaming)
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                mapping = []
         | 
| 96 | 
            +
                for old_item in old_list:
         | 
| 97 | 
            +
                    new_item = old_item
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    new_item = new_item.replace("nin_shortcut", "conv_shortcut")
         | 
| 100 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                return mapping
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def renew_attention_paths(old_list, n_shave_prefix_segments=0):
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
                Updates paths inside attentions to the new naming scheme (local renaming)
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                mapping = []
         | 
| 112 | 
            +
                for old_item in old_list:
         | 
| 113 | 
            +
                    new_item = old_item
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    #         new_item = new_item.replace('norm.weight', 'group_norm.weight')
         | 
| 116 | 
            +
                    #         new_item = new_item.replace('norm.bias', 'group_norm.bias')
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    #         new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
         | 
| 119 | 
            +
                    #         new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    #         new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                return mapping
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                Updates paths inside attentions to the new naming scheme (local renaming)
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                mapping = []
         | 
| 133 | 
            +
                for old_item in old_list:
         | 
| 134 | 
            +
                    new_item = old_item
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    new_item = new_item.replace("norm.weight", "group_norm.weight")
         | 
| 137 | 
            +
                    new_item = new_item.replace("norm.bias", "group_norm.bias")
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    if diffusers.__version__ < "0.17.0":
         | 
| 140 | 
            +
                        new_item = new_item.replace("q.weight", "query.weight")
         | 
| 141 | 
            +
                        new_item = new_item.replace("q.bias", "query.bias")
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                        new_item = new_item.replace("k.weight", "key.weight")
         | 
| 144 | 
            +
                        new_item = new_item.replace("k.bias", "key.bias")
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                        new_item = new_item.replace("v.weight", "value.weight")
         | 
| 147 | 
            +
                        new_item = new_item.replace("v.bias", "value.bias")
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
         | 
| 150 | 
            +
                        new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
         | 
| 151 | 
            +
                    else:
         | 
| 152 | 
            +
                        new_item = new_item.replace("q.weight", "to_q.weight")
         | 
| 153 | 
            +
                        new_item = new_item.replace("q.bias", "to_q.bias")
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                        new_item = new_item.replace("k.weight", "to_k.weight")
         | 
| 156 | 
            +
                        new_item = new_item.replace("k.bias", "to_k.bias")
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                        new_item = new_item.replace("v.weight", "to_v.weight")
         | 
| 159 | 
            +
                        new_item = new_item.replace("v.bias", "to_v.bias")
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                        new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
         | 
| 162 | 
            +
                        new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                return mapping
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def assign_to_checkpoint(
         | 
| 172 | 
            +
                paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
         | 
| 173 | 
            +
            ):
         | 
| 174 | 
            +
                """
         | 
| 175 | 
            +
                This does the final conversion step: take locally converted weights and apply a global renaming
         | 
| 176 | 
            +
                to them. It splits attention layers, and takes into account additional replacements
         | 
| 177 | 
            +
                that may arise.
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                Assigns the weights to the new checkpoint.
         | 
| 180 | 
            +
                """
         | 
| 181 | 
            +
                assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                # Splits the attention layers into three variables.
         | 
| 184 | 
            +
                if attention_paths_to_split is not None:
         | 
| 185 | 
            +
                    for path, path_map in attention_paths_to_split.items():
         | 
| 186 | 
            +
                        old_tensor = old_checkpoint[path]
         | 
| 187 | 
            +
                        channels = old_tensor.shape[0] // 3
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                        target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                        num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
         | 
| 194 | 
            +
                        query, key, value = old_tensor.split(channels // num_heads, dim=1)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                        checkpoint[path_map["query"]] = query.reshape(target_shape)
         | 
| 197 | 
            +
                        checkpoint[path_map["key"]] = key.reshape(target_shape)
         | 
| 198 | 
            +
                        checkpoint[path_map["value"]] = value.reshape(target_shape)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                for path in paths:
         | 
| 201 | 
            +
                    new_path = path["new"]
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    # These have already been assigned
         | 
| 204 | 
            +
                    if attention_paths_to_split is not None and new_path in attention_paths_to_split:
         | 
| 205 | 
            +
                        continue
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # Global renaming happens here
         | 
| 208 | 
            +
                    new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
         | 
| 209 | 
            +
                    new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
         | 
| 210 | 
            +
                    new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    if additional_replacements is not None:
         | 
| 213 | 
            +
                        for replacement in additional_replacements:
         | 
| 214 | 
            +
                            new_path = new_path.replace(replacement["old"], replacement["new"])
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    # proj_attn.weight has to be converted from conv 1D to linear
         | 
| 217 | 
            +
                    reshaping = False
         | 
| 218 | 
            +
                    if diffusers.__version__ < "0.17.0":
         | 
| 219 | 
            +
                        if "proj_attn.weight" in new_path:
         | 
| 220 | 
            +
                            reshaping = True
         | 
| 221 | 
            +
                    else:
         | 
| 222 | 
            +
                        if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
         | 
| 223 | 
            +
                            reshaping = True
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    if reshaping:
         | 
| 226 | 
            +
                        checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
         | 
| 227 | 
            +
                    else:
         | 
| 228 | 
            +
                        checkpoint[new_path] = old_checkpoint[path["old"]]
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            def conv_attn_to_linear(checkpoint):
         | 
| 232 | 
            +
                keys = list(checkpoint.keys())
         | 
| 233 | 
            +
                attn_keys = ["query.weight", "key.weight", "value.weight"]
         | 
| 234 | 
            +
                for key in keys:
         | 
| 235 | 
            +
                    if ".".join(key.split(".")[-2:]) in attn_keys:
         | 
| 236 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 237 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0, 0]
         | 
| 238 | 
            +
                    elif "proj_attn.weight" in key:
         | 
| 239 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 240 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0]
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            def linear_transformer_to_conv(checkpoint):
         | 
| 244 | 
            +
                keys = list(checkpoint.keys())
         | 
| 245 | 
            +
                tf_keys = ["proj_in.weight", "proj_out.weight"]
         | 
| 246 | 
            +
                for key in keys:
         | 
| 247 | 
            +
                    if ".".join(key.split(".")[-2:]) in tf_keys:
         | 
| 248 | 
            +
                        if checkpoint[key].ndim == 2:
         | 
| 249 | 
            +
                            checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            def convert_ldm_unet_checkpoint(v2, checkpoint, config):
         | 
| 253 | 
            +
                """
         | 
| 254 | 
            +
                Takes a state dict and a config, and returns a converted checkpoint.
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                # extract state_dict for UNet
         | 
| 258 | 
            +
                unet_state_dict = {}
         | 
| 259 | 
            +
                unet_key = "model.diffusion_model."
         | 
| 260 | 
            +
                keys = list(checkpoint.keys())
         | 
| 261 | 
            +
                for key in keys:
         | 
| 262 | 
            +
                    if key.startswith(unet_key):
         | 
| 263 | 
            +
                        unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                new_checkpoint = {}
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
         | 
| 268 | 
            +
                new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
         | 
| 269 | 
            +
                new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
         | 
| 270 | 
            +
                new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
         | 
| 273 | 
            +
                new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
         | 
| 276 | 
            +
                new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
         | 
| 277 | 
            +
                new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
         | 
| 278 | 
            +
                new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                # Retrieves the keys for the input blocks only
         | 
| 281 | 
            +
                num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
         | 
| 282 | 
            +
                input_blocks = {
         | 
| 283 | 
            +
                    layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
         | 
| 284 | 
            +
                }
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                # Retrieves the keys for the middle blocks only
         | 
| 287 | 
            +
                num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
         | 
| 288 | 
            +
                middle_blocks = {
         | 
| 289 | 
            +
                    layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
         | 
| 290 | 
            +
                }
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                # Retrieves the keys for the output blocks only
         | 
| 293 | 
            +
                num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
         | 
| 294 | 
            +
                output_blocks = {
         | 
| 295 | 
            +
                    layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
         | 
| 296 | 
            +
                }
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                for i in range(1, num_input_blocks):
         | 
| 299 | 
            +
                    block_id = (i - 1) // (config["layers_per_block"] + 1)
         | 
| 300 | 
            +
                    layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
         | 
| 303 | 
            +
                    attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
         | 
| 306 | 
            +
                        new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
         | 
| 307 | 
            +
                            f"input_blocks.{i}.0.op.weight"
         | 
| 308 | 
            +
                        )
         | 
| 309 | 
            +
                        new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    paths = renew_resnet_paths(resnets)
         | 
| 312 | 
            +
                    meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
         | 
| 313 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    if len(attentions):
         | 
| 316 | 
            +
                        paths = renew_attention_paths(attentions)
         | 
| 317 | 
            +
                        meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
         | 
| 318 | 
            +
                        assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                resnet_0 = middle_blocks[0]
         | 
| 321 | 
            +
                attentions = middle_blocks[1]
         | 
| 322 | 
            +
                resnet_1 = middle_blocks[2]
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                resnet_0_paths = renew_resnet_paths(resnet_0)
         | 
| 325 | 
            +
                assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                resnet_1_paths = renew_resnet_paths(resnet_1)
         | 
| 328 | 
            +
                assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                attentions_paths = renew_attention_paths(attentions)
         | 
| 331 | 
            +
                meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
         | 
| 332 | 
            +
                assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                for i in range(num_output_blocks):
         | 
| 335 | 
            +
                    block_id = i // (config["layers_per_block"] + 1)
         | 
| 336 | 
            +
                    layer_in_block_id = i % (config["layers_per_block"] + 1)
         | 
| 337 | 
            +
                    output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
         | 
| 338 | 
            +
                    output_block_list = {}
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    for layer in output_block_layers:
         | 
| 341 | 
            +
                        layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
         | 
| 342 | 
            +
                        if layer_id in output_block_list:
         | 
| 343 | 
            +
                            output_block_list[layer_id].append(layer_name)
         | 
| 344 | 
            +
                        else:
         | 
| 345 | 
            +
                            output_block_list[layer_id] = [layer_name]
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    if len(output_block_list) > 1:
         | 
| 348 | 
            +
                        resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
         | 
| 349 | 
            +
                        attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                        resnet_0_paths = renew_resnet_paths(resnets)
         | 
| 352 | 
            +
                        paths = renew_resnet_paths(resnets)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                        meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
         | 
| 355 | 
            +
                        assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                        # オリジナル:
         | 
| 358 | 
            +
                        # if ["conv.weight", "conv.bias"] in output_block_list.values():
         | 
| 359 | 
            +
                        #   index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                        # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
         | 
| 362 | 
            +
                        for l in output_block_list.values():
         | 
| 363 | 
            +
                            l.sort()
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                        if ["conv.bias", "conv.weight"] in output_block_list.values():
         | 
| 366 | 
            +
                            index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
         | 
| 367 | 
            +
                            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
         | 
| 368 | 
            +
                                f"output_blocks.{i}.{index}.conv.bias"
         | 
| 369 | 
            +
                            ]
         | 
| 370 | 
            +
                            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
         | 
| 371 | 
            +
                                f"output_blocks.{i}.{index}.conv.weight"
         | 
| 372 | 
            +
                            ]
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                            # Clear attentions as they have been attributed above.
         | 
| 375 | 
            +
                            if len(attentions) == 2:
         | 
| 376 | 
            +
                                attentions = []
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                        if len(attentions):
         | 
| 379 | 
            +
                            paths = renew_attention_paths(attentions)
         | 
| 380 | 
            +
                            meta_path = {
         | 
| 381 | 
            +
                                "old": f"output_blocks.{i}.1",
         | 
| 382 | 
            +
                                "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
         | 
| 383 | 
            +
                            }
         | 
| 384 | 
            +
                            assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 385 | 
            +
                    else:
         | 
| 386 | 
            +
                        resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
         | 
| 387 | 
            +
                        for path in resnet_0_paths:
         | 
| 388 | 
            +
                            old_path = ".".join(["output_blocks", str(i), path["old"]])
         | 
| 389 | 
            +
                            new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                            new_checkpoint[new_path] = unet_state_dict[old_path]
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                # SDのv2では1*1のconv2dがlinearに変わっている
         | 
| 394 | 
            +
                # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
         | 
| 395 | 
            +
                if v2 and not config.get("use_linear_projection", False):
         | 
| 396 | 
            +
                    linear_transformer_to_conv(new_checkpoint)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                return new_checkpoint
         | 
| 399 | 
            +
             | 
| 400 | 
            +
             | 
| 401 | 
            +
            def convert_ldm_vae_checkpoint(checkpoint, config):
         | 
| 402 | 
            +
                # extract state dict for VAE
         | 
| 403 | 
            +
                vae_state_dict = {}
         | 
| 404 | 
            +
                vae_key = "first_stage_model."
         | 
| 405 | 
            +
                keys = list(checkpoint.keys())
         | 
| 406 | 
            +
                for key in keys:
         | 
| 407 | 
            +
                    if key.startswith(vae_key):
         | 
| 408 | 
            +
                        vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
         | 
| 409 | 
            +
                # if len(vae_state_dict) == 0:
         | 
| 410 | 
            +
                #   # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
         | 
| 411 | 
            +
                #   vae_state_dict = checkpoint
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                new_checkpoint = {}
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
         | 
| 416 | 
            +
                new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
         | 
| 417 | 
            +
                new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
         | 
| 418 | 
            +
                new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
         | 
| 419 | 
            +
                new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
         | 
| 420 | 
            +
                new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
         | 
| 423 | 
            +
                new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
         | 
| 424 | 
            +
                new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
         | 
| 425 | 
            +
                new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
         | 
| 426 | 
            +
                new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
         | 
| 427 | 
            +
                new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
         | 
| 430 | 
            +
                new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
         | 
| 431 | 
            +
                new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
         | 
| 432 | 
            +
                new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                # Retrieves the keys for the encoder down blocks only
         | 
| 435 | 
            +
                num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
         | 
| 436 | 
            +
                down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                # Retrieves the keys for the decoder up blocks only
         | 
| 439 | 
            +
                num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
         | 
| 440 | 
            +
                up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                for i in range(num_down_blocks):
         | 
| 443 | 
            +
                    resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
         | 
| 446 | 
            +
                        new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
         | 
| 447 | 
            +
                            f"encoder.down.{i}.downsample.conv.weight"
         | 
| 448 | 
            +
                        )
         | 
| 449 | 
            +
                        new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
         | 
| 450 | 
            +
                            f"encoder.down.{i}.downsample.conv.bias"
         | 
| 451 | 
            +
                        )
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 454 | 
            +
                    meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
         | 
| 455 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
         | 
| 458 | 
            +
                num_mid_res_blocks = 2
         | 
| 459 | 
            +
                for i in range(1, num_mid_res_blocks + 1):
         | 
| 460 | 
            +
                    resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 463 | 
            +
                    meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
         | 
| 464 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
         | 
| 467 | 
            +
                paths = renew_vae_attention_paths(mid_attentions)
         | 
| 468 | 
            +
                meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
         | 
| 469 | 
            +
                assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 470 | 
            +
                conv_attn_to_linear(new_checkpoint)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                for i in range(num_up_blocks):
         | 
| 473 | 
            +
                    block_id = num_up_blocks - 1 - i
         | 
| 474 | 
            +
                    resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
         | 
| 477 | 
            +
                        new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
         | 
| 478 | 
            +
                            f"decoder.up.{block_id}.upsample.conv.weight"
         | 
| 479 | 
            +
                        ]
         | 
| 480 | 
            +
                        new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
         | 
| 481 | 
            +
                            f"decoder.up.{block_id}.upsample.conv.bias"
         | 
| 482 | 
            +
                        ]
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 485 | 
            +
                    meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
         | 
| 486 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
         | 
| 489 | 
            +
                num_mid_res_blocks = 2
         | 
| 490 | 
            +
                for i in range(1, num_mid_res_blocks + 1):
         | 
| 491 | 
            +
                    resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 494 | 
            +
                    meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
         | 
| 495 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
         | 
| 498 | 
            +
                paths = renew_vae_attention_paths(mid_attentions)
         | 
| 499 | 
            +
                meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
         | 
| 500 | 
            +
                assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 501 | 
            +
                conv_attn_to_linear(new_checkpoint)
         | 
| 502 | 
            +
                return new_checkpoint
         | 
| 503 | 
            +
             | 
| 504 | 
            +
             | 
| 505 | 
            +
            def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
         | 
| 506 | 
            +
                """
         | 
| 507 | 
            +
                Creates a config for the diffusers based on the config of the LDM model.
         | 
| 508 | 
            +
                """
         | 
| 509 | 
            +
                # unet_params = original_config.model.params.unet_config.params
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                down_block_types = []
         | 
| 514 | 
            +
                resolution = 1
         | 
| 515 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 516 | 
            +
                    block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
         | 
| 517 | 
            +
                    down_block_types.append(block_type)
         | 
| 518 | 
            +
                    if i != len(block_out_channels) - 1:
         | 
| 519 | 
            +
                        resolution *= 2
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                up_block_types = []
         | 
| 522 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 523 | 
            +
                    block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
         | 
| 524 | 
            +
                    up_block_types.append(block_type)
         | 
| 525 | 
            +
                    resolution //= 2
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                config = dict(
         | 
| 528 | 
            +
                    sample_size=UNET_PARAMS_IMAGE_SIZE,
         | 
| 529 | 
            +
                    in_channels=UNET_PARAMS_IN_CHANNELS,
         | 
| 530 | 
            +
                    out_channels=UNET_PARAMS_OUT_CHANNELS,
         | 
| 531 | 
            +
                    down_block_types=tuple(down_block_types),
         | 
| 532 | 
            +
                    up_block_types=tuple(up_block_types),
         | 
| 533 | 
            +
                    block_out_channels=tuple(block_out_channels),
         | 
| 534 | 
            +
                    layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
         | 
| 535 | 
            +
                    cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
         | 
| 536 | 
            +
                    attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
         | 
| 537 | 
            +
                    # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
         | 
| 538 | 
            +
                )
         | 
| 539 | 
            +
                if v2 and use_linear_projection_in_v2:
         | 
| 540 | 
            +
                    config["use_linear_projection"] = True
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                return config
         | 
| 543 | 
            +
             | 
| 544 | 
            +
             | 
| 545 | 
            +
            def create_vae_diffusers_config():
         | 
| 546 | 
            +
                """
         | 
| 547 | 
            +
                Creates a config for the diffusers based on the config of the LDM model.
         | 
| 548 | 
            +
                """
         | 
| 549 | 
            +
                # vae_params = original_config.model.params.first_stage_config.params.ddconfig
         | 
| 550 | 
            +
                # _ = original_config.model.params.first_stage_config.params.embed_dim
         | 
| 551 | 
            +
                block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
         | 
| 552 | 
            +
                down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
         | 
| 553 | 
            +
                up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                config = dict(
         | 
| 556 | 
            +
                    sample_size=VAE_PARAMS_RESOLUTION,
         | 
| 557 | 
            +
                    in_channels=VAE_PARAMS_IN_CHANNELS,
         | 
| 558 | 
            +
                    out_channels=VAE_PARAMS_OUT_CH,
         | 
| 559 | 
            +
                    down_block_types=tuple(down_block_types),
         | 
| 560 | 
            +
                    up_block_types=tuple(up_block_types),
         | 
| 561 | 
            +
                    block_out_channels=tuple(block_out_channels),
         | 
| 562 | 
            +
                    latent_channels=VAE_PARAMS_Z_CHANNELS,
         | 
| 563 | 
            +
                    layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
         | 
| 564 | 
            +
                )
         | 
| 565 | 
            +
                return config
         | 
| 566 | 
            +
             | 
| 567 | 
            +
             | 
| 568 | 
            +
            def convert_ldm_clip_checkpoint_v1(checkpoint):
         | 
| 569 | 
            +
                keys = list(checkpoint.keys())
         | 
| 570 | 
            +
                text_model_dict = {}
         | 
| 571 | 
            +
                for key in keys:
         | 
| 572 | 
            +
                    if key.startswith("cond_stage_model.transformer"):
         | 
| 573 | 
            +
                        text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                # remove position_ids for newer transformer, which causes error :(
         | 
| 576 | 
            +
                if "text_model.embeddings.position_ids" in text_model_dict:
         | 
| 577 | 
            +
                    text_model_dict.pop("text_model.embeddings.position_ids")
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                return text_model_dict
         | 
| 580 | 
            +
             | 
| 581 | 
            +
             | 
| 582 | 
            +
            def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
         | 
| 583 | 
            +
                # 嫌になるくらい違うぞ!
         | 
| 584 | 
            +
                def convert_key(key):
         | 
| 585 | 
            +
                    if not key.startswith("cond_stage_model"):
         | 
| 586 | 
            +
                        return None
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    # common conversion
         | 
| 589 | 
            +
                    key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
         | 
| 590 | 
            +
                    key = key.replace("cond_stage_model.model.", "text_model.")
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    if "resblocks" in key:
         | 
| 593 | 
            +
                        # resblocks conversion
         | 
| 594 | 
            +
                        key = key.replace(".resblocks.", ".layers.")
         | 
| 595 | 
            +
                        if ".ln_" in key:
         | 
| 596 | 
            +
                            key = key.replace(".ln_", ".layer_norm")
         | 
| 597 | 
            +
                        elif ".mlp." in key:
         | 
| 598 | 
            +
                            key = key.replace(".c_fc.", ".fc1.")
         | 
| 599 | 
            +
                            key = key.replace(".c_proj.", ".fc2.")
         | 
| 600 | 
            +
                        elif ".attn.out_proj" in key:
         | 
| 601 | 
            +
                            key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
         | 
| 602 | 
            +
                        elif ".attn.in_proj" in key:
         | 
| 603 | 
            +
                            key = None  # 特殊なので後で処理する
         | 
| 604 | 
            +
                        else:
         | 
| 605 | 
            +
                            raise ValueError(f"unexpected key in SD: {key}")
         | 
| 606 | 
            +
                    elif ".positional_embedding" in key:
         | 
| 607 | 
            +
                        key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
         | 
| 608 | 
            +
                    elif ".text_projection" in key:
         | 
| 609 | 
            +
                        key = None  # 使われない???
         | 
| 610 | 
            +
                    elif ".logit_scale" in key:
         | 
| 611 | 
            +
                        key = None  # 使われない???
         | 
| 612 | 
            +
                    elif ".token_embedding" in key:
         | 
| 613 | 
            +
                        key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
         | 
| 614 | 
            +
                    elif ".ln_final" in key:
         | 
| 615 | 
            +
                        key = key.replace(".ln_final", ".final_layer_norm")
         | 
| 616 | 
            +
                    return key
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                keys = list(checkpoint.keys())
         | 
| 619 | 
            +
                new_sd = {}
         | 
| 620 | 
            +
                for key in keys:
         | 
| 621 | 
            +
                    # remove resblocks 23
         | 
| 622 | 
            +
                    if ".resblocks.23." in key:
         | 
| 623 | 
            +
                        continue
         | 
| 624 | 
            +
                    new_key = convert_key(key)
         | 
| 625 | 
            +
                    if new_key is None:
         | 
| 626 | 
            +
                        continue
         | 
| 627 | 
            +
                    new_sd[new_key] = checkpoint[key]
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                # attnの変換
         | 
| 630 | 
            +
                for key in keys:
         | 
| 631 | 
            +
                    if ".resblocks.23." in key:
         | 
| 632 | 
            +
                        continue
         | 
| 633 | 
            +
                    if ".resblocks" in key and ".attn.in_proj_" in key:
         | 
| 634 | 
            +
                        # 三つに分割
         | 
| 635 | 
            +
                        values = torch.chunk(checkpoint[key], 3)
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                        key_suffix = ".weight" if "weight" in key else ".bias"
         | 
| 638 | 
            +
                        key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
         | 
| 639 | 
            +
                        key_pfx = key_pfx.replace("_weight", "")
         | 
| 640 | 
            +
                        key_pfx = key_pfx.replace("_bias", "")
         | 
| 641 | 
            +
                        key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
         | 
| 642 | 
            +
                        new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
         | 
| 643 | 
            +
                        new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
         | 
| 644 | 
            +
                        new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                # rename or add position_ids
         | 
| 647 | 
            +
                ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
         | 
| 648 | 
            +
                if ANOTHER_POSITION_IDS_KEY in new_sd:
         | 
| 649 | 
            +
                    # waifu diffusion v1.4
         | 
| 650 | 
            +
                    position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
         | 
| 651 | 
            +
                    del new_sd[ANOTHER_POSITION_IDS_KEY]
         | 
| 652 | 
            +
                else:
         | 
| 653 | 
            +
                    position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                new_sd["text_model.embeddings.position_ids"] = position_ids
         | 
| 656 | 
            +
                return new_sd
         | 
| 657 | 
            +
             | 
| 658 | 
            +
             | 
| 659 | 
            +
            # endregion
         | 
| 660 | 
            +
             | 
| 661 | 
            +
             | 
| 662 | 
            +
            # region Diffusers->StableDiffusion の変換コード
         | 
| 663 | 
            +
            # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
         | 
| 664 | 
            +
             | 
| 665 | 
            +
             | 
| 666 | 
            +
            def conv_transformer_to_linear(checkpoint):
         | 
| 667 | 
            +
                keys = list(checkpoint.keys())
         | 
| 668 | 
            +
                tf_keys = ["proj_in.weight", "proj_out.weight"]
         | 
| 669 | 
            +
                for key in keys:
         | 
| 670 | 
            +
                    if ".".join(key.split(".")[-2:]) in tf_keys:
         | 
| 671 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 672 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0, 0]
         | 
| 673 | 
            +
             | 
| 674 | 
            +
             | 
| 675 | 
            +
            def convert_unet_state_dict_to_sd(v2, unet_state_dict):
         | 
| 676 | 
            +
                unet_conversion_map = [
         | 
| 677 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 678 | 
            +
                    ("time_embed.0.weight", "time_embedding.linear_1.weight"),
         | 
| 679 | 
            +
                    ("time_embed.0.bias", "time_embedding.linear_1.bias"),
         | 
| 680 | 
            +
                    ("time_embed.2.weight", "time_embedding.linear_2.weight"),
         | 
| 681 | 
            +
                    ("time_embed.2.bias", "time_embedding.linear_2.bias"),
         | 
| 682 | 
            +
                    ("input_blocks.0.0.weight", "conv_in.weight"),
         | 
| 683 | 
            +
                    ("input_blocks.0.0.bias", "conv_in.bias"),
         | 
| 684 | 
            +
                    ("out.0.weight", "conv_norm_out.weight"),
         | 
| 685 | 
            +
                    ("out.0.bias", "conv_norm_out.bias"),
         | 
| 686 | 
            +
                    ("out.2.weight", "conv_out.weight"),
         | 
| 687 | 
            +
                    ("out.2.bias", "conv_out.bias"),
         | 
| 688 | 
            +
                ]
         | 
| 689 | 
            +
             | 
| 690 | 
            +
                unet_conversion_map_resnet = [
         | 
| 691 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 692 | 
            +
                    ("in_layers.0", "norm1"),
         | 
| 693 | 
            +
                    ("in_layers.2", "conv1"),
         | 
| 694 | 
            +
                    ("out_layers.0", "norm2"),
         | 
| 695 | 
            +
                    ("out_layers.3", "conv2"),
         | 
| 696 | 
            +
                    ("emb_layers.1", "time_emb_proj"),
         | 
| 697 | 
            +
                    ("skip_connection", "conv_shortcut"),
         | 
| 698 | 
            +
                ]
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                unet_conversion_map_layer = []
         | 
| 701 | 
            +
                for i in range(4):
         | 
| 702 | 
            +
                    # loop over downblocks/upblocks
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    for j in range(2):
         | 
| 705 | 
            +
                        # loop over resnets/attentions for downblocks
         | 
| 706 | 
            +
                        hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
         | 
| 707 | 
            +
                        sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
         | 
| 708 | 
            +
                        unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                        if i < 3:
         | 
| 711 | 
            +
                            # no attention layers in down_blocks.3
         | 
| 712 | 
            +
                            hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
         | 
| 713 | 
            +
                            sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
         | 
| 714 | 
            +
                            unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    for j in range(3):
         | 
| 717 | 
            +
                        # loop over resnets/attentions for upblocks
         | 
| 718 | 
            +
                        hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
         | 
| 719 | 
            +
                        sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
         | 
| 720 | 
            +
                        unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                        if i > 0:
         | 
| 723 | 
            +
                            # no attention layers in up_blocks.0
         | 
| 724 | 
            +
                            hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
         | 
| 725 | 
            +
                            sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
         | 
| 726 | 
            +
                            unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                    if i < 3:
         | 
| 729 | 
            +
                        # no downsample in down_blocks.3
         | 
| 730 | 
            +
                        hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
         | 
| 731 | 
            +
                        sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
         | 
| 732 | 
            +
                        unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                        # no upsample in up_blocks.3
         | 
| 735 | 
            +
                        hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
         | 
| 736 | 
            +
                        sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
         | 
| 737 | 
            +
                        unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                hf_mid_atn_prefix = "mid_block.attentions.0."
         | 
| 740 | 
            +
                sd_mid_atn_prefix = "middle_block.1."
         | 
| 741 | 
            +
                unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
         | 
| 742 | 
            +
             | 
| 743 | 
            +
                for j in range(2):
         | 
| 744 | 
            +
                    hf_mid_res_prefix = f"mid_block.resnets.{j}."
         | 
| 745 | 
            +
                    sd_mid_res_prefix = f"middle_block.{2*j}."
         | 
| 746 | 
            +
                    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                # buyer beware: this is a *brittle* function,
         | 
| 749 | 
            +
                # and correct output requires that all of these pieces interact in
         | 
| 750 | 
            +
                # the exact order in which I have arranged them.
         | 
| 751 | 
            +
                mapping = {k: k for k in unet_state_dict.keys()}
         | 
| 752 | 
            +
                for sd_name, hf_name in unet_conversion_map:
         | 
| 753 | 
            +
                    mapping[hf_name] = sd_name
         | 
| 754 | 
            +
                for k, v in mapping.items():
         | 
| 755 | 
            +
                    if "resnets" in k:
         | 
| 756 | 
            +
                        for sd_part, hf_part in unet_conversion_map_resnet:
         | 
| 757 | 
            +
                            v = v.replace(hf_part, sd_part)
         | 
| 758 | 
            +
                        mapping[k] = v
         | 
| 759 | 
            +
                for k, v in mapping.items():
         | 
| 760 | 
            +
                    for sd_part, hf_part in unet_conversion_map_layer:
         | 
| 761 | 
            +
                        v = v.replace(hf_part, sd_part)
         | 
| 762 | 
            +
                    mapping[k] = v
         | 
| 763 | 
            +
                new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                if v2:
         | 
| 766 | 
            +
                    conv_transformer_to_linear(new_state_dict)
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                return new_state_dict
         | 
| 769 | 
            +
             | 
| 770 | 
            +
             | 
| 771 | 
            +
            def controlnet_conversion_map():
         | 
| 772 | 
            +
                unet_conversion_map = [
         | 
| 773 | 
            +
                    ("time_embed.0.weight", "time_embedding.linear_1.weight"),
         | 
| 774 | 
            +
                    ("time_embed.0.bias", "time_embedding.linear_1.bias"),
         | 
| 775 | 
            +
                    ("time_embed.2.weight", "time_embedding.linear_2.weight"),
         | 
| 776 | 
            +
                    ("time_embed.2.bias", "time_embedding.linear_2.bias"),
         | 
| 777 | 
            +
                    ("input_blocks.0.0.weight", "conv_in.weight"),
         | 
| 778 | 
            +
                    ("input_blocks.0.0.bias", "conv_in.bias"),
         | 
| 779 | 
            +
                    ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
         | 
| 780 | 
            +
                    ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
         | 
| 781 | 
            +
                ]
         | 
| 782 | 
            +
             | 
| 783 | 
            +
                unet_conversion_map_resnet = [
         | 
| 784 | 
            +
                    ("in_layers.0", "norm1"),
         | 
| 785 | 
            +
                    ("in_layers.2", "conv1"),
         | 
| 786 | 
            +
                    ("out_layers.0", "norm2"),
         | 
| 787 | 
            +
                    ("out_layers.3", "conv2"),
         | 
| 788 | 
            +
                    ("emb_layers.1", "time_emb_proj"),
         | 
| 789 | 
            +
                    ("skip_connection", "conv_shortcut"),
         | 
| 790 | 
            +
                ]
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                unet_conversion_map_layer = []
         | 
| 793 | 
            +
                for i in range(4):
         | 
| 794 | 
            +
                    for j in range(2):
         | 
| 795 | 
            +
                        hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
         | 
| 796 | 
            +
                        sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
         | 
| 797 | 
            +
                        unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                        if i < 3:
         | 
| 800 | 
            +
                            hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
         | 
| 801 | 
            +
                            sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
         | 
| 802 | 
            +
                            unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                    if i < 3:
         | 
| 805 | 
            +
                        hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
         | 
| 806 | 
            +
                        sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
         | 
| 807 | 
            +
                        unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
         | 
| 808 | 
            +
             | 
| 809 | 
            +
                hf_mid_atn_prefix = "mid_block.attentions.0."
         | 
| 810 | 
            +
                sd_mid_atn_prefix = "middle_block.1."
         | 
| 811 | 
            +
                unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                for j in range(2):
         | 
| 814 | 
            +
                    hf_mid_res_prefix = f"mid_block.resnets.{j}."
         | 
| 815 | 
            +
                    sd_mid_res_prefix = f"middle_block.{2*j}."
         | 
| 816 | 
            +
                    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
         | 
| 819 | 
            +
                for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
         | 
| 820 | 
            +
                    hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
         | 
| 821 | 
            +
                    sd_prefix = f"input_hint_block.{i*2}."
         | 
| 822 | 
            +
                    unet_conversion_map_layer.append((sd_prefix, hf_prefix))
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                for i in range(12):
         | 
| 825 | 
            +
                    hf_prefix = f"controlnet_down_blocks.{i}."
         | 
| 826 | 
            +
                    sd_prefix = f"zero_convs.{i}.0."
         | 
| 827 | 
            +
                    unet_conversion_map_layer.append((sd_prefix, hf_prefix))
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
         | 
| 830 | 
            +
             | 
| 831 | 
            +
             | 
| 832 | 
            +
            def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
         | 
| 833 | 
            +
                unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                mapping = {k: k for k in controlnet_state_dict.keys()}
         | 
| 836 | 
            +
                for sd_name, diffusers_name in unet_conversion_map:
         | 
| 837 | 
            +
                    mapping[diffusers_name] = sd_name
         | 
| 838 | 
            +
                for k, v in mapping.items():
         | 
| 839 | 
            +
                    if "resnets" in k:
         | 
| 840 | 
            +
                        for sd_part, diffusers_part in unet_conversion_map_resnet:
         | 
| 841 | 
            +
                            v = v.replace(diffusers_part, sd_part)
         | 
| 842 | 
            +
                        mapping[k] = v
         | 
| 843 | 
            +
                for k, v in mapping.items():
         | 
| 844 | 
            +
                    for sd_part, diffusers_part in unet_conversion_map_layer:
         | 
| 845 | 
            +
                        v = v.replace(diffusers_part, sd_part)
         | 
| 846 | 
            +
                    mapping[k] = v
         | 
| 847 | 
            +
                new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
         | 
| 848 | 
            +
                return new_state_dict
         | 
| 849 | 
            +
             | 
| 850 | 
            +
             | 
| 851 | 
            +
            def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
         | 
| 852 | 
            +
                unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                mapping = {k: k for k in controlnet_state_dict.keys()}
         | 
| 855 | 
            +
                for sd_name, diffusers_name in unet_conversion_map:
         | 
| 856 | 
            +
                    mapping[sd_name] = diffusers_name
         | 
| 857 | 
            +
                for k, v in mapping.items():
         | 
| 858 | 
            +
                    for sd_part, diffusers_part in unet_conversion_map_layer:
         | 
| 859 | 
            +
                        v = v.replace(sd_part, diffusers_part)
         | 
| 860 | 
            +
                    mapping[k] = v
         | 
| 861 | 
            +
                for k, v in mapping.items():
         | 
| 862 | 
            +
                    if "resnets" in v:
         | 
| 863 | 
            +
                        for sd_part, diffusers_part in unet_conversion_map_resnet:
         | 
| 864 | 
            +
                            v = v.replace(sd_part, diffusers_part)
         | 
| 865 | 
            +
                        mapping[k] = v
         | 
| 866 | 
            +
                new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
         | 
| 867 | 
            +
                return new_state_dict
         | 
| 868 | 
            +
             | 
| 869 | 
            +
             | 
| 870 | 
            +
            # ================#
         | 
| 871 | 
            +
            # VAE Conversion #
         | 
| 872 | 
            +
            # ================#
         | 
| 873 | 
            +
             | 
| 874 | 
            +
             | 
| 875 | 
            +
            def reshape_weight_for_sd(w):
         | 
| 876 | 
            +
                # convert HF linear weights to SD conv2d weights
         | 
| 877 | 
            +
                return w.reshape(*w.shape, 1, 1)
         | 
| 878 | 
            +
             | 
| 879 | 
            +
             | 
| 880 | 
            +
            def convert_vae_state_dict(vae_state_dict):
         | 
| 881 | 
            +
                vae_conversion_map = [
         | 
| 882 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 883 | 
            +
                    ("nin_shortcut", "conv_shortcut"),
         | 
| 884 | 
            +
                    ("norm_out", "conv_norm_out"),
         | 
| 885 | 
            +
                    ("mid.attn_1.", "mid_block.attentions.0."),
         | 
| 886 | 
            +
                ]
         | 
| 887 | 
            +
             | 
| 888 | 
            +
                for i in range(4):
         | 
| 889 | 
            +
                    # down_blocks have two resnets
         | 
| 890 | 
            +
                    for j in range(2):
         | 
| 891 | 
            +
                        hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
         | 
| 892 | 
            +
                        sd_down_prefix = f"encoder.down.{i}.block.{j}."
         | 
| 893 | 
            +
                        vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                    if i < 3:
         | 
| 896 | 
            +
                        hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
         | 
| 897 | 
            +
                        sd_downsample_prefix = f"down.{i}.downsample."
         | 
| 898 | 
            +
                        vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
         | 
| 899 | 
            +
             | 
| 900 | 
            +
                        hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
         | 
| 901 | 
            +
                        sd_upsample_prefix = f"up.{3-i}.upsample."
         | 
| 902 | 
            +
                        vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                    # up_blocks have three resnets
         | 
| 905 | 
            +
                    # also, up blocks in hf are numbered in reverse from sd
         | 
| 906 | 
            +
                    for j in range(3):
         | 
| 907 | 
            +
                        hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
         | 
| 908 | 
            +
                        sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
         | 
| 909 | 
            +
                        vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
         | 
| 910 | 
            +
             | 
| 911 | 
            +
                # this part accounts for mid blocks in both the encoder and the decoder
         | 
| 912 | 
            +
                for i in range(2):
         | 
| 913 | 
            +
                    hf_mid_res_prefix = f"mid_block.resnets.{i}."
         | 
| 914 | 
            +
                    sd_mid_res_prefix = f"mid.block_{i+1}."
         | 
| 915 | 
            +
                    vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
         | 
| 916 | 
            +
             | 
| 917 | 
            +
                if diffusers.__version__ < "0.17.0":
         | 
| 918 | 
            +
                    vae_conversion_map_attn = [
         | 
| 919 | 
            +
                        # (stable-diffusion, HF Diffusers)
         | 
| 920 | 
            +
                        ("norm.", "group_norm."),
         | 
| 921 | 
            +
                        ("q.", "query."),
         | 
| 922 | 
            +
                        ("k.", "key."),
         | 
| 923 | 
            +
                        ("v.", "value."),
         | 
| 924 | 
            +
                        ("proj_out.", "proj_attn."),
         | 
| 925 | 
            +
                    ]
         | 
| 926 | 
            +
                else:
         | 
| 927 | 
            +
                    vae_conversion_map_attn = [
         | 
| 928 | 
            +
                        # (stable-diffusion, HF Diffusers)
         | 
| 929 | 
            +
                        ("norm.", "group_norm."),
         | 
| 930 | 
            +
                        ("q.", "to_q."),
         | 
| 931 | 
            +
                        ("k.", "to_k."),
         | 
| 932 | 
            +
                        ("v.", "to_v."),
         | 
| 933 | 
            +
                        ("proj_out.", "to_out.0."),
         | 
| 934 | 
            +
                    ]
         | 
| 935 | 
            +
             | 
| 936 | 
            +
                mapping = {k: k for k in vae_state_dict.keys()}
         | 
| 937 | 
            +
                for k, v in mapping.items():
         | 
| 938 | 
            +
                    for sd_part, hf_part in vae_conversion_map:
         | 
| 939 | 
            +
                        v = v.replace(hf_part, sd_part)
         | 
| 940 | 
            +
                    mapping[k] = v
         | 
| 941 | 
            +
                for k, v in mapping.items():
         | 
| 942 | 
            +
                    if "attentions" in k:
         | 
| 943 | 
            +
                        for sd_part, hf_part in vae_conversion_map_attn:
         | 
| 944 | 
            +
                            v = v.replace(hf_part, sd_part)
         | 
| 945 | 
            +
                        mapping[k] = v
         | 
| 946 | 
            +
                new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
         | 
| 947 | 
            +
                weights_to_convert = ["q", "k", "v", "proj_out"]
         | 
| 948 | 
            +
                for k, v in new_state_dict.items():
         | 
| 949 | 
            +
                    for weight_name in weights_to_convert:
         | 
| 950 | 
            +
                        if f"mid.attn_1.{weight_name}.weight" in k:
         | 
| 951 | 
            +
                            # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
         | 
| 952 | 
            +
                            new_state_dict[k] = reshape_weight_for_sd(v)
         | 
| 953 | 
            +
             | 
| 954 | 
            +
                return new_state_dict
         | 
| 955 | 
            +
             | 
| 956 | 
            +
             | 
| 957 | 
            +
            # endregion
         | 
| 958 | 
            +
             | 
| 959 | 
            +
            # region 自作のモデル読み書きなど
         | 
| 960 | 
            +
             | 
| 961 | 
            +
             | 
| 962 | 
            +
            def is_safetensors(path):
         | 
| 963 | 
            +
                return os.path.splitext(path)[1].lower() == ".safetensors"
         | 
| 964 | 
            +
             | 
| 965 | 
            +
             | 
| 966 | 
            +
            def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
         | 
| 967 | 
            +
                # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
         | 
| 968 | 
            +
                TEXT_ENCODER_KEY_REPLACEMENTS = [
         | 
| 969 | 
            +
                    ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
         | 
| 970 | 
            +
                    ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
         | 
| 971 | 
            +
                    ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
         | 
| 972 | 
            +
                ]
         | 
| 973 | 
            +
             | 
| 974 | 
            +
                if is_safetensors(ckpt_path):
         | 
| 975 | 
            +
                    checkpoint = None
         | 
| 976 | 
            +
                    state_dict = load_file(ckpt_path)  # , device) # may causes error
         | 
| 977 | 
            +
                else:
         | 
| 978 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location=device)
         | 
| 979 | 
            +
                    if "state_dict" in checkpoint:
         | 
| 980 | 
            +
                        state_dict = checkpoint["state_dict"]
         | 
| 981 | 
            +
                    else:
         | 
| 982 | 
            +
                        state_dict = checkpoint
         | 
| 983 | 
            +
                        checkpoint = None
         | 
| 984 | 
            +
             | 
| 985 | 
            +
                key_reps = []
         | 
| 986 | 
            +
                for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
         | 
| 987 | 
            +
                    for key in state_dict.keys():
         | 
| 988 | 
            +
                        if key.startswith(rep_from):
         | 
| 989 | 
            +
                            new_key = rep_to + key[len(rep_from) :]
         | 
| 990 | 
            +
                            key_reps.append((key, new_key))
         | 
| 991 | 
            +
             | 
| 992 | 
            +
                for key, new_key in key_reps:
         | 
| 993 | 
            +
                    state_dict[new_key] = state_dict[key]
         | 
| 994 | 
            +
                    del state_dict[key]
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                return checkpoint, state_dict
         | 
| 997 | 
            +
             | 
| 998 | 
            +
             | 
| 999 | 
            +
            # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
         | 
| 1000 | 
            +
            def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
         | 
| 1001 | 
            +
                _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
         | 
| 1002 | 
            +
             | 
| 1003 | 
            +
                # Convert the UNet2DConditionModel model.
         | 
| 1004 | 
            +
                unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
         | 
| 1005 | 
            +
                converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
         | 
| 1006 | 
            +
             | 
| 1007 | 
            +
                unet = UNet2DConditionModel(**unet_config).to(device)
         | 
| 1008 | 
            +
                info = unet.load_state_dict(converted_unet_checkpoint)
         | 
| 1009 | 
            +
                logger.info(f"loading u-net: {info}")
         | 
| 1010 | 
            +
             | 
| 1011 | 
            +
                # Convert the VAE model.
         | 
| 1012 | 
            +
                vae_config = create_vae_diffusers_config()
         | 
| 1013 | 
            +
                converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
         | 
| 1014 | 
            +
             | 
| 1015 | 
            +
                vae = AutoencoderKL(**vae_config).to(device)
         | 
| 1016 | 
            +
                info = vae.load_state_dict(converted_vae_checkpoint)
         | 
| 1017 | 
            +
                logger.info(f"loading vae: {info}")
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
                # convert text_model
         | 
| 1020 | 
            +
                if v2:
         | 
| 1021 | 
            +
                    converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
         | 
| 1022 | 
            +
                    cfg = CLIPTextConfig(
         | 
| 1023 | 
            +
                        vocab_size=49408,
         | 
| 1024 | 
            +
                        hidden_size=1024,
         | 
| 1025 | 
            +
                        intermediate_size=4096,
         | 
| 1026 | 
            +
                        num_hidden_layers=23,
         | 
| 1027 | 
            +
                        num_attention_heads=16,
         | 
| 1028 | 
            +
                        max_position_embeddings=77,
         | 
| 1029 | 
            +
                        hidden_act="gelu",
         | 
| 1030 | 
            +
                        layer_norm_eps=1e-05,
         | 
| 1031 | 
            +
                        dropout=0.0,
         | 
| 1032 | 
            +
                        attention_dropout=0.0,
         | 
| 1033 | 
            +
                        initializer_range=0.02,
         | 
| 1034 | 
            +
                        initializer_factor=1.0,
         | 
| 1035 | 
            +
                        pad_token_id=1,
         | 
| 1036 | 
            +
                        bos_token_id=0,
         | 
| 1037 | 
            +
                        eos_token_id=2,
         | 
| 1038 | 
            +
                        model_type="clip_text_model",
         | 
| 1039 | 
            +
                        projection_dim=512,
         | 
| 1040 | 
            +
                        torch_dtype="float32",
         | 
| 1041 | 
            +
                        transformers_version="4.25.0.dev0",
         | 
| 1042 | 
            +
                    )
         | 
| 1043 | 
            +
                    text_model = CLIPTextModel._from_config(cfg)
         | 
| 1044 | 
            +
                    info = text_model.load_state_dict(converted_text_encoder_checkpoint)
         | 
| 1045 | 
            +
                else:
         | 
| 1046 | 
            +
                    converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
         | 
| 1047 | 
            +
             | 
| 1048 | 
            +
                    # logging.set_verbosity_error()  # don't show annoying warning
         | 
| 1049 | 
            +
                    # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
         | 
| 1050 | 
            +
                    # logging.set_verbosity_warning()
         | 
| 1051 | 
            +
                    # logger.info(f"config: {text_model.config}")
         | 
| 1052 | 
            +
                    cfg = CLIPTextConfig(
         | 
| 1053 | 
            +
                        vocab_size=49408,
         | 
| 1054 | 
            +
                        hidden_size=768,
         | 
| 1055 | 
            +
                        intermediate_size=3072,
         | 
| 1056 | 
            +
                        num_hidden_layers=12,
         | 
| 1057 | 
            +
                        num_attention_heads=12,
         | 
| 1058 | 
            +
                        max_position_embeddings=77,
         | 
| 1059 | 
            +
                        hidden_act="quick_gelu",
         | 
| 1060 | 
            +
                        layer_norm_eps=1e-05,
         | 
| 1061 | 
            +
                        dropout=0.0,
         | 
| 1062 | 
            +
                        attention_dropout=0.0,
         | 
| 1063 | 
            +
                        initializer_range=0.02,
         | 
| 1064 | 
            +
                        initializer_factor=1.0,
         | 
| 1065 | 
            +
                        pad_token_id=1,
         | 
| 1066 | 
            +
                        bos_token_id=0,
         | 
| 1067 | 
            +
                        eos_token_id=2,
         | 
| 1068 | 
            +
                        model_type="clip_text_model",
         | 
| 1069 | 
            +
                        projection_dim=768,
         | 
| 1070 | 
            +
                        torch_dtype="float32",
         | 
| 1071 | 
            +
                    )
         | 
| 1072 | 
            +
                    text_model = CLIPTextModel._from_config(cfg)
         | 
| 1073 | 
            +
                    info = text_model.load_state_dict(converted_text_encoder_checkpoint)
         | 
| 1074 | 
            +
                logger.info(f"loading text encoder: {info}")
         | 
| 1075 | 
            +
             | 
| 1076 | 
            +
                return text_model, vae, unet
         | 
| 1077 | 
            +
             | 
| 1078 | 
            +
             | 
| 1079 | 
            +
            def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
         | 
| 1080 | 
            +
                # only for reference
         | 
| 1081 | 
            +
                version_str = "sd"
         | 
| 1082 | 
            +
                if v2:
         | 
| 1083 | 
            +
                    version_str += "_v2"
         | 
| 1084 | 
            +
                else:
         | 
| 1085 | 
            +
                    version_str += "_v1"
         | 
| 1086 | 
            +
                if v_parameterization:
         | 
| 1087 | 
            +
                    version_str += "_v"
         | 
| 1088 | 
            +
                return version_str
         | 
| 1089 | 
            +
             | 
| 1090 | 
            +
             | 
| 1091 | 
            +
            def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
         | 
| 1092 | 
            +
                def convert_key(key):
         | 
| 1093 | 
            +
                    # position_idsの除去
         | 
| 1094 | 
            +
                    if ".position_ids" in key:
         | 
| 1095 | 
            +
                        return None
         | 
| 1096 | 
            +
             | 
| 1097 | 
            +
                    # common
         | 
| 1098 | 
            +
                    key = key.replace("text_model.encoder.", "transformer.")
         | 
| 1099 | 
            +
                    key = key.replace("text_model.", "")
         | 
| 1100 | 
            +
                    if "layers" in key:
         | 
| 1101 | 
            +
                        # resblocks conversion
         | 
| 1102 | 
            +
                        key = key.replace(".layers.", ".resblocks.")
         | 
| 1103 | 
            +
                        if ".layer_norm" in key:
         | 
| 1104 | 
            +
                            key = key.replace(".layer_norm", ".ln_")
         | 
| 1105 | 
            +
                        elif ".mlp." in key:
         | 
| 1106 | 
            +
                            key = key.replace(".fc1.", ".c_fc.")
         | 
| 1107 | 
            +
                            key = key.replace(".fc2.", ".c_proj.")
         | 
| 1108 | 
            +
                        elif ".self_attn.out_proj" in key:
         | 
| 1109 | 
            +
                            key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
         | 
| 1110 | 
            +
                        elif ".self_attn." in key:
         | 
| 1111 | 
            +
                            key = None  # 特殊なので後で処理する
         | 
| 1112 | 
            +
                        else:
         | 
| 1113 | 
            +
                            raise ValueError(f"unexpected key in DiffUsers model: {key}")
         | 
| 1114 | 
            +
                    elif ".position_embedding" in key:
         | 
| 1115 | 
            +
                        key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
         | 
| 1116 | 
            +
                    elif ".token_embedding" in key:
         | 
| 1117 | 
            +
                        key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
         | 
| 1118 | 
            +
                    elif "final_layer_norm" in key:
         | 
| 1119 | 
            +
                        key = key.replace("final_layer_norm", "ln_final")
         | 
| 1120 | 
            +
                    return key
         | 
| 1121 | 
            +
             | 
| 1122 | 
            +
                keys = list(checkpoint.keys())
         | 
| 1123 | 
            +
                new_sd = {}
         | 
| 1124 | 
            +
                for key in keys:
         | 
| 1125 | 
            +
                    new_key = convert_key(key)
         | 
| 1126 | 
            +
                    if new_key is None:
         | 
| 1127 | 
            +
                        continue
         | 
| 1128 | 
            +
                    new_sd[new_key] = checkpoint[key]
         | 
| 1129 | 
            +
             | 
| 1130 | 
            +
                # attnの変換
         | 
| 1131 | 
            +
                for key in keys:
         | 
| 1132 | 
            +
                    if "layers" in key and "q_proj" in key:
         | 
| 1133 | 
            +
                        # 三つを結合
         | 
| 1134 | 
            +
                        key_q = key
         | 
| 1135 | 
            +
                        key_k = key.replace("q_proj", "k_proj")
         | 
| 1136 | 
            +
                        key_v = key.replace("q_proj", "v_proj")
         | 
| 1137 | 
            +
             | 
| 1138 | 
            +
                        value_q = checkpoint[key_q]
         | 
| 1139 | 
            +
                        value_k = checkpoint[key_k]
         | 
| 1140 | 
            +
                        value_v = checkpoint[key_v]
         | 
| 1141 | 
            +
                        value = torch.cat([value_q, value_k, value_v])
         | 
| 1142 | 
            +
             | 
| 1143 | 
            +
                        new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
         | 
| 1144 | 
            +
                        new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
         | 
| 1145 | 
            +
                        new_sd[new_key] = value
         | 
| 1146 | 
            +
             | 
| 1147 | 
            +
                # 最後の層などを捏造するか
         | 
| 1148 | 
            +
                if make_dummy_weights:
         | 
| 1149 | 
            +
                    logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
         | 
| 1150 | 
            +
                    keys = list(new_sd.keys())
         | 
| 1151 | 
            +
                    for key in keys:
         | 
| 1152 | 
            +
                        if key.startswith("transformer.resblocks.22."):
         | 
| 1153 | 
            +
                            new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone()  # copyしないとsafetensorsの保存で落ちる
         | 
| 1154 | 
            +
             | 
| 1155 | 
            +
                    # Diffusersに含まれない重みを作っておく
         | 
| 1156 | 
            +
                    new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
         | 
| 1157 | 
            +
                    new_sd["logit_scale"] = torch.tensor(1)
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                return new_sd
         | 
| 1160 | 
            +
             | 
| 1161 | 
            +
             | 
| 1162 | 
            +
            def save_stable_diffusion_checkpoint(
         | 
| 1163 | 
            +
                v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
         | 
| 1164 | 
            +
            ):
         | 
| 1165 | 
            +
                if ckpt_path is not None:
         | 
| 1166 | 
            +
                    # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
         | 
| 1167 | 
            +
                    checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
         | 
| 1168 | 
            +
                    if checkpoint is None:  # safetensors または state_dictのckpt
         | 
| 1169 | 
            +
                        checkpoint = {}
         | 
| 1170 | 
            +
                        strict = False
         | 
| 1171 | 
            +
                    else:
         | 
| 1172 | 
            +
                        strict = True
         | 
| 1173 | 
            +
                    if "state_dict" in state_dict:
         | 
| 1174 | 
            +
                        del state_dict["state_dict"]
         | 
| 1175 | 
            +
                else:
         | 
| 1176 | 
            +
                    # 新しく作る
         | 
| 1177 | 
            +
                    assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
         | 
| 1178 | 
            +
                    checkpoint = {}
         | 
| 1179 | 
            +
                    state_dict = {}
         | 
| 1180 | 
            +
                    strict = False
         | 
| 1181 | 
            +
             | 
| 1182 | 
            +
                def update_sd(prefix, sd):
         | 
| 1183 | 
            +
                    for k, v in sd.items():
         | 
| 1184 | 
            +
                        key = prefix + k
         | 
| 1185 | 
            +
                        assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
         | 
| 1186 | 
            +
                        if save_dtype is not None:
         | 
| 1187 | 
            +
                            v = v.detach().clone().to("cpu").to(save_dtype)
         | 
| 1188 | 
            +
                        state_dict[key] = v
         | 
| 1189 | 
            +
             | 
| 1190 | 
            +
                # Convert the UNet model
         | 
| 1191 | 
            +
                unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
         | 
| 1192 | 
            +
                update_sd("model.diffusion_model.", unet_state_dict)
         | 
| 1193 | 
            +
             | 
| 1194 | 
            +
                # Convert the text encoder model
         | 
| 1195 | 
            +
                if v2:
         | 
| 1196 | 
            +
                    make_dummy = ckpt_path is None  # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
         | 
| 1197 | 
            +
                    text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
         | 
| 1198 | 
            +
                    update_sd("cond_stage_model.model.", text_enc_dict)
         | 
| 1199 | 
            +
                else:
         | 
| 1200 | 
            +
                    text_enc_dict = text_encoder.state_dict()
         | 
| 1201 | 
            +
                    update_sd("cond_stage_model.transformer.", text_enc_dict)
         | 
| 1202 | 
            +
             | 
| 1203 | 
            +
                # Convert the VAE
         | 
| 1204 | 
            +
                if vae is not None:
         | 
| 1205 | 
            +
                    vae_dict = convert_vae_state_dict(vae.state_dict())
         | 
| 1206 | 
            +
                    update_sd("first_stage_model.", vae_dict)
         | 
| 1207 | 
            +
             | 
| 1208 | 
            +
                # Put together new checkpoint
         | 
| 1209 | 
            +
                key_count = len(state_dict.keys())
         | 
| 1210 | 
            +
                new_ckpt = {"state_dict": state_dict}
         | 
| 1211 | 
            +
             | 
| 1212 | 
            +
                # epoch and global_step are sometimes not int
         | 
| 1213 | 
            +
                try:
         | 
| 1214 | 
            +
                    if "epoch" in checkpoint:
         | 
| 1215 | 
            +
                        epochs += checkpoint["epoch"]
         | 
| 1216 | 
            +
                    if "global_step" in checkpoint:
         | 
| 1217 | 
            +
                        steps += checkpoint["global_step"]
         | 
| 1218 | 
            +
                except:
         | 
| 1219 | 
            +
                    pass
         | 
| 1220 | 
            +
             | 
| 1221 | 
            +
                new_ckpt["epoch"] = epochs
         | 
| 1222 | 
            +
                new_ckpt["global_step"] = steps
         | 
| 1223 | 
            +
             | 
| 1224 | 
            +
                if is_safetensors(output_file):
         | 
| 1225 | 
            +
                    # TODO Tensor以外のdictの値を削除したほうがいいか
         | 
| 1226 | 
            +
                    save_file(state_dict, output_file, metadata)
         | 
| 1227 | 
            +
                else:
         | 
| 1228 | 
            +
                    torch.save(new_ckpt, output_file)
         | 
| 1229 | 
            +
             | 
| 1230 | 
            +
                return key_count
         | 
| 1231 | 
            +
             | 
| 1232 | 
            +
             | 
| 1233 | 
            +
            def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
         | 
| 1234 | 
            +
                if pretrained_model_name_or_path is None:
         | 
| 1235 | 
            +
                    # load default settings for v1/v2
         | 
| 1236 | 
            +
                    if v2:
         | 
| 1237 | 
            +
                        pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
         | 
| 1238 | 
            +
                    else:
         | 
| 1239 | 
            +
                        pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
         | 
| 1240 | 
            +
             | 
| 1241 | 
            +
                scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
         | 
| 1242 | 
            +
                tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
         | 
| 1243 | 
            +
                if vae is None:
         | 
| 1244 | 
            +
                    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
         | 
| 1245 | 
            +
             | 
| 1246 | 
            +
                # original U-Net cannot be saved, so we need to convert it to the Diffusers version
         | 
| 1247 | 
            +
                # TODO this consumes a lot of memory
         | 
| 1248 | 
            +
                diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
         | 
| 1249 | 
            +
                diffusers_unet.load_state_dict(unet.state_dict())
         | 
| 1250 | 
            +
             | 
| 1251 | 
            +
                pipeline = StableDiffusionPipeline(
         | 
| 1252 | 
            +
                    unet=diffusers_unet,
         | 
| 1253 | 
            +
                    text_encoder=text_encoder,
         | 
| 1254 | 
            +
                    vae=vae,
         | 
| 1255 | 
            +
                    scheduler=scheduler,
         | 
| 1256 | 
            +
                    tokenizer=tokenizer,
         | 
| 1257 | 
            +
                    safety_checker=None,
         | 
| 1258 | 
            +
                    feature_extractor=None,
         | 
| 1259 | 
            +
                    requires_safety_checker=None,
         | 
| 1260 | 
            +
                )
         | 
| 1261 | 
            +
                pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
         | 
| 1262 | 
            +
             | 
| 1263 | 
            +
             | 
| 1264 | 
            +
            VAE_PREFIX = "first_stage_model."
         | 
| 1265 | 
            +
             | 
| 1266 | 
            +
             | 
| 1267 | 
            +
            def load_vae(vae_id, dtype):
         | 
| 1268 | 
            +
                logger.info(f"load VAE: {vae_id}")
         | 
| 1269 | 
            +
                if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
         | 
| 1270 | 
            +
                    # Diffusers local/remote
         | 
| 1271 | 
            +
                    try:
         | 
| 1272 | 
            +
                        vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
         | 
| 1273 | 
            +
                    except EnvironmentError as e:
         | 
| 1274 | 
            +
                        logger.error(f"exception occurs in loading vae: {e}")
         | 
| 1275 | 
            +
                        logger.error("retry with subfolder='vae'")
         | 
| 1276 | 
            +
                        vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
         | 
| 1277 | 
            +
                    return vae
         | 
| 1278 | 
            +
             | 
| 1279 | 
            +
                # local
         | 
| 1280 | 
            +
                vae_config = create_vae_diffusers_config()
         | 
| 1281 | 
            +
             | 
| 1282 | 
            +
                if vae_id.endswith(".bin"):
         | 
| 1283 | 
            +
                    # SD 1.5 VAE on Huggingface
         | 
| 1284 | 
            +
                    converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
         | 
| 1285 | 
            +
                else:
         | 
| 1286 | 
            +
                    # StableDiffusion
         | 
| 1287 | 
            +
                    vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
         | 
| 1288 | 
            +
                    vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
         | 
| 1289 | 
            +
             | 
| 1290 | 
            +
                    # vae only or full model
         | 
| 1291 | 
            +
                    full_model = False
         | 
| 1292 | 
            +
                    for vae_key in vae_sd:
         | 
| 1293 | 
            +
                        if vae_key.startswith(VAE_PREFIX):
         | 
| 1294 | 
            +
                            full_model = True
         | 
| 1295 | 
            +
                            break
         | 
| 1296 | 
            +
                    if not full_model:
         | 
| 1297 | 
            +
                        sd = {}
         | 
| 1298 | 
            +
                        for key, value in vae_sd.items():
         | 
| 1299 | 
            +
                            sd[VAE_PREFIX + key] = value
         | 
| 1300 | 
            +
                        vae_sd = sd
         | 
| 1301 | 
            +
                        del sd
         | 
| 1302 | 
            +
             | 
| 1303 | 
            +
                    # Convert the VAE model.
         | 
| 1304 | 
            +
                    converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
         | 
| 1305 | 
            +
             | 
| 1306 | 
            +
                vae = AutoencoderKL(**vae_config)
         | 
| 1307 | 
            +
                vae.load_state_dict(converted_vae_checkpoint)
         | 
| 1308 | 
            +
                return vae
         | 
| 1309 | 
            +
             | 
| 1310 | 
            +
             | 
| 1311 | 
            +
            # endregion
         | 
| 1312 | 
            +
             | 
| 1313 | 
            +
             | 
| 1314 | 
            +
            def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
         | 
| 1315 | 
            +
                max_width, max_height = max_reso
         | 
| 1316 | 
            +
                max_area = max_width * max_height
         | 
| 1317 | 
            +
             | 
| 1318 | 
            +
                resos = set()
         | 
| 1319 | 
            +
             | 
| 1320 | 
            +
                width = int(math.sqrt(max_area) // divisible) * divisible
         | 
| 1321 | 
            +
                resos.add((width, width))
         | 
| 1322 | 
            +
             | 
| 1323 | 
            +
                width = min_size
         | 
| 1324 | 
            +
                while width <= max_size:
         | 
| 1325 | 
            +
                    height = min(max_size, int((max_area // width) // divisible) * divisible)
         | 
| 1326 | 
            +
                    if height >= min_size:
         | 
| 1327 | 
            +
                        resos.add((width, height))
         | 
| 1328 | 
            +
                        resos.add((height, width))
         | 
| 1329 | 
            +
             | 
| 1330 | 
            +
                    # # make additional resos
         | 
| 1331 | 
            +
                    # if width >= height and width - divisible >= min_size:
         | 
| 1332 | 
            +
                    #   resos.add((width - divisible, height))
         | 
| 1333 | 
            +
                    #   resos.add((height, width - divisible))
         | 
| 1334 | 
            +
                    # if height >= width and height - divisible >= min_size:
         | 
| 1335 | 
            +
                    #   resos.add((width, height - divisible))
         | 
| 1336 | 
            +
                    #   resos.add((height - divisible, width))
         | 
| 1337 | 
            +
             | 
| 1338 | 
            +
                    width += divisible
         | 
| 1339 | 
            +
             | 
| 1340 | 
            +
                resos = list(resos)
         | 
| 1341 | 
            +
                resos.sort()
         | 
| 1342 | 
            +
                return resos
         | 
| 1343 | 
            +
             | 
| 1344 | 
            +
             | 
| 1345 | 
            +
            if __name__ == "__main__":
         | 
| 1346 | 
            +
                resos = make_bucket_resolutions((512, 768))
         | 
| 1347 | 
            +
                logger.info(f"{len(resos)}")
         | 
| 1348 | 
            +
                logger.info(f"{resos}")
         | 
| 1349 | 
            +
                aspect_ratios = [w / h for w, h in resos]
         | 
| 1350 | 
            +
                logger.info(f"{aspect_ratios}")
         | 
| 1351 | 
            +
             | 
| 1352 | 
            +
                ars = set()
         | 
| 1353 | 
            +
                for ar in aspect_ratios:
         | 
| 1354 | 
            +
                    if ar in ars:
         | 
| 1355 | 
            +
                        logger.error(f"error! duplicate ar: {ar}")
         | 
| 1356 | 
            +
                    ars.add(ar)
         | 
    	
        library/original_unet.py
    ADDED
    
    | @@ -0,0 +1,1919 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
         | 
| 2 | 
            +
            # 条件分岐等で不要な部分は削除している
         | 
| 3 | 
            +
            # コードの多くはDiffusersからコピーしている
         | 
| 4 | 
            +
            # 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
         | 
| 7 | 
            +
            # Unnecessary parts are deleted by condition branching.
         | 
| 8 | 
            +
            # As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
            v1.5とv2.1の相違点は
         | 
| 12 | 
            +
            - attention_head_dimがintかlist[int]か
         | 
| 13 | 
            +
            - cross_attention_dimが768か1024か
         | 
| 14 | 
            +
            - use_linear_projection: trueがない(=False, 1.5)かあるか
         | 
| 15 | 
            +
            - upcast_attentionがFalse(1.5)かTrue(2.1)か
         | 
| 16 | 
            +
            - (以下は多分無視していい)
         | 
| 17 | 
            +
            - sample_sizeが64か96か
         | 
| 18 | 
            +
            - dual_cross_attentionがあるかないか
         | 
| 19 | 
            +
            - num_class_embedsがあるかないか
         | 
| 20 | 
            +
            - only_cross_attentionがあるかないか
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            v1.5
         | 
| 23 | 
            +
            {
         | 
| 24 | 
            +
              "_class_name": "UNet2DConditionModel",
         | 
| 25 | 
            +
              "_diffusers_version": "0.6.0",
         | 
| 26 | 
            +
              "act_fn": "silu",
         | 
| 27 | 
            +
              "attention_head_dim": 8,
         | 
| 28 | 
            +
              "block_out_channels": [
         | 
| 29 | 
            +
                320,
         | 
| 30 | 
            +
                640,
         | 
| 31 | 
            +
                1280,
         | 
| 32 | 
            +
                1280
         | 
| 33 | 
            +
              ],
         | 
| 34 | 
            +
              "center_input_sample": false,
         | 
| 35 | 
            +
              "cross_attention_dim": 768,
         | 
| 36 | 
            +
              "down_block_types": [
         | 
| 37 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 38 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 39 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 40 | 
            +
                "DownBlock2D"
         | 
| 41 | 
            +
              ],
         | 
| 42 | 
            +
              "downsample_padding": 1,
         | 
| 43 | 
            +
              "flip_sin_to_cos": true,
         | 
| 44 | 
            +
              "freq_shift": 0,
         | 
| 45 | 
            +
              "in_channels": 4,
         | 
| 46 | 
            +
              "layers_per_block": 2,
         | 
| 47 | 
            +
              "mid_block_scale_factor": 1,
         | 
| 48 | 
            +
              "norm_eps": 1e-05,
         | 
| 49 | 
            +
              "norm_num_groups": 32,
         | 
| 50 | 
            +
              "out_channels": 4,
         | 
| 51 | 
            +
              "sample_size": 64,
         | 
| 52 | 
            +
              "up_block_types": [
         | 
| 53 | 
            +
                "UpBlock2D",
         | 
| 54 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 55 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 56 | 
            +
                "CrossAttnUpBlock2D"
         | 
| 57 | 
            +
              ]
         | 
| 58 | 
            +
            }
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            v2.1
         | 
| 61 | 
            +
            {
         | 
| 62 | 
            +
              "_class_name": "UNet2DConditionModel",
         | 
| 63 | 
            +
              "_diffusers_version": "0.10.0.dev0",
         | 
| 64 | 
            +
              "act_fn": "silu",
         | 
| 65 | 
            +
              "attention_head_dim": [
         | 
| 66 | 
            +
                5,
         | 
| 67 | 
            +
                10,
         | 
| 68 | 
            +
                20,
         | 
| 69 | 
            +
                20
         | 
| 70 | 
            +
              ],
         | 
| 71 | 
            +
              "block_out_channels": [
         | 
| 72 | 
            +
                320,
         | 
| 73 | 
            +
                640,
         | 
| 74 | 
            +
                1280,
         | 
| 75 | 
            +
                1280
         | 
| 76 | 
            +
              ],
         | 
| 77 | 
            +
              "center_input_sample": false,
         | 
| 78 | 
            +
              "cross_attention_dim": 1024,
         | 
| 79 | 
            +
              "down_block_types": [
         | 
| 80 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 81 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 82 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 83 | 
            +
                "DownBlock2D"
         | 
| 84 | 
            +
              ],
         | 
| 85 | 
            +
              "downsample_padding": 1,
         | 
| 86 | 
            +
              "dual_cross_attention": false,
         | 
| 87 | 
            +
              "flip_sin_to_cos": true,
         | 
| 88 | 
            +
              "freq_shift": 0,
         | 
| 89 | 
            +
              "in_channels": 4,
         | 
| 90 | 
            +
              "layers_per_block": 2,
         | 
| 91 | 
            +
              "mid_block_scale_factor": 1,
         | 
| 92 | 
            +
              "norm_eps": 1e-05,
         | 
| 93 | 
            +
              "norm_num_groups": 32,
         | 
| 94 | 
            +
              "num_class_embeds": null,
         | 
| 95 | 
            +
              "only_cross_attention": false,
         | 
| 96 | 
            +
              "out_channels": 4,
         | 
| 97 | 
            +
              "sample_size": 96,
         | 
| 98 | 
            +
              "up_block_types": [
         | 
| 99 | 
            +
                "UpBlock2D",
         | 
| 100 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 101 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 102 | 
            +
                "CrossAttnUpBlock2D"
         | 
| 103 | 
            +
              ],
         | 
| 104 | 
            +
              "use_linear_projection": true,
         | 
| 105 | 
            +
              "upcast_attention": true
         | 
| 106 | 
            +
            }
         | 
| 107 | 
            +
            """
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            import math
         | 
| 110 | 
            +
            from types import SimpleNamespace
         | 
| 111 | 
            +
            from typing import Dict, Optional, Tuple, Union
         | 
| 112 | 
            +
            import torch
         | 
| 113 | 
            +
            from torch import nn
         | 
| 114 | 
            +
            from torch.nn import functional as F
         | 
| 115 | 
            +
            from einops import rearrange
         | 
| 116 | 
            +
            from library.utils import setup_logging
         | 
| 117 | 
            +
            setup_logging()
         | 
| 118 | 
            +
            import logging
         | 
| 119 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
         | 
| 122 | 
            +
            TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
         | 
| 123 | 
            +
            TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
         | 
| 124 | 
            +
            IN_CHANNELS: int = 4
         | 
| 125 | 
            +
            OUT_CHANNELS: int = 4
         | 
| 126 | 
            +
            LAYERS_PER_BLOCK: int = 2
         | 
| 127 | 
            +
            LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
         | 
| 128 | 
            +
            TIME_EMBED_FLIP_SIN_TO_COS: bool = True
         | 
| 129 | 
            +
            TIME_EMBED_FREQ_SHIFT: int = 0
         | 
| 130 | 
            +
            NORM_GROUPS: int = 32
         | 
| 131 | 
            +
            NORM_EPS: float = 1e-5
         | 
| 132 | 
            +
            TRANSFORMER_NORM_NUM_GROUPS = 32
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
         | 
| 135 | 
            +
            UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            # region memory efficient attention
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            # FlashAttentionを使うCrossAttention
         | 
| 141 | 
            +
            # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
         | 
| 142 | 
            +
            # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            # constants
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            EPSILON = 1e-6
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            # helper functions
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def exists(val):
         | 
| 152 | 
            +
                return val is not None
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def default(val, d):
         | 
| 156 | 
            +
                return val if exists(val) else d
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            # flash attention forwards and backwards
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            # https://arxiv.org/abs/2205.14135
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            class FlashAttentionFunction(torch.autograd.Function):
         | 
| 165 | 
            +
                @staticmethod
         | 
| 166 | 
            +
                @torch.no_grad()
         | 
| 167 | 
            +
                def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
         | 
| 168 | 
            +
                    """Algorithm 2 in the paper"""
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    device = q.device
         | 
| 171 | 
            +
                    dtype = q.dtype
         | 
| 172 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 173 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    o = torch.zeros_like(q)
         | 
| 176 | 
            +
                    all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
         | 
| 177 | 
            +
                    all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    scale = q.shape[-1] ** -0.5
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    if not exists(mask):
         | 
| 182 | 
            +
                        mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
         | 
| 183 | 
            +
                    else:
         | 
| 184 | 
            +
                        mask = rearrange(mask, "b n -> b 1 1 n")
         | 
| 185 | 
            +
                        mask = mask.split(q_bucket_size, dim=-1)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    row_splits = zip(
         | 
| 188 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 189 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 190 | 
            +
                        mask,
         | 
| 191 | 
            +
                        all_row_sums.split(q_bucket_size, dim=-2),
         | 
| 192 | 
            +
                        all_row_maxes.split(q_bucket_size, dim=-2),
         | 
| 193 | 
            +
                    )
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
         | 
| 196 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                        col_splits = zip(
         | 
| 199 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 200 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 201 | 
            +
                        )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        for k_ind, (kc, vc) in enumerate(col_splits):
         | 
| 204 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                            attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                            if exists(row_mask):
         | 
| 209 | 
            +
                                attn_weights.masked_fill_(~row_mask, max_neg_value)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 212 | 
            +
                                causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
         | 
| 213 | 
            +
                                    q_start_index - k_start_index + 1
         | 
| 214 | 
            +
                                )
         | 
| 215 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                            block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
         | 
| 218 | 
            +
                            attn_weights -= block_row_maxes
         | 
| 219 | 
            +
                            exp_weights = torch.exp(attn_weights)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                            if exists(row_mask):
         | 
| 222 | 
            +
                                exp_weights.masked_fill_(~row_mask, 0.0)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                            block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                            new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                            exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                            exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
         | 
| 231 | 
            +
                            exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                            new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                            oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                            row_maxes.copy_(new_row_maxes)
         | 
| 238 | 
            +
                            row_sums.copy_(new_row_sums)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
         | 
| 241 | 
            +
                    ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    return o
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                @staticmethod
         | 
| 246 | 
            +
                @torch.no_grad()
         | 
| 247 | 
            +
                def backward(ctx, do):
         | 
| 248 | 
            +
                    """Algorithm 4 in the paper"""
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
         | 
| 251 | 
            +
                    q, k, v, o, l, m = ctx.saved_tensors
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    device = q.device
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 256 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    dq = torch.zeros_like(q)
         | 
| 259 | 
            +
                    dk = torch.zeros_like(k)
         | 
| 260 | 
            +
                    dv = torch.zeros_like(v)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    row_splits = zip(
         | 
| 263 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 264 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 265 | 
            +
                        do.split(q_bucket_size, dim=-2),
         | 
| 266 | 
            +
                        mask,
         | 
| 267 | 
            +
                        l.split(q_bucket_size, dim=-2),
         | 
| 268 | 
            +
                        m.split(q_bucket_size, dim=-2),
         | 
| 269 | 
            +
                        dq.split(q_bucket_size, dim=-2),
         | 
| 270 | 
            +
                    )
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
         | 
| 273 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                        col_splits = zip(
         | 
| 276 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 277 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 278 | 
            +
                            dk.split(k_bucket_size, dim=-2),
         | 
| 279 | 
            +
                            dv.split(k_bucket_size, dim=-2),
         | 
| 280 | 
            +
                        )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                        for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
         | 
| 283 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                            attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 288 | 
            +
                                causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
         | 
| 289 | 
            +
                                    q_start_index - k_start_index + 1
         | 
| 290 | 
            +
                                )
         | 
| 291 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                            exp_attn_weights = torch.exp(attn_weights - mc)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                            if exists(row_mask):
         | 
| 296 | 
            +
                                exp_attn_weights.masked_fill_(~row_mask, 0.0)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                            p = exp_attn_weights / lc
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                            dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
         | 
| 301 | 
            +
                            dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                            D = (doc * oc).sum(dim=-1, keepdims=True)
         | 
| 304 | 
            +
                            ds = p * scale * (dp - D)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                            dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
         | 
| 307 | 
            +
                            dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                            dqc.add_(dq_chunk)
         | 
| 310 | 
            +
                            dkc.add_(dk_chunk)
         | 
| 311 | 
            +
                            dvc.add_(dv_chunk)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    return dq, dk, dv, None, None, None, None
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            # endregion
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
| 319 | 
            +
            def get_parameter_dtype(parameter: torch.nn.Module):
         | 
| 320 | 
            +
                return next(parameter.parameters()).dtype
         | 
| 321 | 
            +
             | 
| 322 | 
            +
             | 
| 323 | 
            +
            def get_parameter_device(parameter: torch.nn.Module):
         | 
| 324 | 
            +
                return next(parameter.parameters()).device
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            def get_timestep_embedding(
         | 
| 328 | 
            +
                timesteps: torch.Tensor,
         | 
| 329 | 
            +
                embedding_dim: int,
         | 
| 330 | 
            +
                flip_sin_to_cos: bool = False,
         | 
| 331 | 
            +
                downscale_freq_shift: float = 1,
         | 
| 332 | 
            +
                scale: float = 1,
         | 
| 333 | 
            +
                max_period: int = 10000,
         | 
| 334 | 
            +
            ):
         | 
| 335 | 
            +
                """
         | 
| 336 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 339 | 
            +
                                  These may be fractional.
         | 
| 340 | 
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         | 
| 341 | 
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         | 
| 342 | 
            +
                """
         | 
| 343 | 
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                half_dim = embedding_dim // 2
         | 
| 346 | 
            +
                exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
         | 
| 347 | 
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                emb = torch.exp(exponent)
         | 
| 350 | 
            +
                emb = timesteps[:, None].float() * emb[None, :]
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                # scale embeddings
         | 
| 353 | 
            +
                emb = scale * emb
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                # concat sine and cosine embeddings
         | 
| 356 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                # flip sine and cosine embeddings
         | 
| 359 | 
            +
                if flip_sin_to_cos:
         | 
| 360 | 
            +
                    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                # zero pad
         | 
| 363 | 
            +
                if embedding_dim % 2 == 1:
         | 
| 364 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 365 | 
            +
                return emb
         | 
| 366 | 
            +
             | 
| 367 | 
            +
             | 
| 368 | 
            +
            # Deep Shrink: We do not common this function, because minimize dependencies.
         | 
| 369 | 
            +
            def resize_like(x, target, mode="bicubic", align_corners=False):
         | 
| 370 | 
            +
                org_dtype = x.dtype
         | 
| 371 | 
            +
                if org_dtype == torch.bfloat16:
         | 
| 372 | 
            +
                    x = x.to(torch.float32)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                if x.shape[-2:] != target.shape[-2:]:
         | 
| 375 | 
            +
                    if mode == "nearest":
         | 
| 376 | 
            +
                        x = F.interpolate(x, size=target.shape[-2:], mode=mode)
         | 
| 377 | 
            +
                    else:
         | 
| 378 | 
            +
                        x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                if org_dtype == torch.bfloat16:
         | 
| 381 | 
            +
                    x = x.to(org_dtype)
         | 
| 382 | 
            +
                return x
         | 
| 383 | 
            +
             | 
| 384 | 
            +
             | 
| 385 | 
            +
            class SampleOutput:
         | 
| 386 | 
            +
                def __init__(self, sample):
         | 
| 387 | 
            +
                    self.sample = sample
         | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
            +
            class TimestepEmbedding(nn.Module):
         | 
| 391 | 
            +
                def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
         | 
| 392 | 
            +
                    super().__init__()
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    self.linear_1 = nn.Linear(in_channels, time_embed_dim)
         | 
| 395 | 
            +
                    self.act = None
         | 
| 396 | 
            +
                    if act_fn == "silu":
         | 
| 397 | 
            +
                        self.act = nn.SiLU()
         | 
| 398 | 
            +
                    elif act_fn == "mish":
         | 
| 399 | 
            +
                        self.act = nn.Mish()
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    if out_dim is not None:
         | 
| 402 | 
            +
                        time_embed_dim_out = out_dim
         | 
| 403 | 
            +
                    else:
         | 
| 404 | 
            +
                        time_embed_dim_out = time_embed_dim
         | 
| 405 | 
            +
                    self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                def forward(self, sample):
         | 
| 408 | 
            +
                    sample = self.linear_1(sample)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    if self.act is not None:
         | 
| 411 | 
            +
                        sample = self.act(sample)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    sample = self.linear_2(sample)
         | 
| 414 | 
            +
                    return sample
         | 
| 415 | 
            +
             | 
| 416 | 
            +
             | 
| 417 | 
            +
            class Timesteps(nn.Module):
         | 
| 418 | 
            +
                def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
         | 
| 419 | 
            +
                    super().__init__()
         | 
| 420 | 
            +
                    self.num_channels = num_channels
         | 
| 421 | 
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         | 
| 422 | 
            +
                    self.downscale_freq_shift = downscale_freq_shift
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                def forward(self, timesteps):
         | 
| 425 | 
            +
                    t_emb = get_timestep_embedding(
         | 
| 426 | 
            +
                        timesteps,
         | 
| 427 | 
            +
                        self.num_channels,
         | 
| 428 | 
            +
                        flip_sin_to_cos=self.flip_sin_to_cos,
         | 
| 429 | 
            +
                        downscale_freq_shift=self.downscale_freq_shift,
         | 
| 430 | 
            +
                    )
         | 
| 431 | 
            +
                    return t_emb
         | 
| 432 | 
            +
             | 
| 433 | 
            +
             | 
| 434 | 
            +
            class ResnetBlock2D(nn.Module):
         | 
| 435 | 
            +
                def __init__(
         | 
| 436 | 
            +
                    self,
         | 
| 437 | 
            +
                    in_channels,
         | 
| 438 | 
            +
                    out_channels,
         | 
| 439 | 
            +
                ):
         | 
| 440 | 
            +
                    super().__init__()
         | 
| 441 | 
            +
                    self.in_channels = in_channels
         | 
| 442 | 
            +
                    self.out_channels = out_channels
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
         | 
| 451 | 
            +
                    self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    # if non_linearity == "swish":
         | 
| 454 | 
            +
                    self.nonlinearity = lambda x: F.silu(x)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    self.use_in_shortcut = self.in_channels != self.out_channels
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    self.conv_shortcut = None
         | 
| 459 | 
            +
                    if self.use_in_shortcut:
         | 
| 460 | 
            +
                        self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                def forward(self, input_tensor, temb):
         | 
| 463 | 
            +
                    hidden_states = input_tensor
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    hidden_states = self.norm1(hidden_states)
         | 
| 466 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                    hidden_states = self.conv1(hidden_states)
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
         | 
| 471 | 
            +
                    hidden_states = hidden_states + temb
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    hidden_states = self.norm2(hidden_states)
         | 
| 474 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    hidden_states = self.conv2(hidden_states)
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                    if self.conv_shortcut is not None:
         | 
| 479 | 
            +
                        input_tensor = self.conv_shortcut(input_tensor)
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    output_tensor = input_tensor + hidden_states
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    return output_tensor
         | 
| 484 | 
            +
             | 
| 485 | 
            +
             | 
| 486 | 
            +
            class DownBlock2D(nn.Module):
         | 
| 487 | 
            +
                def __init__(
         | 
| 488 | 
            +
                    self,
         | 
| 489 | 
            +
                    in_channels: int,
         | 
| 490 | 
            +
                    out_channels: int,
         | 
| 491 | 
            +
                    add_downsample=True,
         | 
| 492 | 
            +
                ):
         | 
| 493 | 
            +
                    super().__init__()
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    self.has_cross_attention = False
         | 
| 496 | 
            +
                    resnets = []
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    for i in range(LAYERS_PER_BLOCK):
         | 
| 499 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 500 | 
            +
                        resnets.append(
         | 
| 501 | 
            +
                            ResnetBlock2D(
         | 
| 502 | 
            +
                                in_channels=in_channels,
         | 
| 503 | 
            +
                                out_channels=out_channels,
         | 
| 504 | 
            +
                            )
         | 
| 505 | 
            +
                        )
         | 
| 506 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                    if add_downsample:
         | 
| 509 | 
            +
                        self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
         | 
| 510 | 
            +
                    else:
         | 
| 511 | 
            +
                        self.downsamplers = None
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    self.gradient_checkpointing = False
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 516 | 
            +
                    pass
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 519 | 
            +
                    pass
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                def forward(self, hidden_states, temb=None):
         | 
| 522 | 
            +
                    output_states = ()
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                    for resnet in self.resnets:
         | 
| 525 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                            def create_custom_forward(module):
         | 
| 528 | 
            +
                                def custom_forward(*inputs):
         | 
| 529 | 
            +
                                    return module(*inputs)
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                                return custom_forward
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
         | 
| 534 | 
            +
                        else:
         | 
| 535 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                        output_states += (hidden_states,)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    if self.downsamplers is not None:
         | 
| 540 | 
            +
                        for downsampler in self.downsamplers:
         | 
| 541 | 
            +
                            hidden_states = downsampler(hidden_states)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                        output_states += (hidden_states,)
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    return hidden_states, output_states
         | 
| 546 | 
            +
             | 
| 547 | 
            +
             | 
| 548 | 
            +
            class Downsample2D(nn.Module):
         | 
| 549 | 
            +
                def __init__(self, channels, out_channels):
         | 
| 550 | 
            +
                    super().__init__()
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    self.channels = channels
         | 
| 553 | 
            +
                    self.out_channels = out_channels
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                def forward(self, hidden_states):
         | 
| 558 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 559 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    return hidden_states
         | 
| 562 | 
            +
             | 
| 563 | 
            +
             | 
| 564 | 
            +
            class CrossAttention(nn.Module):
         | 
| 565 | 
            +
                def __init__(
         | 
| 566 | 
            +
                    self,
         | 
| 567 | 
            +
                    query_dim: int,
         | 
| 568 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 569 | 
            +
                    heads: int = 8,
         | 
| 570 | 
            +
                    dim_head: int = 64,
         | 
| 571 | 
            +
                    upcast_attention: bool = False,
         | 
| 572 | 
            +
                ):
         | 
| 573 | 
            +
                    super().__init__()
         | 
| 574 | 
            +
                    inner_dim = dim_head * heads
         | 
| 575 | 
            +
                    cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
         | 
| 576 | 
            +
                    self.upcast_attention = upcast_attention
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 579 | 
            +
                    self.heads = heads
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
         | 
| 582 | 
            +
                    self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
         | 
| 583 | 
            +
                    self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    self.to_out = nn.ModuleList([])
         | 
| 586 | 
            +
                    self.to_out.append(nn.Linear(inner_dim, query_dim))
         | 
| 587 | 
            +
                    # no dropout here
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    self.use_memory_efficient_attention_xformers = False
         | 
| 590 | 
            +
                    self.use_memory_efficient_attention_mem_eff = False
         | 
| 591 | 
            +
                    self.use_sdpa = False
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    # Attention processor
         | 
| 594 | 
            +
                    self.processor = None
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 597 | 
            +
                    self.use_memory_efficient_attention_xformers = xformers
         | 
| 598 | 
            +
                    self.use_memory_efficient_attention_mem_eff = mem_eff
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 601 | 
            +
                    self.use_sdpa = sdpa
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                def reshape_heads_to_batch_dim(self, tensor):
         | 
| 604 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 605 | 
            +
                    head_size = self.heads
         | 
| 606 | 
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         | 
| 607 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
         | 
| 608 | 
            +
                    return tensor
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                def reshape_batch_dim_to_heads(self, tensor):
         | 
| 611 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 612 | 
            +
                    head_size = self.heads
         | 
| 613 | 
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         | 
| 614 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
         | 
| 615 | 
            +
                    return tensor
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                def set_processor(self):
         | 
| 618 | 
            +
                    return self.processor
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                def get_processor(self):
         | 
| 621 | 
            +
                    return self.processor
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                def forward(self, hidden_states, context=None, mask=None, **kwargs):
         | 
| 624 | 
            +
                    if self.processor is not None:
         | 
| 625 | 
            +
                        (
         | 
| 626 | 
            +
                            hidden_states,
         | 
| 627 | 
            +
                            encoder_hidden_states,
         | 
| 628 | 
            +
                            attention_mask,
         | 
| 629 | 
            +
                        ) = translate_attention_names_from_diffusers(
         | 
| 630 | 
            +
                            hidden_states=hidden_states, context=context, mask=mask, **kwargs
         | 
| 631 | 
            +
                        )
         | 
| 632 | 
            +
                        return self.processor(
         | 
| 633 | 
            +
                            attn=self,
         | 
| 634 | 
            +
                            hidden_states=hidden_states,
         | 
| 635 | 
            +
                            encoder_hidden_states=context,
         | 
| 636 | 
            +
                            attention_mask=mask,
         | 
| 637 | 
            +
                            **kwargs
         | 
| 638 | 
            +
                        )
         | 
| 639 | 
            +
                    if self.use_memory_efficient_attention_xformers:
         | 
| 640 | 
            +
                        return self.forward_memory_efficient_xformers(hidden_states, context, mask)
         | 
| 641 | 
            +
                    if self.use_memory_efficient_attention_mem_eff:
         | 
| 642 | 
            +
                        return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
         | 
| 643 | 
            +
                    if self.use_sdpa:
         | 
| 644 | 
            +
                        return self.forward_sdpa(hidden_states, context, mask)
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                    query = self.to_q(hidden_states)
         | 
| 647 | 
            +
                    context = context if context is not None else hidden_states
         | 
| 648 | 
            +
                    key = self.to_k(context)
         | 
| 649 | 
            +
                    value = self.to_v(context)
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                    query = self.reshape_heads_to_batch_dim(query)
         | 
| 652 | 
            +
                    key = self.reshape_heads_to_batch_dim(key)
         | 
| 653 | 
            +
                    value = self.reshape_heads_to_batch_dim(value)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    hidden_states = self._attention(query, key, value)
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    # linear proj
         | 
| 658 | 
            +
                    hidden_states = self.to_out[0](hidden_states)
         | 
| 659 | 
            +
                    # hidden_states = self.to_out[1](hidden_states)     # no dropout
         | 
| 660 | 
            +
                    return hidden_states
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                def _attention(self, query, key, value):
         | 
| 663 | 
            +
                    if self.upcast_attention:
         | 
| 664 | 
            +
                        query = query.float()
         | 
| 665 | 
            +
                        key = key.float()
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                    attention_scores = torch.baddbmm(
         | 
| 668 | 
            +
                        torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
         | 
| 669 | 
            +
                        query,
         | 
| 670 | 
            +
                        key.transpose(-1, -2),
         | 
| 671 | 
            +
                        beta=0,
         | 
| 672 | 
            +
                        alpha=self.scale,
         | 
| 673 | 
            +
                    )
         | 
| 674 | 
            +
                    attention_probs = attention_scores.softmax(dim=-1)
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                    # cast back to the original dtype
         | 
| 677 | 
            +
                    attention_probs = attention_probs.to(value.dtype)
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    # compute attention output
         | 
| 680 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                    # reshape hidden_states
         | 
| 683 | 
            +
                    hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
         | 
| 684 | 
            +
                    return hidden_states
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                # TODO support Hypernetworks
         | 
| 687 | 
            +
                def forward_memory_efficient_xformers(self, x, context=None, mask=None):
         | 
| 688 | 
            +
                    import xformers.ops
         | 
| 689 | 
            +
             | 
| 690 | 
            +
                    h = self.heads
         | 
| 691 | 
            +
                    q_in = self.to_q(x)
         | 
| 692 | 
            +
                    context = context if context is not None else x
         | 
| 693 | 
            +
                    context = context.to(x.dtype)
         | 
| 694 | 
            +
                    k_in = self.to_k(context)
         | 
| 695 | 
            +
                    v_in = self.to_v(context)
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
         | 
| 698 | 
            +
                    del q_in, k_in, v_in
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    q = q.contiguous()
         | 
| 701 | 
            +
                    k = k.contiguous()
         | 
| 702 | 
            +
                    v = v.contiguous()
         | 
| 703 | 
            +
                    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)  # 最適なのを選んでくれる
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                    out = rearrange(out, "b n h d -> b n (h d)", h=h)
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                    out = self.to_out[0](out)
         | 
| 708 | 
            +
                    return out
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
         | 
| 711 | 
            +
                    flash_func = FlashAttentionFunction
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                    q_bucket_size = 512
         | 
| 714 | 
            +
                    k_bucket_size = 1024
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    h = self.heads
         | 
| 717 | 
            +
                    q = self.to_q(x)
         | 
| 718 | 
            +
                    context = context if context is not None else x
         | 
| 719 | 
            +
                    context = context.to(x.dtype)
         | 
| 720 | 
            +
                    k = self.to_k(context)
         | 
| 721 | 
            +
                    v = self.to_v(context)
         | 
| 722 | 
            +
                    del context, x
         | 
| 723 | 
            +
             | 
| 724 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                    out = rearrange(out, "b h n d -> b n (h d)")
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                    out = self.to_out[0](out)
         | 
| 731 | 
            +
                    return out
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                def forward_sdpa(self, x, context=None, mask=None):
         | 
| 734 | 
            +
                    h = self.heads
         | 
| 735 | 
            +
                    q_in = self.to_q(x)
         | 
| 736 | 
            +
                    context = context if context is not None else x
         | 
| 737 | 
            +
                    context = context.to(x.dtype)
         | 
| 738 | 
            +
                    k_in = self.to_k(context)
         | 
| 739 | 
            +
                    v_in = self.to_v(context)
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
         | 
| 742 | 
            +
                    del q_in, k_in, v_in
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                    out = rearrange(out, "b h n d -> b n (h d)", h=h)
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                    out = self.to_out[0](out)
         | 
| 749 | 
            +
                    return out
         | 
| 750 | 
            +
             | 
| 751 | 
            +
            def translate_attention_names_from_diffusers(
         | 
| 752 | 
            +
                hidden_states: torch.FloatTensor,
         | 
| 753 | 
            +
                context: Optional[torch.FloatTensor] = None,
         | 
| 754 | 
            +
                mask: Optional[torch.FloatTensor] = None,
         | 
| 755 | 
            +
                # HF naming
         | 
| 756 | 
            +
                encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 757 | 
            +
                attention_mask: Optional[torch.FloatTensor] = None
         | 
| 758 | 
            +
            ):
         | 
| 759 | 
            +
                # translate from hugging face diffusers
         | 
| 760 | 
            +
                context = context if context is not None else encoder_hidden_states
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                # translate from hugging face diffusers
         | 
| 763 | 
            +
                mask = mask if mask is not None else attention_mask
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                return hidden_states, context, mask
         | 
| 766 | 
            +
             | 
| 767 | 
            +
            # feedforward
         | 
| 768 | 
            +
            class GEGLU(nn.Module):
         | 
| 769 | 
            +
                r"""
         | 
| 770 | 
            +
                A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                Parameters:
         | 
| 773 | 
            +
                    dim_in (`int`): The number of channels in the input.
         | 
| 774 | 
            +
                    dim_out (`int`): The number of channels in the output.
         | 
| 775 | 
            +
                """
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                def __init__(self, dim_in: int, dim_out: int):
         | 
| 778 | 
            +
                    super().__init__()
         | 
| 779 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                def gelu(self, gate):
         | 
| 782 | 
            +
                    if gate.device.type != "mps":
         | 
| 783 | 
            +
                        return F.gelu(gate)
         | 
| 784 | 
            +
                    # mps: gelu is not implemented for float16
         | 
| 785 | 
            +
                    return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
         | 
| 786 | 
            +
             | 
| 787 | 
            +
                def forward(self, hidden_states):
         | 
| 788 | 
            +
                    hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
         | 
| 789 | 
            +
                    return hidden_states * self.gelu(gate)
         | 
| 790 | 
            +
             | 
| 791 | 
            +
             | 
| 792 | 
            +
            class FeedForward(nn.Module):
         | 
| 793 | 
            +
                def __init__(
         | 
| 794 | 
            +
                    self,
         | 
| 795 | 
            +
                    dim: int,
         | 
| 796 | 
            +
                ):
         | 
| 797 | 
            +
                    super().__init__()
         | 
| 798 | 
            +
                    inner_dim = int(dim * 4)  # mult is always 4
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                    self.net = nn.ModuleList([])
         | 
| 801 | 
            +
                    # project in
         | 
| 802 | 
            +
                    self.net.append(GEGLU(dim, inner_dim))
         | 
| 803 | 
            +
                    # project dropout
         | 
| 804 | 
            +
                    self.net.append(nn.Identity())  # nn.Dropout(0)) # dummy for dropout with 0
         | 
| 805 | 
            +
                    # project out
         | 
| 806 | 
            +
                    self.net.append(nn.Linear(inner_dim, dim))
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                def forward(self, hidden_states):
         | 
| 809 | 
            +
                    for module in self.net:
         | 
| 810 | 
            +
                        hidden_states = module(hidden_states)
         | 
| 811 | 
            +
                    return hidden_states
         | 
| 812 | 
            +
             | 
| 813 | 
            +
             | 
| 814 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 815 | 
            +
                def __init__(
         | 
| 816 | 
            +
                    self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
         | 
| 817 | 
            +
                ):
         | 
| 818 | 
            +
                    super().__init__()
         | 
| 819 | 
            +
             | 
| 820 | 
            +
                    # 1. Self-Attn
         | 
| 821 | 
            +
                    self.attn1 = CrossAttention(
         | 
| 822 | 
            +
                        query_dim=dim,
         | 
| 823 | 
            +
                        cross_attention_dim=None,
         | 
| 824 | 
            +
                        heads=num_attention_heads,
         | 
| 825 | 
            +
                        dim_head=attention_head_dim,
         | 
| 826 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 827 | 
            +
                    )
         | 
| 828 | 
            +
                    self.ff = FeedForward(dim)
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                    # 2. Cross-Attn
         | 
| 831 | 
            +
                    self.attn2 = CrossAttention(
         | 
| 832 | 
            +
                        query_dim=dim,
         | 
| 833 | 
            +
                        cross_attention_dim=cross_attention_dim,
         | 
| 834 | 
            +
                        heads=num_attention_heads,
         | 
| 835 | 
            +
                        dim_head=attention_head_dim,
         | 
| 836 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 837 | 
            +
                    )
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                    self.norm1 = nn.LayerNorm(dim)
         | 
| 840 | 
            +
                    self.norm2 = nn.LayerNorm(dim)
         | 
| 841 | 
            +
             | 
| 842 | 
            +
                    # 3. Feed-forward
         | 
| 843 | 
            +
                    self.norm3 = nn.LayerNorm(dim)
         | 
| 844 | 
            +
             | 
| 845 | 
            +
                def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
         | 
| 846 | 
            +
                    self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 847 | 
            +
                    self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 848 | 
            +
             | 
| 849 | 
            +
                def set_use_sdpa(self, sdpa: bool):
         | 
| 850 | 
            +
                    self.attn1.set_use_sdpa(sdpa)
         | 
| 851 | 
            +
                    self.attn2.set_use_sdpa(sdpa)
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                def forward(self, hidden_states, context=None, timestep=None):
         | 
| 854 | 
            +
                    # 1. Self-Attention
         | 
| 855 | 
            +
                    norm_hidden_states = self.norm1(hidden_states)
         | 
| 856 | 
            +
             | 
| 857 | 
            +
                    hidden_states = self.attn1(norm_hidden_states) + hidden_states
         | 
| 858 | 
            +
             | 
| 859 | 
            +
                    # 2. Cross-Attention
         | 
| 860 | 
            +
                    norm_hidden_states = self.norm2(hidden_states)
         | 
| 861 | 
            +
                    hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
         | 
| 862 | 
            +
             | 
| 863 | 
            +
                    # 3. Feed-forward
         | 
| 864 | 
            +
                    hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
         | 
| 865 | 
            +
             | 
| 866 | 
            +
                    return hidden_states
         | 
| 867 | 
            +
             | 
| 868 | 
            +
             | 
| 869 | 
            +
            class Transformer2DModel(nn.Module):
         | 
| 870 | 
            +
                def __init__(
         | 
| 871 | 
            +
                    self,
         | 
| 872 | 
            +
                    num_attention_heads: int = 16,
         | 
| 873 | 
            +
                    attention_head_dim: int = 88,
         | 
| 874 | 
            +
                    in_channels: Optional[int] = None,
         | 
| 875 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 876 | 
            +
                    use_linear_projection: bool = False,
         | 
| 877 | 
            +
                    upcast_attention: bool = False,
         | 
| 878 | 
            +
                ):
         | 
| 879 | 
            +
                    super().__init__()
         | 
| 880 | 
            +
                    self.in_channels = in_channels
         | 
| 881 | 
            +
                    self.num_attention_heads = num_attention_heads
         | 
| 882 | 
            +
                    self.attention_head_dim = attention_head_dim
         | 
| 883 | 
            +
                    inner_dim = num_attention_heads * attention_head_dim
         | 
| 884 | 
            +
                    self.use_linear_projection = use_linear_projection
         | 
| 885 | 
            +
             | 
| 886 | 
            +
                    self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 887 | 
            +
             | 
| 888 | 
            +
                    if use_linear_projection:
         | 
| 889 | 
            +
                        self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 890 | 
            +
                    else:
         | 
| 891 | 
            +
                        self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 894 | 
            +
                        [
         | 
| 895 | 
            +
                            BasicTransformerBlock(
         | 
| 896 | 
            +
                                inner_dim,
         | 
| 897 | 
            +
                                num_attention_heads,
         | 
| 898 | 
            +
                                attention_head_dim,
         | 
| 899 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 900 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 901 | 
            +
                            )
         | 
| 902 | 
            +
                        ]
         | 
| 903 | 
            +
                    )
         | 
| 904 | 
            +
             | 
| 905 | 
            +
                    if use_linear_projection:
         | 
| 906 | 
            +
                        self.proj_out = nn.Linear(in_channels, inner_dim)
         | 
| 907 | 
            +
                    else:
         | 
| 908 | 
            +
                        self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 911 | 
            +
                    for transformer in self.transformer_blocks:
         | 
| 912 | 
            +
                        transformer.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 913 | 
            +
             | 
| 914 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 915 | 
            +
                    for transformer in self.transformer_blocks:
         | 
| 916 | 
            +
                        transformer.set_use_sdpa(sdpa)
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
         | 
| 919 | 
            +
                    # 1. Input
         | 
| 920 | 
            +
                    batch, _, height, weight = hidden_states.shape
         | 
| 921 | 
            +
                    residual = hidden_states
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                    hidden_states = self.norm(hidden_states)
         | 
| 924 | 
            +
                    if not self.use_linear_projection:
         | 
| 925 | 
            +
                        hidden_states = self.proj_in(hidden_states)
         | 
| 926 | 
            +
                        inner_dim = hidden_states.shape[1]
         | 
| 927 | 
            +
                        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
         | 
| 928 | 
            +
                    else:
         | 
| 929 | 
            +
                        inner_dim = hidden_states.shape[1]
         | 
| 930 | 
            +
                        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
         | 
| 931 | 
            +
                        hidden_states = self.proj_in(hidden_states)
         | 
| 932 | 
            +
             | 
| 933 | 
            +
                    # 2. Blocks
         | 
| 934 | 
            +
                    for block in self.transformer_blocks:
         | 
| 935 | 
            +
                        hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                    # 3. Output
         | 
| 938 | 
            +
                    if not self.use_linear_projection:
         | 
| 939 | 
            +
                        hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 940 | 
            +
                        hidden_states = self.proj_out(hidden_states)
         | 
| 941 | 
            +
                    else:
         | 
| 942 | 
            +
                        hidden_states = self.proj_out(hidden_states)
         | 
| 943 | 
            +
                        hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 944 | 
            +
             | 
| 945 | 
            +
                    output = hidden_states + residual
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                    if not return_dict:
         | 
| 948 | 
            +
                        return (output,)
         | 
| 949 | 
            +
             | 
| 950 | 
            +
                    return SampleOutput(sample=output)
         | 
| 951 | 
            +
             | 
| 952 | 
            +
             | 
| 953 | 
            +
            class CrossAttnDownBlock2D(nn.Module):
         | 
| 954 | 
            +
                def __init__(
         | 
| 955 | 
            +
                    self,
         | 
| 956 | 
            +
                    in_channels: int,
         | 
| 957 | 
            +
                    out_channels: int,
         | 
| 958 | 
            +
                    add_downsample=True,
         | 
| 959 | 
            +
                    cross_attention_dim=1280,
         | 
| 960 | 
            +
                    attn_num_head_channels=1,
         | 
| 961 | 
            +
                    use_linear_projection=False,
         | 
| 962 | 
            +
                    upcast_attention=False,
         | 
| 963 | 
            +
                ):
         | 
| 964 | 
            +
                    super().__init__()
         | 
| 965 | 
            +
                    self.has_cross_attention = True
         | 
| 966 | 
            +
                    resnets = []
         | 
| 967 | 
            +
                    attentions = []
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                    self.attn_num_head_channels = attn_num_head_channels
         | 
| 970 | 
            +
             | 
| 971 | 
            +
                    for i in range(LAYERS_PER_BLOCK):
         | 
| 972 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 973 | 
            +
             | 
| 974 | 
            +
                        resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
         | 
| 975 | 
            +
                        attentions.append(
         | 
| 976 | 
            +
                            Transformer2DModel(
         | 
| 977 | 
            +
                                attn_num_head_channels,
         | 
| 978 | 
            +
                                out_channels // attn_num_head_channels,
         | 
| 979 | 
            +
                                in_channels=out_channels,
         | 
| 980 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 981 | 
            +
                                use_linear_projection=use_linear_projection,
         | 
| 982 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 983 | 
            +
                            )
         | 
| 984 | 
            +
                        )
         | 
| 985 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 986 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 987 | 
            +
             | 
| 988 | 
            +
                    if add_downsample:
         | 
| 989 | 
            +
                        self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
         | 
| 990 | 
            +
                    else:
         | 
| 991 | 
            +
                        self.downsamplers = None
         | 
| 992 | 
            +
             | 
| 993 | 
            +
                    self.gradient_checkpointing = False
         | 
| 994 | 
            +
             | 
| 995 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 996 | 
            +
                    for attn in self.attentions:
         | 
| 997 | 
            +
                        attn.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 998 | 
            +
             | 
| 999 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 1000 | 
            +
                    for attn in self.attentions:
         | 
| 1001 | 
            +
                        attn.set_use_sdpa(sdpa)
         | 
| 1002 | 
            +
             | 
| 1003 | 
            +
                def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
         | 
| 1004 | 
            +
                    output_states = ()
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                    for resnet, attn in zip(self.resnets, self.attentions):
         | 
| 1007 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 1008 | 
            +
             | 
| 1009 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 1010 | 
            +
                                def custom_forward(*inputs):
         | 
| 1011 | 
            +
                                    if return_dict is not None:
         | 
| 1012 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 1013 | 
            +
                                    else:
         | 
| 1014 | 
            +
                                        return module(*inputs)
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                                return custom_forward
         | 
| 1017 | 
            +
             | 
| 1018 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
         | 
| 1019 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 1020 | 
            +
                                create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
         | 
| 1021 | 
            +
                            )[0]
         | 
| 1022 | 
            +
                        else:
         | 
| 1023 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 1024 | 
            +
                            hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
         | 
| 1025 | 
            +
             | 
| 1026 | 
            +
                        output_states += (hidden_states,)
         | 
| 1027 | 
            +
             | 
| 1028 | 
            +
                    if self.downsamplers is not None:
         | 
| 1029 | 
            +
                        for downsampler in self.downsamplers:
         | 
| 1030 | 
            +
                            hidden_states = downsampler(hidden_states)
         | 
| 1031 | 
            +
             | 
| 1032 | 
            +
                        output_states += (hidden_states,)
         | 
| 1033 | 
            +
             | 
| 1034 | 
            +
                    return hidden_states, output_states
         | 
| 1035 | 
            +
             | 
| 1036 | 
            +
             | 
| 1037 | 
            +
            class UNetMidBlock2DCrossAttn(nn.Module):
         | 
| 1038 | 
            +
                def __init__(
         | 
| 1039 | 
            +
                    self,
         | 
| 1040 | 
            +
                    in_channels: int,
         | 
| 1041 | 
            +
                    attn_num_head_channels=1,
         | 
| 1042 | 
            +
                    cross_attention_dim=1280,
         | 
| 1043 | 
            +
                    use_linear_projection=False,
         | 
| 1044 | 
            +
                ):
         | 
| 1045 | 
            +
                    super().__init__()
         | 
| 1046 | 
            +
             | 
| 1047 | 
            +
                    self.has_cross_attention = True
         | 
| 1048 | 
            +
                    self.attn_num_head_channels = attn_num_head_channels
         | 
| 1049 | 
            +
             | 
| 1050 | 
            +
                    # Middle block has two resnets and one attention
         | 
| 1051 | 
            +
                    resnets = [
         | 
| 1052 | 
            +
                        ResnetBlock2D(
         | 
| 1053 | 
            +
                            in_channels=in_channels,
         | 
| 1054 | 
            +
                            out_channels=in_channels,
         | 
| 1055 | 
            +
                        ),
         | 
| 1056 | 
            +
                        ResnetBlock2D(
         | 
| 1057 | 
            +
                            in_channels=in_channels,
         | 
| 1058 | 
            +
                            out_channels=in_channels,
         | 
| 1059 | 
            +
                        ),
         | 
| 1060 | 
            +
                    ]
         | 
| 1061 | 
            +
                    attentions = [
         | 
| 1062 | 
            +
                        Transformer2DModel(
         | 
| 1063 | 
            +
                            attn_num_head_channels,
         | 
| 1064 | 
            +
                            in_channels // attn_num_head_channels,
         | 
| 1065 | 
            +
                            in_channels=in_channels,
         | 
| 1066 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 1067 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 1068 | 
            +
                        )
         | 
| 1069 | 
            +
                    ]
         | 
| 1070 | 
            +
             | 
| 1071 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 1072 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 1073 | 
            +
             | 
| 1074 | 
            +
                    self.gradient_checkpointing = False
         | 
| 1075 | 
            +
             | 
| 1076 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 1077 | 
            +
                    for attn in self.attentions:
         | 
| 1078 | 
            +
                        attn.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 1079 | 
            +
             | 
| 1080 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 1081 | 
            +
                    for attn in self.attentions:
         | 
| 1082 | 
            +
                        attn.set_use_sdpa(sdpa)
         | 
| 1083 | 
            +
             | 
| 1084 | 
            +
                def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
         | 
| 1085 | 
            +
                    for i, resnet in enumerate(self.resnets):
         | 
| 1086 | 
            +
                        attn = None if i == 0 else self.attentions[i - 1]
         | 
| 1087 | 
            +
             | 
| 1088 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 1089 | 
            +
             | 
| 1090 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 1091 | 
            +
                                def custom_forward(*inputs):
         | 
| 1092 | 
            +
                                    if return_dict is not None:
         | 
| 1093 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 1094 | 
            +
                                    else:
         | 
| 1095 | 
            +
                                        return module(*inputs)
         | 
| 1096 | 
            +
             | 
| 1097 | 
            +
                                return custom_forward
         | 
| 1098 | 
            +
             | 
| 1099 | 
            +
                            if attn is not None:
         | 
| 1100 | 
            +
                                hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 1101 | 
            +
                                    create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
         | 
| 1102 | 
            +
                                )[0]
         | 
| 1103 | 
            +
             | 
| 1104 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
         | 
| 1105 | 
            +
                        else:
         | 
| 1106 | 
            +
                            if attn is not None:
         | 
| 1107 | 
            +
                                hidden_states = attn(hidden_states, encoder_hidden_states).sample
         | 
| 1108 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 1109 | 
            +
             | 
| 1110 | 
            +
                    return hidden_states
         | 
| 1111 | 
            +
             | 
| 1112 | 
            +
             | 
| 1113 | 
            +
            class Upsample2D(nn.Module):
         | 
| 1114 | 
            +
                def __init__(self, channels, out_channels):
         | 
| 1115 | 
            +
                    super().__init__()
         | 
| 1116 | 
            +
                    self.channels = channels
         | 
| 1117 | 
            +
                    self.out_channels = out_channels
         | 
| 1118 | 
            +
                    self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
         | 
| 1119 | 
            +
             | 
| 1120 | 
            +
                def forward(self, hidden_states, output_size):
         | 
| 1121 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 1122 | 
            +
             | 
| 1123 | 
            +
                    # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
         | 
| 1124 | 
            +
                    # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
         | 
| 1125 | 
            +
                    # https://github.com/pytorch/pytorch/issues/86679
         | 
| 1126 | 
            +
                    dtype = hidden_states.dtype
         | 
| 1127 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 1128 | 
            +
                        hidden_states = hidden_states.to(torch.float32)
         | 
| 1129 | 
            +
             | 
| 1130 | 
            +
                    # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
         | 
| 1131 | 
            +
                    if hidden_states.shape[0] >= 64:
         | 
| 1132 | 
            +
                        hidden_states = hidden_states.contiguous()
         | 
| 1133 | 
            +
             | 
| 1134 | 
            +
                    # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
         | 
| 1135 | 
            +
                    if output_size is None:
         | 
| 1136 | 
            +
                        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
         | 
| 1137 | 
            +
                    else:
         | 
| 1138 | 
            +
                        hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
         | 
| 1139 | 
            +
             | 
| 1140 | 
            +
                    # If the input is bfloat16, we cast back to bfloat16
         | 
| 1141 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 1142 | 
            +
                        hidden_states = hidden_states.to(dtype)
         | 
| 1143 | 
            +
             | 
| 1144 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 1145 | 
            +
             | 
| 1146 | 
            +
                    return hidden_states
         | 
| 1147 | 
            +
             | 
| 1148 | 
            +
             | 
| 1149 | 
            +
            class UpBlock2D(nn.Module):
         | 
| 1150 | 
            +
                def __init__(
         | 
| 1151 | 
            +
                    self,
         | 
| 1152 | 
            +
                    in_channels: int,
         | 
| 1153 | 
            +
                    prev_output_channel: int,
         | 
| 1154 | 
            +
                    out_channels: int,
         | 
| 1155 | 
            +
                    add_upsample=True,
         | 
| 1156 | 
            +
                ):
         | 
| 1157 | 
            +
                    super().__init__()
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                    self.has_cross_attention = False
         | 
| 1160 | 
            +
                    resnets = []
         | 
| 1161 | 
            +
             | 
| 1162 | 
            +
                    for i in range(LAYERS_PER_BLOCK_UP):
         | 
| 1163 | 
            +
                        res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
         | 
| 1164 | 
            +
                        resnet_in_channels = prev_output_channel if i == 0 else out_channels
         | 
| 1165 | 
            +
             | 
| 1166 | 
            +
                        resnets.append(
         | 
| 1167 | 
            +
                            ResnetBlock2D(
         | 
| 1168 | 
            +
                                in_channels=resnet_in_channels + res_skip_channels,
         | 
| 1169 | 
            +
                                out_channels=out_channels,
         | 
| 1170 | 
            +
                            )
         | 
| 1171 | 
            +
                        )
         | 
| 1172 | 
            +
             | 
| 1173 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 1174 | 
            +
             | 
| 1175 | 
            +
                    if add_upsample:
         | 
| 1176 | 
            +
                        self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
         | 
| 1177 | 
            +
                    else:
         | 
| 1178 | 
            +
                        self.upsamplers = None
         | 
| 1179 | 
            +
             | 
| 1180 | 
            +
                    self.gradient_checkpointing = False
         | 
| 1181 | 
            +
             | 
| 1182 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 1183 | 
            +
                    pass
         | 
| 1184 | 
            +
             | 
| 1185 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 1186 | 
            +
                    pass
         | 
| 1187 | 
            +
             | 
| 1188 | 
            +
                def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
         | 
| 1189 | 
            +
                    for resnet in self.resnets:
         | 
| 1190 | 
            +
                        # pop res hidden states
         | 
| 1191 | 
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         | 
| 1192 | 
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         | 
| 1193 | 
            +
             | 
| 1194 | 
            +
                        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         | 
| 1195 | 
            +
             | 
| 1196 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 1197 | 
            +
             | 
| 1198 | 
            +
                            def create_custom_forward(module):
         | 
| 1199 | 
            +
                                def custom_forward(*inputs):
         | 
| 1200 | 
            +
                                    return module(*inputs)
         | 
| 1201 | 
            +
             | 
| 1202 | 
            +
                                return custom_forward
         | 
| 1203 | 
            +
             | 
| 1204 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
         | 
| 1205 | 
            +
                        else:
         | 
| 1206 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 1207 | 
            +
             | 
| 1208 | 
            +
                    if self.upsamplers is not None:
         | 
| 1209 | 
            +
                        for upsampler in self.upsamplers:
         | 
| 1210 | 
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         | 
| 1211 | 
            +
             | 
| 1212 | 
            +
                    return hidden_states
         | 
| 1213 | 
            +
             | 
| 1214 | 
            +
             | 
| 1215 | 
            +
            class CrossAttnUpBlock2D(nn.Module):
         | 
| 1216 | 
            +
                def __init__(
         | 
| 1217 | 
            +
                    self,
         | 
| 1218 | 
            +
                    in_channels: int,
         | 
| 1219 | 
            +
                    out_channels: int,
         | 
| 1220 | 
            +
                    prev_output_channel: int,
         | 
| 1221 | 
            +
                    attn_num_head_channels=1,
         | 
| 1222 | 
            +
                    cross_attention_dim=1280,
         | 
| 1223 | 
            +
                    add_upsample=True,
         | 
| 1224 | 
            +
                    use_linear_projection=False,
         | 
| 1225 | 
            +
                    upcast_attention=False,
         | 
| 1226 | 
            +
                ):
         | 
| 1227 | 
            +
                    super().__init__()
         | 
| 1228 | 
            +
                    resnets = []
         | 
| 1229 | 
            +
                    attentions = []
         | 
| 1230 | 
            +
             | 
| 1231 | 
            +
                    self.has_cross_attention = True
         | 
| 1232 | 
            +
                    self.attn_num_head_channels = attn_num_head_channels
         | 
| 1233 | 
            +
             | 
| 1234 | 
            +
                    for i in range(LAYERS_PER_BLOCK_UP):
         | 
| 1235 | 
            +
                        res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
         | 
| 1236 | 
            +
                        resnet_in_channels = prev_output_channel if i == 0 else out_channels
         | 
| 1237 | 
            +
             | 
| 1238 | 
            +
                        resnets.append(
         | 
| 1239 | 
            +
                            ResnetBlock2D(
         | 
| 1240 | 
            +
                                in_channels=resnet_in_channels + res_skip_channels,
         | 
| 1241 | 
            +
                                out_channels=out_channels,
         | 
| 1242 | 
            +
                            )
         | 
| 1243 | 
            +
                        )
         | 
| 1244 | 
            +
                        attentions.append(
         | 
| 1245 | 
            +
                            Transformer2DModel(
         | 
| 1246 | 
            +
                                attn_num_head_channels,
         | 
| 1247 | 
            +
                                out_channels // attn_num_head_channels,
         | 
| 1248 | 
            +
                                in_channels=out_channels,
         | 
| 1249 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 1250 | 
            +
                                use_linear_projection=use_linear_projection,
         | 
| 1251 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 1252 | 
            +
                            )
         | 
| 1253 | 
            +
                        )
         | 
| 1254 | 
            +
             | 
| 1255 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 1256 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 1257 | 
            +
             | 
| 1258 | 
            +
                    if add_upsample:
         | 
| 1259 | 
            +
                        self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
         | 
| 1260 | 
            +
                    else:
         | 
| 1261 | 
            +
                        self.upsamplers = None
         | 
| 1262 | 
            +
             | 
| 1263 | 
            +
                    self.gradient_checkpointing = False
         | 
| 1264 | 
            +
             | 
| 1265 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 1266 | 
            +
                    for attn in self.attentions:
         | 
| 1267 | 
            +
                        attn.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 1268 | 
            +
             | 
| 1269 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 1270 | 
            +
                    for attn in self.attentions:
         | 
| 1271 | 
            +
                        attn.set_use_sdpa(sdpa)
         | 
| 1272 | 
            +
             | 
| 1273 | 
            +
                def forward(
         | 
| 1274 | 
            +
                    self,
         | 
| 1275 | 
            +
                    hidden_states,
         | 
| 1276 | 
            +
                    res_hidden_states_tuple,
         | 
| 1277 | 
            +
                    temb=None,
         | 
| 1278 | 
            +
                    encoder_hidden_states=None,
         | 
| 1279 | 
            +
                    upsample_size=None,
         | 
| 1280 | 
            +
                ):
         | 
| 1281 | 
            +
                    for resnet, attn in zip(self.resnets, self.attentions):
         | 
| 1282 | 
            +
                        # pop res hidden states
         | 
| 1283 | 
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         | 
| 1284 | 
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         | 
| 1285 | 
            +
             | 
| 1286 | 
            +
                        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         | 
| 1287 | 
            +
             | 
| 1288 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 1289 | 
            +
             | 
| 1290 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 1291 | 
            +
                                def custom_forward(*inputs):
         | 
| 1292 | 
            +
                                    if return_dict is not None:
         | 
| 1293 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 1294 | 
            +
                                    else:
         | 
| 1295 | 
            +
                                        return module(*inputs)
         | 
| 1296 | 
            +
             | 
| 1297 | 
            +
                                return custom_forward
         | 
| 1298 | 
            +
             | 
| 1299 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
         | 
| 1300 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 1301 | 
            +
                                create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
         | 
| 1302 | 
            +
                            )[0]
         | 
| 1303 | 
            +
                        else:
         | 
| 1304 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 1305 | 
            +
                            hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
         | 
| 1306 | 
            +
             | 
| 1307 | 
            +
                    if self.upsamplers is not None:
         | 
| 1308 | 
            +
                        for upsampler in self.upsamplers:
         | 
| 1309 | 
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         | 
| 1310 | 
            +
             | 
| 1311 | 
            +
                    return hidden_states
         | 
| 1312 | 
            +
             | 
| 1313 | 
            +
             | 
| 1314 | 
            +
            def get_down_block(
         | 
| 1315 | 
            +
                down_block_type,
         | 
| 1316 | 
            +
                in_channels,
         | 
| 1317 | 
            +
                out_channels,
         | 
| 1318 | 
            +
                add_downsample,
         | 
| 1319 | 
            +
                attn_num_head_channels,
         | 
| 1320 | 
            +
                cross_attention_dim,
         | 
| 1321 | 
            +
                use_linear_projection,
         | 
| 1322 | 
            +
                upcast_attention,
         | 
| 1323 | 
            +
            ):
         | 
| 1324 | 
            +
                if down_block_type == "DownBlock2D":
         | 
| 1325 | 
            +
                    return DownBlock2D(
         | 
| 1326 | 
            +
                        in_channels=in_channels,
         | 
| 1327 | 
            +
                        out_channels=out_channels,
         | 
| 1328 | 
            +
                        add_downsample=add_downsample,
         | 
| 1329 | 
            +
                    )
         | 
| 1330 | 
            +
                elif down_block_type == "CrossAttnDownBlock2D":
         | 
| 1331 | 
            +
                    return CrossAttnDownBlock2D(
         | 
| 1332 | 
            +
                        in_channels=in_channels,
         | 
| 1333 | 
            +
                        out_channels=out_channels,
         | 
| 1334 | 
            +
                        add_downsample=add_downsample,
         | 
| 1335 | 
            +
                        cross_attention_dim=cross_attention_dim,
         | 
| 1336 | 
            +
                        attn_num_head_channels=attn_num_head_channels,
         | 
| 1337 | 
            +
                        use_linear_projection=use_linear_projection,
         | 
| 1338 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 1339 | 
            +
                    )
         | 
| 1340 | 
            +
             | 
| 1341 | 
            +
             | 
| 1342 | 
            +
            def get_up_block(
         | 
| 1343 | 
            +
                up_block_type,
         | 
| 1344 | 
            +
                in_channels,
         | 
| 1345 | 
            +
                out_channels,
         | 
| 1346 | 
            +
                prev_output_channel,
         | 
| 1347 | 
            +
                add_upsample,
         | 
| 1348 | 
            +
                attn_num_head_channels,
         | 
| 1349 | 
            +
                cross_attention_dim=None,
         | 
| 1350 | 
            +
                use_linear_projection=False,
         | 
| 1351 | 
            +
                upcast_attention=False,
         | 
| 1352 | 
            +
            ):
         | 
| 1353 | 
            +
                if up_block_type == "UpBlock2D":
         | 
| 1354 | 
            +
                    return UpBlock2D(
         | 
| 1355 | 
            +
                        in_channels=in_channels,
         | 
| 1356 | 
            +
                        prev_output_channel=prev_output_channel,
         | 
| 1357 | 
            +
                        out_channels=out_channels,
         | 
| 1358 | 
            +
                        add_upsample=add_upsample,
         | 
| 1359 | 
            +
                    )
         | 
| 1360 | 
            +
                elif up_block_type == "CrossAttnUpBlock2D":
         | 
| 1361 | 
            +
                    return CrossAttnUpBlock2D(
         | 
| 1362 | 
            +
                        in_channels=in_channels,
         | 
| 1363 | 
            +
                        out_channels=out_channels,
         | 
| 1364 | 
            +
                        prev_output_channel=prev_output_channel,
         | 
| 1365 | 
            +
                        attn_num_head_channels=attn_num_head_channels,
         | 
| 1366 | 
            +
                        cross_attention_dim=cross_attention_dim,
         | 
| 1367 | 
            +
                        add_upsample=add_upsample,
         | 
| 1368 | 
            +
                        use_linear_projection=use_linear_projection,
         | 
| 1369 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 1370 | 
            +
                    )
         | 
| 1371 | 
            +
             | 
| 1372 | 
            +
             | 
| 1373 | 
            +
            class UNet2DConditionModel(nn.Module):
         | 
| 1374 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 1375 | 
            +
             | 
| 1376 | 
            +
                def __init__(
         | 
| 1377 | 
            +
                    self,
         | 
| 1378 | 
            +
                    sample_size: Optional[int] = None,
         | 
| 1379 | 
            +
                    attention_head_dim: Union[int, Tuple[int]] = 8,
         | 
| 1380 | 
            +
                    cross_attention_dim: int = 1280,
         | 
| 1381 | 
            +
                    use_linear_projection: bool = False,
         | 
| 1382 | 
            +
                    upcast_attention: bool = False,
         | 
| 1383 | 
            +
                    **kwargs,
         | 
| 1384 | 
            +
                ):
         | 
| 1385 | 
            +
                    super().__init__()
         | 
| 1386 | 
            +
                    assert sample_size is not None, "sample_size must be specified"
         | 
| 1387 | 
            +
                    logger.info(
         | 
| 1388 | 
            +
                        f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
         | 
| 1389 | 
            +
                    )
         | 
| 1390 | 
            +
             | 
| 1391 | 
            +
                    # 外部からの参照用に定義しておく
         | 
| 1392 | 
            +
                    self.in_channels = IN_CHANNELS
         | 
| 1393 | 
            +
                    self.out_channels = OUT_CHANNELS
         | 
| 1394 | 
            +
             | 
| 1395 | 
            +
                    self.sample_size = sample_size
         | 
| 1396 | 
            +
                    self.prepare_config(sample_size=sample_size)
         | 
| 1397 | 
            +
             | 
| 1398 | 
            +
                    # state_dictの書式が変わるのでmoduleの持ち方は変えられない
         | 
| 1399 | 
            +
             | 
| 1400 | 
            +
                    # input
         | 
| 1401 | 
            +
                    self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
         | 
| 1402 | 
            +
             | 
| 1403 | 
            +
                    # time
         | 
| 1404 | 
            +
                    self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
         | 
| 1405 | 
            +
             | 
| 1406 | 
            +
                    self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
         | 
| 1407 | 
            +
             | 
| 1408 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 1409 | 
            +
                    self.mid_block = None
         | 
| 1410 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 1411 | 
            +
             | 
| 1412 | 
            +
                    if isinstance(attention_head_dim, int):
         | 
| 1413 | 
            +
                        attention_head_dim = (attention_head_dim,) * 4
         | 
| 1414 | 
            +
             | 
| 1415 | 
            +
                    # down
         | 
| 1416 | 
            +
                    output_channel = BLOCK_OUT_CHANNELS[0]
         | 
| 1417 | 
            +
                    for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
         | 
| 1418 | 
            +
                        input_channel = output_channel
         | 
| 1419 | 
            +
                        output_channel = BLOCK_OUT_CHANNELS[i]
         | 
| 1420 | 
            +
                        is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
         | 
| 1421 | 
            +
             | 
| 1422 | 
            +
                        down_block = get_down_block(
         | 
| 1423 | 
            +
                            down_block_type,
         | 
| 1424 | 
            +
                            in_channels=input_channel,
         | 
| 1425 | 
            +
                            out_channels=output_channel,
         | 
| 1426 | 
            +
                            add_downsample=not is_final_block,
         | 
| 1427 | 
            +
                            attn_num_head_channels=attention_head_dim[i],
         | 
| 1428 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 1429 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 1430 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 1431 | 
            +
                        )
         | 
| 1432 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 1433 | 
            +
             | 
| 1434 | 
            +
                    # mid
         | 
| 1435 | 
            +
                    self.mid_block = UNetMidBlock2DCrossAttn(
         | 
| 1436 | 
            +
                        in_channels=BLOCK_OUT_CHANNELS[-1],
         | 
| 1437 | 
            +
                        attn_num_head_channels=attention_head_dim[-1],
         | 
| 1438 | 
            +
                        cross_attention_dim=cross_attention_dim,
         | 
| 1439 | 
            +
                        use_linear_projection=use_linear_projection,
         | 
| 1440 | 
            +
                    )
         | 
| 1441 | 
            +
             | 
| 1442 | 
            +
                    # count how many layers upsample the images
         | 
| 1443 | 
            +
                    self.num_upsamplers = 0
         | 
| 1444 | 
            +
             | 
| 1445 | 
            +
                    # up
         | 
| 1446 | 
            +
                    reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
         | 
| 1447 | 
            +
                    reversed_attention_head_dim = list(reversed(attention_head_dim))
         | 
| 1448 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 1449 | 
            +
                    for i, up_block_type in enumerate(UP_BLOCK_TYPES):
         | 
| 1450 | 
            +
                        is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
         | 
| 1451 | 
            +
             | 
| 1452 | 
            +
                        prev_output_channel = output_channel
         | 
| 1453 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 1454 | 
            +
                        input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
         | 
| 1455 | 
            +
             | 
| 1456 | 
            +
                        # add upsample block for all BUT final layer
         | 
| 1457 | 
            +
                        if not is_final_block:
         | 
| 1458 | 
            +
                            add_upsample = True
         | 
| 1459 | 
            +
                            self.num_upsamplers += 1
         | 
| 1460 | 
            +
                        else:
         | 
| 1461 | 
            +
                            add_upsample = False
         | 
| 1462 | 
            +
             | 
| 1463 | 
            +
                        up_block = get_up_block(
         | 
| 1464 | 
            +
                            up_block_type,
         | 
| 1465 | 
            +
                            in_channels=input_channel,
         | 
| 1466 | 
            +
                            out_channels=output_channel,
         | 
| 1467 | 
            +
                            prev_output_channel=prev_output_channel,
         | 
| 1468 | 
            +
                            add_upsample=add_upsample,
         | 
| 1469 | 
            +
                            attn_num_head_channels=reversed_attention_head_dim[i],
         | 
| 1470 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 1471 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 1472 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 1473 | 
            +
                        )
         | 
| 1474 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 1475 | 
            +
                        prev_output_channel = output_channel
         | 
| 1476 | 
            +
             | 
| 1477 | 
            +
                    # out
         | 
| 1478 | 
            +
                    self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
         | 
| 1479 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 1480 | 
            +
                    self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
         | 
| 1481 | 
            +
             | 
| 1482 | 
            +
                # region diffusers compatibility
         | 
| 1483 | 
            +
                def prepare_config(self, *args, **kwargs):
         | 
| 1484 | 
            +
                    self.config = SimpleNamespace(**kwargs)
         | 
| 1485 | 
            +
             | 
| 1486 | 
            +
                @property
         | 
| 1487 | 
            +
                def dtype(self) -> torch.dtype:
         | 
| 1488 | 
            +
                    # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
         | 
| 1489 | 
            +
                    return get_parameter_dtype(self)
         | 
| 1490 | 
            +
             | 
| 1491 | 
            +
                @property
         | 
| 1492 | 
            +
                def device(self) -> torch.device:
         | 
| 1493 | 
            +
                    # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
         | 
| 1494 | 
            +
                    return get_parameter_device(self)
         | 
| 1495 | 
            +
             | 
| 1496 | 
            +
                def set_attention_slice(self, slice_size):
         | 
| 1497 | 
            +
                    raise NotImplementedError("Attention slicing is not supported for this model.")
         | 
| 1498 | 
            +
             | 
| 1499 | 
            +
                def is_gradient_checkpointing(self) -> bool:
         | 
| 1500 | 
            +
                    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
         | 
| 1501 | 
            +
             | 
| 1502 | 
            +
                def enable_gradient_checkpointing(self):
         | 
| 1503 | 
            +
                    self.set_gradient_checkpointing(value=True)
         | 
| 1504 | 
            +
             | 
| 1505 | 
            +
                def disable_gradient_checkpointing(self):
         | 
| 1506 | 
            +
                    self.set_gradient_checkpointing(value=False)
         | 
| 1507 | 
            +
             | 
| 1508 | 
            +
                def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
         | 
| 1509 | 
            +
                    modules = self.down_blocks + [self.mid_block] + self.up_blocks
         | 
| 1510 | 
            +
                    for module in modules:
         | 
| 1511 | 
            +
                        module.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 1512 | 
            +
             | 
| 1513 | 
            +
                def set_use_sdpa(self, sdpa: bool) -> None:
         | 
| 1514 | 
            +
                    modules = self.down_blocks + [self.mid_block] + self.up_blocks
         | 
| 1515 | 
            +
                    for module in modules:
         | 
| 1516 | 
            +
                        module.set_use_sdpa(sdpa)
         | 
| 1517 | 
            +
             | 
| 1518 | 
            +
                def set_gradient_checkpointing(self, value=False):
         | 
| 1519 | 
            +
                    modules = self.down_blocks + [self.mid_block] + self.up_blocks
         | 
| 1520 | 
            +
                    for module in modules:
         | 
| 1521 | 
            +
                        logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
         | 
| 1522 | 
            +
                        module.gradient_checkpointing = value
         | 
| 1523 | 
            +
             | 
| 1524 | 
            +
                # endregion
         | 
| 1525 | 
            +
             | 
| 1526 | 
            +
                def forward(
         | 
| 1527 | 
            +
                    self,
         | 
| 1528 | 
            +
                    sample: torch.FloatTensor,
         | 
| 1529 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 1530 | 
            +
                    encoder_hidden_states: torch.Tensor,
         | 
| 1531 | 
            +
                    class_labels: Optional[torch.Tensor] = None,
         | 
| 1532 | 
            +
                    return_dict: bool = True,
         | 
| 1533 | 
            +
                    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
         | 
| 1534 | 
            +
                    mid_block_additional_residual: Optional[torch.Tensor] = None,
         | 
| 1535 | 
            +
                ) -> Union[Dict, Tuple]:
         | 
| 1536 | 
            +
                    r"""
         | 
| 1537 | 
            +
                    Args:
         | 
| 1538 | 
            +
                        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
         | 
| 1539 | 
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
         | 
| 1540 | 
            +
                        encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
         | 
| 1541 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1542 | 
            +
                            Whether or not to return a dict instead of a plain tuple.
         | 
| 1543 | 
            +
             | 
| 1544 | 
            +
                    Returns:
         | 
| 1545 | 
            +
                        `SampleOutput` or `tuple`:
         | 
| 1546 | 
            +
                        `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
         | 
| 1547 | 
            +
                    """
         | 
| 1548 | 
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         | 
| 1549 | 
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
         | 
| 1550 | 
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         | 
| 1551 | 
            +
                    # on the fly if necessary.
         | 
| 1552 | 
            +
                    # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
         | 
| 1553 | 
            +
                    # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
         | 
| 1554 | 
            +
                    # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
         | 
| 1555 | 
            +
                    default_overall_up_factor = 2**self.num_upsamplers
         | 
| 1556 | 
            +
             | 
| 1557 | 
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         | 
| 1558 | 
            +
                    # 64で割り切れないときはupsamplerにサイズを伝える
         | 
| 1559 | 
            +
                    forward_upsample_size = False
         | 
| 1560 | 
            +
                    upsample_size = None
         | 
| 1561 | 
            +
             | 
| 1562 | 
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         | 
| 1563 | 
            +
                        # logger.info("Forward upsample size to force interpolation output size.")
         | 
| 1564 | 
            +
                        forward_upsample_size = True
         | 
| 1565 | 
            +
             | 
| 1566 | 
            +
                    # 1. time
         | 
| 1567 | 
            +
                    timesteps = timestep
         | 
| 1568 | 
            +
                    timesteps = self.handle_unusual_timesteps(sample, timesteps)  # 変な時だけ処理
         | 
| 1569 | 
            +
             | 
| 1570 | 
            +
                    t_emb = self.time_proj(timesteps)
         | 
| 1571 | 
            +
             | 
| 1572 | 
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         | 
| 1573 | 
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         | 
| 1574 | 
            +
                    # there might be better ways to encapsulate this.
         | 
| 1575 | 
            +
                    # timestepsは重みを含まないので常にfloat32のテンソルを返す
         | 
| 1576 | 
            +
                    # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
         | 
| 1577 | 
            +
                    # time_projでキャストしておけばいいんじゃね?
         | 
| 1578 | 
            +
                    t_emb = t_emb.to(dtype=self.dtype)
         | 
| 1579 | 
            +
                    emb = self.time_embedding(t_emb)
         | 
| 1580 | 
            +
             | 
| 1581 | 
            +
                    # 2. pre-process
         | 
| 1582 | 
            +
                    sample = self.conv_in(sample)
         | 
| 1583 | 
            +
             | 
| 1584 | 
            +
                    down_block_res_samples = (sample,)
         | 
| 1585 | 
            +
                    for downsample_block in self.down_blocks:
         | 
| 1586 | 
            +
                        # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
         | 
| 1587 | 
            +
                        # まあこちらのほうがわかりやすいかもしれない
         | 
| 1588 | 
            +
                        if downsample_block.has_cross_attention:
         | 
| 1589 | 
            +
                            sample, res_samples = downsample_block(
         | 
| 1590 | 
            +
                                hidden_states=sample,
         | 
| 1591 | 
            +
                                temb=emb,
         | 
| 1592 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 1593 | 
            +
                            )
         | 
| 1594 | 
            +
                        else:
         | 
| 1595 | 
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         | 
| 1596 | 
            +
             | 
| 1597 | 
            +
                        down_block_res_samples += res_samples
         | 
| 1598 | 
            +
             | 
| 1599 | 
            +
                    # skip connectionにControlNetの出力を追加する
         | 
| 1600 | 
            +
                    if down_block_additional_residuals is not None:
         | 
| 1601 | 
            +
                        down_block_res_samples = list(down_block_res_samples)
         | 
| 1602 | 
            +
                        for i in range(len(down_block_res_samples)):
         | 
| 1603 | 
            +
                            down_block_res_samples[i] += down_block_additional_residuals[i]
         | 
| 1604 | 
            +
                        down_block_res_samples = tuple(down_block_res_samples)
         | 
| 1605 | 
            +
             | 
| 1606 | 
            +
                    # 4. mid
         | 
| 1607 | 
            +
                    sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
         | 
| 1608 | 
            +
             | 
| 1609 | 
            +
                    # ControlNetの出力を追加する
         | 
| 1610 | 
            +
                    if mid_block_additional_residual is not None:
         | 
| 1611 | 
            +
                        sample += mid_block_additional_residual
         | 
| 1612 | 
            +
             | 
| 1613 | 
            +
                    # 5. up
         | 
| 1614 | 
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         | 
| 1615 | 
            +
                        is_final_block = i == len(self.up_blocks) - 1
         | 
| 1616 | 
            +
             | 
| 1617 | 
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         | 
| 1618 | 
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]  # skip connection
         | 
| 1619 | 
            +
             | 
| 1620 | 
            +
                        # if we have not reached the final block and need to forward the upsample size, we do it here
         | 
| 1621 | 
            +
                        # 前述のように最後のブロック以外ではupsample_sizeを伝える
         | 
| 1622 | 
            +
                        if not is_final_block and forward_upsample_size:
         | 
| 1623 | 
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         | 
| 1624 | 
            +
             | 
| 1625 | 
            +
                        if upsample_block.has_cross_attention:
         | 
| 1626 | 
            +
                            sample = upsample_block(
         | 
| 1627 | 
            +
                                hidden_states=sample,
         | 
| 1628 | 
            +
                                temb=emb,
         | 
| 1629 | 
            +
                                res_hidden_states_tuple=res_samples,
         | 
| 1630 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 1631 | 
            +
                                upsample_size=upsample_size,
         | 
| 1632 | 
            +
                            )
         | 
| 1633 | 
            +
                        else:
         | 
| 1634 | 
            +
                            sample = upsample_block(
         | 
| 1635 | 
            +
                                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
         | 
| 1636 | 
            +
                            )
         | 
| 1637 | 
            +
             | 
| 1638 | 
            +
                    # 6. post-process
         | 
| 1639 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 1640 | 
            +
                    sample = self.conv_act(sample)
         | 
| 1641 | 
            +
                    sample = self.conv_out(sample)
         | 
| 1642 | 
            +
             | 
| 1643 | 
            +
                    if not return_dict:
         | 
| 1644 | 
            +
                        return (sample,)
         | 
| 1645 | 
            +
             | 
| 1646 | 
            +
                    return SampleOutput(sample=sample)
         | 
| 1647 | 
            +
             | 
| 1648 | 
            +
                def handle_unusual_timesteps(self, sample, timesteps):
         | 
| 1649 | 
            +
                    r"""
         | 
| 1650 | 
            +
                    timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
         | 
| 1651 | 
            +
                    """
         | 
| 1652 | 
            +
                    if not torch.is_tensor(timesteps):
         | 
| 1653 | 
            +
                        # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         | 
| 1654 | 
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         | 
| 1655 | 
            +
                        is_mps = sample.device.type == "mps"
         | 
| 1656 | 
            +
                        if isinstance(timesteps, float):
         | 
| 1657 | 
            +
                            dtype = torch.float32 if is_mps else torch.float64
         | 
| 1658 | 
            +
                        else:
         | 
| 1659 | 
            +
                            dtype = torch.int32 if is_mps else torch.int64
         | 
| 1660 | 
            +
                        timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
         | 
| 1661 | 
            +
                    elif len(timesteps.shape) == 0:
         | 
| 1662 | 
            +
                        timesteps = timesteps[None].to(sample.device)
         | 
| 1663 | 
            +
             | 
| 1664 | 
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 1665 | 
            +
                    timesteps = timesteps.expand(sample.shape[0])
         | 
| 1666 | 
            +
             | 
| 1667 | 
            +
                    return timesteps
         | 
| 1668 | 
            +
             | 
| 1669 | 
            +
             | 
| 1670 | 
            +
            class InferUNet2DConditionModel:
         | 
| 1671 | 
            +
                def __init__(self, original_unet: UNet2DConditionModel):
         | 
| 1672 | 
            +
                    self.delegate = original_unet
         | 
| 1673 | 
            +
             | 
| 1674 | 
            +
                    # override original model's forward method: because forward is not called by `__call__`
         | 
| 1675 | 
            +
                    # overriding `__call__` is not enough, because nn.Module.forward has a special handling
         | 
| 1676 | 
            +
                    self.delegate.forward = self.forward
         | 
| 1677 | 
            +
             | 
| 1678 | 
            +
                    # override original model's up blocks' forward method
         | 
| 1679 | 
            +
                    for up_block in self.delegate.up_blocks:
         | 
| 1680 | 
            +
                        if up_block.__class__.__name__ == "UpBlock2D":
         | 
| 1681 | 
            +
             | 
| 1682 | 
            +
                            def resnet_wrapper(func, block):
         | 
| 1683 | 
            +
                                def forward(*args, **kwargs):
         | 
| 1684 | 
            +
                                    return func(block, *args, **kwargs)
         | 
| 1685 | 
            +
             | 
| 1686 | 
            +
                                return forward
         | 
| 1687 | 
            +
             | 
| 1688 | 
            +
                            up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
         | 
| 1689 | 
            +
             | 
| 1690 | 
            +
                        elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
         | 
| 1691 | 
            +
             | 
| 1692 | 
            +
                            def cross_attn_up_wrapper(func, block):
         | 
| 1693 | 
            +
                                def forward(*args, **kwargs):
         | 
| 1694 | 
            +
                                    return func(block, *args, **kwargs)
         | 
| 1695 | 
            +
             | 
| 1696 | 
            +
                                return forward
         | 
| 1697 | 
            +
             | 
| 1698 | 
            +
                            up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
         | 
| 1699 | 
            +
             | 
| 1700 | 
            +
                    # Deep Shrink
         | 
| 1701 | 
            +
                    self.ds_depth_1 = None
         | 
| 1702 | 
            +
                    self.ds_depth_2 = None
         | 
| 1703 | 
            +
                    self.ds_timesteps_1 = None
         | 
| 1704 | 
            +
                    self.ds_timesteps_2 = None
         | 
| 1705 | 
            +
                    self.ds_ratio = None
         | 
| 1706 | 
            +
             | 
| 1707 | 
            +
                # call original model's methods
         | 
| 1708 | 
            +
                def __getattr__(self, name):
         | 
| 1709 | 
            +
                    return getattr(self.delegate, name)
         | 
| 1710 | 
            +
             | 
| 1711 | 
            +
                def __call__(self, *args, **kwargs):
         | 
| 1712 | 
            +
                    return self.delegate(*args, **kwargs)
         | 
| 1713 | 
            +
             | 
| 1714 | 
            +
                def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
         | 
| 1715 | 
            +
                    if ds_depth_1 is None:
         | 
| 1716 | 
            +
                        logger.info("Deep Shrink is disabled.")
         | 
| 1717 | 
            +
                        self.ds_depth_1 = None
         | 
| 1718 | 
            +
                        self.ds_timesteps_1 = None
         | 
| 1719 | 
            +
                        self.ds_depth_2 = None
         | 
| 1720 | 
            +
                        self.ds_timesteps_2 = None
         | 
| 1721 | 
            +
                        self.ds_ratio = None
         | 
| 1722 | 
            +
                    else:
         | 
| 1723 | 
            +
                        logger.info(
         | 
| 1724 | 
            +
                            f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
         | 
| 1725 | 
            +
                        )
         | 
| 1726 | 
            +
                        self.ds_depth_1 = ds_depth_1
         | 
| 1727 | 
            +
                        self.ds_timesteps_1 = ds_timesteps_1
         | 
| 1728 | 
            +
                        self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
         | 
| 1729 | 
            +
                        self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
         | 
| 1730 | 
            +
                        self.ds_ratio = ds_ratio
         | 
| 1731 | 
            +
             | 
| 1732 | 
            +
                def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
         | 
| 1733 | 
            +
                    for resnet in _self.resnets:
         | 
| 1734 | 
            +
                        # pop res hidden states
         | 
| 1735 | 
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         | 
| 1736 | 
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         | 
| 1737 | 
            +
             | 
| 1738 | 
            +
                        # Deep Shrink
         | 
| 1739 | 
            +
                        if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
         | 
| 1740 | 
            +
                            hidden_states = resize_like(hidden_states, res_hidden_states)
         | 
| 1741 | 
            +
             | 
| 1742 | 
            +
                        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         | 
| 1743 | 
            +
                        hidden_states = resnet(hidden_states, temb)
         | 
| 1744 | 
            +
             | 
| 1745 | 
            +
                    if _self.upsamplers is not None:
         | 
| 1746 | 
            +
                        for upsampler in _self.upsamplers:
         | 
| 1747 | 
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         | 
| 1748 | 
            +
             | 
| 1749 | 
            +
                    return hidden_states
         | 
| 1750 | 
            +
             | 
| 1751 | 
            +
                def cross_attn_up_block_forward(
         | 
| 1752 | 
            +
                    self,
         | 
| 1753 | 
            +
                    _self,
         | 
| 1754 | 
            +
                    hidden_states,
         | 
| 1755 | 
            +
                    res_hidden_states_tuple,
         | 
| 1756 | 
            +
                    temb=None,
         | 
| 1757 | 
            +
                    encoder_hidden_states=None,
         | 
| 1758 | 
            +
                    upsample_size=None,
         | 
| 1759 | 
            +
                ):
         | 
| 1760 | 
            +
                    for resnet, attn in zip(_self.resnets, _self.attentions):
         | 
| 1761 | 
            +
                        # pop res hidden states
         | 
| 1762 | 
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         | 
| 1763 | 
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         | 
| 1764 | 
            +
             | 
| 1765 | 
            +
                        # Deep Shrink
         | 
| 1766 | 
            +
                        if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
         | 
| 1767 | 
            +
                            hidden_states = resize_like(hidden_states, res_hidden_states)
         | 
| 1768 | 
            +
             | 
| 1769 | 
            +
                        hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
         | 
| 1770 | 
            +
                        hidden_states = resnet(hidden_states, temb)
         | 
| 1771 | 
            +
                        hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
         | 
| 1772 | 
            +
             | 
| 1773 | 
            +
                    if _self.upsamplers is not None:
         | 
| 1774 | 
            +
                        for upsampler in _self.upsamplers:
         | 
| 1775 | 
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         | 
| 1776 | 
            +
             | 
| 1777 | 
            +
                    return hidden_states
         | 
| 1778 | 
            +
             | 
| 1779 | 
            +
                def forward(
         | 
| 1780 | 
            +
                    self,
         | 
| 1781 | 
            +
                    sample: torch.FloatTensor,
         | 
| 1782 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 1783 | 
            +
                    encoder_hidden_states: torch.Tensor,
         | 
| 1784 | 
            +
                    class_labels: Optional[torch.Tensor] = None,
         | 
| 1785 | 
            +
                    return_dict: bool = True,
         | 
| 1786 | 
            +
                    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
         | 
| 1787 | 
            +
                    mid_block_additional_residual: Optional[torch.Tensor] = None,
         | 
| 1788 | 
            +
                ) -> Union[Dict, Tuple]:
         | 
| 1789 | 
            +
                    r"""
         | 
| 1790 | 
            +
                    current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
         | 
| 1791 | 
            +
                    """
         | 
| 1792 | 
            +
             | 
| 1793 | 
            +
                    r"""
         | 
| 1794 | 
            +
                    Args:
         | 
| 1795 | 
            +
                        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
         | 
| 1796 | 
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
         | 
| 1797 | 
            +
                        encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
         | 
| 1798 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1799 | 
            +
                            Whether or not to return a dict instead of a plain tuple.
         | 
| 1800 | 
            +
             | 
| 1801 | 
            +
                    Returns:
         | 
| 1802 | 
            +
                        `SampleOutput` or `tuple`:
         | 
| 1803 | 
            +
                        `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
         | 
| 1804 | 
            +
                    """
         | 
| 1805 | 
            +
             | 
| 1806 | 
            +
                    _self = self.delegate
         | 
| 1807 | 
            +
             | 
| 1808 | 
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         | 
| 1809 | 
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
         | 
| 1810 | 
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         | 
| 1811 | 
            +
                    # on the fly if necessary.
         | 
| 1812 | 
            +
                    # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
         | 
| 1813 | 
            +
                    # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
         | 
| 1814 | 
            +
                    # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
         | 
| 1815 | 
            +
                    default_overall_up_factor = 2**_self.num_upsamplers
         | 
| 1816 | 
            +
             | 
| 1817 | 
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         | 
| 1818 | 
            +
                    # 64で割り切れないときはupsamplerにサイズを伝える
         | 
| 1819 | 
            +
                    forward_upsample_size = False
         | 
| 1820 | 
            +
                    upsample_size = None
         | 
| 1821 | 
            +
             | 
| 1822 | 
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         | 
| 1823 | 
            +
                        # logger.info("Forward upsample size to force interpolation output size.")
         | 
| 1824 | 
            +
                        forward_upsample_size = True
         | 
| 1825 | 
            +
             | 
| 1826 | 
            +
                    # 1. time
         | 
| 1827 | 
            +
                    timesteps = timestep
         | 
| 1828 | 
            +
                    timesteps = _self.handle_unusual_timesteps(sample, timesteps)  # 変な時だけ処理
         | 
| 1829 | 
            +
             | 
| 1830 | 
            +
                    t_emb = _self.time_proj(timesteps)
         | 
| 1831 | 
            +
             | 
| 1832 | 
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         | 
| 1833 | 
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         | 
| 1834 | 
            +
                    # there might be better ways to encapsulate this.
         | 
| 1835 | 
            +
                    # timestepsは重みを含まないので常にfloat32のテンソルを返す
         | 
| 1836 | 
            +
                    # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
         | 
| 1837 | 
            +
                    # time_projでキャストしておけばいいんじゃね?
         | 
| 1838 | 
            +
                    t_emb = t_emb.to(dtype=_self.dtype)
         | 
| 1839 | 
            +
                    emb = _self.time_embedding(t_emb)
         | 
| 1840 | 
            +
             | 
| 1841 | 
            +
                    # 2. pre-process
         | 
| 1842 | 
            +
                    sample = _self.conv_in(sample)
         | 
| 1843 | 
            +
             | 
| 1844 | 
            +
                    down_block_res_samples = (sample,)
         | 
| 1845 | 
            +
                    for depth, downsample_block in enumerate(_self.down_blocks):
         | 
| 1846 | 
            +
                        # Deep Shrink
         | 
| 1847 | 
            +
                        if self.ds_depth_1 is not None:
         | 
| 1848 | 
            +
                            if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
         | 
| 1849 | 
            +
                                self.ds_depth_2 is not None
         | 
| 1850 | 
            +
                                and depth == self.ds_depth_2
         | 
| 1851 | 
            +
                                and timesteps[0] < self.ds_timesteps_1
         | 
| 1852 | 
            +
                                and timesteps[0] >= self.ds_timesteps_2
         | 
| 1853 | 
            +
                            ):
         | 
| 1854 | 
            +
                                org_dtype = sample.dtype
         | 
| 1855 | 
            +
                                if org_dtype == torch.bfloat16:
         | 
| 1856 | 
            +
                                    sample = sample.to(torch.float32)
         | 
| 1857 | 
            +
                                sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
         | 
| 1858 | 
            +
             | 
| 1859 | 
            +
                        # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
         | 
| 1860 | 
            +
                        # まあこちらのほうがわかりやすいかもしれない
         | 
| 1861 | 
            +
                        if downsample_block.has_cross_attention:
         | 
| 1862 | 
            +
                            sample, res_samples = downsample_block(
         | 
| 1863 | 
            +
                                hidden_states=sample,
         | 
| 1864 | 
            +
                                temb=emb,
         | 
| 1865 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 1866 | 
            +
                            )
         | 
| 1867 | 
            +
                        else:
         | 
| 1868 | 
            +
                            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
         | 
| 1869 | 
            +
             | 
| 1870 | 
            +
                        down_block_res_samples += res_samples
         | 
| 1871 | 
            +
             | 
| 1872 | 
            +
                    # skip connectionにControlNetの出力を追加する
         | 
| 1873 | 
            +
                    if down_block_additional_residuals is not None:
         | 
| 1874 | 
            +
                        down_block_res_samples = list(down_block_res_samples)
         | 
| 1875 | 
            +
                        for i in range(len(down_block_res_samples)):
         | 
| 1876 | 
            +
                            down_block_res_samples[i] += down_block_additional_residuals[i]
         | 
| 1877 | 
            +
                        down_block_res_samples = tuple(down_block_res_samples)
         | 
| 1878 | 
            +
             | 
| 1879 | 
            +
                    # 4. mid
         | 
| 1880 | 
            +
                    sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
         | 
| 1881 | 
            +
             | 
| 1882 | 
            +
                    # ControlNetの出力を追加する
         | 
| 1883 | 
            +
                    if mid_block_additional_residual is not None:
         | 
| 1884 | 
            +
                        sample += mid_block_additional_residual
         | 
| 1885 | 
            +
             | 
| 1886 | 
            +
                    # 5. up
         | 
| 1887 | 
            +
                    for i, upsample_block in enumerate(_self.up_blocks):
         | 
| 1888 | 
            +
                        is_final_block = i == len(_self.up_blocks) - 1
         | 
| 1889 | 
            +
             | 
| 1890 | 
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
         | 
| 1891 | 
            +
                        down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]  # skip connection
         | 
| 1892 | 
            +
             | 
| 1893 | 
            +
                        # if we have not reached the final block and need to forward the upsample size, we do it here
         | 
| 1894 | 
            +
                        # 前述のように最後のブロック以外ではupsample_sizeを伝える
         | 
| 1895 | 
            +
                        if not is_final_block and forward_upsample_size:
         | 
| 1896 | 
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         | 
| 1897 | 
            +
             | 
| 1898 | 
            +
                        if upsample_block.has_cross_attention:
         | 
| 1899 | 
            +
                            sample = upsample_block(
         | 
| 1900 | 
            +
                                hidden_states=sample,
         | 
| 1901 | 
            +
                                temb=emb,
         | 
| 1902 | 
            +
                                res_hidden_states_tuple=res_samples,
         | 
| 1903 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 1904 | 
            +
                                upsample_size=upsample_size,
         | 
| 1905 | 
            +
                            )
         | 
| 1906 | 
            +
                        else:
         | 
| 1907 | 
            +
                            sample = upsample_block(
         | 
| 1908 | 
            +
                                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
         | 
| 1909 | 
            +
                            )
         | 
| 1910 | 
            +
             | 
| 1911 | 
            +
                    # 6. post-process
         | 
| 1912 | 
            +
                    sample = _self.conv_norm_out(sample)
         | 
| 1913 | 
            +
                    sample = _self.conv_act(sample)
         | 
| 1914 | 
            +
                    sample = _self.conv_out(sample)
         | 
| 1915 | 
            +
             | 
| 1916 | 
            +
                    if not return_dict:
         | 
| 1917 | 
            +
                        return (sample,)
         | 
| 1918 | 
            +
             | 
| 1919 | 
            +
                    return SampleOutput(sample=sample)
         | 
    	
        library/sai_model_spec.py
    ADDED
    
    | @@ -0,0 +1,309 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # based on https://github.com/Stability-AI/ModelSpec
         | 
| 2 | 
            +
            import datetime
         | 
| 3 | 
            +
            import hashlib
         | 
| 4 | 
            +
            from io import BytesIO
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 7 | 
            +
            import safetensors
         | 
| 8 | 
            +
            from library.utils import setup_logging
         | 
| 9 | 
            +
            setup_logging()
         | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            r"""
         | 
| 14 | 
            +
            # Metadata Example
         | 
| 15 | 
            +
            metadata = {
         | 
| 16 | 
            +
                # === Must ===
         | 
| 17 | 
            +
                "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
         | 
| 18 | 
            +
                "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
         | 
| 19 | 
            +
                "modelspec.implementation": "sgm",
         | 
| 20 | 
            +
                "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
         | 
| 21 | 
            +
                # === Should ===
         | 
| 22 | 
            +
                "modelspec.author": "Example Corp", # Your name or company name
         | 
| 23 | 
            +
                "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
         | 
| 24 | 
            +
                "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
         | 
| 25 | 
            +
                # === Can ===
         | 
| 26 | 
            +
                "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
         | 
| 27 | 
            +
                "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
         | 
| 28 | 
            +
            }
         | 
| 29 | 
            +
            """
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            BASE_METADATA = {
         | 
| 32 | 
            +
                # === Must ===
         | 
| 33 | 
            +
                "modelspec.sai_model_spec": "1.0.0",  # Required version ID for the spec
         | 
| 34 | 
            +
                "modelspec.architecture": None,
         | 
| 35 | 
            +
                "modelspec.implementation": None,
         | 
| 36 | 
            +
                "modelspec.title": None,
         | 
| 37 | 
            +
                "modelspec.resolution": None,
         | 
| 38 | 
            +
                # === Should ===
         | 
| 39 | 
            +
                "modelspec.description": None,
         | 
| 40 | 
            +
                "modelspec.author": None,
         | 
| 41 | 
            +
                "modelspec.date": None,
         | 
| 42 | 
            +
                # === Can ===
         | 
| 43 | 
            +
                "modelspec.license": None,
         | 
| 44 | 
            +
                "modelspec.tags": None,
         | 
| 45 | 
            +
                "modelspec.merged_from": None,
         | 
| 46 | 
            +
                "modelspec.prediction_type": None,
         | 
| 47 | 
            +
                "modelspec.timestep_range": None,
         | 
| 48 | 
            +
                "modelspec.encoder_layer": None,
         | 
| 49 | 
            +
            }
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            # 別に使うやつだけ定義
         | 
| 52 | 
            +
            MODELSPEC_TITLE = "modelspec.title"
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            ARCH_SD_V1 = "stable-diffusion-v1"
         | 
| 55 | 
            +
            ARCH_SD_V2_512 = "stable-diffusion-v2-512"
         | 
| 56 | 
            +
            ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
         | 
| 57 | 
            +
            ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            ADAPTER_LORA = "lora"
         | 
| 60 | 
            +
            ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
         | 
| 63 | 
            +
            IMPL_DIFFUSERS = "diffusers"
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            PRED_TYPE_EPSILON = "epsilon"
         | 
| 66 | 
            +
            PRED_TYPE_V = "v"
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def load_bytes_in_safetensors(tensors):
         | 
| 70 | 
            +
                bytes = safetensors.torch.save(tensors)
         | 
| 71 | 
            +
                b = BytesIO(bytes)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                b.seek(0)
         | 
| 74 | 
            +
                header = b.read(8)
         | 
| 75 | 
            +
                n = int.from_bytes(header, "little")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                offset = n + 8
         | 
| 78 | 
            +
                b.seek(offset)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                return b.read()
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def precalculate_safetensors_hashes(state_dict):
         | 
| 84 | 
            +
                # calculate each tensor one by one to reduce memory usage
         | 
| 85 | 
            +
                hash_sha256 = hashlib.sha256()
         | 
| 86 | 
            +
                for tensor in state_dict.values():
         | 
| 87 | 
            +
                    single_tensor_sd = {"tensor": tensor}
         | 
| 88 | 
            +
                    bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
         | 
| 89 | 
            +
                    hash_sha256.update(bytes_for_tensor)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                return f"0x{hash_sha256.hexdigest()}"
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def update_hash_sha256(metadata: dict, state_dict: dict):
         | 
| 95 | 
            +
                raise NotImplementedError
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def build_metadata(
         | 
| 99 | 
            +
                state_dict: Optional[dict],
         | 
| 100 | 
            +
                v2: bool,
         | 
| 101 | 
            +
                v_parameterization: bool,
         | 
| 102 | 
            +
                sdxl: bool,
         | 
| 103 | 
            +
                lora: bool,
         | 
| 104 | 
            +
                textual_inversion: bool,
         | 
| 105 | 
            +
                timestamp: float,
         | 
| 106 | 
            +
                title: Optional[str] = None,
         | 
| 107 | 
            +
                reso: Optional[Union[int, Tuple[int, int]]] = None,
         | 
| 108 | 
            +
                is_stable_diffusion_ckpt: Optional[bool] = None,
         | 
| 109 | 
            +
                author: Optional[str] = None,
         | 
| 110 | 
            +
                description: Optional[str] = None,
         | 
| 111 | 
            +
                license: Optional[str] = None,
         | 
| 112 | 
            +
                tags: Optional[str] = None,
         | 
| 113 | 
            +
                merged_from: Optional[str] = None,
         | 
| 114 | 
            +
                timesteps: Optional[Tuple[int, int]] = None,
         | 
| 115 | 
            +
                clip_skip: Optional[int] = None,
         | 
| 116 | 
            +
            ):
         | 
| 117 | 
            +
                # if state_dict is None, hash is not calculated
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                metadata = {}
         | 
| 120 | 
            +
                metadata.update(BASE_METADATA)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
         | 
| 123 | 
            +
                # if state_dict is not None:
         | 
| 124 | 
            +
                # hash = precalculate_safetensors_hashes(state_dict)
         | 
| 125 | 
            +
                # metadata["modelspec.hash_sha256"] = hash
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                if sdxl:
         | 
| 128 | 
            +
                    arch = ARCH_SD_XL_V1_BASE
         | 
| 129 | 
            +
                elif v2:
         | 
| 130 | 
            +
                    if v_parameterization:
         | 
| 131 | 
            +
                        arch = ARCH_SD_V2_768_V
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        arch = ARCH_SD_V2_512
         | 
| 134 | 
            +
                else:
         | 
| 135 | 
            +
                    arch = ARCH_SD_V1
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                if lora:
         | 
| 138 | 
            +
                    arch += f"/{ADAPTER_LORA}"
         | 
| 139 | 
            +
                elif textual_inversion:
         | 
| 140 | 
            +
                    arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                metadata["modelspec.architecture"] = arch
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
         | 
| 145 | 
            +
                    is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
         | 
| 148 | 
            +
                    # Stable Diffusion ckpt, TI, SDXL LoRA
         | 
| 149 | 
            +
                    impl = IMPL_STABILITY_AI
         | 
| 150 | 
            +
                else:
         | 
| 151 | 
            +
                    # v1/v2 LoRA or Diffusers
         | 
| 152 | 
            +
                    impl = IMPL_DIFFUSERS
         | 
| 153 | 
            +
                metadata["modelspec.implementation"] = impl
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                if title is None:
         | 
| 156 | 
            +
                    if lora:
         | 
| 157 | 
            +
                        title = "LoRA"
         | 
| 158 | 
            +
                    elif textual_inversion:
         | 
| 159 | 
            +
                        title = "TextualInversion"
         | 
| 160 | 
            +
                    else:
         | 
| 161 | 
            +
                        title = "Checkpoint"
         | 
| 162 | 
            +
                    title += f"@{timestamp}"
         | 
| 163 | 
            +
                metadata[MODELSPEC_TITLE] = title
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                if author is not None:
         | 
| 166 | 
            +
                    metadata["modelspec.author"] = author
         | 
| 167 | 
            +
                else:
         | 
| 168 | 
            +
                    del metadata["modelspec.author"]
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                if description is not None:
         | 
| 171 | 
            +
                    metadata["modelspec.description"] = description
         | 
| 172 | 
            +
                else:
         | 
| 173 | 
            +
                    del metadata["modelspec.description"]
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                if merged_from is not None:
         | 
| 176 | 
            +
                    metadata["modelspec.merged_from"] = merged_from
         | 
| 177 | 
            +
                else:
         | 
| 178 | 
            +
                    del metadata["modelspec.merged_from"]
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                if license is not None:
         | 
| 181 | 
            +
                    metadata["modelspec.license"] = license
         | 
| 182 | 
            +
                else:
         | 
| 183 | 
            +
                    del metadata["modelspec.license"]
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                if tags is not None:
         | 
| 186 | 
            +
                    metadata["modelspec.tags"] = tags
         | 
| 187 | 
            +
                else:
         | 
| 188 | 
            +
                    del metadata["modelspec.tags"]
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                # remove microsecond from time
         | 
| 191 | 
            +
                int_ts = int(timestamp)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                # time to iso-8601 compliant date
         | 
| 194 | 
            +
                date = datetime.datetime.fromtimestamp(int_ts).isoformat()
         | 
| 195 | 
            +
                metadata["modelspec.date"] = date
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if reso is not None:
         | 
| 198 | 
            +
                    # comma separated to tuple
         | 
| 199 | 
            +
                    if isinstance(reso, str):
         | 
| 200 | 
            +
                        reso = tuple(map(int, reso.split(",")))
         | 
| 201 | 
            +
                    if len(reso) == 1:
         | 
| 202 | 
            +
                        reso = (reso[0], reso[0])
         | 
| 203 | 
            +
                else:
         | 
| 204 | 
            +
                    # resolution is defined in dataset, so use default
         | 
| 205 | 
            +
                    if sdxl:
         | 
| 206 | 
            +
                        reso = 1024
         | 
| 207 | 
            +
                    elif v2 and v_parameterization:
         | 
| 208 | 
            +
                        reso = 768
         | 
| 209 | 
            +
                    else:
         | 
| 210 | 
            +
                        reso = 512
         | 
| 211 | 
            +
                if isinstance(reso, int):
         | 
| 212 | 
            +
                    reso = (reso, reso)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                if v_parameterization:
         | 
| 217 | 
            +
                    metadata["modelspec.prediction_type"] = PRED_TYPE_V
         | 
| 218 | 
            +
                else:
         | 
| 219 | 
            +
                    metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                if timesteps is not None:
         | 
| 222 | 
            +
                    if isinstance(timesteps, str) or isinstance(timesteps, int):
         | 
| 223 | 
            +
                        timesteps = (timesteps, timesteps)
         | 
| 224 | 
            +
                    if len(timesteps) == 1:
         | 
| 225 | 
            +
                        timesteps = (timesteps[0], timesteps[0])
         | 
| 226 | 
            +
                    metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
         | 
| 227 | 
            +
                else:
         | 
| 228 | 
            +
                    del metadata["modelspec.timestep_range"]
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                if clip_skip is not None:
         | 
| 231 | 
            +
                    metadata["modelspec.encoder_layer"] = f"{clip_skip}"
         | 
| 232 | 
            +
                else:
         | 
| 233 | 
            +
                    del metadata["modelspec.encoder_layer"]
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                # # assert all values are filled
         | 
| 236 | 
            +
                # assert all([v is not None for v in metadata.values()]), metadata
         | 
| 237 | 
            +
                if not all([v is not None for v in metadata.values()]):
         | 
| 238 | 
            +
                    logger.error(f"Internal error: some metadata values are None: {metadata}")
         | 
| 239 | 
            +
                
         | 
| 240 | 
            +
                return metadata
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            # region utils
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            def get_title(metadata: dict) -> Optional[str]:
         | 
| 247 | 
            +
                return metadata.get(MODELSPEC_TITLE, None)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
             | 
| 250 | 
            +
            def load_metadata_from_safetensors(model: str) -> dict:
         | 
| 251 | 
            +
                if not model.endswith(".safetensors"):
         | 
| 252 | 
            +
                    return {}
         | 
| 253 | 
            +
                
         | 
| 254 | 
            +
                with safetensors.safe_open(model, framework="pt") as f:
         | 
| 255 | 
            +
                    metadata = f.metadata()
         | 
| 256 | 
            +
                if metadata is None:
         | 
| 257 | 
            +
                    metadata = {}
         | 
| 258 | 
            +
                return metadata
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            def build_merged_from(models: List[str]) -> str:
         | 
| 262 | 
            +
                def get_title(model: str):
         | 
| 263 | 
            +
                    metadata = load_metadata_from_safetensors(model)
         | 
| 264 | 
            +
                    title = metadata.get(MODELSPEC_TITLE, None)
         | 
| 265 | 
            +
                    if title is None:
         | 
| 266 | 
            +
                        title = os.path.splitext(os.path.basename(model))[0]  # use filename
         | 
| 267 | 
            +
                    return title
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                titles = [get_title(model) for model in models]
         | 
| 270 | 
            +
                return ", ".join(titles)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
             | 
| 273 | 
            +
            # endregion
         | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            r"""
         | 
| 277 | 
            +
            if __name__ == "__main__":
         | 
| 278 | 
            +
                import argparse
         | 
| 279 | 
            +
                import torch
         | 
| 280 | 
            +
                from safetensors.torch import load_file
         | 
| 281 | 
            +
                from library import train_util
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 284 | 
            +
                parser.add_argument("--ckpt", type=str, required=True)
         | 
| 285 | 
            +
                args = parser.parse_args()
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                print(f"Loading {args.ckpt}")
         | 
| 288 | 
            +
                state_dict = load_file(args.ckpt)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                print(f"Calculating metadata")
         | 
| 291 | 
            +
                metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
         | 
| 292 | 
            +
                print(metadata)
         | 
| 293 | 
            +
                del state_dict
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                # by reference implementation
         | 
| 296 | 
            +
                with open(args.ckpt, mode="rb") as file_data:
         | 
| 297 | 
            +
                    file_hash = hashlib.sha256()
         | 
| 298 | 
            +
                    head_len = struct.unpack("Q", file_data.read(8))  # int64 header length prefix
         | 
| 299 | 
            +
                    header = json.loads(file_data.read(head_len[0]))  # header itself, json string
         | 
| 300 | 
            +
                    content = (
         | 
| 301 | 
            +
                        file_data.read()
         | 
| 302 | 
            +
                    )  # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
         | 
| 303 | 
            +
                    file_hash.update(content)
         | 
| 304 | 
            +
                    # ===== Update the hash for modelspec =====
         | 
| 305 | 
            +
                    by_ref = f"0x{file_hash.hexdigest()}"
         | 
| 306 | 
            +
                print(by_ref)
         | 
| 307 | 
            +
                print("is same?", by_ref == metadata["modelspec.hash_sha256"])
         | 
| 308 | 
            +
             | 
| 309 | 
            +
            """
         | 
    	
        library/sdxl_lpw_stable_diffusion.py
    ADDED
    
    | @@ -0,0 +1,1347 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
         | 
| 2 | 
            +
            # and modify to support SD2.x
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import inspect
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
            from typing import Callable, List, Optional, Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import PIL.Image
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from packaging import version
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
            from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from diffusers import SchedulerMixin, StableDiffusionPipeline
         | 
| 16 | 
            +
            from diffusers.models import AutoencoderKL, UNet2DConditionModel
         | 
| 17 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
         | 
| 18 | 
            +
            from diffusers.utils import logging
         | 
| 19 | 
            +
            from PIL import Image
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from library import sdxl_model_util, sdxl_train_util, train_util
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            try:
         | 
| 25 | 
            +
                from diffusers.utils import PIL_INTERPOLATION
         | 
| 26 | 
            +
            except ImportError:
         | 
| 27 | 
            +
                if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
         | 
| 28 | 
            +
                    PIL_INTERPOLATION = {
         | 
| 29 | 
            +
                        "linear": PIL.Image.Resampling.BILINEAR,
         | 
| 30 | 
            +
                        "bilinear": PIL.Image.Resampling.BILINEAR,
         | 
| 31 | 
            +
                        "bicubic": PIL.Image.Resampling.BICUBIC,
         | 
| 32 | 
            +
                        "lanczos": PIL.Image.Resampling.LANCZOS,
         | 
| 33 | 
            +
                        "nearest": PIL.Image.Resampling.NEAREST,
         | 
| 34 | 
            +
                    }
         | 
| 35 | 
            +
                else:
         | 
| 36 | 
            +
                    PIL_INTERPOLATION = {
         | 
| 37 | 
            +
                        "linear": PIL.Image.LINEAR,
         | 
| 38 | 
            +
                        "bilinear": PIL.Image.BILINEAR,
         | 
| 39 | 
            +
                        "bicubic": PIL.Image.BICUBIC,
         | 
| 40 | 
            +
                        "lanczos": PIL.Image.LANCZOS,
         | 
| 41 | 
            +
                        "nearest": PIL.Image.NEAREST,
         | 
| 42 | 
            +
                    }
         | 
| 43 | 
            +
            # ------------------------------------------------------------------------------
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            re_attention = re.compile(
         | 
| 48 | 
            +
                r"""
         | 
| 49 | 
            +
            \\\(|
         | 
| 50 | 
            +
            \\\)|
         | 
| 51 | 
            +
            \\\[|
         | 
| 52 | 
            +
            \\]|
         | 
| 53 | 
            +
            \\\\|
         | 
| 54 | 
            +
            \\|
         | 
| 55 | 
            +
            \(|
         | 
| 56 | 
            +
            \[|
         | 
| 57 | 
            +
            :([+-]?[.\d]+)\)|
         | 
| 58 | 
            +
            \)|
         | 
| 59 | 
            +
            ]|
         | 
| 60 | 
            +
            [^\\()\[\]:]+|
         | 
| 61 | 
            +
            :
         | 
| 62 | 
            +
            """,
         | 
| 63 | 
            +
                re.X,
         | 
| 64 | 
            +
            )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def parse_prompt_attention(text):
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
         | 
| 70 | 
            +
                Accepted tokens are:
         | 
| 71 | 
            +
                  (abc) - increases attention to abc by a multiplier of 1.1
         | 
| 72 | 
            +
                  (abc:3.12) - increases attention to abc by a multiplier of 3.12
         | 
| 73 | 
            +
                  [abc] - decreases attention to abc by a multiplier of 1.1
         | 
| 74 | 
            +
                  \( - literal character '('
         | 
| 75 | 
            +
                  \[ - literal character '['
         | 
| 76 | 
            +
                  \) - literal character ')'
         | 
| 77 | 
            +
                  \] - literal character ']'
         | 
| 78 | 
            +
                  \\ - literal character '\'
         | 
| 79 | 
            +
                  anything else - just text
         | 
| 80 | 
            +
                >>> parse_prompt_attention('normal text')
         | 
| 81 | 
            +
                [['normal text', 1.0]]
         | 
| 82 | 
            +
                >>> parse_prompt_attention('an (important) word')
         | 
| 83 | 
            +
                [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
         | 
| 84 | 
            +
                >>> parse_prompt_attention('(unbalanced')
         | 
| 85 | 
            +
                [['unbalanced', 1.1]]
         | 
| 86 | 
            +
                >>> parse_prompt_attention('\(literal\]')
         | 
| 87 | 
            +
                [['(literal]', 1.0]]
         | 
| 88 | 
            +
                >>> parse_prompt_attention('(unnecessary)(parens)')
         | 
| 89 | 
            +
                [['unnecessaryparens', 1.1]]
         | 
| 90 | 
            +
                >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
         | 
| 91 | 
            +
                [['a ', 1.0],
         | 
| 92 | 
            +
                 ['house', 1.5730000000000004],
         | 
| 93 | 
            +
                 [' ', 1.1],
         | 
| 94 | 
            +
                 ['on', 1.0],
         | 
| 95 | 
            +
                 [' a ', 1.1],
         | 
| 96 | 
            +
                 ['hill', 0.55],
         | 
| 97 | 
            +
                 [', sun, ', 1.1],
         | 
| 98 | 
            +
                 ['sky', 1.4641000000000006],
         | 
| 99 | 
            +
                 ['.', 1.1]]
         | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                res = []
         | 
| 103 | 
            +
                round_brackets = []
         | 
| 104 | 
            +
                square_brackets = []
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                round_bracket_multiplier = 1.1
         | 
| 107 | 
            +
                square_bracket_multiplier = 1 / 1.1
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def multiply_range(start_position, multiplier):
         | 
| 110 | 
            +
                    for p in range(start_position, len(res)):
         | 
| 111 | 
            +
                        res[p][1] *= multiplier
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                for m in re_attention.finditer(text):
         | 
| 114 | 
            +
                    text = m.group(0)
         | 
| 115 | 
            +
                    weight = m.group(1)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if text.startswith("\\"):
         | 
| 118 | 
            +
                        res.append([text[1:], 1.0])
         | 
| 119 | 
            +
                    elif text == "(":
         | 
| 120 | 
            +
                        round_brackets.append(len(res))
         | 
| 121 | 
            +
                    elif text == "[":
         | 
| 122 | 
            +
                        square_brackets.append(len(res))
         | 
| 123 | 
            +
                    elif weight is not None and len(round_brackets) > 0:
         | 
| 124 | 
            +
                        multiply_range(round_brackets.pop(), float(weight))
         | 
| 125 | 
            +
                    elif text == ")" and len(round_brackets) > 0:
         | 
| 126 | 
            +
                        multiply_range(round_brackets.pop(), round_bracket_multiplier)
         | 
| 127 | 
            +
                    elif text == "]" and len(square_brackets) > 0:
         | 
| 128 | 
            +
                        multiply_range(square_brackets.pop(), square_bracket_multiplier)
         | 
| 129 | 
            +
                    else:
         | 
| 130 | 
            +
                        res.append([text, 1.0])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                for pos in round_brackets:
         | 
| 133 | 
            +
                    multiply_range(pos, round_bracket_multiplier)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                for pos in square_brackets:
         | 
| 136 | 
            +
                    multiply_range(pos, square_bracket_multiplier)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                if len(res) == 0:
         | 
| 139 | 
            +
                    res = [["", 1.0]]
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                # merge runs of identical weights
         | 
| 142 | 
            +
                i = 0
         | 
| 143 | 
            +
                while i + 1 < len(res):
         | 
| 144 | 
            +
                    if res[i][1] == res[i + 1][1]:
         | 
| 145 | 
            +
                        res[i][0] += res[i + 1][0]
         | 
| 146 | 
            +
                        res.pop(i + 1)
         | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        i += 1
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                return res
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
         | 
| 154 | 
            +
                r"""
         | 
| 155 | 
            +
                Tokenize a list of prompts and return its tokens with weights of each token.
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                No padding, starting or ending token is included.
         | 
| 158 | 
            +
                """
         | 
| 159 | 
            +
                tokens = []
         | 
| 160 | 
            +
                weights = []
         | 
| 161 | 
            +
                truncated = False
         | 
| 162 | 
            +
                for text in prompt:
         | 
| 163 | 
            +
                    texts_and_weights = parse_prompt_attention(text)
         | 
| 164 | 
            +
                    text_token = []
         | 
| 165 | 
            +
                    text_weight = []
         | 
| 166 | 
            +
                    for word, weight in texts_and_weights:
         | 
| 167 | 
            +
                        # tokenize and discard the starting and the ending token
         | 
| 168 | 
            +
                        token = pipe.tokenizer(word).input_ids[1:-1]
         | 
| 169 | 
            +
                        text_token += token
         | 
| 170 | 
            +
                        # copy the weight by length of token
         | 
| 171 | 
            +
                        text_weight += [weight] * len(token)
         | 
| 172 | 
            +
                        # stop if the text is too long (longer than truncation limit)
         | 
| 173 | 
            +
                        if len(text_token) > max_length:
         | 
| 174 | 
            +
                            truncated = True
         | 
| 175 | 
            +
                            break
         | 
| 176 | 
            +
                    # truncate
         | 
| 177 | 
            +
                    if len(text_token) > max_length:
         | 
| 178 | 
            +
                        truncated = True
         | 
| 179 | 
            +
                        text_token = text_token[:max_length]
         | 
| 180 | 
            +
                        text_weight = text_weight[:max_length]
         | 
| 181 | 
            +
                    tokens.append(text_token)
         | 
| 182 | 
            +
                    weights.append(text_weight)
         | 
| 183 | 
            +
                if truncated:
         | 
| 184 | 
            +
                    logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
         | 
| 185 | 
            +
                return tokens, weights
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
         | 
| 189 | 
            +
                r"""
         | 
| 190 | 
            +
                Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
         | 
| 191 | 
            +
                """
         | 
| 192 | 
            +
                max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
         | 
| 193 | 
            +
                weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
         | 
| 194 | 
            +
                for i in range(len(tokens)):
         | 
| 195 | 
            +
                    tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
         | 
| 196 | 
            +
                    if no_boseos_middle:
         | 
| 197 | 
            +
                        weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
         | 
| 198 | 
            +
                    else:
         | 
| 199 | 
            +
                        w = []
         | 
| 200 | 
            +
                        if len(weights[i]) == 0:
         | 
| 201 | 
            +
                            w = [1.0] * weights_length
         | 
| 202 | 
            +
                        else:
         | 
| 203 | 
            +
                            for j in range(max_embeddings_multiples):
         | 
| 204 | 
            +
                                w.append(1.0)  # weight for starting token in this chunk
         | 
| 205 | 
            +
                                w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
         | 
| 206 | 
            +
                                w.append(1.0)  # weight for ending token in this chunk
         | 
| 207 | 
            +
                            w += [1.0] * (weights_length - len(w))
         | 
| 208 | 
            +
                        weights[i] = w[:]
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                return tokens, weights
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
         | 
| 214 | 
            +
                if not is_sdxl_text_encoder2:
         | 
| 215 | 
            +
                    # text_encoder1: same as SD1/2
         | 
| 216 | 
            +
                    enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
         | 
| 217 | 
            +
                    hidden_states = enc_out["hidden_states"][11]
         | 
| 218 | 
            +
                    pool = None
         | 
| 219 | 
            +
                else:
         | 
| 220 | 
            +
                    # text_encoder2
         | 
| 221 | 
            +
                    enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
         | 
| 222 | 
            +
                    hidden_states = enc_out["hidden_states"][-2]  # penuultimate layer
         | 
| 223 | 
            +
                    # pool = enc_out["text_embeds"]
         | 
| 224 | 
            +
                    pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
         | 
| 225 | 
            +
                hidden_states = hidden_states.to(device)
         | 
| 226 | 
            +
                if pool is not None:
         | 
| 227 | 
            +
                    pool = pool.to(device)
         | 
| 228 | 
            +
                return hidden_states, pool
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            def get_unweighted_text_embeddings(
         | 
| 232 | 
            +
                pipe: StableDiffusionPipeline,
         | 
| 233 | 
            +
                text_input: torch.Tensor,
         | 
| 234 | 
            +
                chunk_length: int,
         | 
| 235 | 
            +
                clip_skip: int,
         | 
| 236 | 
            +
                eos: int,
         | 
| 237 | 
            +
                pad: int,
         | 
| 238 | 
            +
                is_sdxl_text_encoder2: bool,
         | 
| 239 | 
            +
                no_boseos_middle: Optional[bool] = True,
         | 
| 240 | 
            +
            ):
         | 
| 241 | 
            +
                """
         | 
| 242 | 
            +
                When the length of tokens is a multiple of the capacity of the text encoder,
         | 
| 243 | 
            +
                it should be split into chunks and sent to the text encoder individually.
         | 
| 244 | 
            +
                """
         | 
| 245 | 
            +
                max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
         | 
| 246 | 
            +
                text_pool = None
         | 
| 247 | 
            +
                if max_embeddings_multiples > 1:
         | 
| 248 | 
            +
                    text_embeddings = []
         | 
| 249 | 
            +
                    for i in range(max_embeddings_multiples):
         | 
| 250 | 
            +
                        # extract the i-th chunk
         | 
| 251 | 
            +
                        text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        # cover the head and the tail by the starting and the ending tokens
         | 
| 254 | 
            +
                        text_input_chunk[:, 0] = text_input[0, 0]
         | 
| 255 | 
            +
                        if pad == eos:  # v1
         | 
| 256 | 
            +
                            text_input_chunk[:, -1] = text_input[0, -1]
         | 
| 257 | 
            +
                        else:  # v2
         | 
| 258 | 
            +
                            for j in range(len(text_input_chunk)):
         | 
| 259 | 
            +
                                if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad:  # 最後に普通の文字がある
         | 
| 260 | 
            +
                                    text_input_chunk[j, -1] = eos
         | 
| 261 | 
            +
                                if text_input_chunk[j, 1] == pad:  # BOSだけであとはPAD
         | 
| 262 | 
            +
                                    text_input_chunk[j, 1] = eos
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        text_embedding, current_text_pool = get_hidden_states(
         | 
| 265 | 
            +
                            pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
         | 
| 266 | 
            +
                        )
         | 
| 267 | 
            +
                        if text_pool is None:
         | 
| 268 | 
            +
                            text_pool = current_text_pool
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                        if no_boseos_middle:
         | 
| 271 | 
            +
                            if i == 0:
         | 
| 272 | 
            +
                                # discard the ending token
         | 
| 273 | 
            +
                                text_embedding = text_embedding[:, :-1]
         | 
| 274 | 
            +
                            elif i == max_embeddings_multiples - 1:
         | 
| 275 | 
            +
                                # discard the starting token
         | 
| 276 | 
            +
                                text_embedding = text_embedding[:, 1:]
         | 
| 277 | 
            +
                            else:
         | 
| 278 | 
            +
                                # discard both starting and ending tokens
         | 
| 279 | 
            +
                                text_embedding = text_embedding[:, 1:-1]
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        text_embeddings.append(text_embedding)
         | 
| 282 | 
            +
                    text_embeddings = torch.concat(text_embeddings, axis=1)
         | 
| 283 | 
            +
                else:
         | 
| 284 | 
            +
                    text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
         | 
| 285 | 
            +
                return text_embeddings, text_pool
         | 
| 286 | 
            +
             | 
| 287 | 
            +
             | 
| 288 | 
            +
            def get_weighted_text_embeddings(
         | 
| 289 | 
            +
                pipe,  # : SdxlStableDiffusionLongPromptWeightingPipeline,
         | 
| 290 | 
            +
                prompt: Union[str, List[str]],
         | 
| 291 | 
            +
                uncond_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 292 | 
            +
                max_embeddings_multiples: Optional[int] = 3,
         | 
| 293 | 
            +
                no_boseos_middle: Optional[bool] = False,
         | 
| 294 | 
            +
                skip_parsing: Optional[bool] = False,
         | 
| 295 | 
            +
                skip_weighting: Optional[bool] = False,
         | 
| 296 | 
            +
                clip_skip=None,
         | 
| 297 | 
            +
                is_sdxl_text_encoder2=False,
         | 
| 298 | 
            +
            ):
         | 
| 299 | 
            +
                r"""
         | 
| 300 | 
            +
                Prompts can be assigned with local weights using brackets. For example,
         | 
| 301 | 
            +
                prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
         | 
| 302 | 
            +
                and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                Args:
         | 
| 307 | 
            +
                    pipe (`StableDiffusionPipeline`):
         | 
| 308 | 
            +
                        Pipe to provide access to the tokenizer and the text encoder.
         | 
| 309 | 
            +
                    prompt (`str` or `List[str]`):
         | 
| 310 | 
            +
                        The prompt or prompts to guide the image generation.
         | 
| 311 | 
            +
                    uncond_prompt (`str` or `List[str]`):
         | 
| 312 | 
            +
                        The unconditional prompt or prompts for guide the image generation. If unconditional prompt
         | 
| 313 | 
            +
                        is provided, the embeddings of prompt and uncond_prompt are concatenated.
         | 
| 314 | 
            +
                    max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 315 | 
            +
                        The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 316 | 
            +
                    no_boseos_middle (`bool`, *optional*, defaults to `False`):
         | 
| 317 | 
            +
                        If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
         | 
| 318 | 
            +
                        ending token in each of the chunk in the middle.
         | 
| 319 | 
            +
                    skip_parsing (`bool`, *optional*, defaults to `False`):
         | 
| 320 | 
            +
                        Skip the parsing of brackets.
         | 
| 321 | 
            +
                    skip_weighting (`bool`, *optional*, defaults to `False`):
         | 
| 322 | 
            +
                        Skip the weighting. When the parsing is skipped, it is forced True.
         | 
| 323 | 
            +
                """
         | 
| 324 | 
            +
                max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 325 | 
            +
                if isinstance(prompt, str):
         | 
| 326 | 
            +
                    prompt = [prompt]
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                if not skip_parsing:
         | 
| 329 | 
            +
                    prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
         | 
| 330 | 
            +
                    if uncond_prompt is not None:
         | 
| 331 | 
            +
                        if isinstance(uncond_prompt, str):
         | 
| 332 | 
            +
                            uncond_prompt = [uncond_prompt]
         | 
| 333 | 
            +
                        uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
         | 
| 334 | 
            +
                else:
         | 
| 335 | 
            +
                    prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
         | 
| 336 | 
            +
                    prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
         | 
| 337 | 
            +
                    if uncond_prompt is not None:
         | 
| 338 | 
            +
                        if isinstance(uncond_prompt, str):
         | 
| 339 | 
            +
                            uncond_prompt = [uncond_prompt]
         | 
| 340 | 
            +
                        uncond_tokens = [
         | 
| 341 | 
            +
                            token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
         | 
| 342 | 
            +
                        ]
         | 
| 343 | 
            +
                        uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                # round up the longest length of tokens to a multiple of (model_max_length - 2)
         | 
| 346 | 
            +
                max_length = max([len(token) for token in prompt_tokens])
         | 
| 347 | 
            +
                if uncond_prompt is not None:
         | 
| 348 | 
            +
                    max_length = max(max_length, max([len(token) for token in uncond_tokens]))
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                max_embeddings_multiples = min(
         | 
| 351 | 
            +
                    max_embeddings_multiples,
         | 
| 352 | 
            +
                    (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
         | 
| 353 | 
            +
                )
         | 
| 354 | 
            +
                max_embeddings_multiples = max(1, max_embeddings_multiples)
         | 
| 355 | 
            +
                max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                # pad the length of tokens and weights
         | 
| 358 | 
            +
                bos = pipe.tokenizer.bos_token_id
         | 
| 359 | 
            +
                eos = pipe.tokenizer.eos_token_id
         | 
| 360 | 
            +
                pad = pipe.tokenizer.pad_token_id
         | 
| 361 | 
            +
                prompt_tokens, prompt_weights = pad_tokens_and_weights(
         | 
| 362 | 
            +
                    prompt_tokens,
         | 
| 363 | 
            +
                    prompt_weights,
         | 
| 364 | 
            +
                    max_length,
         | 
| 365 | 
            +
                    bos,
         | 
| 366 | 
            +
                    eos,
         | 
| 367 | 
            +
                    pad,
         | 
| 368 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 369 | 
            +
                    chunk_length=pipe.tokenizer.model_max_length,
         | 
| 370 | 
            +
                )
         | 
| 371 | 
            +
                prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
         | 
| 372 | 
            +
                if uncond_prompt is not None:
         | 
| 373 | 
            +
                    uncond_tokens, uncond_weights = pad_tokens_and_weights(
         | 
| 374 | 
            +
                        uncond_tokens,
         | 
| 375 | 
            +
                        uncond_weights,
         | 
| 376 | 
            +
                        max_length,
         | 
| 377 | 
            +
                        bos,
         | 
| 378 | 
            +
                        eos,
         | 
| 379 | 
            +
                        pad,
         | 
| 380 | 
            +
                        no_boseos_middle=no_boseos_middle,
         | 
| 381 | 
            +
                        chunk_length=pipe.tokenizer.model_max_length,
         | 
| 382 | 
            +
                    )
         | 
| 383 | 
            +
                    uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                # get the embeddings
         | 
| 386 | 
            +
                text_embeddings, text_pool = get_unweighted_text_embeddings(
         | 
| 387 | 
            +
                    pipe,
         | 
| 388 | 
            +
                    prompt_tokens,
         | 
| 389 | 
            +
                    pipe.tokenizer.model_max_length,
         | 
| 390 | 
            +
                    clip_skip,
         | 
| 391 | 
            +
                    eos,
         | 
| 392 | 
            +
                    pad,
         | 
| 393 | 
            +
                    is_sdxl_text_encoder2,
         | 
| 394 | 
            +
                    no_boseos_middle=no_boseos_middle,
         | 
| 395 | 
            +
                )
         | 
| 396 | 
            +
                prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                if uncond_prompt is not None:
         | 
| 399 | 
            +
                    uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
         | 
| 400 | 
            +
                        pipe,
         | 
| 401 | 
            +
                        uncond_tokens,
         | 
| 402 | 
            +
                        pipe.tokenizer.model_max_length,
         | 
| 403 | 
            +
                        clip_skip,
         | 
| 404 | 
            +
                        eos,
         | 
| 405 | 
            +
                        pad,
         | 
| 406 | 
            +
                        is_sdxl_text_encoder2,
         | 
| 407 | 
            +
                        no_boseos_middle=no_boseos_middle,
         | 
| 408 | 
            +
                    )
         | 
| 409 | 
            +
                    uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                # assign weights to the prompts and normalize in the sense of mean
         | 
| 412 | 
            +
                # TODO: should we normalize by chunk or in a whole (current implementation)?
         | 
| 413 | 
            +
                if (not skip_parsing) and (not skip_weighting):
         | 
| 414 | 
            +
                    previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 415 | 
            +
                    text_embeddings *= prompt_weights.unsqueeze(-1)
         | 
| 416 | 
            +
                    current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
         | 
| 417 | 
            +
                    text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
         | 
| 418 | 
            +
                    if uncond_prompt is not None:
         | 
| 419 | 
            +
                        previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
         | 
| 420 | 
            +
                        uncond_embeddings *= uncond_weights.unsqueeze(-1)
         | 
| 421 | 
            +
                        current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
         | 
| 422 | 
            +
                        uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                if uncond_prompt is not None:
         | 
| 425 | 
            +
                    return text_embeddings, text_pool, uncond_embeddings, uncond_pool
         | 
| 426 | 
            +
                return text_embeddings, text_pool, None, None
         | 
| 427 | 
            +
             | 
| 428 | 
            +
             | 
| 429 | 
            +
            def preprocess_image(image):
         | 
| 430 | 
            +
                w, h = image.size
         | 
| 431 | 
            +
                w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
         | 
| 432 | 
            +
                image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
         | 
| 433 | 
            +
                image = np.array(image).astype(np.float32) / 255.0
         | 
| 434 | 
            +
                image = image[None].transpose(0, 3, 1, 2)
         | 
| 435 | 
            +
                image = torch.from_numpy(image)
         | 
| 436 | 
            +
                return 2.0 * image - 1.0
         | 
| 437 | 
            +
             | 
| 438 | 
            +
             | 
| 439 | 
            +
            def preprocess_mask(mask, scale_factor=8):
         | 
| 440 | 
            +
                mask = mask.convert("L")
         | 
| 441 | 
            +
                w, h = mask.size
         | 
| 442 | 
            +
                w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
         | 
| 443 | 
            +
                mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
         | 
| 444 | 
            +
                mask = np.array(mask).astype(np.float32) / 255.0
         | 
| 445 | 
            +
                mask = np.tile(mask, (4, 1, 1))
         | 
| 446 | 
            +
                mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?
         | 
| 447 | 
            +
                mask = 1 - mask  # repaint white, keep black
         | 
| 448 | 
            +
                mask = torch.from_numpy(mask)
         | 
| 449 | 
            +
                return mask
         | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
            def prepare_controlnet_image(
         | 
| 453 | 
            +
                image: PIL.Image.Image,
         | 
| 454 | 
            +
                width: int,
         | 
| 455 | 
            +
                height: int,
         | 
| 456 | 
            +
                batch_size: int,
         | 
| 457 | 
            +
                num_images_per_prompt: int,
         | 
| 458 | 
            +
                device: torch.device,
         | 
| 459 | 
            +
                dtype: torch.dtype,
         | 
| 460 | 
            +
                do_classifier_free_guidance: bool = False,
         | 
| 461 | 
            +
                guess_mode: bool = False,
         | 
| 462 | 
            +
            ):
         | 
| 463 | 
            +
                if not isinstance(image, torch.Tensor):
         | 
| 464 | 
            +
                    if isinstance(image, PIL.Image.Image):
         | 
| 465 | 
            +
                        image = [image]
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    if isinstance(image[0], PIL.Image.Image):
         | 
| 468 | 
            +
                        images = []
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                        for image_ in image:
         | 
| 471 | 
            +
                            image_ = image_.convert("RGB")
         | 
| 472 | 
            +
                            image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
         | 
| 473 | 
            +
                            image_ = np.array(image_)
         | 
| 474 | 
            +
                            image_ = image_[None, :]
         | 
| 475 | 
            +
                            images.append(image_)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                        image = images
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                        image = np.concatenate(image, axis=0)
         | 
| 480 | 
            +
                        image = np.array(image).astype(np.float32) / 255.0
         | 
| 481 | 
            +
                        image = image.transpose(0, 3, 1, 2)
         | 
| 482 | 
            +
                        image = torch.from_numpy(image)
         | 
| 483 | 
            +
                    elif isinstance(image[0], torch.Tensor):
         | 
| 484 | 
            +
                        image = torch.cat(image, dim=0)
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                image_batch_size = image.shape[0]
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                if image_batch_size == 1:
         | 
| 489 | 
            +
                    repeat_by = batch_size
         | 
| 490 | 
            +
                else:
         | 
| 491 | 
            +
                    # image batch size is the same as prompt batch size
         | 
| 492 | 
            +
                    repeat_by = num_images_per_prompt
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                image = image.repeat_interleave(repeat_by, dim=0)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                image = image.to(device=device, dtype=dtype)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                if do_classifier_free_guidance and not guess_mode:
         | 
| 499 | 
            +
                    image = torch.cat([image] * 2)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                return image
         | 
| 502 | 
            +
             | 
| 503 | 
            +
             | 
| 504 | 
            +
            class SdxlStableDiffusionLongPromptWeightingPipeline:
         | 
| 505 | 
            +
                r"""
         | 
| 506 | 
            +
                Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
         | 
| 507 | 
            +
                weighting in prompt.
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 510 | 
            +
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                Args:
         | 
| 513 | 
            +
                    vae ([`AutoencoderKL`]):
         | 
| 514 | 
            +
                        Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
         | 
| 515 | 
            +
                    text_encoder ([`CLIPTextModel`]):
         | 
| 516 | 
            +
                        Frozen text-encoder. Stable Diffusion uses the text portion of
         | 
| 517 | 
            +
                        [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
         | 
| 518 | 
            +
                        the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
         | 
| 519 | 
            +
                    tokenizer (`CLIPTokenizer`):
         | 
| 520 | 
            +
                        Tokenizer of class
         | 
| 521 | 
            +
                        [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
         | 
| 522 | 
            +
                    unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
         | 
| 523 | 
            +
                    scheduler ([`SchedulerMixin`]):
         | 
| 524 | 
            +
                        A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
         | 
| 525 | 
            +
                        [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
         | 
| 526 | 
            +
                    safety_checker ([`StableDiffusionSafetyChecker`]):
         | 
| 527 | 
            +
                        Classification module that estimates whether generated images could be considered offensive or harmful.
         | 
| 528 | 
            +
                        Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
         | 
| 529 | 
            +
                    feature_extractor ([`CLIPFeatureExtractor`]):
         | 
| 530 | 
            +
                        Model that extracts features from generated images to be used as inputs for the `safety_checker`.
         | 
| 531 | 
            +
                """
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                def __init__(
         | 
| 536 | 
            +
                    self,
         | 
| 537 | 
            +
                    vae: AutoencoderKL,
         | 
| 538 | 
            +
                    text_encoder: List[CLIPTextModel],
         | 
| 539 | 
            +
                    tokenizer: List[CLIPTokenizer],
         | 
| 540 | 
            +
                    unet: UNet2DConditionModel,
         | 
| 541 | 
            +
                    scheduler: SchedulerMixin,
         | 
| 542 | 
            +
                    # clip_skip: int,
         | 
| 543 | 
            +
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 544 | 
            +
                    feature_extractor: CLIPFeatureExtractor,
         | 
| 545 | 
            +
                    requires_safety_checker: bool = True,
         | 
| 546 | 
            +
                    clip_skip: int = 1,
         | 
| 547 | 
            +
                ):
         | 
| 548 | 
            +
                    # clip skip is ignored currently
         | 
| 549 | 
            +
                    self.tokenizer = tokenizer[0]
         | 
| 550 | 
            +
                    self.text_encoder = text_encoder[0]
         | 
| 551 | 
            +
                    self.unet = unet
         | 
| 552 | 
            +
                    self.scheduler = scheduler
         | 
| 553 | 
            +
                    self.safety_checker = safety_checker
         | 
| 554 | 
            +
                    self.feature_extractor = feature_extractor
         | 
| 555 | 
            +
                    self.requires_safety_checker = requires_safety_checker
         | 
| 556 | 
            +
                    self.vae = vae
         | 
| 557 | 
            +
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 558 | 
            +
                    self.progress_bar = lambda x: tqdm(x, leave=False)
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    self.clip_skip = clip_skip
         | 
| 561 | 
            +
                    self.tokenizers = tokenizer
         | 
| 562 | 
            +
                    self.text_encoders = text_encoder
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                #     self.__init__additional__()
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                # def __init__additional__(self):
         | 
| 567 | 
            +
                #     if not hasattr(self, "vae_scale_factor"):
         | 
| 568 | 
            +
                #         setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                def to(self, device=None, dtype=None):
         | 
| 571 | 
            +
                    if device is not None:
         | 
| 572 | 
            +
                        self.device = device
         | 
| 573 | 
            +
                        # self.vae.to(device=self.device)
         | 
| 574 | 
            +
                    if dtype is not None:
         | 
| 575 | 
            +
                        self.dtype = dtype
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    # do not move Text Encoders to device, because Text Encoder should be on CPU
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                @property
         | 
| 580 | 
            +
                def _execution_device(self):
         | 
| 581 | 
            +
                    r"""
         | 
| 582 | 
            +
                    Returns the device on which the pipeline's models will be executed. After calling
         | 
| 583 | 
            +
                    `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
         | 
| 584 | 
            +
                    hooks.
         | 
| 585 | 
            +
                    """
         | 
| 586 | 
            +
                    if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
         | 
| 587 | 
            +
                        return self.device
         | 
| 588 | 
            +
                    for module in self.unet.modules():
         | 
| 589 | 
            +
                        if (
         | 
| 590 | 
            +
                            hasattr(module, "_hf_hook")
         | 
| 591 | 
            +
                            and hasattr(module._hf_hook, "execution_device")
         | 
| 592 | 
            +
                            and module._hf_hook.execution_device is not None
         | 
| 593 | 
            +
                        ):
         | 
| 594 | 
            +
                            return torch.device(module._hf_hook.execution_device)
         | 
| 595 | 
            +
                    return self.device
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                def _encode_prompt(
         | 
| 598 | 
            +
                    self,
         | 
| 599 | 
            +
                    prompt,
         | 
| 600 | 
            +
                    device,
         | 
| 601 | 
            +
                    num_images_per_prompt,
         | 
| 602 | 
            +
                    do_classifier_free_guidance,
         | 
| 603 | 
            +
                    negative_prompt,
         | 
| 604 | 
            +
                    max_embeddings_multiples,
         | 
| 605 | 
            +
                    is_sdxl_text_encoder2,
         | 
| 606 | 
            +
                ):
         | 
| 607 | 
            +
                    r"""
         | 
| 608 | 
            +
                    Encodes the prompt into text encoder hidden states.
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    Args:
         | 
| 611 | 
            +
                        prompt (`str` or `list(int)`):
         | 
| 612 | 
            +
                            prompt to be encoded
         | 
| 613 | 
            +
                        device: (`torch.device`):
         | 
| 614 | 
            +
                            torch device
         | 
| 615 | 
            +
                        num_images_per_prompt (`int`):
         | 
| 616 | 
            +
                            number of images that should be generated per prompt
         | 
| 617 | 
            +
                        do_classifier_free_guidance (`bool`):
         | 
| 618 | 
            +
                            whether to use classifier free guidance or not
         | 
| 619 | 
            +
                        negative_prompt (`str` or `List[str]`):
         | 
| 620 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 621 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 622 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 623 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 624 | 
            +
                    """
         | 
| 625 | 
            +
                    batch_size = len(prompt) if isinstance(prompt, list) else 1
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    if negative_prompt is None:
         | 
| 628 | 
            +
                        negative_prompt = [""] * batch_size
         | 
| 629 | 
            +
                    elif isinstance(negative_prompt, str):
         | 
| 630 | 
            +
                        negative_prompt = [negative_prompt] * batch_size
         | 
| 631 | 
            +
                    if batch_size != len(negative_prompt):
         | 
| 632 | 
            +
                        raise ValueError(
         | 
| 633 | 
            +
                            f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 634 | 
            +
                            f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 635 | 
            +
                            " the batch size of `prompt`."
         | 
| 636 | 
            +
                        )
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
         | 
| 639 | 
            +
                        pipe=self,
         | 
| 640 | 
            +
                        prompt=prompt,
         | 
| 641 | 
            +
                        uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
         | 
| 642 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 643 | 
            +
                        clip_skip=self.clip_skip,
         | 
| 644 | 
            +
                        is_sdxl_text_encoder2=is_sdxl_text_encoder2,
         | 
| 645 | 
            +
                    )
         | 
| 646 | 
            +
                    bs_embed, seq_len, _ = text_embeddings.shape
         | 
| 647 | 
            +
                    text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)  # ??
         | 
| 648 | 
            +
                    text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 649 | 
            +
                    if text_pool is not None:
         | 
| 650 | 
            +
                        text_pool = text_pool.repeat(1, num_images_per_prompt)
         | 
| 651 | 
            +
                        text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                    if do_classifier_free_guidance:
         | 
| 654 | 
            +
                        bs_embed, seq_len, _ = uncond_embeddings.shape
         | 
| 655 | 
            +
                        uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
         | 
| 656 | 
            +
                        uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 657 | 
            +
                        if uncond_pool is not None:
         | 
| 658 | 
            +
                            uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
         | 
| 659 | 
            +
                            uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                        return text_embeddings, text_pool, uncond_embeddings, uncond_pool
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    return text_embeddings, text_pool, None, None
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                def check_inputs(self, prompt, height, width, strength, callback_steps):
         | 
| 666 | 
            +
                    if not isinstance(prompt, str) and not isinstance(prompt, list):
         | 
| 667 | 
            +
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    if strength < 0 or strength > 1:
         | 
| 670 | 
            +
                        raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 673 | 
            +
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    if (callback_steps is None) or (
         | 
| 676 | 
            +
                        callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
         | 
| 677 | 
            +
                    ):
         | 
| 678 | 
            +
                        raise ValueError(
         | 
| 679 | 
            +
                            f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
         | 
| 680 | 
            +
                        )
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
         | 
| 683 | 
            +
                    if is_text2img:
         | 
| 684 | 
            +
                        return self.scheduler.timesteps.to(device), num_inference_steps
         | 
| 685 | 
            +
                    else:
         | 
| 686 | 
            +
                        # get the original timestep using init_timestep
         | 
| 687 | 
            +
                        offset = self.scheduler.config.get("steps_offset", 0)
         | 
| 688 | 
            +
                        init_timestep = int(num_inference_steps * strength) + offset
         | 
| 689 | 
            +
                        init_timestep = min(init_timestep, num_inference_steps)
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                        t_start = max(num_inference_steps - init_timestep + offset, 0)
         | 
| 692 | 
            +
                        timesteps = self.scheduler.timesteps[t_start:].to(device)
         | 
| 693 | 
            +
                        return timesteps, num_inference_steps - t_start
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                def run_safety_checker(self, image, device, dtype):
         | 
| 696 | 
            +
                    if self.safety_checker is not None:
         | 
| 697 | 
            +
                        safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
         | 
| 698 | 
            +
                        image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
         | 
| 699 | 
            +
                    else:
         | 
| 700 | 
            +
                        has_nsfw_concept = None
         | 
| 701 | 
            +
                    return image, has_nsfw_concept
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                def decode_latents(self, latents):
         | 
| 704 | 
            +
                    with torch.no_grad():
         | 
| 705 | 
            +
                        latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                        # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype)  # torch.float32
         | 
| 708 | 
            +
                        # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
         | 
| 709 | 
            +
                        # print("latents dtype:", latents.dtype, "x dtype:", x.dtype)  # torch.float32, torch.float16
         | 
| 710 | 
            +
                        # self.vae.to("cpu")
         | 
| 711 | 
            +
                        # self.vae.set_use_memory_efficient_attention_xformers(False)
         | 
| 712 | 
            +
                        # image = self.vae.decode(latents.to("cpu")).sample
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                        image = self.vae.decode(latents.to(self.vae.dtype)).sample
         | 
| 715 | 
            +
                        image = (image / 2 + 0.5).clamp(0, 1)
         | 
| 716 | 
            +
                        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
         | 
| 717 | 
            +
                        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
         | 
| 718 | 
            +
                        return image
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 721 | 
            +
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 722 | 
            +
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 723 | 
            +
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 724 | 
            +
                    # and should be between [0, 1]
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 727 | 
            +
                    extra_step_kwargs = {}
         | 
| 728 | 
            +
                    if accepts_eta:
         | 
| 729 | 
            +
                        extra_step_kwargs["eta"] = eta
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    # check if the scheduler accepts generator
         | 
| 732 | 
            +
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 733 | 
            +
                    if accepts_generator:
         | 
| 734 | 
            +
                        extra_step_kwargs["generator"] = generator
         | 
| 735 | 
            +
                    return extra_step_kwargs
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
         | 
| 738 | 
            +
                    if image is None:
         | 
| 739 | 
            +
                        shape = (
         | 
| 740 | 
            +
                            batch_size,
         | 
| 741 | 
            +
                            self.unet.in_channels,
         | 
| 742 | 
            +
                            height // self.vae_scale_factor,
         | 
| 743 | 
            +
                            width // self.vae_scale_factor,
         | 
| 744 | 
            +
                        )
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                        if latents is None:
         | 
| 747 | 
            +
                            if device.type == "mps":
         | 
| 748 | 
            +
                                # randn does not work reproducibly on mps
         | 
| 749 | 
            +
                                latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
         | 
| 750 | 
            +
                            else:
         | 
| 751 | 
            +
                                latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
         | 
| 752 | 
            +
                        else:
         | 
| 753 | 
            +
                            if latents.shape != shape:
         | 
| 754 | 
            +
                                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
         | 
| 755 | 
            +
                            latents = latents.to(device)
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                        # scale the initial noise by the standard deviation required by the scheduler
         | 
| 758 | 
            +
                        latents = latents * self.scheduler.init_noise_sigma
         | 
| 759 | 
            +
                        return latents, None, None
         | 
| 760 | 
            +
                    else:
         | 
| 761 | 
            +
                        init_latent_dist = self.vae.encode(image).latent_dist
         | 
| 762 | 
            +
                        init_latents = init_latent_dist.sample(generator=generator)
         | 
| 763 | 
            +
                        init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
         | 
| 764 | 
            +
                        init_latents = torch.cat([init_latents] * batch_size, dim=0)
         | 
| 765 | 
            +
                        init_latents_orig = init_latents
         | 
| 766 | 
            +
                        shape = init_latents.shape
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                        # add noise to latents using the timesteps
         | 
| 769 | 
            +
                        if device.type == "mps":
         | 
| 770 | 
            +
                            noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
         | 
| 771 | 
            +
                        else:
         | 
| 772 | 
            +
                            noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
         | 
| 773 | 
            +
                        latents = self.scheduler.add_noise(init_latents, noise, timestep)
         | 
| 774 | 
            +
                        return latents, init_latents_orig, noise
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                @torch.no_grad()
         | 
| 777 | 
            +
                def __call__(
         | 
| 778 | 
            +
                    self,
         | 
| 779 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 780 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 781 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 782 | 
            +
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 783 | 
            +
                    height: int = 512,
         | 
| 784 | 
            +
                    width: int = 512,
         | 
| 785 | 
            +
                    num_inference_steps: int = 50,
         | 
| 786 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 787 | 
            +
                    strength: float = 0.8,
         | 
| 788 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 789 | 
            +
                    eta: float = 0.0,
         | 
| 790 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 791 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 792 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 793 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 794 | 
            +
                    return_dict: bool = True,
         | 
| 795 | 
            +
                    controlnet=None,
         | 
| 796 | 
            +
                    controlnet_image=None,
         | 
| 797 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 798 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 799 | 
            +
                    callback_steps: int = 1,
         | 
| 800 | 
            +
                ):
         | 
| 801 | 
            +
                    r"""
         | 
| 802 | 
            +
                    Function invoked when calling the pipeline for generation.
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                    Args:
         | 
| 805 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 806 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 807 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 808 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 809 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 810 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 811 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 812 | 
            +
                            process.
         | 
| 813 | 
            +
                        mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 814 | 
            +
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 815 | 
            +
                            replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
         | 
| 816 | 
            +
                            PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
         | 
| 817 | 
            +
                            contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
         | 
| 818 | 
            +
                        height (`int`, *optional*, defaults to 512):
         | 
| 819 | 
            +
                            The height in pixels of the generated image.
         | 
| 820 | 
            +
                        width (`int`, *optional*, defaults to 512):
         | 
| 821 | 
            +
                            The width in pixels of the generated image.
         | 
| 822 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 823 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 824 | 
            +
                            expense of slower inference.
         | 
| 825 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 826 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 827 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 828 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 829 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 830 | 
            +
                            usually at the expense of lower image quality.
         | 
| 831 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 832 | 
            +
                            Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
         | 
| 833 | 
            +
                            `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
         | 
| 834 | 
            +
                            number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
         | 
| 835 | 
            +
                            noise will be maximum and the denoising process will run for the full number of iterations specified in
         | 
| 836 | 
            +
                            `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
         | 
| 837 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 838 | 
            +
                            The number of images to generate per prompt.
         | 
| 839 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 840 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 841 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 842 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 843 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 844 | 
            +
                            deterministic.
         | 
| 845 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 846 | 
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 847 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 848 | 
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 849 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 850 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 851 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 852 | 
            +
                            The output format of the generate image. Choose between
         | 
| 853 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 854 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 855 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 856 | 
            +
                            plain tuple.
         | 
| 857 | 
            +
                        controlnet (`diffusers.ControlNetModel`, *optional*):
         | 
| 858 | 
            +
                            A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
         | 
| 859 | 
            +
                        controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
         | 
| 860 | 
            +
                            `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
         | 
| 861 | 
            +
                            inference.
         | 
| 862 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 863 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 864 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 865 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 866 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 867 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 868 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 869 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 870 | 
            +
                            called at every step.
         | 
| 871 | 
            +
             | 
| 872 | 
            +
                    Returns:
         | 
| 873 | 
            +
                        `None` if cancelled by `is_cancelled_callback`,
         | 
| 874 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 875 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 876 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 877 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 878 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 879 | 
            +
                    """
         | 
| 880 | 
            +
                    if controlnet is not None and controlnet_image is None:
         | 
| 881 | 
            +
                        raise ValueError("controlnet_image must be provided if controlnet is not None.")
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                    # 0. Default height and width to unet
         | 
| 884 | 
            +
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 885 | 
            +
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                    # 1. Check inputs. Raise error if not correct
         | 
| 888 | 
            +
                    self.check_inputs(prompt, height, width, strength, callback_steps)
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                    # 2. Define call parameters
         | 
| 891 | 
            +
                    batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 892 | 
            +
                    device = self._execution_device
         | 
| 893 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 894 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 895 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 896 | 
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         | 
| 897 | 
            +
             | 
| 898 | 
            +
                    # 3. Encode input prompt
         | 
| 899 | 
            +
                    # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
         | 
| 900 | 
            +
                    # To simplify the implementation, switch the tokenzer/text encoder and call it twice
         | 
| 901 | 
            +
                    text_embeddings_list = []
         | 
| 902 | 
            +
                    text_pool = None
         | 
| 903 | 
            +
                    uncond_embeddings_list = []
         | 
| 904 | 
            +
                    uncond_pool = None
         | 
| 905 | 
            +
                    for i in range(len(self.tokenizers)):
         | 
| 906 | 
            +
                        self.tokenizer = self.tokenizers[i]
         | 
| 907 | 
            +
                        self.text_encoder = self.text_encoders[i]
         | 
| 908 | 
            +
             | 
| 909 | 
            +
                        text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
         | 
| 910 | 
            +
                            prompt,
         | 
| 911 | 
            +
                            device,
         | 
| 912 | 
            +
                            num_images_per_prompt,
         | 
| 913 | 
            +
                            do_classifier_free_guidance,
         | 
| 914 | 
            +
                            negative_prompt,
         | 
| 915 | 
            +
                            max_embeddings_multiples,
         | 
| 916 | 
            +
                            is_sdxl_text_encoder2=i == 1,
         | 
| 917 | 
            +
                        )
         | 
| 918 | 
            +
                        text_embeddings_list.append(text_embeddings)
         | 
| 919 | 
            +
                        uncond_embeddings_list.append(uncond_embeddings)
         | 
| 920 | 
            +
             | 
| 921 | 
            +
                        if tp1 is not None:
         | 
| 922 | 
            +
                            text_pool = tp1
         | 
| 923 | 
            +
                        if up1 is not None:
         | 
| 924 | 
            +
                            uncond_pool = up1
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                    unet_dtype = self.unet.dtype
         | 
| 927 | 
            +
                    dtype = unet_dtype
         | 
| 928 | 
            +
                    if hasattr(dtype, "itemsize") and dtype.itemsize == 1:  # fp8
         | 
| 929 | 
            +
                        dtype = torch.float16
         | 
| 930 | 
            +
                        self.unet.to(dtype)
         | 
| 931 | 
            +
             | 
| 932 | 
            +
                    # 4. Preprocess image and mask
         | 
| 933 | 
            +
                    if isinstance(image, PIL.Image.Image):
         | 
| 934 | 
            +
                        image = preprocess_image(image)
         | 
| 935 | 
            +
                    if image is not None:
         | 
| 936 | 
            +
                        image = image.to(device=self.device, dtype=dtype)
         | 
| 937 | 
            +
                    if isinstance(mask_image, PIL.Image.Image):
         | 
| 938 | 
            +
                        mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
         | 
| 939 | 
            +
                    if mask_image is not None:
         | 
| 940 | 
            +
                        mask = mask_image.to(device=self.device, dtype=dtype)
         | 
| 941 | 
            +
                        mask = torch.cat([mask] * batch_size * num_images_per_prompt)
         | 
| 942 | 
            +
                    else:
         | 
| 943 | 
            +
                        mask = None
         | 
| 944 | 
            +
             | 
| 945 | 
            +
                    # ControlNet is not working yet in SDXL, but keep the code here for future use
         | 
| 946 | 
            +
                    if controlnet_image is not None:
         | 
| 947 | 
            +
                        controlnet_image = prepare_controlnet_image(
         | 
| 948 | 
            +
                            controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
         | 
| 949 | 
            +
                        )
         | 
| 950 | 
            +
             | 
| 951 | 
            +
                    # 5. set timesteps
         | 
| 952 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 953 | 
            +
                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
         | 
| 954 | 
            +
                    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         | 
| 955 | 
            +
             | 
| 956 | 
            +
                    # 6. Prepare latent variables
         | 
| 957 | 
            +
                    latents, init_latents_orig, noise = self.prepare_latents(
         | 
| 958 | 
            +
                        image,
         | 
| 959 | 
            +
                        latent_timestep,
         | 
| 960 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 961 | 
            +
                        height,
         | 
| 962 | 
            +
                        width,
         | 
| 963 | 
            +
                        dtype,
         | 
| 964 | 
            +
                        device,
         | 
| 965 | 
            +
                        generator,
         | 
| 966 | 
            +
                        latents,
         | 
| 967 | 
            +
                    )
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                    # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 970 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 971 | 
            +
             | 
| 972 | 
            +
                    # create size embs and concat embeddings for SDXL
         | 
| 973 | 
            +
                    orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
         | 
| 974 | 
            +
                    crop_size = torch.zeros_like(orig_size)
         | 
| 975 | 
            +
                    target_size = orig_size
         | 
| 976 | 
            +
                    embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
         | 
| 977 | 
            +
             | 
| 978 | 
            +
                    # make conditionings
         | 
| 979 | 
            +
                    if do_classifier_free_guidance:
         | 
| 980 | 
            +
                        text_embeddings = torch.cat(text_embeddings_list, dim=2)
         | 
| 981 | 
            +
                        uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
         | 
| 982 | 
            +
                        text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                        cond_vector = torch.cat([text_pool, embs], dim=1)
         | 
| 985 | 
            +
                        uncond_vector = torch.cat([uncond_pool, embs], dim=1)
         | 
| 986 | 
            +
                        vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
         | 
| 987 | 
            +
                    else:
         | 
| 988 | 
            +
                        text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
         | 
| 989 | 
            +
                        vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
         | 
| 990 | 
            +
             | 
| 991 | 
            +
                    # 8. Denoising loop
         | 
| 992 | 
            +
                    for i, t in enumerate(self.progress_bar(timesteps)):
         | 
| 993 | 
            +
                        # expand the latents if we are doing classifier free guidance
         | 
| 994 | 
            +
                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         | 
| 995 | 
            +
                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                        unet_additional_args = {}
         | 
| 998 | 
            +
                        if controlnet is not None:
         | 
| 999 | 
            +
                            down_block_res_samples, mid_block_res_sample = controlnet(
         | 
| 1000 | 
            +
                                latent_model_input,
         | 
| 1001 | 
            +
                                t,
         | 
| 1002 | 
            +
                                encoder_hidden_states=text_embeddings,
         | 
| 1003 | 
            +
                                controlnet_cond=controlnet_image,
         | 
| 1004 | 
            +
                                conditioning_scale=1.0,
         | 
| 1005 | 
            +
                                guess_mode=False,
         | 
| 1006 | 
            +
                                return_dict=False,
         | 
| 1007 | 
            +
                            )
         | 
| 1008 | 
            +
                            unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
         | 
| 1009 | 
            +
                            unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
         | 
| 1010 | 
            +
             | 
| 1011 | 
            +
                        # predict the noise residual
         | 
| 1012 | 
            +
                        noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
         | 
| 1013 | 
            +
                        noise_pred = noise_pred.to(dtype)  # U-Net changes dtype in LoRA training
         | 
| 1014 | 
            +
             | 
| 1015 | 
            +
                        # perform guidance
         | 
| 1016 | 
            +
                        if do_classifier_free_guidance:
         | 
| 1017 | 
            +
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 1018 | 
            +
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 1019 | 
            +
             | 
| 1020 | 
            +
                        # compute the previous noisy sample x_t -> x_t-1
         | 
| 1021 | 
            +
                        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
         | 
| 1022 | 
            +
             | 
| 1023 | 
            +
                        if mask is not None:
         | 
| 1024 | 
            +
                            # masking
         | 
| 1025 | 
            +
                            init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
         | 
| 1026 | 
            +
                            latents = (init_latents_proper * mask) + (latents * (1 - mask))
         | 
| 1027 | 
            +
             | 
| 1028 | 
            +
                        # call the callback, if provided
         | 
| 1029 | 
            +
                        if i % callback_steps == 0:
         | 
| 1030 | 
            +
                            if callback is not None:
         | 
| 1031 | 
            +
                                callback(i, t, latents)
         | 
| 1032 | 
            +
                            if is_cancelled_callback is not None and is_cancelled_callback():
         | 
| 1033 | 
            +
                                return None
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
                    self.unet.to(unet_dtype)
         | 
| 1036 | 
            +
                    return latents
         | 
| 1037 | 
            +
             | 
| 1038 | 
            +
                def latents_to_image(self, latents):
         | 
| 1039 | 
            +
                    # 9. Post-processing
         | 
| 1040 | 
            +
                    image = self.decode_latents(latents.to(self.vae.dtype))
         | 
| 1041 | 
            +
                    image = self.numpy_to_pil(image)
         | 
| 1042 | 
            +
                    return image
         | 
| 1043 | 
            +
             | 
| 1044 | 
            +
                # copy from pil_utils.py
         | 
| 1045 | 
            +
                def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
         | 
| 1046 | 
            +
                    """
         | 
| 1047 | 
            +
                    Convert a numpy image or a batch of images to a PIL image.
         | 
| 1048 | 
            +
                    """
         | 
| 1049 | 
            +
                    if images.ndim == 3:
         | 
| 1050 | 
            +
                        images = images[None, ...]
         | 
| 1051 | 
            +
                    images = (images * 255).round().astype("uint8")
         | 
| 1052 | 
            +
                    if images.shape[-1] == 1:
         | 
| 1053 | 
            +
                        # special case for grayscale (single channel) images
         | 
| 1054 | 
            +
                        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
         | 
| 1055 | 
            +
                    else:
         | 
| 1056 | 
            +
                        pil_images = [Image.fromarray(image) for image in images]
         | 
| 1057 | 
            +
             | 
| 1058 | 
            +
                    return pil_images
         | 
| 1059 | 
            +
             | 
| 1060 | 
            +
                def text2img(
         | 
| 1061 | 
            +
                    self,
         | 
| 1062 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 1063 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 1064 | 
            +
                    height: int = 512,
         | 
| 1065 | 
            +
                    width: int = 512,
         | 
| 1066 | 
            +
                    num_inference_steps: int = 50,
         | 
| 1067 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 1068 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 1069 | 
            +
                    eta: float = 0.0,
         | 
| 1070 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 1071 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 1072 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 1073 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 1074 | 
            +
                    return_dict: bool = True,
         | 
| 1075 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 1076 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 1077 | 
            +
                    callback_steps: int = 1,
         | 
| 1078 | 
            +
                ):
         | 
| 1079 | 
            +
                    r"""
         | 
| 1080 | 
            +
                    Function for text-to-image generation.
         | 
| 1081 | 
            +
                    Args:
         | 
| 1082 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 1083 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 1084 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 1085 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 1086 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 1087 | 
            +
                        height (`int`, *optional*, defaults to 512):
         | 
| 1088 | 
            +
                            The height in pixels of the generated image.
         | 
| 1089 | 
            +
                        width (`int`, *optional*, defaults to 512):
         | 
| 1090 | 
            +
                            The width in pixels of the generated image.
         | 
| 1091 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 1092 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 1093 | 
            +
                            expense of slower inference.
         | 
| 1094 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 1095 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 1096 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 1097 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 1098 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 1099 | 
            +
                            usually at the expense of lower image quality.
         | 
| 1100 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 1101 | 
            +
                            The number of images to generate per prompt.
         | 
| 1102 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 1103 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 1104 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 1105 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 1106 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 1107 | 
            +
                            deterministic.
         | 
| 1108 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 1109 | 
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 1110 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 1111 | 
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 1112 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 1113 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1114 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1115 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1116 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1117 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1118 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1119 | 
            +
                            plain tuple.
         | 
| 1120 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1121 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1122 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1123 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1124 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1125 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1126 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1127 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1128 | 
            +
                            called at every step.
         | 
| 1129 | 
            +
                    Returns:
         | 
| 1130 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1131 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1132 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1133 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1134 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1135 | 
            +
                    """
         | 
| 1136 | 
            +
                    return self.__call__(
         | 
| 1137 | 
            +
                        prompt=prompt,
         | 
| 1138 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1139 | 
            +
                        height=height,
         | 
| 1140 | 
            +
                        width=width,
         | 
| 1141 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1142 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1143 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1144 | 
            +
                        eta=eta,
         | 
| 1145 | 
            +
                        generator=generator,
         | 
| 1146 | 
            +
                        latents=latents,
         | 
| 1147 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1148 | 
            +
                        output_type=output_type,
         | 
| 1149 | 
            +
                        return_dict=return_dict,
         | 
| 1150 | 
            +
                        callback=callback,
         | 
| 1151 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1152 | 
            +
                        callback_steps=callback_steps,
         | 
| 1153 | 
            +
                    )
         | 
| 1154 | 
            +
             | 
| 1155 | 
            +
                def img2img(
         | 
| 1156 | 
            +
                    self,
         | 
| 1157 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1158 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 1159 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 1160 | 
            +
                    strength: float = 0.8,
         | 
| 1161 | 
            +
                    num_inference_steps: Optional[int] = 50,
         | 
| 1162 | 
            +
                    guidance_scale: Optional[float] = 7.5,
         | 
| 1163 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 1164 | 
            +
                    eta: Optional[float] = 0.0,
         | 
| 1165 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 1166 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 1167 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 1168 | 
            +
                    return_dict: bool = True,
         | 
| 1169 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 1170 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 1171 | 
            +
                    callback_steps: int = 1,
         | 
| 1172 | 
            +
                ):
         | 
| 1173 | 
            +
                    r"""
         | 
| 1174 | 
            +
                    Function for image-to-image generation.
         | 
| 1175 | 
            +
                    Args:
         | 
| 1176 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1177 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 1178 | 
            +
                            process.
         | 
| 1179 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 1180 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 1181 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 1182 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 1183 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 1184 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 1185 | 
            +
                            Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
         | 
| 1186 | 
            +
                            `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
         | 
| 1187 | 
            +
                            number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
         | 
| 1188 | 
            +
                            noise will be maximum and the denoising process will run for the full number of iterations specified in
         | 
| 1189 | 
            +
                            `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
         | 
| 1190 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 1191 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 1192 | 
            +
                            expense of slower inference. This parameter will be modulated by `strength`.
         | 
| 1193 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 1194 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 1195 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 1196 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 1197 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 1198 | 
            +
                            usually at the expense of lower image quality.
         | 
| 1199 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 1200 | 
            +
                            The number of images to generate per prompt.
         | 
| 1201 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 1202 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 1203 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 1204 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 1205 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 1206 | 
            +
                            deterministic.
         | 
| 1207 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 1208 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1209 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1210 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1211 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1212 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1213 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1214 | 
            +
                            plain tuple.
         | 
| 1215 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1216 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1217 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1218 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1219 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1220 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1221 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1222 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1223 | 
            +
                            called at every step.
         | 
| 1224 | 
            +
                    Returns:
         | 
| 1225 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1226 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1227 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1228 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1229 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1230 | 
            +
                    """
         | 
| 1231 | 
            +
                    return self.__call__(
         | 
| 1232 | 
            +
                        prompt=prompt,
         | 
| 1233 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1234 | 
            +
                        image=image,
         | 
| 1235 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1236 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1237 | 
            +
                        strength=strength,
         | 
| 1238 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1239 | 
            +
                        eta=eta,
         | 
| 1240 | 
            +
                        generator=generator,
         | 
| 1241 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1242 | 
            +
                        output_type=output_type,
         | 
| 1243 | 
            +
                        return_dict=return_dict,
         | 
| 1244 | 
            +
                        callback=callback,
         | 
| 1245 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1246 | 
            +
                        callback_steps=callback_steps,
         | 
| 1247 | 
            +
                    )
         | 
| 1248 | 
            +
             | 
| 1249 | 
            +
                def inpaint(
         | 
| 1250 | 
            +
                    self,
         | 
| 1251 | 
            +
                    image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1252 | 
            +
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image],
         | 
| 1253 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 1254 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 1255 | 
            +
                    strength: float = 0.8,
         | 
| 1256 | 
            +
                    num_inference_steps: Optional[int] = 50,
         | 
| 1257 | 
            +
                    guidance_scale: Optional[float] = 7.5,
         | 
| 1258 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 1259 | 
            +
                    eta: Optional[float] = 0.0,
         | 
| 1260 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 1261 | 
            +
                    max_embeddings_multiples: Optional[int] = 3,
         | 
| 1262 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 1263 | 
            +
                    return_dict: bool = True,
         | 
| 1264 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
| 1265 | 
            +
                    is_cancelled_callback: Optional[Callable[[], bool]] = None,
         | 
| 1266 | 
            +
                    callback_steps: int = 1,
         | 
| 1267 | 
            +
                ):
         | 
| 1268 | 
            +
                    r"""
         | 
| 1269 | 
            +
                    Function for inpaint.
         | 
| 1270 | 
            +
                    Args:
         | 
| 1271 | 
            +
                        image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1272 | 
            +
                            `Image`, or tensor representing an image batch, that will be used as the starting point for the
         | 
| 1273 | 
            +
                            process. This is the image whose masked region will be inpainted.
         | 
| 1274 | 
            +
                        mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
         | 
| 1275 | 
            +
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 1276 | 
            +
                            replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
         | 
| 1277 | 
            +
                            PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
         | 
| 1278 | 
            +
                            contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
         | 
| 1279 | 
            +
                        prompt (`str` or `List[str]`):
         | 
| 1280 | 
            +
                            The prompt or prompts to guide the image generation.
         | 
| 1281 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 1282 | 
            +
                            The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
         | 
| 1283 | 
            +
                            if `guidance_scale` is less than `1`).
         | 
| 1284 | 
            +
                        strength (`float`, *optional*, defaults to 0.8):
         | 
| 1285 | 
            +
                            Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
         | 
| 1286 | 
            +
                            is 1, the denoising process will be run on the masked area for the full number of iterations specified
         | 
| 1287 | 
            +
                            in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
         | 
| 1288 | 
            +
                            noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
         | 
| 1289 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 1290 | 
            +
                            The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
         | 
| 1291 | 
            +
                            the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
         | 
| 1292 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 1293 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 1294 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 1295 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 1296 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 1297 | 
            +
                            usually at the expense of lower image quality.
         | 
| 1298 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 1299 | 
            +
                            The number of images to generate per prompt.
         | 
| 1300 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 1301 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 1302 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 1303 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 1304 | 
            +
                            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
         | 
| 1305 | 
            +
                            deterministic.
         | 
| 1306 | 
            +
                        max_embeddings_multiples (`int`, *optional*, defaults to `3`):
         | 
| 1307 | 
            +
                            The max multiple length of prompt embeddings compared to the max output length of text encoder.
         | 
| 1308 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 1309 | 
            +
                            The output format of the generate image. Choose between
         | 
| 1310 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 1311 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 1312 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 1313 | 
            +
                            plain tuple.
         | 
| 1314 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 1315 | 
            +
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 1316 | 
            +
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 1317 | 
            +
                        is_cancelled_callback (`Callable`, *optional*):
         | 
| 1318 | 
            +
                            A function that will be called every `callback_steps` steps during inference. If the function returns
         | 
| 1319 | 
            +
                            `True`, the inference will be cancelled.
         | 
| 1320 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 1321 | 
            +
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 1322 | 
            +
                            called at every step.
         | 
| 1323 | 
            +
                    Returns:
         | 
| 1324 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 1325 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
         | 
| 1326 | 
            +
                        When returning a tuple, the first element is a list with the generated images, and the second element is a
         | 
| 1327 | 
            +
                        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
         | 
| 1328 | 
            +
                        (nsfw) content, according to the `safety_checker`.
         | 
| 1329 | 
            +
                    """
         | 
| 1330 | 
            +
                    return self.__call__(
         | 
| 1331 | 
            +
                        prompt=prompt,
         | 
| 1332 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1333 | 
            +
                        image=image,
         | 
| 1334 | 
            +
                        mask_image=mask_image,
         | 
| 1335 | 
            +
                        num_inference_steps=num_inference_steps,
         | 
| 1336 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 1337 | 
            +
                        strength=strength,
         | 
| 1338 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1339 | 
            +
                        eta=eta,
         | 
| 1340 | 
            +
                        generator=generator,
         | 
| 1341 | 
            +
                        max_embeddings_multiples=max_embeddings_multiples,
         | 
| 1342 | 
            +
                        output_type=output_type,
         | 
| 1343 | 
            +
                        return_dict=return_dict,
         | 
| 1344 | 
            +
                        callback=callback,
         | 
| 1345 | 
            +
                        is_cancelled_callback=is_cancelled_callback,
         | 
| 1346 | 
            +
                        callback_steps=callback_steps,
         | 
| 1347 | 
            +
                    )
         | 
    	
        library/sdxl_model_util.py
    ADDED
    
    | @@ -0,0 +1,583 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import safetensors
         | 
| 3 | 
            +
            from accelerate import init_empty_weights
         | 
| 4 | 
            +
            from accelerate.utils.modeling import set_module_tensor_to_device
         | 
| 5 | 
            +
            from safetensors.torch import load_file, save_file
         | 
| 6 | 
            +
            from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
         | 
| 7 | 
            +
            from typing import List
         | 
| 8 | 
            +
            from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
         | 
| 9 | 
            +
            from library import model_util
         | 
| 10 | 
            +
            from library import sdxl_original_unet
         | 
| 11 | 
            +
            from .utils import setup_logging
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            setup_logging()
         | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            VAE_SCALE_FACTOR = 0.13025
         | 
| 19 | 
            +
            MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # Diffusersの設定を読み込むための参照モデル
         | 
| 22 | 
            +
            DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            DIFFUSERS_SDXL_UNET_CONFIG = {
         | 
| 25 | 
            +
                "act_fn": "silu",
         | 
| 26 | 
            +
                "addition_embed_type": "text_time",
         | 
| 27 | 
            +
                "addition_embed_type_num_heads": 64,
         | 
| 28 | 
            +
                "addition_time_embed_dim": 256,
         | 
| 29 | 
            +
                "attention_head_dim": [5, 10, 20],
         | 
| 30 | 
            +
                "block_out_channels": [320, 640, 1280],
         | 
| 31 | 
            +
                "center_input_sample": False,
         | 
| 32 | 
            +
                "class_embed_type": None,
         | 
| 33 | 
            +
                "class_embeddings_concat": False,
         | 
| 34 | 
            +
                "conv_in_kernel": 3,
         | 
| 35 | 
            +
                "conv_out_kernel": 3,
         | 
| 36 | 
            +
                "cross_attention_dim": 2048,
         | 
| 37 | 
            +
                "cross_attention_norm": None,
         | 
| 38 | 
            +
                "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
         | 
| 39 | 
            +
                "downsample_padding": 1,
         | 
| 40 | 
            +
                "dual_cross_attention": False,
         | 
| 41 | 
            +
                "encoder_hid_dim": None,
         | 
| 42 | 
            +
                "encoder_hid_dim_type": None,
         | 
| 43 | 
            +
                "flip_sin_to_cos": True,
         | 
| 44 | 
            +
                "freq_shift": 0,
         | 
| 45 | 
            +
                "in_channels": 4,
         | 
| 46 | 
            +
                "layers_per_block": 2,
         | 
| 47 | 
            +
                "mid_block_only_cross_attention": None,
         | 
| 48 | 
            +
                "mid_block_scale_factor": 1,
         | 
| 49 | 
            +
                "mid_block_type": "UNetMidBlock2DCrossAttn",
         | 
| 50 | 
            +
                "norm_eps": 1e-05,
         | 
| 51 | 
            +
                "norm_num_groups": 32,
         | 
| 52 | 
            +
                "num_attention_heads": None,
         | 
| 53 | 
            +
                "num_class_embeds": None,
         | 
| 54 | 
            +
                "only_cross_attention": False,
         | 
| 55 | 
            +
                "out_channels": 4,
         | 
| 56 | 
            +
                "projection_class_embeddings_input_dim": 2816,
         | 
| 57 | 
            +
                "resnet_out_scale_factor": 1.0,
         | 
| 58 | 
            +
                "resnet_skip_time_act": False,
         | 
| 59 | 
            +
                "resnet_time_scale_shift": "default",
         | 
| 60 | 
            +
                "sample_size": 128,
         | 
| 61 | 
            +
                "time_cond_proj_dim": None,
         | 
| 62 | 
            +
                "time_embedding_act_fn": None,
         | 
| 63 | 
            +
                "time_embedding_dim": None,
         | 
| 64 | 
            +
                "time_embedding_type": "positional",
         | 
| 65 | 
            +
                "timestep_post_act": None,
         | 
| 66 | 
            +
                "transformer_layers_per_block": [1, 2, 10],
         | 
| 67 | 
            +
                "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
         | 
| 68 | 
            +
                "upcast_attention": False,
         | 
| 69 | 
            +
                "use_linear_projection": True,
         | 
| 70 | 
            +
            }
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
         | 
| 74 | 
            +
                SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
         | 
| 77 | 
            +
                # logit_scaleはcheckpointの保存時に使用する
         | 
| 78 | 
            +
                def convert_key(key):
         | 
| 79 | 
            +
                    # common conversion
         | 
| 80 | 
            +
                    key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
         | 
| 81 | 
            +
                    key = key.replace(SDXL_KEY_PREFIX, "text_model.")
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    if "resblocks" in key:
         | 
| 84 | 
            +
                        # resblocks conversion
         | 
| 85 | 
            +
                        key = key.replace(".resblocks.", ".layers.")
         | 
| 86 | 
            +
                        if ".ln_" in key:
         | 
| 87 | 
            +
                            key = key.replace(".ln_", ".layer_norm")
         | 
| 88 | 
            +
                        elif ".mlp." in key:
         | 
| 89 | 
            +
                            key = key.replace(".c_fc.", ".fc1.")
         | 
| 90 | 
            +
                            key = key.replace(".c_proj.", ".fc2.")
         | 
| 91 | 
            +
                        elif ".attn.out_proj" in key:
         | 
| 92 | 
            +
                            key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
         | 
| 93 | 
            +
                        elif ".attn.in_proj" in key:
         | 
| 94 | 
            +
                            key = None  # 特殊なので後で処理する
         | 
| 95 | 
            +
                        else:
         | 
| 96 | 
            +
                            raise ValueError(f"unexpected key in SD: {key}")
         | 
| 97 | 
            +
                    elif ".positional_embedding" in key:
         | 
| 98 | 
            +
                        key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
         | 
| 99 | 
            +
                    elif ".text_projection" in key:
         | 
| 100 | 
            +
                        key = key.replace("text_model.text_projection", "text_projection.weight")
         | 
| 101 | 
            +
                    elif ".logit_scale" in key:
         | 
| 102 | 
            +
                        key = None  # 後で処理する
         | 
| 103 | 
            +
                    elif ".token_embedding" in key:
         | 
| 104 | 
            +
                        key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
         | 
| 105 | 
            +
                    elif ".ln_final" in key:
         | 
| 106 | 
            +
                        key = key.replace(".ln_final", ".final_layer_norm")
         | 
| 107 | 
            +
                    # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
         | 
| 108 | 
            +
                    elif ".embeddings.position_ids" in key:
         | 
| 109 | 
            +
                        key = None  # remove this key: position_ids is not used in newer transformers
         | 
| 110 | 
            +
                    return key
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                keys = list(checkpoint.keys())
         | 
| 113 | 
            +
                new_sd = {}
         | 
| 114 | 
            +
                for key in keys:
         | 
| 115 | 
            +
                    new_key = convert_key(key)
         | 
| 116 | 
            +
                    if new_key is None:
         | 
| 117 | 
            +
                        continue
         | 
| 118 | 
            +
                    new_sd[new_key] = checkpoint[key]
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                # attnの変換
         | 
| 121 | 
            +
                for key in keys:
         | 
| 122 | 
            +
                    if ".resblocks" in key and ".attn.in_proj_" in key:
         | 
| 123 | 
            +
                        # 三つに分割
         | 
| 124 | 
            +
                        values = torch.chunk(checkpoint[key], 3)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        key_suffix = ".weight" if "weight" in key else ".bias"
         | 
| 127 | 
            +
                        key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
         | 
| 128 | 
            +
                        key_pfx = key_pfx.replace("_weight", "")
         | 
| 129 | 
            +
                        key_pfx = key_pfx.replace("_bias", "")
         | 
| 130 | 
            +
                        key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
         | 
| 131 | 
            +
                        new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
         | 
| 132 | 
            +
                        new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
         | 
| 133 | 
            +
                        new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
         | 
| 136 | 
            +
                logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                # temporary workaround for text_projection.weight.weight for Playground-v2
         | 
| 139 | 
            +
                if "text_projection.weight.weight" in new_sd:
         | 
| 140 | 
            +
                    logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
         | 
| 141 | 
            +
                    new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
         | 
| 142 | 
            +
                    del new_sd["text_projection.weight.weight"]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                return new_sd, logit_scale
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            # load state_dict without allocating new tensors
         | 
| 148 | 
            +
            def _load_state_dict_on_device(model, state_dict, device, dtype=None):
         | 
| 149 | 
            +
                # dtype will use fp32 as default
         | 
| 150 | 
            +
                missing_keys = list(model.state_dict().keys() - state_dict.keys())
         | 
| 151 | 
            +
                unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                # similar to model.load_state_dict()
         | 
| 154 | 
            +
                if not missing_keys and not unexpected_keys:
         | 
| 155 | 
            +
                    for k in list(state_dict.keys()):
         | 
| 156 | 
            +
                        set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
         | 
| 157 | 
            +
                    return "<All keys matched successfully>"
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # error_msgs
         | 
| 160 | 
            +
                error_msgs: List[str] = []
         | 
| 161 | 
            +
                if missing_keys:
         | 
| 162 | 
            +
                    error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
         | 
| 163 | 
            +
                if unexpected_keys:
         | 
| 164 | 
            +
                    error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
         | 
| 170 | 
            +
                # model_version is reserved for future use
         | 
| 171 | 
            +
                # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                # Load the state dict
         | 
| 174 | 
            +
                if model_util.is_safetensors(ckpt_path):
         | 
| 175 | 
            +
                    checkpoint = None
         | 
| 176 | 
            +
                    if disable_mmap:
         | 
| 177 | 
            +
                        state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        try:
         | 
| 180 | 
            +
                            state_dict = load_file(ckpt_path, device=map_location)
         | 
| 181 | 
            +
                        except:
         | 
| 182 | 
            +
                            state_dict = load_file(ckpt_path)  # prevent device invalid Error
         | 
| 183 | 
            +
                    epoch = None
         | 
| 184 | 
            +
                    global_step = None
         | 
| 185 | 
            +
                else:
         | 
| 186 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location=map_location)
         | 
| 187 | 
            +
                    if "state_dict" in checkpoint:
         | 
| 188 | 
            +
                        state_dict = checkpoint["state_dict"]
         | 
| 189 | 
            +
                        epoch = checkpoint.get("epoch", 0)
         | 
| 190 | 
            +
                        global_step = checkpoint.get("global_step", 0)
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        state_dict = checkpoint
         | 
| 193 | 
            +
                        epoch = 0
         | 
| 194 | 
            +
                        global_step = 0
         | 
| 195 | 
            +
                    checkpoint = None
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                # U-Net
         | 
| 198 | 
            +
                logger.info("building U-Net")
         | 
| 199 | 
            +
                with init_empty_weights():
         | 
| 200 | 
            +
                    unet = sdxl_original_unet.SdxlUNet2DConditionModel()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                logger.info("loading U-Net from checkpoint")
         | 
| 203 | 
            +
                unet_sd = {}
         | 
| 204 | 
            +
                for k in list(state_dict.keys()):
         | 
| 205 | 
            +
                    if k.startswith("model.diffusion_model."):
         | 
| 206 | 
            +
                        unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
         | 
| 207 | 
            +
                info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
         | 
| 208 | 
            +
                logger.info(f"U-Net: {info}")
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                # Text Encoders
         | 
| 211 | 
            +
                logger.info("building text encoders")
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                # Text Encoder 1 is same to Stability AI's SDXL
         | 
| 214 | 
            +
                text_model1_cfg = CLIPTextConfig(
         | 
| 215 | 
            +
                    vocab_size=49408,
         | 
| 216 | 
            +
                    hidden_size=768,
         | 
| 217 | 
            +
                    intermediate_size=3072,
         | 
| 218 | 
            +
                    num_hidden_layers=12,
         | 
| 219 | 
            +
                    num_attention_heads=12,
         | 
| 220 | 
            +
                    max_position_embeddings=77,
         | 
| 221 | 
            +
                    hidden_act="quick_gelu",
         | 
| 222 | 
            +
                    layer_norm_eps=1e-05,
         | 
| 223 | 
            +
                    dropout=0.0,
         | 
| 224 | 
            +
                    attention_dropout=0.0,
         | 
| 225 | 
            +
                    initializer_range=0.02,
         | 
| 226 | 
            +
                    initializer_factor=1.0,
         | 
| 227 | 
            +
                    pad_token_id=1,
         | 
| 228 | 
            +
                    bos_token_id=0,
         | 
| 229 | 
            +
                    eos_token_id=2,
         | 
| 230 | 
            +
                    model_type="clip_text_model",
         | 
| 231 | 
            +
                    projection_dim=768,
         | 
| 232 | 
            +
                    # torch_dtype="float32",
         | 
| 233 | 
            +
                    # transformers_version="4.25.0.dev0",
         | 
| 234 | 
            +
                )
         | 
| 235 | 
            +
                with init_empty_weights():
         | 
| 236 | 
            +
                    text_model1 = CLIPTextModel._from_config(text_model1_cfg)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
         | 
| 239 | 
            +
                # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
         | 
| 240 | 
            +
                text_model2_cfg = CLIPTextConfig(
         | 
| 241 | 
            +
                    vocab_size=49408,
         | 
| 242 | 
            +
                    hidden_size=1280,
         | 
| 243 | 
            +
                    intermediate_size=5120,
         | 
| 244 | 
            +
                    num_hidden_layers=32,
         | 
| 245 | 
            +
                    num_attention_heads=20,
         | 
| 246 | 
            +
                    max_position_embeddings=77,
         | 
| 247 | 
            +
                    hidden_act="gelu",
         | 
| 248 | 
            +
                    layer_norm_eps=1e-05,
         | 
| 249 | 
            +
                    dropout=0.0,
         | 
| 250 | 
            +
                    attention_dropout=0.0,
         | 
| 251 | 
            +
                    initializer_range=0.02,
         | 
| 252 | 
            +
                    initializer_factor=1.0,
         | 
| 253 | 
            +
                    pad_token_id=1,
         | 
| 254 | 
            +
                    bos_token_id=0,
         | 
| 255 | 
            +
                    eos_token_id=2,
         | 
| 256 | 
            +
                    model_type="clip_text_model",
         | 
| 257 | 
            +
                    projection_dim=1280,
         | 
| 258 | 
            +
                    # torch_dtype="float32",
         | 
| 259 | 
            +
                    # transformers_version="4.25.0.dev0",
         | 
| 260 | 
            +
                )
         | 
| 261 | 
            +
                with init_empty_weights():
         | 
| 262 | 
            +
                    text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                logger.info("loading text encoders from checkpoint")
         | 
| 265 | 
            +
                te1_sd = {}
         | 
| 266 | 
            +
                te2_sd = {}
         | 
| 267 | 
            +
                for k in list(state_dict.keys()):
         | 
| 268 | 
            +
                    if k.startswith("conditioner.embedders.0.transformer."):
         | 
| 269 | 
            +
                        te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
         | 
| 270 | 
            +
                    elif k.startswith("conditioner.embedders.1.model."):
         | 
| 271 | 
            +
                        te2_sd[k] = state_dict.pop(k)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
         | 
| 274 | 
            +
                if "text_model.embeddings.position_ids" in te1_sd:
         | 
| 275 | 
            +
                    te1_sd.pop("text_model.embeddings.position_ids")
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location)  # remain fp32
         | 
| 278 | 
            +
                logger.info(f"text encoder 1: {info1}")
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
         | 
| 281 | 
            +
                info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location)  # remain fp32
         | 
| 282 | 
            +
                logger.info(f"text encoder 2: {info2}")
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                # prepare vae
         | 
| 285 | 
            +
                logger.info("building VAE")
         | 
| 286 | 
            +
                vae_config = model_util.create_vae_diffusers_config()
         | 
| 287 | 
            +
                with init_empty_weights():
         | 
| 288 | 
            +
                    vae = AutoencoderKL(**vae_config)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                logger.info("loading VAE from checkpoint")
         | 
| 291 | 
            +
                converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
         | 
| 292 | 
            +
                info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
         | 
| 293 | 
            +
                logger.info(f"VAE: {info}")
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                ckpt_info = (epoch, global_step) if epoch is not None else None
         | 
| 296 | 
            +
                return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
         | 
| 297 | 
            +
             | 
| 298 | 
            +
             | 
| 299 | 
            +
            def make_unet_conversion_map():
         | 
| 300 | 
            +
                unet_conversion_map_layer = []
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                for i in range(3):  # num_blocks is 3 in sdxl
         | 
| 303 | 
            +
                    # loop over downblocks/upblocks
         | 
| 304 | 
            +
                    for j in range(2):
         | 
| 305 | 
            +
                        # loop over resnets/attentions for downblocks
         | 
| 306 | 
            +
                        hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
         | 
| 307 | 
            +
                        sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
         | 
| 308 | 
            +
                        unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                        if i < 3:
         | 
| 311 | 
            +
                            # no attention layers in down_blocks.3
         | 
| 312 | 
            +
                            hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
         | 
| 313 | 
            +
                            sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
         | 
| 314 | 
            +
                            unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    for j in range(3):
         | 
| 317 | 
            +
                        # loop over resnets/attentions for upblocks
         | 
| 318 | 
            +
                        hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
         | 
| 319 | 
            +
                        sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
         | 
| 320 | 
            +
                        unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                        # if i > 0: commentout for sdxl
         | 
| 323 | 
            +
                        # no attention layers in up_blocks.0
         | 
| 324 | 
            +
                        hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
         | 
| 325 | 
            +
                        sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
         | 
| 326 | 
            +
                        unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    if i < 3:
         | 
| 329 | 
            +
                        # no downsample in down_blocks.3
         | 
| 330 | 
            +
                        hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
         | 
| 331 | 
            +
                        sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
         | 
| 332 | 
            +
                        unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        # no upsample in up_blocks.3
         | 
| 335 | 
            +
                        hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
         | 
| 336 | 
            +
                        sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}."  # change for sdxl
         | 
| 337 | 
            +
                        unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                hf_mid_atn_prefix = "mid_block.attentions.0."
         | 
| 340 | 
            +
                sd_mid_atn_prefix = "middle_block.1."
         | 
| 341 | 
            +
                unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                for j in range(2):
         | 
| 344 | 
            +
                    hf_mid_res_prefix = f"mid_block.resnets.{j}."
         | 
| 345 | 
            +
                    sd_mid_res_prefix = f"middle_block.{2*j}."
         | 
| 346 | 
            +
                    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                unet_conversion_map_resnet = [
         | 
| 349 | 
            +
                    # (stable-diffusion, HF Diffusers)
         | 
| 350 | 
            +
                    ("in_layers.0.", "norm1."),
         | 
| 351 | 
            +
                    ("in_layers.2.", "conv1."),
         | 
| 352 | 
            +
                    ("out_layers.0.", "norm2."),
         | 
| 353 | 
            +
                    ("out_layers.3.", "conv2."),
         | 
| 354 | 
            +
                    ("emb_layers.1.", "time_emb_proj."),
         | 
| 355 | 
            +
                    ("skip_connection.", "conv_shortcut."),
         | 
| 356 | 
            +
                ]
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                unet_conversion_map = []
         | 
| 359 | 
            +
                for sd, hf in unet_conversion_map_layer:
         | 
| 360 | 
            +
                    if "resnets" in hf:
         | 
| 361 | 
            +
                        for sd_res, hf_res in unet_conversion_map_resnet:
         | 
| 362 | 
            +
                            unet_conversion_map.append((sd + sd_res, hf + hf_res))
         | 
| 363 | 
            +
                    else:
         | 
| 364 | 
            +
                        unet_conversion_map.append((sd, hf))
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                for j in range(2):
         | 
| 367 | 
            +
                    hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
         | 
| 368 | 
            +
                    sd_time_embed_prefix = f"time_embed.{j*2}."
         | 
| 369 | 
            +
                    unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                for j in range(2):
         | 
| 372 | 
            +
                    hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
         | 
| 373 | 
            +
                    sd_label_embed_prefix = f"label_emb.0.{j*2}."
         | 
| 374 | 
            +
                    unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
         | 
| 377 | 
            +
                unet_conversion_map.append(("out.0.", "conv_norm_out."))
         | 
| 378 | 
            +
                unet_conversion_map.append(("out.2.", "conv_out."))
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                return unet_conversion_map
         | 
| 381 | 
            +
             | 
| 382 | 
            +
             | 
| 383 | 
            +
            def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
         | 
| 384 | 
            +
                unet_conversion_map = make_unet_conversion_map()
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                conversion_map = {hf: sd for sd, hf in unet_conversion_map}
         | 
| 387 | 
            +
                return convert_unet_state_dict(du_sd, conversion_map)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
            +
            def convert_unet_state_dict(src_sd, conversion_map):
         | 
| 391 | 
            +
                converted_sd = {}
         | 
| 392 | 
            +
                for src_key, value in src_sd.items():
         | 
| 393 | 
            +
                    # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
         | 
| 394 | 
            +
                    src_key_fragments = src_key.split(".")[:-1]  # remove weight/bias
         | 
| 395 | 
            +
                    while len(src_key_fragments) > 0:
         | 
| 396 | 
            +
                        src_key_prefix = ".".join(src_key_fragments) + "."
         | 
| 397 | 
            +
                        if src_key_prefix in conversion_map:
         | 
| 398 | 
            +
                            converted_prefix = conversion_map[src_key_prefix]
         | 
| 399 | 
            +
                            converted_key = converted_prefix + src_key[len(src_key_prefix) :]
         | 
| 400 | 
            +
                            converted_sd[converted_key] = value
         | 
| 401 | 
            +
                            break
         | 
| 402 | 
            +
                        src_key_fragments.pop(-1)
         | 
| 403 | 
            +
                    assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                return converted_sd
         | 
| 406 | 
            +
             | 
| 407 | 
            +
             | 
| 408 | 
            +
            def convert_sdxl_unet_state_dict_to_diffusers(sd):
         | 
| 409 | 
            +
                unet_conversion_map = make_unet_conversion_map()
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
         | 
| 412 | 
            +
                return convert_unet_state_dict(sd, conversion_dict)
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 415 | 
            +
            def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
         | 
| 416 | 
            +
                def convert_key(key):
         | 
| 417 | 
            +
                    # position_idsの除去
         | 
| 418 | 
            +
                    if ".position_ids" in key:
         | 
| 419 | 
            +
                        return None
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    # common
         | 
| 422 | 
            +
                    key = key.replace("text_model.encoder.", "transformer.")
         | 
| 423 | 
            +
                    key = key.replace("text_model.", "")
         | 
| 424 | 
            +
                    if "layers" in key:
         | 
| 425 | 
            +
                        # resblocks conversion
         | 
| 426 | 
            +
                        key = key.replace(".layers.", ".resblocks.")
         | 
| 427 | 
            +
                        if ".layer_norm" in key:
         | 
| 428 | 
            +
                            key = key.replace(".layer_norm", ".ln_")
         | 
| 429 | 
            +
                        elif ".mlp." in key:
         | 
| 430 | 
            +
                            key = key.replace(".fc1.", ".c_fc.")
         | 
| 431 | 
            +
                            key = key.replace(".fc2.", ".c_proj.")
         | 
| 432 | 
            +
                        elif ".self_attn.out_proj" in key:
         | 
| 433 | 
            +
                            key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
         | 
| 434 | 
            +
                        elif ".self_attn." in key:
         | 
| 435 | 
            +
                            key = None  # 特殊なので後で処理する
         | 
| 436 | 
            +
                        else:
         | 
| 437 | 
            +
                            raise ValueError(f"unexpected key in DiffUsers model: {key}")
         | 
| 438 | 
            +
                    elif ".position_embedding" in key:
         | 
| 439 | 
            +
                        key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
         | 
| 440 | 
            +
                    elif ".token_embedding" in key:
         | 
| 441 | 
            +
                        key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
         | 
| 442 | 
            +
                    elif "text_projection" in key:  # no dot in key
         | 
| 443 | 
            +
                        key = key.replace("text_projection.weight", "text_projection")
         | 
| 444 | 
            +
                    elif "final_layer_norm" in key:
         | 
| 445 | 
            +
                        key = key.replace("final_layer_norm", "ln_final")
         | 
| 446 | 
            +
                    return key
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                keys = list(checkpoint.keys())
         | 
| 449 | 
            +
                new_sd = {}
         | 
| 450 | 
            +
                for key in keys:
         | 
| 451 | 
            +
                    new_key = convert_key(key)
         | 
| 452 | 
            +
                    if new_key is None:
         | 
| 453 | 
            +
                        continue
         | 
| 454 | 
            +
                    new_sd[new_key] = checkpoint[key]
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                # attnの変換
         | 
| 457 | 
            +
                for key in keys:
         | 
| 458 | 
            +
                    if "layers" in key and "q_proj" in key:
         | 
| 459 | 
            +
                        # 三つを結合
         | 
| 460 | 
            +
                        key_q = key
         | 
| 461 | 
            +
                        key_k = key.replace("q_proj", "k_proj")
         | 
| 462 | 
            +
                        key_v = key.replace("q_proj", "v_proj")
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                        value_q = checkpoint[key_q]
         | 
| 465 | 
            +
                        value_k = checkpoint[key_k]
         | 
| 466 | 
            +
                        value_v = checkpoint[key_v]
         | 
| 467 | 
            +
                        value = torch.cat([value_q, value_k, value_v])
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                        new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
         | 
| 470 | 
            +
                        new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
         | 
| 471 | 
            +
                        new_sd[new_key] = value
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                if logit_scale is not None:
         | 
| 474 | 
            +
                    new_sd["logit_scale"] = logit_scale
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                return new_sd
         | 
| 477 | 
            +
             | 
| 478 | 
            +
             | 
| 479 | 
            +
            def save_stable_diffusion_checkpoint(
         | 
| 480 | 
            +
                output_file,
         | 
| 481 | 
            +
                text_encoder1,
         | 
| 482 | 
            +
                text_encoder2,
         | 
| 483 | 
            +
                unet,
         | 
| 484 | 
            +
                epochs,
         | 
| 485 | 
            +
                steps,
         | 
| 486 | 
            +
                ckpt_info,
         | 
| 487 | 
            +
                vae,
         | 
| 488 | 
            +
                logit_scale,
         | 
| 489 | 
            +
                metadata,
         | 
| 490 | 
            +
                save_dtype=None,
         | 
| 491 | 
            +
            ):
         | 
| 492 | 
            +
                state_dict = {}
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                def update_sd(prefix, sd):
         | 
| 495 | 
            +
                    for k, v in sd.items():
         | 
| 496 | 
            +
                        key = prefix + k
         | 
| 497 | 
            +
                        if save_dtype is not None:
         | 
| 498 | 
            +
                            v = v.detach().clone().to("cpu").to(save_dtype)
         | 
| 499 | 
            +
                        state_dict[key] = v
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                # Convert the UNet model
         | 
| 502 | 
            +
                update_sd("model.diffusion_model.", unet.state_dict())
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                # Convert the text encoders
         | 
| 505 | 
            +
                update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
         | 
| 508 | 
            +
                update_sd("conditioner.embedders.1.model.", text_enc2_dict)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                # Convert the VAE
         | 
| 511 | 
            +
                vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
         | 
| 512 | 
            +
                update_sd("first_stage_model.", vae_dict)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                # Put together new checkpoint
         | 
| 515 | 
            +
                key_count = len(state_dict.keys())
         | 
| 516 | 
            +
                new_ckpt = {"state_dict": state_dict}
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                # epoch and global_step are sometimes not int
         | 
| 519 | 
            +
                if ckpt_info is not None:
         | 
| 520 | 
            +
                    epochs += ckpt_info[0]
         | 
| 521 | 
            +
                    steps += ckpt_info[1]
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                new_ckpt["epoch"] = epochs
         | 
| 524 | 
            +
                new_ckpt["global_step"] = steps
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                if model_util.is_safetensors(output_file):
         | 
| 527 | 
            +
                    save_file(state_dict, output_file, metadata)
         | 
| 528 | 
            +
                else:
         | 
| 529 | 
            +
                    torch.save(new_ckpt, output_file)
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                return key_count
         | 
| 532 | 
            +
             | 
| 533 | 
            +
             | 
| 534 | 
            +
            def save_diffusers_checkpoint(
         | 
| 535 | 
            +
                output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
         | 
| 536 | 
            +
            ):
         | 
| 537 | 
            +
                from diffusers import StableDiffusionXLPipeline
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                # convert U-Net
         | 
| 540 | 
            +
                unet_sd = unet.state_dict()
         | 
| 541 | 
            +
                du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
         | 
| 544 | 
            +
                if save_dtype is not None:
         | 
| 545 | 
            +
                    diffusers_unet.to(save_dtype)
         | 
| 546 | 
            +
                diffusers_unet.load_state_dict(du_unet_sd)
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                # create pipeline to save
         | 
| 549 | 
            +
                if pretrained_model_name_or_path is None:
         | 
| 550 | 
            +
                    pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
         | 
| 553 | 
            +
                tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
         | 
| 554 | 
            +
                tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
         | 
| 555 | 
            +
                if vae is None:
         | 
| 556 | 
            +
                    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                # prevent local path from being saved
         | 
| 559 | 
            +
                def remove_name_or_path(model):
         | 
| 560 | 
            +
                    if hasattr(model, "config"):
         | 
| 561 | 
            +
                        model.config._name_or_path = None
         | 
| 562 | 
            +
                        model.config._name_or_path = None
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                remove_name_or_path(diffusers_unet)
         | 
| 565 | 
            +
                remove_name_or_path(text_encoder1)
         | 
| 566 | 
            +
                remove_name_or_path(text_encoder2)
         | 
| 567 | 
            +
                remove_name_or_path(scheduler)
         | 
| 568 | 
            +
                remove_name_or_path(tokenizer1)
         | 
| 569 | 
            +
                remove_name_or_path(tokenizer2)
         | 
| 570 | 
            +
                remove_name_or_path(vae)
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                pipeline = StableDiffusionXLPipeline(
         | 
| 573 | 
            +
                    unet=diffusers_unet,
         | 
| 574 | 
            +
                    text_encoder=text_encoder1,
         | 
| 575 | 
            +
                    text_encoder_2=text_encoder2,
         | 
| 576 | 
            +
                    vae=vae,
         | 
| 577 | 
            +
                    scheduler=scheduler,
         | 
| 578 | 
            +
                    tokenizer=tokenizer1,
         | 
| 579 | 
            +
                    tokenizer_2=tokenizer2,
         | 
| 580 | 
            +
                )
         | 
| 581 | 
            +
                if save_dtype is not None:
         | 
| 582 | 
            +
                    pipeline.to(None, save_dtype)
         | 
| 583 | 
            +
                pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
         | 
    	
        library/sdxl_original_unet.py
    ADDED
    
    | @@ -0,0 +1,1286 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Diffusersのコードをベースとした sd_xl_baseのU-Net
         | 
| 2 | 
            +
            # state dictの形式をSDXLに合わせてある
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         | 
| 6 | 
            +
                  params:
         | 
| 7 | 
            +
                    adm_in_channels: 2816
         | 
| 8 | 
            +
                    num_classes: sequential
         | 
| 9 | 
            +
                    use_checkpoint: True
         | 
| 10 | 
            +
                    in_channels: 4
         | 
| 11 | 
            +
                    out_channels: 4
         | 
| 12 | 
            +
                    model_channels: 320
         | 
| 13 | 
            +
                    attention_resolutions: [4, 2]
         | 
| 14 | 
            +
                    num_res_blocks: 2
         | 
| 15 | 
            +
                    channel_mult: [1, 2, 4]
         | 
| 16 | 
            +
                    num_head_channels: 64
         | 
| 17 | 
            +
                    use_spatial_transformer: True
         | 
| 18 | 
            +
                    use_linear_in_transformer: True
         | 
| 19 | 
            +
                    transformer_depth: [1, 2, 10]  # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
         | 
| 20 | 
            +
                    context_dim: 2048
         | 
| 21 | 
            +
                    spatial_transformer_attn_type: softmax-xformers
         | 
| 22 | 
            +
                    legacy: False
         | 
| 23 | 
            +
            """
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            import math
         | 
| 26 | 
            +
            from types import SimpleNamespace
         | 
| 27 | 
            +
            from typing import Any, Optional
         | 
| 28 | 
            +
            import torch
         | 
| 29 | 
            +
            import torch.utils.checkpoint
         | 
| 30 | 
            +
            from torch import nn
         | 
| 31 | 
            +
            from torch.nn import functional as F
         | 
| 32 | 
            +
            from einops import rearrange
         | 
| 33 | 
            +
            from .utils import setup_logging
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            setup_logging()
         | 
| 36 | 
            +
            import logging
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            IN_CHANNELS: int = 4
         | 
| 41 | 
            +
            OUT_CHANNELS: int = 4
         | 
| 42 | 
            +
            ADM_IN_CHANNELS: int = 2816
         | 
| 43 | 
            +
            CONTEXT_DIM: int = 2048
         | 
| 44 | 
            +
            MODEL_CHANNELS: int = 320
         | 
| 45 | 
            +
            TIME_EMBED_DIM = 320 * 4
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            USE_REENTRANT = True
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            # region memory efficient attention
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            # FlashAttentionを使うCrossAttention
         | 
| 52 | 
            +
            # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
         | 
| 53 | 
            +
            # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            # constants
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            EPSILON = 1e-6
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            # helper functions
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def exists(val):
         | 
| 63 | 
            +
                return val is not None
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def default(val, d):
         | 
| 67 | 
            +
                return val if exists(val) else d
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            # flash attention forwards and backwards
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            # https://arxiv.org/abs/2205.14135
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            class FlashAttentionFunction(torch.autograd.Function):
         | 
| 76 | 
            +
                @staticmethod
         | 
| 77 | 
            +
                @torch.no_grad()
         | 
| 78 | 
            +
                def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
         | 
| 79 | 
            +
                    """Algorithm 2 in the paper"""
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    device = q.device
         | 
| 82 | 
            +
                    dtype = q.dtype
         | 
| 83 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 84 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    o = torch.zeros_like(q)
         | 
| 87 | 
            +
                    all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
         | 
| 88 | 
            +
                    all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    scale = q.shape[-1] ** -0.5
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if not exists(mask):
         | 
| 93 | 
            +
                        mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        mask = rearrange(mask, "b n -> b 1 1 n")
         | 
| 96 | 
            +
                        mask = mask.split(q_bucket_size, dim=-1)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    row_splits = zip(
         | 
| 99 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 100 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 101 | 
            +
                        mask,
         | 
| 102 | 
            +
                        all_row_sums.split(q_bucket_size, dim=-2),
         | 
| 103 | 
            +
                        all_row_maxes.split(q_bucket_size, dim=-2),
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
         | 
| 107 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        col_splits = zip(
         | 
| 110 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 111 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 112 | 
            +
                        )
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                        for k_ind, (kc, vc) in enumerate(col_splits):
         | 
| 115 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                            attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                            if exists(row_mask):
         | 
| 120 | 
            +
                                attn_weights.masked_fill_(~row_mask, max_neg_value)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 123 | 
            +
                                causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
         | 
| 124 | 
            +
                                    q_start_index - k_start_index + 1
         | 
| 125 | 
            +
                                )
         | 
| 126 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                            block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
         | 
| 129 | 
            +
                            attn_weights -= block_row_maxes
         | 
| 130 | 
            +
                            exp_weights = torch.exp(attn_weights)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                            if exists(row_mask):
         | 
| 133 | 
            +
                                exp_weights.masked_fill_(~row_mask, 0.0)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                            block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                            new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                            exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                            exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
         | 
| 142 | 
            +
                            exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                            new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                            oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                            row_maxes.copy_(new_row_maxes)
         | 
| 149 | 
            +
                            row_sums.copy_(new_row_sums)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
         | 
| 152 | 
            +
                    ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    return o
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                @staticmethod
         | 
| 157 | 
            +
                @torch.no_grad()
         | 
| 158 | 
            +
                def backward(ctx, do):
         | 
| 159 | 
            +
                    """Algorithm 4 in the paper"""
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
         | 
| 162 | 
            +
                    q, k, v, o, l, m = ctx.saved_tensors
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    device = q.device
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    max_neg_value = -torch.finfo(q.dtype).max
         | 
| 167 | 
            +
                    qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    dq = torch.zeros_like(q)
         | 
| 170 | 
            +
                    dk = torch.zeros_like(k)
         | 
| 171 | 
            +
                    dv = torch.zeros_like(v)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    row_splits = zip(
         | 
| 174 | 
            +
                        q.split(q_bucket_size, dim=-2),
         | 
| 175 | 
            +
                        o.split(q_bucket_size, dim=-2),
         | 
| 176 | 
            +
                        do.split(q_bucket_size, dim=-2),
         | 
| 177 | 
            +
                        mask,
         | 
| 178 | 
            +
                        l.split(q_bucket_size, dim=-2),
         | 
| 179 | 
            +
                        m.split(q_bucket_size, dim=-2),
         | 
| 180 | 
            +
                        dq.split(q_bucket_size, dim=-2),
         | 
| 181 | 
            +
                    )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
         | 
| 184 | 
            +
                        q_start_index = ind * q_bucket_size - qk_len_diff
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                        col_splits = zip(
         | 
| 187 | 
            +
                            k.split(k_bucket_size, dim=-2),
         | 
| 188 | 
            +
                            v.split(k_bucket_size, dim=-2),
         | 
| 189 | 
            +
                            dk.split(k_bucket_size, dim=-2),
         | 
| 190 | 
            +
                            dv.split(k_bucket_size, dim=-2),
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
         | 
| 194 | 
            +
                            k_start_index = k_ind * k_bucket_size
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                            attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                            if causal and q_start_index < (k_start_index + k_bucket_size - 1):
         | 
| 199 | 
            +
                                causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
         | 
| 200 | 
            +
                                    q_start_index - k_start_index + 1
         | 
| 201 | 
            +
                                )
         | 
| 202 | 
            +
                                attn_weights.masked_fill_(causal_mask, max_neg_value)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                            exp_attn_weights = torch.exp(attn_weights - mc)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                            if exists(row_mask):
         | 
| 207 | 
            +
                                exp_attn_weights.masked_fill_(~row_mask, 0.0)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                            p = exp_attn_weights / lc
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                            dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
         | 
| 212 | 
            +
                            dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                            D = (doc * oc).sum(dim=-1, keepdims=True)
         | 
| 215 | 
            +
                            ds = p * scale * (dp - D)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                            dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
         | 
| 218 | 
            +
                            dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                            dqc.add_(dq_chunk)
         | 
| 221 | 
            +
                            dkc.add_(dk_chunk)
         | 
| 222 | 
            +
                            dvc.add_(dv_chunk)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    return dq, dk, dv, None, None, None, None
         | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
            # endregion
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            def get_parameter_dtype(parameter: torch.nn.Module):
         | 
| 231 | 
            +
                return next(parameter.parameters()).dtype
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            def get_parameter_device(parameter: torch.nn.Module):
         | 
| 235 | 
            +
                return next(parameter.parameters()).device
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            def get_timestep_embedding(
         | 
| 239 | 
            +
                timesteps: torch.Tensor,
         | 
| 240 | 
            +
                embedding_dim: int,
         | 
| 241 | 
            +
                downscale_freq_shift: float = 1,
         | 
| 242 | 
            +
                scale: float = 1,
         | 
| 243 | 
            +
                max_period: int = 10000,
         | 
| 244 | 
            +
            ):
         | 
| 245 | 
            +
                """
         | 
| 246 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 249 | 
            +
                                  These may be fractional.
         | 
| 250 | 
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         | 
| 251 | 
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         | 
| 252 | 
            +
                """
         | 
| 253 | 
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                half_dim = embedding_dim // 2
         | 
| 256 | 
            +
                exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
         | 
| 257 | 
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                emb = torch.exp(exponent)
         | 
| 260 | 
            +
                emb = timesteps[:, None].float() * emb[None, :]
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                # scale embeddings
         | 
| 263 | 
            +
                emb = scale * emb
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
         | 
| 266 | 
            +
                emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                # zero pad
         | 
| 269 | 
            +
                if embedding_dim % 2 == 1:
         | 
| 270 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 271 | 
            +
                return emb
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            # Deep Shrink: We do not common this function, because minimize dependencies.
         | 
| 275 | 
            +
            def resize_like(x, target, mode="bicubic", align_corners=False):
         | 
| 276 | 
            +
                org_dtype = x.dtype
         | 
| 277 | 
            +
                if org_dtype == torch.bfloat16:
         | 
| 278 | 
            +
                    x = x.to(torch.float32)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                if x.shape[-2:] != target.shape[-2:]:
         | 
| 281 | 
            +
                    if mode == "nearest":
         | 
| 282 | 
            +
                        x = F.interpolate(x, size=target.shape[-2:], mode=mode)
         | 
| 283 | 
            +
                    else:
         | 
| 284 | 
            +
                        x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                if org_dtype == torch.bfloat16:
         | 
| 287 | 
            +
                    x = x.to(org_dtype)
         | 
| 288 | 
            +
                return x
         | 
| 289 | 
            +
             | 
| 290 | 
            +
             | 
| 291 | 
            +
            class GroupNorm32(nn.GroupNorm):
         | 
| 292 | 
            +
                def forward(self, x):
         | 
| 293 | 
            +
                    if self.weight.dtype != torch.float32:
         | 
| 294 | 
            +
                        return super().forward(x)
         | 
| 295 | 
            +
                    return super().forward(x.float()).type(x.dtype)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
| 298 | 
            +
            class ResnetBlock2D(nn.Module):
         | 
| 299 | 
            +
                def __init__(
         | 
| 300 | 
            +
                    self,
         | 
| 301 | 
            +
                    in_channels,
         | 
| 302 | 
            +
                    out_channels,
         | 
| 303 | 
            +
                ):
         | 
| 304 | 
            +
                    super().__init__()
         | 
| 305 | 
            +
                    self.in_channels = in_channels
         | 
| 306 | 
            +
                    self.out_channels = out_channels
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    self.in_layers = nn.Sequential(
         | 
| 309 | 
            +
                        GroupNorm32(32, in_channels),
         | 
| 310 | 
            +
                        nn.SiLU(),
         | 
| 311 | 
            +
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
         | 
| 312 | 
            +
                    )
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    self.out_layers = nn.Sequential(
         | 
| 317 | 
            +
                        GroupNorm32(32, out_channels),
         | 
| 318 | 
            +
                        nn.SiLU(),
         | 
| 319 | 
            +
                        nn.Identity(),  # to make state_dict compatible with original model
         | 
| 320 | 
            +
                        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
         | 
| 321 | 
            +
                    )
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    if in_channels != out_channels:
         | 
| 324 | 
            +
                        self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
         | 
| 325 | 
            +
                    else:
         | 
| 326 | 
            +
                        self.skip_connection = nn.Identity()
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    self.gradient_checkpointing = False
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def forward_body(self, x, emb):
         | 
| 331 | 
            +
                    h = self.in_layers(x)
         | 
| 332 | 
            +
                    emb_out = self.emb_layers(emb).type(h.dtype)
         | 
| 333 | 
            +
                    h = h + emb_out[:, :, None, None]
         | 
| 334 | 
            +
                    h = self.out_layers(h)
         | 
| 335 | 
            +
                    x = self.skip_connection(x)
         | 
| 336 | 
            +
                    return x + h
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                def forward(self, x, emb):
         | 
| 339 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 340 | 
            +
                        # logger.info("ResnetBlock2D: gradient_checkpointing")
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                        def create_custom_forward(func):
         | 
| 343 | 
            +
                            def custom_forward(*inputs):
         | 
| 344 | 
            +
                                return func(*inputs)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                            return custom_forward
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                        x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
         | 
| 349 | 
            +
                    else:
         | 
| 350 | 
            +
                        x = self.forward_body(x, emb)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    return x
         | 
| 353 | 
            +
             | 
| 354 | 
            +
             | 
| 355 | 
            +
            class Downsample2D(nn.Module):
         | 
| 356 | 
            +
                def __init__(self, channels, out_channels):
         | 
| 357 | 
            +
                    super().__init__()
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    self.channels = channels
         | 
| 360 | 
            +
                    self.out_channels = out_channels
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    self.gradient_checkpointing = False
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def forward_body(self, hidden_states):
         | 
| 367 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 368 | 
            +
                    hidden_states = self.op(hidden_states)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    return hidden_states
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                def forward(self, hidden_states):
         | 
| 373 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 374 | 
            +
                        # logger.info("Downsample2D: gradient_checkpointing")
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                        def create_custom_forward(func):
         | 
| 377 | 
            +
                            def custom_forward(*inputs):
         | 
| 378 | 
            +
                                return func(*inputs)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                            return custom_forward
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                        hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 383 | 
            +
                            create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
         | 
| 384 | 
            +
                        )
         | 
| 385 | 
            +
                    else:
         | 
| 386 | 
            +
                        hidden_states = self.forward_body(hidden_states)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    return hidden_states
         | 
| 389 | 
            +
             | 
| 390 | 
            +
             | 
| 391 | 
            +
            class CrossAttention(nn.Module):
         | 
| 392 | 
            +
                def __init__(
         | 
| 393 | 
            +
                    self,
         | 
| 394 | 
            +
                    query_dim: int,
         | 
| 395 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 396 | 
            +
                    heads: int = 8,
         | 
| 397 | 
            +
                    dim_head: int = 64,
         | 
| 398 | 
            +
                    upcast_attention: bool = False,
         | 
| 399 | 
            +
                ):
         | 
| 400 | 
            +
                    super().__init__()
         | 
| 401 | 
            +
                    inner_dim = dim_head * heads
         | 
| 402 | 
            +
                    cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
         | 
| 403 | 
            +
                    self.upcast_attention = upcast_attention
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 406 | 
            +
                    self.heads = heads
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
         | 
| 409 | 
            +
                    self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
         | 
| 410 | 
            +
                    self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    self.to_out = nn.ModuleList([])
         | 
| 413 | 
            +
                    self.to_out.append(nn.Linear(inner_dim, query_dim))
         | 
| 414 | 
            +
                    # no dropout here
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    self.use_memory_efficient_attention_xformers = False
         | 
| 417 | 
            +
                    self.use_memory_efficient_attention_mem_eff = False
         | 
| 418 | 
            +
                    self.use_sdpa = False
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 421 | 
            +
                    self.use_memory_efficient_attention_xformers = xformers
         | 
| 422 | 
            +
                    self.use_memory_efficient_attention_mem_eff = mem_eff
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 425 | 
            +
                    self.use_sdpa = sdpa
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                def reshape_heads_to_batch_dim(self, tensor):
         | 
| 428 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 429 | 
            +
                    head_size = self.heads
         | 
| 430 | 
            +
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         | 
| 431 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
         | 
| 432 | 
            +
                    return tensor
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                def reshape_batch_dim_to_heads(self, tensor):
         | 
| 435 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 436 | 
            +
                    head_size = self.heads
         | 
| 437 | 
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         | 
| 438 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
         | 
| 439 | 
            +
                    return tensor
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                def forward(self, hidden_states, context=None, mask=None):
         | 
| 442 | 
            +
                    if self.use_memory_efficient_attention_xformers:
         | 
| 443 | 
            +
                        return self.forward_memory_efficient_xformers(hidden_states, context, mask)
         | 
| 444 | 
            +
                    if self.use_memory_efficient_attention_mem_eff:
         | 
| 445 | 
            +
                        return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
         | 
| 446 | 
            +
                    if self.use_sdpa:
         | 
| 447 | 
            +
                        return self.forward_sdpa(hidden_states, context, mask)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    query = self.to_q(hidden_states)
         | 
| 450 | 
            +
                    context = context if context is not None else hidden_states
         | 
| 451 | 
            +
                    key = self.to_k(context)
         | 
| 452 | 
            +
                    value = self.to_v(context)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    query = self.reshape_heads_to_batch_dim(query)
         | 
| 455 | 
            +
                    key = self.reshape_heads_to_batch_dim(key)
         | 
| 456 | 
            +
                    value = self.reshape_heads_to_batch_dim(value)
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    hidden_states = self._attention(query, key, value)
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # linear proj
         | 
| 461 | 
            +
                    hidden_states = self.to_out[0](hidden_states)
         | 
| 462 | 
            +
                    # hidden_states = self.to_out[1](hidden_states)     # no dropout
         | 
| 463 | 
            +
                    return hidden_states
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                def _attention(self, query, key, value):
         | 
| 466 | 
            +
                    if self.upcast_attention:
         | 
| 467 | 
            +
                        query = query.float()
         | 
| 468 | 
            +
                        key = key.float()
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    attention_scores = torch.baddbmm(
         | 
| 471 | 
            +
                        torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
         | 
| 472 | 
            +
                        query,
         | 
| 473 | 
            +
                        key.transpose(-1, -2),
         | 
| 474 | 
            +
                        beta=0,
         | 
| 475 | 
            +
                        alpha=self.scale,
         | 
| 476 | 
            +
                    )
         | 
| 477 | 
            +
                    attention_probs = attention_scores.softmax(dim=-1)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    # cast back to the original dtype
         | 
| 480 | 
            +
                    attention_probs = attention_probs.to(value.dtype)
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    # compute attention output
         | 
| 483 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                    # reshape hidden_states
         | 
| 486 | 
            +
                    hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
         | 
| 487 | 
            +
                    return hidden_states
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                # TODO support Hypernetworks
         | 
| 490 | 
            +
                def forward_memory_efficient_xformers(self, x, context=None, mask=None):
         | 
| 491 | 
            +
                    import xformers.ops
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    h = self.heads
         | 
| 494 | 
            +
                    q_in = self.to_q(x)
         | 
| 495 | 
            +
                    context = context if context is not None else x
         | 
| 496 | 
            +
                    context = context.to(x.dtype)
         | 
| 497 | 
            +
                    k_in = self.to_k(context)
         | 
| 498 | 
            +
                    v_in = self.to_v(context)
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
         | 
| 501 | 
            +
                    del q_in, k_in, v_in
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    q = q.contiguous()
         | 
| 504 | 
            +
                    k = k.contiguous()
         | 
| 505 | 
            +
                    v = v.contiguous()
         | 
| 506 | 
            +
                    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)  # 最適なのを選んでくれる
         | 
| 507 | 
            +
                    del q, k, v
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    out = rearrange(out, "b n h d -> b n (h d)", h=h)
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    out = self.to_out[0](out)
         | 
| 512 | 
            +
                    return out
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
         | 
| 515 | 
            +
                    flash_func = FlashAttentionFunction
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    q_bucket_size = 512
         | 
| 518 | 
            +
                    k_bucket_size = 1024
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    h = self.heads
         | 
| 521 | 
            +
                    q = self.to_q(x)
         | 
| 522 | 
            +
                    context = context if context is not None else x
         | 
| 523 | 
            +
                    context = context.to(x.dtype)
         | 
| 524 | 
            +
                    k = self.to_k(context)
         | 
| 525 | 
            +
                    v = self.to_v(context)
         | 
| 526 | 
            +
                    del context, x
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    out = rearrange(out, "b h n d -> b n (h d)")
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    out = self.to_out[0](out)
         | 
| 535 | 
            +
                    return out
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                def forward_sdpa(self, x, context=None, mask=None):
         | 
| 538 | 
            +
                    h = self.heads
         | 
| 539 | 
            +
                    q_in = self.to_q(x)
         | 
| 540 | 
            +
                    context = context if context is not None else x
         | 
| 541 | 
            +
                    context = context.to(x.dtype)
         | 
| 542 | 
            +
                    k_in = self.to_k(context)
         | 
| 543 | 
            +
                    v_in = self.to_v(context)
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
         | 
| 546 | 
            +
                    del q_in, k_in, v_in
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    out = rearrange(out, "b h n d -> b n (h d)", h=h)
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    out = self.to_out[0](out)
         | 
| 553 | 
            +
                    return out
         | 
| 554 | 
            +
             | 
| 555 | 
            +
             | 
| 556 | 
            +
            # feedforward
         | 
| 557 | 
            +
            class GEGLU(nn.Module):
         | 
| 558 | 
            +
                r"""
         | 
| 559 | 
            +
                A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                Parameters:
         | 
| 562 | 
            +
                    dim_in (`int`): The number of channels in the input.
         | 
| 563 | 
            +
                    dim_out (`int`): The number of channels in the output.
         | 
| 564 | 
            +
                """
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                def __init__(self, dim_in: int, dim_out: int):
         | 
| 567 | 
            +
                    super().__init__()
         | 
| 568 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                def gelu(self, gate):
         | 
| 571 | 
            +
                    if gate.device.type != "mps":
         | 
| 572 | 
            +
                        return F.gelu(gate)
         | 
| 573 | 
            +
                    # mps: gelu is not implemented for float16
         | 
| 574 | 
            +
                    return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                def forward(self, hidden_states):
         | 
| 577 | 
            +
                    hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
         | 
| 578 | 
            +
                    return hidden_states * self.gelu(gate)
         | 
| 579 | 
            +
             | 
| 580 | 
            +
             | 
| 581 | 
            +
            class FeedForward(nn.Module):
         | 
| 582 | 
            +
                def __init__(
         | 
| 583 | 
            +
                    self,
         | 
| 584 | 
            +
                    dim: int,
         | 
| 585 | 
            +
                ):
         | 
| 586 | 
            +
                    super().__init__()
         | 
| 587 | 
            +
                    inner_dim = int(dim * 4)  # mult is always 4
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    self.net = nn.ModuleList([])
         | 
| 590 | 
            +
                    # project in
         | 
| 591 | 
            +
                    self.net.append(GEGLU(dim, inner_dim))
         | 
| 592 | 
            +
                    # project dropout
         | 
| 593 | 
            +
                    self.net.append(nn.Identity())  # nn.Dropout(0)) # dummy for dropout with 0
         | 
| 594 | 
            +
                    # project out
         | 
| 595 | 
            +
                    self.net.append(nn.Linear(inner_dim, dim))
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                def forward(self, hidden_states):
         | 
| 598 | 
            +
                    for module in self.net:
         | 
| 599 | 
            +
                        hidden_states = module(hidden_states)
         | 
| 600 | 
            +
                    return hidden_states
         | 
| 601 | 
            +
             | 
| 602 | 
            +
             | 
| 603 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 604 | 
            +
                def __init__(
         | 
| 605 | 
            +
                    self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
         | 
| 606 | 
            +
                ):
         | 
| 607 | 
            +
                    super().__init__()
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    self.gradient_checkpointing = False
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                    # 1. Self-Attn
         | 
| 612 | 
            +
                    self.attn1 = CrossAttention(
         | 
| 613 | 
            +
                        query_dim=dim,
         | 
| 614 | 
            +
                        cross_attention_dim=None,
         | 
| 615 | 
            +
                        heads=num_attention_heads,
         | 
| 616 | 
            +
                        dim_head=attention_head_dim,
         | 
| 617 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 618 | 
            +
                    )
         | 
| 619 | 
            +
                    self.ff = FeedForward(dim)
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                    # 2. Cross-Attn
         | 
| 622 | 
            +
                    self.attn2 = CrossAttention(
         | 
| 623 | 
            +
                        query_dim=dim,
         | 
| 624 | 
            +
                        cross_attention_dim=cross_attention_dim,
         | 
| 625 | 
            +
                        heads=num_attention_heads,
         | 
| 626 | 
            +
                        dim_head=attention_head_dim,
         | 
| 627 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 628 | 
            +
                    )
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                    self.norm1 = nn.LayerNorm(dim)
         | 
| 631 | 
            +
                    self.norm2 = nn.LayerNorm(dim)
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                    # 3. Feed-forward
         | 
| 634 | 
            +
                    self.norm3 = nn.LayerNorm(dim)
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
         | 
| 637 | 
            +
                    self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 638 | 
            +
                    self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                def set_use_sdpa(self, sdpa: bool):
         | 
| 641 | 
            +
                    self.attn1.set_use_sdpa(sdpa)
         | 
| 642 | 
            +
                    self.attn2.set_use_sdpa(sdpa)
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                def forward_body(self, hidden_states, context=None, timestep=None):
         | 
| 645 | 
            +
                    # 1. Self-Attention
         | 
| 646 | 
            +
                    norm_hidden_states = self.norm1(hidden_states)
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    hidden_states = self.attn1(norm_hidden_states) + hidden_states
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                    # 2. Cross-Attention
         | 
| 651 | 
            +
                    norm_hidden_states = self.norm2(hidden_states)
         | 
| 652 | 
            +
                    hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    # 3. Feed-forward
         | 
| 655 | 
            +
                    hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    return hidden_states
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                def forward(self, hidden_states, context=None, timestep=None):
         | 
| 660 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 661 | 
            +
                        # logger.info("BasicTransformerBlock: checkpointing")
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                        def create_custom_forward(func):
         | 
| 664 | 
            +
                            def custom_forward(*inputs):
         | 
| 665 | 
            +
                                return func(*inputs)
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                            return custom_forward
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                        output = torch.utils.checkpoint.checkpoint(
         | 
| 670 | 
            +
                            create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
         | 
| 671 | 
            +
                        )
         | 
| 672 | 
            +
                    else:
         | 
| 673 | 
            +
                        output = self.forward_body(hidden_states, context, timestep)
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                    return output
         | 
| 676 | 
            +
             | 
| 677 | 
            +
             | 
| 678 | 
            +
            class Transformer2DModel(nn.Module):
         | 
| 679 | 
            +
                def __init__(
         | 
| 680 | 
            +
                    self,
         | 
| 681 | 
            +
                    num_attention_heads: int = 16,
         | 
| 682 | 
            +
                    attention_head_dim: int = 88,
         | 
| 683 | 
            +
                    in_channels: Optional[int] = None,
         | 
| 684 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 685 | 
            +
                    use_linear_projection: bool = False,
         | 
| 686 | 
            +
                    upcast_attention: bool = False,
         | 
| 687 | 
            +
                    num_transformer_layers: int = 1,
         | 
| 688 | 
            +
                ):
         | 
| 689 | 
            +
                    super().__init__()
         | 
| 690 | 
            +
                    self.in_channels = in_channels
         | 
| 691 | 
            +
                    self.num_attention_heads = num_attention_heads
         | 
| 692 | 
            +
                    self.attention_head_dim = attention_head_dim
         | 
| 693 | 
            +
                    inner_dim = num_attention_heads * attention_head_dim
         | 
| 694 | 
            +
                    self.use_linear_projection = use_linear_projection
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                    self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 697 | 
            +
                    # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
         | 
| 698 | 
            +
             | 
| 699 | 
            +
                    if use_linear_projection:
         | 
| 700 | 
            +
                        self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 701 | 
            +
                    else:
         | 
| 702 | 
            +
                        self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    blocks = []
         | 
| 705 | 
            +
                    for _ in range(num_transformer_layers):
         | 
| 706 | 
            +
                        blocks.append(
         | 
| 707 | 
            +
                            BasicTransformerBlock(
         | 
| 708 | 
            +
                                inner_dim,
         | 
| 709 | 
            +
                                num_attention_heads,
         | 
| 710 | 
            +
                                attention_head_dim,
         | 
| 711 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 712 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 713 | 
            +
                            )
         | 
| 714 | 
            +
                        )
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    self.transformer_blocks = nn.ModuleList(blocks)
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                    if use_linear_projection:
         | 
| 719 | 
            +
                        self.proj_out = nn.Linear(in_channels, inner_dim)
         | 
| 720 | 
            +
                    else:
         | 
| 721 | 
            +
                        self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    self.gradient_checkpointing = False
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                def set_use_memory_efficient_attention(self, xformers, mem_eff):
         | 
| 726 | 
            +
                    for transformer in self.transformer_blocks:
         | 
| 727 | 
            +
                        transformer.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                def set_use_sdpa(self, sdpa):
         | 
| 730 | 
            +
                    for transformer in self.transformer_blocks:
         | 
| 731 | 
            +
                        transformer.set_use_sdpa(sdpa)
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
         | 
| 734 | 
            +
                    # 1. Input
         | 
| 735 | 
            +
                    batch, _, height, weight = hidden_states.shape
         | 
| 736 | 
            +
                    residual = hidden_states
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                    hidden_states = self.norm(hidden_states)
         | 
| 739 | 
            +
                    if not self.use_linear_projection:
         | 
| 740 | 
            +
                        hidden_states = self.proj_in(hidden_states)
         | 
| 741 | 
            +
                        inner_dim = hidden_states.shape[1]
         | 
| 742 | 
            +
                        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
         | 
| 743 | 
            +
                    else:
         | 
| 744 | 
            +
                        inner_dim = hidden_states.shape[1]
         | 
| 745 | 
            +
                        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
         | 
| 746 | 
            +
                        hidden_states = self.proj_in(hidden_states)
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                    # 2. Blocks
         | 
| 749 | 
            +
                    for block in self.transformer_blocks:
         | 
| 750 | 
            +
                        hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                    # 3. Output
         | 
| 753 | 
            +
                    if not self.use_linear_projection:
         | 
| 754 | 
            +
                        hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 755 | 
            +
                        hidden_states = self.proj_out(hidden_states)
         | 
| 756 | 
            +
                    else:
         | 
| 757 | 
            +
                        hidden_states = self.proj_out(hidden_states)
         | 
| 758 | 
            +
                        hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                    output = hidden_states + residual
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                    return output
         | 
| 763 | 
            +
             | 
| 764 | 
            +
             | 
| 765 | 
            +
            class Upsample2D(nn.Module):
         | 
| 766 | 
            +
                def __init__(self, channels, out_channels):
         | 
| 767 | 
            +
                    super().__init__()
         | 
| 768 | 
            +
                    self.channels = channels
         | 
| 769 | 
            +
                    self.out_channels = out_channels
         | 
| 770 | 
            +
                    self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                    self.gradient_checkpointing = False
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                def forward_body(self, hidden_states, output_size=None):
         | 
| 775 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                    # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
         | 
| 778 | 
            +
                    # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
         | 
| 779 | 
            +
                    # https://github.com/pytorch/pytorch/issues/86679
         | 
| 780 | 
            +
                    dtype = hidden_states.dtype
         | 
| 781 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 782 | 
            +
                        hidden_states = hidden_states.to(torch.float32)
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                    # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
         | 
| 785 | 
            +
                    if hidden_states.shape[0] >= 64:
         | 
| 786 | 
            +
                        hidden_states = hidden_states.contiguous()
         | 
| 787 | 
            +
             | 
| 788 | 
            +
                    # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
         | 
| 789 | 
            +
                    if output_size is None:
         | 
| 790 | 
            +
                        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
         | 
| 791 | 
            +
                    else:
         | 
| 792 | 
            +
                        hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
         | 
| 793 | 
            +
             | 
| 794 | 
            +
                    # If the input is bfloat16, we cast back to bfloat16
         | 
| 795 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 796 | 
            +
                        hidden_states = hidden_states.to(dtype)
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                    return hidden_states
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                def forward(self, hidden_states, output_size=None):
         | 
| 803 | 
            +
                    if self.training and self.gradient_checkpointing:
         | 
| 804 | 
            +
                        # logger.info("Upsample2D: gradient_checkpointing")
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                        def create_custom_forward(func):
         | 
| 807 | 
            +
                            def custom_forward(*inputs):
         | 
| 808 | 
            +
                                return func(*inputs)
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                            return custom_forward
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                        hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 813 | 
            +
                            create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
         | 
| 814 | 
            +
                        )
         | 
| 815 | 
            +
                    else:
         | 
| 816 | 
            +
                        hidden_states = self.forward_body(hidden_states, output_size)
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                    return hidden_states
         | 
| 819 | 
            +
             | 
| 820 | 
            +
             | 
| 821 | 
            +
            class SdxlUNet2DConditionModel(nn.Module):
         | 
| 822 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                def __init__(
         | 
| 825 | 
            +
                    self,
         | 
| 826 | 
            +
                    **kwargs,
         | 
| 827 | 
            +
                ):
         | 
| 828 | 
            +
                    super().__init__()
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                    self.in_channels = IN_CHANNELS
         | 
| 831 | 
            +
                    self.out_channels = OUT_CHANNELS
         | 
| 832 | 
            +
                    self.model_channels = MODEL_CHANNELS
         | 
| 833 | 
            +
                    self.time_embed_dim = TIME_EMBED_DIM
         | 
| 834 | 
            +
                    self.adm_in_channels = ADM_IN_CHANNELS
         | 
| 835 | 
            +
             | 
| 836 | 
            +
                    self.gradient_checkpointing = False
         | 
| 837 | 
            +
                    # self.sample_size = sample_size
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                    # time embedding
         | 
| 840 | 
            +
                    self.time_embed = nn.Sequential(
         | 
| 841 | 
            +
                        nn.Linear(self.model_channels, self.time_embed_dim),
         | 
| 842 | 
            +
                        nn.SiLU(),
         | 
| 843 | 
            +
                        nn.Linear(self.time_embed_dim, self.time_embed_dim),
         | 
| 844 | 
            +
                    )
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                    # label embedding
         | 
| 847 | 
            +
                    self.label_emb = nn.Sequential(
         | 
| 848 | 
            +
                        nn.Sequential(
         | 
| 849 | 
            +
                            nn.Linear(self.adm_in_channels, self.time_embed_dim),
         | 
| 850 | 
            +
                            nn.SiLU(),
         | 
| 851 | 
            +
                            nn.Linear(self.time_embed_dim, self.time_embed_dim),
         | 
| 852 | 
            +
                        )
         | 
| 853 | 
            +
                    )
         | 
| 854 | 
            +
             | 
| 855 | 
            +
                    # input
         | 
| 856 | 
            +
                    self.input_blocks = nn.ModuleList(
         | 
| 857 | 
            +
                        [
         | 
| 858 | 
            +
                            nn.Sequential(
         | 
| 859 | 
            +
                                nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
         | 
| 860 | 
            +
                            )
         | 
| 861 | 
            +
                        ]
         | 
| 862 | 
            +
                    )
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                    # level 0
         | 
| 865 | 
            +
                    for i in range(2):
         | 
| 866 | 
            +
                        layers = [
         | 
| 867 | 
            +
                            ResnetBlock2D(
         | 
| 868 | 
            +
                                in_channels=1 * self.model_channels,
         | 
| 869 | 
            +
                                out_channels=1 * self.model_channels,
         | 
| 870 | 
            +
                            ),
         | 
| 871 | 
            +
                        ]
         | 
| 872 | 
            +
                        self.input_blocks.append(nn.ModuleList(layers))
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    self.input_blocks.append(
         | 
| 875 | 
            +
                        nn.Sequential(
         | 
| 876 | 
            +
                            Downsample2D(
         | 
| 877 | 
            +
                                channels=1 * self.model_channels,
         | 
| 878 | 
            +
                                out_channels=1 * self.model_channels,
         | 
| 879 | 
            +
                            ),
         | 
| 880 | 
            +
                        )
         | 
| 881 | 
            +
                    )
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                    # level 1
         | 
| 884 | 
            +
                    for i in range(2):
         | 
| 885 | 
            +
                        layers = [
         | 
| 886 | 
            +
                            ResnetBlock2D(
         | 
| 887 | 
            +
                                in_channels=(1 if i == 0 else 2) * self.model_channels,
         | 
| 888 | 
            +
                                out_channels=2 * self.model_channels,
         | 
| 889 | 
            +
                            ),
         | 
| 890 | 
            +
                            Transformer2DModel(
         | 
| 891 | 
            +
                                num_attention_heads=2 * self.model_channels // 64,
         | 
| 892 | 
            +
                                attention_head_dim=64,
         | 
| 893 | 
            +
                                in_channels=2 * self.model_channels,
         | 
| 894 | 
            +
                                num_transformer_layers=2,
         | 
| 895 | 
            +
                                use_linear_projection=True,
         | 
| 896 | 
            +
                                cross_attention_dim=2048,
         | 
| 897 | 
            +
                            ),
         | 
| 898 | 
            +
                        ]
         | 
| 899 | 
            +
                        self.input_blocks.append(nn.ModuleList(layers))
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                    self.input_blocks.append(
         | 
| 902 | 
            +
                        nn.Sequential(
         | 
| 903 | 
            +
                            Downsample2D(
         | 
| 904 | 
            +
                                channels=2 * self.model_channels,
         | 
| 905 | 
            +
                                out_channels=2 * self.model_channels,
         | 
| 906 | 
            +
                            ),
         | 
| 907 | 
            +
                        )
         | 
| 908 | 
            +
                    )
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                    # level 2
         | 
| 911 | 
            +
                    for i in range(2):
         | 
| 912 | 
            +
                        layers = [
         | 
| 913 | 
            +
                            ResnetBlock2D(
         | 
| 914 | 
            +
                                in_channels=(2 if i == 0 else 4) * self.model_channels,
         | 
| 915 | 
            +
                                out_channels=4 * self.model_channels,
         | 
| 916 | 
            +
                            ),
         | 
| 917 | 
            +
                            Transformer2DModel(
         | 
| 918 | 
            +
                                num_attention_heads=4 * self.model_channels // 64,
         | 
| 919 | 
            +
                                attention_head_dim=64,
         | 
| 920 | 
            +
                                in_channels=4 * self.model_channels,
         | 
| 921 | 
            +
                                num_transformer_layers=10,
         | 
| 922 | 
            +
                                use_linear_projection=True,
         | 
| 923 | 
            +
                                cross_attention_dim=2048,
         | 
| 924 | 
            +
                            ),
         | 
| 925 | 
            +
                        ]
         | 
| 926 | 
            +
                        self.input_blocks.append(nn.ModuleList(layers))
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                    # mid
         | 
| 929 | 
            +
                    self.middle_block = nn.ModuleList(
         | 
| 930 | 
            +
                        [
         | 
| 931 | 
            +
                            ResnetBlock2D(
         | 
| 932 | 
            +
                                in_channels=4 * self.model_channels,
         | 
| 933 | 
            +
                                out_channels=4 * self.model_channels,
         | 
| 934 | 
            +
                            ),
         | 
| 935 | 
            +
                            Transformer2DModel(
         | 
| 936 | 
            +
                                num_attention_heads=4 * self.model_channels // 64,
         | 
| 937 | 
            +
                                attention_head_dim=64,
         | 
| 938 | 
            +
                                in_channels=4 * self.model_channels,
         | 
| 939 | 
            +
                                num_transformer_layers=10,
         | 
| 940 | 
            +
                                use_linear_projection=True,
         | 
| 941 | 
            +
                                cross_attention_dim=2048,
         | 
| 942 | 
            +
                            ),
         | 
| 943 | 
            +
                            ResnetBlock2D(
         | 
| 944 | 
            +
                                in_channels=4 * self.model_channels,
         | 
| 945 | 
            +
                                out_channels=4 * self.model_channels,
         | 
| 946 | 
            +
                            ),
         | 
| 947 | 
            +
                        ]
         | 
| 948 | 
            +
                    )
         | 
| 949 | 
            +
             | 
| 950 | 
            +
                    # output
         | 
| 951 | 
            +
                    self.output_blocks = nn.ModuleList([])
         | 
| 952 | 
            +
             | 
| 953 | 
            +
                    # level 2
         | 
| 954 | 
            +
                    for i in range(3):
         | 
| 955 | 
            +
                        layers = [
         | 
| 956 | 
            +
                            ResnetBlock2D(
         | 
| 957 | 
            +
                                in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
         | 
| 958 | 
            +
                                out_channels=4 * self.model_channels,
         | 
| 959 | 
            +
                            ),
         | 
| 960 | 
            +
                            Transformer2DModel(
         | 
| 961 | 
            +
                                num_attention_heads=4 * self.model_channels // 64,
         | 
| 962 | 
            +
                                attention_head_dim=64,
         | 
| 963 | 
            +
                                in_channels=4 * self.model_channels,
         | 
| 964 | 
            +
                                num_transformer_layers=10,
         | 
| 965 | 
            +
                                use_linear_projection=True,
         | 
| 966 | 
            +
                                cross_attention_dim=2048,
         | 
| 967 | 
            +
                            ),
         | 
| 968 | 
            +
                        ]
         | 
| 969 | 
            +
                        if i == 2:
         | 
| 970 | 
            +
                            layers.append(
         | 
| 971 | 
            +
                                Upsample2D(
         | 
| 972 | 
            +
                                    channels=4 * self.model_channels,
         | 
| 973 | 
            +
                                    out_channels=4 * self.model_channels,
         | 
| 974 | 
            +
                                )
         | 
| 975 | 
            +
                            )
         | 
| 976 | 
            +
             | 
| 977 | 
            +
                        self.output_blocks.append(nn.ModuleList(layers))
         | 
| 978 | 
            +
             | 
| 979 | 
            +
                    # level 1
         | 
| 980 | 
            +
                    for i in range(3):
         | 
| 981 | 
            +
                        layers = [
         | 
| 982 | 
            +
                            ResnetBlock2D(
         | 
| 983 | 
            +
                                in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
         | 
| 984 | 
            +
                                out_channels=2 * self.model_channels,
         | 
| 985 | 
            +
                            ),
         | 
| 986 | 
            +
                            Transformer2DModel(
         | 
| 987 | 
            +
                                num_attention_heads=2 * self.model_channels // 64,
         | 
| 988 | 
            +
                                attention_head_dim=64,
         | 
| 989 | 
            +
                                in_channels=2 * self.model_channels,
         | 
| 990 | 
            +
                                num_transformer_layers=2,
         | 
| 991 | 
            +
                                use_linear_projection=True,
         | 
| 992 | 
            +
                                cross_attention_dim=2048,
         | 
| 993 | 
            +
                            ),
         | 
| 994 | 
            +
                        ]
         | 
| 995 | 
            +
                        if i == 2:
         | 
| 996 | 
            +
                            layers.append(
         | 
| 997 | 
            +
                                Upsample2D(
         | 
| 998 | 
            +
                                    channels=2 * self.model_channels,
         | 
| 999 | 
            +
                                    out_channels=2 * self.model_channels,
         | 
| 1000 | 
            +
                                )
         | 
| 1001 | 
            +
                            )
         | 
| 1002 | 
            +
             | 
| 1003 | 
            +
                        self.output_blocks.append(nn.ModuleList(layers))
         | 
| 1004 | 
            +
             | 
| 1005 | 
            +
                    # level 0
         | 
| 1006 | 
            +
                    for i in range(3):
         | 
| 1007 | 
            +
                        layers = [
         | 
| 1008 | 
            +
                            ResnetBlock2D(
         | 
| 1009 | 
            +
                                in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
         | 
| 1010 | 
            +
                                out_channels=1 * self.model_channels,
         | 
| 1011 | 
            +
                            ),
         | 
| 1012 | 
            +
                        ]
         | 
| 1013 | 
            +
             | 
| 1014 | 
            +
                        self.output_blocks.append(nn.ModuleList(layers))
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                    # output
         | 
| 1017 | 
            +
                    self.out = nn.ModuleList(
         | 
| 1018 | 
            +
                        [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
         | 
| 1019 | 
            +
                    )
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                # region diffusers compatibility
         | 
| 1022 | 
            +
                def prepare_config(self):
         | 
| 1023 | 
            +
                    self.config = SimpleNamespace()
         | 
| 1024 | 
            +
             | 
| 1025 | 
            +
                @property
         | 
| 1026 | 
            +
                def dtype(self) -> torch.dtype:
         | 
| 1027 | 
            +
                    # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
         | 
| 1028 | 
            +
                    return get_parameter_dtype(self)
         | 
| 1029 | 
            +
             | 
| 1030 | 
            +
                @property
         | 
| 1031 | 
            +
                def device(self) -> torch.device:
         | 
| 1032 | 
            +
                    # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
         | 
| 1033 | 
            +
                    return get_parameter_device(self)
         | 
| 1034 | 
            +
             | 
| 1035 | 
            +
                def set_attention_slice(self, slice_size):
         | 
| 1036 | 
            +
                    raise NotImplementedError("Attention slicing is not supported for this model.")
         | 
| 1037 | 
            +
             | 
| 1038 | 
            +
                def is_gradient_checkpointing(self) -> bool:
         | 
| 1039 | 
            +
                    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                def enable_gradient_checkpointing(self):
         | 
| 1042 | 
            +
                    self.gradient_checkpointing = True
         | 
| 1043 | 
            +
                    self.set_gradient_checkpointing(value=True)
         | 
| 1044 | 
            +
             | 
| 1045 | 
            +
                def disable_gradient_checkpointing(self):
         | 
| 1046 | 
            +
                    self.gradient_checkpointing = False
         | 
| 1047 | 
            +
                    self.set_gradient_checkpointing(value=False)
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
         | 
| 1050 | 
            +
                    blocks = self.input_blocks + [self.middle_block] + self.output_blocks
         | 
| 1051 | 
            +
                    for block in blocks:
         | 
| 1052 | 
            +
                        for module in block:
         | 
| 1053 | 
            +
                            if hasattr(module, "set_use_memory_efficient_attention"):
         | 
| 1054 | 
            +
                                # logger.info(module.__class__.__name__)
         | 
| 1055 | 
            +
                                module.set_use_memory_efficient_attention(xformers, mem_eff)
         | 
| 1056 | 
            +
             | 
| 1057 | 
            +
                def set_use_sdpa(self, sdpa: bool) -> None:
         | 
| 1058 | 
            +
                    blocks = self.input_blocks + [self.middle_block] + self.output_blocks
         | 
| 1059 | 
            +
                    for block in blocks:
         | 
| 1060 | 
            +
                        for module in block:
         | 
| 1061 | 
            +
                            if hasattr(module, "set_use_sdpa"):
         | 
| 1062 | 
            +
                                module.set_use_sdpa(sdpa)
         | 
| 1063 | 
            +
             | 
| 1064 | 
            +
                def set_gradient_checkpointing(self, value=False):
         | 
| 1065 | 
            +
                    blocks = self.input_blocks + [self.middle_block] + self.output_blocks
         | 
| 1066 | 
            +
                    for block in blocks:
         | 
| 1067 | 
            +
                        for module in block.modules():
         | 
| 1068 | 
            +
                            if hasattr(module, "gradient_checkpointing"):
         | 
| 1069 | 
            +
                                # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
         | 
| 1070 | 
            +
                                module.gradient_checkpointing = value
         | 
| 1071 | 
            +
             | 
| 1072 | 
            +
                # endregion
         | 
| 1073 | 
            +
             | 
| 1074 | 
            +
                def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
         | 
| 1075 | 
            +
                    # broadcast timesteps to batch dimension
         | 
| 1076 | 
            +
                    timesteps = timesteps.expand(x.shape[0])
         | 
| 1077 | 
            +
             | 
| 1078 | 
            +
                    hs = []
         | 
| 1079 | 
            +
                    t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)  # , repeat_only=False)
         | 
| 1080 | 
            +
                    t_emb = t_emb.to(x.dtype)
         | 
| 1081 | 
            +
                    emb = self.time_embed(t_emb)
         | 
| 1082 | 
            +
             | 
| 1083 | 
            +
                    assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
         | 
| 1084 | 
            +
                    assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
         | 
| 1085 | 
            +
                    # assert x.dtype == self.dtype
         | 
| 1086 | 
            +
                    emb = emb + self.label_emb(y)
         | 
| 1087 | 
            +
             | 
| 1088 | 
            +
                    def call_module(module, h, emb, context):
         | 
| 1089 | 
            +
                        x = h
         | 
| 1090 | 
            +
                        for layer in module:
         | 
| 1091 | 
            +
                            # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
         | 
| 1092 | 
            +
                            if isinstance(layer, ResnetBlock2D):
         | 
| 1093 | 
            +
                                x = layer(x, emb)
         | 
| 1094 | 
            +
                            elif isinstance(layer, Transformer2DModel):
         | 
| 1095 | 
            +
                                x = layer(x, context)
         | 
| 1096 | 
            +
                            else:
         | 
| 1097 | 
            +
                                x = layer(x)
         | 
| 1098 | 
            +
                        return x
         | 
| 1099 | 
            +
             | 
| 1100 | 
            +
                    # h = x.type(self.dtype)
         | 
| 1101 | 
            +
                    h = x
         | 
| 1102 | 
            +
             | 
| 1103 | 
            +
                    for module in self.input_blocks:
         | 
| 1104 | 
            +
                        h = call_module(module, h, emb, context)
         | 
| 1105 | 
            +
                        hs.append(h)
         | 
| 1106 | 
            +
             | 
| 1107 | 
            +
                    h = call_module(self.middle_block, h, emb, context)
         | 
| 1108 | 
            +
             | 
| 1109 | 
            +
                    for module in self.output_blocks:
         | 
| 1110 | 
            +
                        h = torch.cat([h, hs.pop()], dim=1)
         | 
| 1111 | 
            +
                        h = call_module(module, h, emb, context)
         | 
| 1112 | 
            +
             | 
| 1113 | 
            +
                    h = h.type(x.dtype)
         | 
| 1114 | 
            +
                    h = call_module(self.out, h, emb, context)
         | 
| 1115 | 
            +
             | 
| 1116 | 
            +
                    return h
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
             | 
| 1119 | 
            +
            class InferSdxlUNet2DConditionModel:
         | 
| 1120 | 
            +
                def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
         | 
| 1121 | 
            +
                    self.delegate = original_unet
         | 
| 1122 | 
            +
             | 
| 1123 | 
            +
                    # override original model's forward method: because forward is not called by `__call__`
         | 
| 1124 | 
            +
                    # overriding `__call__` is not enough, because nn.Module.forward has a special handling
         | 
| 1125 | 
            +
                    self.delegate.forward = self.forward
         | 
| 1126 | 
            +
             | 
| 1127 | 
            +
                    # Deep Shrink
         | 
| 1128 | 
            +
                    self.ds_depth_1 = None
         | 
| 1129 | 
            +
                    self.ds_depth_2 = None
         | 
| 1130 | 
            +
                    self.ds_timesteps_1 = None
         | 
| 1131 | 
            +
                    self.ds_timesteps_2 = None
         | 
| 1132 | 
            +
                    self.ds_ratio = None
         | 
| 1133 | 
            +
             | 
| 1134 | 
            +
                # call original model's methods
         | 
| 1135 | 
            +
                def __getattr__(self, name):
         | 
| 1136 | 
            +
                    return getattr(self.delegate, name)
         | 
| 1137 | 
            +
             | 
| 1138 | 
            +
                def __call__(self, *args, **kwargs):
         | 
| 1139 | 
            +
                    return self.delegate(*args, **kwargs)
         | 
| 1140 | 
            +
             | 
| 1141 | 
            +
                def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
         | 
| 1142 | 
            +
                    if ds_depth_1 is None:
         | 
| 1143 | 
            +
                        logger.info("Deep Shrink is disabled.")
         | 
| 1144 | 
            +
                        self.ds_depth_1 = None
         | 
| 1145 | 
            +
                        self.ds_timesteps_1 = None
         | 
| 1146 | 
            +
                        self.ds_depth_2 = None
         | 
| 1147 | 
            +
                        self.ds_timesteps_2 = None
         | 
| 1148 | 
            +
                        self.ds_ratio = None
         | 
| 1149 | 
            +
                    else:
         | 
| 1150 | 
            +
                        logger.info(
         | 
| 1151 | 
            +
                            f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
         | 
| 1152 | 
            +
                        )
         | 
| 1153 | 
            +
                        self.ds_depth_1 = ds_depth_1
         | 
| 1154 | 
            +
                        self.ds_timesteps_1 = ds_timesteps_1
         | 
| 1155 | 
            +
                        self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
         | 
| 1156 | 
            +
                        self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
         | 
| 1157 | 
            +
                        self.ds_ratio = ds_ratio
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
         | 
| 1160 | 
            +
                    r"""
         | 
| 1161 | 
            +
                    current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
         | 
| 1162 | 
            +
                    """
         | 
| 1163 | 
            +
                    _self = self.delegate
         | 
| 1164 | 
            +
             | 
| 1165 | 
            +
                    # broadcast timesteps to batch dimension
         | 
| 1166 | 
            +
                    timesteps = timesteps.expand(x.shape[0])
         | 
| 1167 | 
            +
             | 
| 1168 | 
            +
                    hs = []
         | 
| 1169 | 
            +
                    t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0)  # , repeat_only=False)
         | 
| 1170 | 
            +
                    t_emb = t_emb.to(x.dtype)
         | 
| 1171 | 
            +
                    emb = _self.time_embed(t_emb)
         | 
| 1172 | 
            +
             | 
| 1173 | 
            +
                    assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
         | 
| 1174 | 
            +
                    assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
         | 
| 1175 | 
            +
                    # assert x.dtype == _self.dtype
         | 
| 1176 | 
            +
                    emb = emb + _self.label_emb(y)
         | 
| 1177 | 
            +
             | 
| 1178 | 
            +
                    def call_module(module, h, emb, context):
         | 
| 1179 | 
            +
                        x = h
         | 
| 1180 | 
            +
                        for layer in module:
         | 
| 1181 | 
            +
                            # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
         | 
| 1182 | 
            +
                            if isinstance(layer, ResnetBlock2D):
         | 
| 1183 | 
            +
                                x = layer(x, emb)
         | 
| 1184 | 
            +
                            elif isinstance(layer, Transformer2DModel):
         | 
| 1185 | 
            +
                                x = layer(x, context)
         | 
| 1186 | 
            +
                            else:
         | 
| 1187 | 
            +
                                x = layer(x)
         | 
| 1188 | 
            +
                        return x
         | 
| 1189 | 
            +
             | 
| 1190 | 
            +
                    # h = x.type(self.dtype)
         | 
| 1191 | 
            +
                    h = x
         | 
| 1192 | 
            +
             | 
| 1193 | 
            +
                    for depth, module in enumerate(_self.input_blocks):
         | 
| 1194 | 
            +
                        # Deep Shrink
         | 
| 1195 | 
            +
                        if self.ds_depth_1 is not None:
         | 
| 1196 | 
            +
                            if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
         | 
| 1197 | 
            +
                                self.ds_depth_2 is not None
         | 
| 1198 | 
            +
                                and depth == self.ds_depth_2
         | 
| 1199 | 
            +
                                and timesteps[0] < self.ds_timesteps_1
         | 
| 1200 | 
            +
                                and timesteps[0] >= self.ds_timesteps_2
         | 
| 1201 | 
            +
                            ):
         | 
| 1202 | 
            +
                                # print("downsample", h.shape, self.ds_ratio)
         | 
| 1203 | 
            +
                                org_dtype = h.dtype
         | 
| 1204 | 
            +
                                if org_dtype == torch.bfloat16:
         | 
| 1205 | 
            +
                                    h = h.to(torch.float32)
         | 
| 1206 | 
            +
                                h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
         | 
| 1207 | 
            +
             | 
| 1208 | 
            +
                        h = call_module(module, h, emb, context)
         | 
| 1209 | 
            +
                        hs.append(h)
         | 
| 1210 | 
            +
             | 
| 1211 | 
            +
                    h = call_module(_self.middle_block, h, emb, context)
         | 
| 1212 | 
            +
             | 
| 1213 | 
            +
                    for module in _self.output_blocks:
         | 
| 1214 | 
            +
                        # Deep Shrink
         | 
| 1215 | 
            +
                        if self.ds_depth_1 is not None:
         | 
| 1216 | 
            +
                            if hs[-1].shape[-2:] != h.shape[-2:]:
         | 
| 1217 | 
            +
                                # print("upsample", h.shape, hs[-1].shape)
         | 
| 1218 | 
            +
                                h = resize_like(h, hs[-1])
         | 
| 1219 | 
            +
             | 
| 1220 | 
            +
                        h = torch.cat([h, hs.pop()], dim=1)
         | 
| 1221 | 
            +
                        h = call_module(module, h, emb, context)
         | 
| 1222 | 
            +
             | 
| 1223 | 
            +
                    # Deep Shrink: in case of depth 0
         | 
| 1224 | 
            +
                    if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
         | 
| 1225 | 
            +
                        # print("upsample", h.shape, x.shape)
         | 
| 1226 | 
            +
                        h = resize_like(h, x)
         | 
| 1227 | 
            +
             | 
| 1228 | 
            +
                    h = h.type(x.dtype)
         | 
| 1229 | 
            +
                    h = call_module(_self.out, h, emb, context)
         | 
| 1230 | 
            +
             | 
| 1231 | 
            +
                    return h
         | 
| 1232 | 
            +
             | 
| 1233 | 
            +
             | 
| 1234 | 
            +
            if __name__ == "__main__":
         | 
| 1235 | 
            +
                import time
         | 
| 1236 | 
            +
             | 
| 1237 | 
            +
                logger.info("create unet")
         | 
| 1238 | 
            +
                unet = SdxlUNet2DConditionModel()
         | 
| 1239 | 
            +
             | 
| 1240 | 
            +
                unet.to("cuda")
         | 
| 1241 | 
            +
                unet.set_use_memory_efficient_attention(True, False)
         | 
| 1242 | 
            +
                unet.set_gradient_checkpointing(True)
         | 
| 1243 | 
            +
                unet.train()
         | 
| 1244 | 
            +
             | 
| 1245 | 
            +
                # 使用メモリ量確認用の疑似学習ループ
         | 
| 1246 | 
            +
                logger.info("preparing optimizer")
         | 
| 1247 | 
            +
             | 
| 1248 | 
            +
                # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
         | 
| 1249 | 
            +
             | 
| 1250 | 
            +
                # import bitsandbytes
         | 
| 1251 | 
            +
                # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3)        # not working
         | 
| 1252 | 
            +
                # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3)  # working at 23.5 GB with torch2
         | 
| 1253 | 
            +
                # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3)  # working at 23.5 GB with torch2
         | 
| 1254 | 
            +
             | 
| 1255 | 
            +
                import transformers
         | 
| 1256 | 
            +
             | 
| 1257 | 
            +
                optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True)  # working at 22.2GB with torch2
         | 
| 1258 | 
            +
             | 
| 1259 | 
            +
                scaler = torch.cuda.amp.GradScaler(enabled=True)
         | 
| 1260 | 
            +
             | 
| 1261 | 
            +
                logger.info("start training")
         | 
| 1262 | 
            +
                steps = 10
         | 
| 1263 | 
            +
                batch_size = 1
         | 
| 1264 | 
            +
             | 
| 1265 | 
            +
                for step in range(steps):
         | 
| 1266 | 
            +
                    logger.info(f"step {step}")
         | 
| 1267 | 
            +
                    if step == 1:
         | 
| 1268 | 
            +
                        time_start = time.perf_counter()
         | 
| 1269 | 
            +
             | 
| 1270 | 
            +
                    x = torch.randn(batch_size, 4, 128, 128).cuda()  # 1024x1024
         | 
| 1271 | 
            +
                    t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
         | 
| 1272 | 
            +
                    ctx = torch.randn(batch_size, 77, 2048).cuda()
         | 
| 1273 | 
            +
                    y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
         | 
| 1274 | 
            +
             | 
| 1275 | 
            +
                    with torch.cuda.amp.autocast(enabled=True):
         | 
| 1276 | 
            +
                        output = unet(x, t, ctx, y)
         | 
| 1277 | 
            +
                        target = torch.randn_like(output)
         | 
| 1278 | 
            +
                        loss = torch.nn.functional.mse_loss(output, target)
         | 
| 1279 | 
            +
             | 
| 1280 | 
            +
                    scaler.scale(loss).backward()
         | 
| 1281 | 
            +
                    scaler.step(optimizer)
         | 
| 1282 | 
            +
                    scaler.update()
         | 
| 1283 | 
            +
                    optimizer.zero_grad(set_to_none=True)
         | 
| 1284 | 
            +
             | 
| 1285 | 
            +
                time_end = time.perf_counter()
         | 
| 1286 | 
            +
                logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
         | 
    	
        library/sdxl_train_util.py
    ADDED
    
    | @@ -0,0 +1,381 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from typing import Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from library.device_utils import init_ipex, clean_memory_on_device
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            init_ipex()
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from accelerate import init_empty_weights
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
            from transformers import CLIPTokenizer
         | 
| 14 | 
            +
            from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
         | 
| 15 | 
            +
            from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
         | 
| 16 | 
            +
            from .utils import setup_logging
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            setup_logging()
         | 
| 19 | 
            +
            import logging
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
         | 
| 24 | 
            +
            TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # DEFAULT_NOISE_OFFSET = 0.0357
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def load_target_model(args, accelerator, model_version: str, weight_dtype):
         | 
| 30 | 
            +
                model_dtype = match_mixed_precision(args, weight_dtype)  # prepare fp16/bf16
         | 
| 31 | 
            +
                for pi in range(accelerator.state.num_processes):
         | 
| 32 | 
            +
                    if pi == accelerator.state.local_process_index:
         | 
| 33 | 
            +
                        logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                        (
         | 
| 36 | 
            +
                            load_stable_diffusion_format,
         | 
| 37 | 
            +
                            text_encoder1,
         | 
| 38 | 
            +
                            text_encoder2,
         | 
| 39 | 
            +
                            vae,
         | 
| 40 | 
            +
                            unet,
         | 
| 41 | 
            +
                            logit_scale,
         | 
| 42 | 
            +
                            ckpt_info,
         | 
| 43 | 
            +
                        ) = _load_target_model(
         | 
| 44 | 
            +
                            args.pretrained_model_name_or_path,
         | 
| 45 | 
            +
                            args.vae,
         | 
| 46 | 
            +
                            model_version,
         | 
| 47 | 
            +
                            weight_dtype,
         | 
| 48 | 
            +
                            accelerator.device if args.lowram else "cpu",
         | 
| 49 | 
            +
                            model_dtype,
         | 
| 50 | 
            +
                            args.disable_mmap_load_safetensors,
         | 
| 51 | 
            +
                        )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                        # work on low-ram device
         | 
| 54 | 
            +
                        if args.lowram:
         | 
| 55 | 
            +
                            text_encoder1.to(accelerator.device)
         | 
| 56 | 
            +
                            text_encoder2.to(accelerator.device)
         | 
| 57 | 
            +
                            unet.to(accelerator.device)
         | 
| 58 | 
            +
                            vae.to(accelerator.device)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        clean_memory_on_device(accelerator.device)
         | 
| 61 | 
            +
                    accelerator.wait_for_everyone()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def _load_target_model(
         | 
| 67 | 
            +
                name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
         | 
| 68 | 
            +
            ):
         | 
| 69 | 
            +
                # model_dtype only work with full fp16/bf16
         | 
| 70 | 
            +
                name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
         | 
| 71 | 
            +
                load_stable_diffusion_format = os.path.isfile(name_or_path)  # determine SD or Diffusers
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                if load_stable_diffusion_format:
         | 
| 74 | 
            +
                    logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
         | 
| 75 | 
            +
                    (
         | 
| 76 | 
            +
                        text_encoder1,
         | 
| 77 | 
            +
                        text_encoder2,
         | 
| 78 | 
            +
                        vae,
         | 
| 79 | 
            +
                        unet,
         | 
| 80 | 
            +
                        logit_scale,
         | 
| 81 | 
            +
                        ckpt_info,
         | 
| 82 | 
            +
                    ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    # Diffusers model is loaded to CPU
         | 
| 85 | 
            +
                    from diffusers import StableDiffusionXLPipeline
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    variant = "fp16" if weight_dtype == torch.float16 else None
         | 
| 88 | 
            +
                    logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
         | 
| 89 | 
            +
                    try:
         | 
| 90 | 
            +
                        try:
         | 
| 91 | 
            +
                            pipe = StableDiffusionXLPipeline.from_pretrained(
         | 
| 92 | 
            +
                                name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
         | 
| 93 | 
            +
                            )
         | 
| 94 | 
            +
                        except EnvironmentError as ex:
         | 
| 95 | 
            +
                            if variant is not None:
         | 
| 96 | 
            +
                                logger.info("try to load fp32 model")
         | 
| 97 | 
            +
                                pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
         | 
| 98 | 
            +
                            else:
         | 
| 99 | 
            +
                                raise ex
         | 
| 100 | 
            +
                    except EnvironmentError as ex:
         | 
| 101 | 
            +
                        logger.error(
         | 
| 102 | 
            +
                            f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
         | 
| 103 | 
            +
                        )
         | 
| 104 | 
            +
                        raise ex
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    text_encoder1 = pipe.text_encoder
         | 
| 107 | 
            +
                    text_encoder2 = pipe.text_encoder_2
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # convert to fp32 for cache text_encoders outputs
         | 
| 110 | 
            +
                    if text_encoder1.dtype != torch.float32:
         | 
| 111 | 
            +
                        text_encoder1 = text_encoder1.to(dtype=torch.float32)
         | 
| 112 | 
            +
                    if text_encoder2.dtype != torch.float32:
         | 
| 113 | 
            +
                        text_encoder2 = text_encoder2.to(dtype=torch.float32)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    vae = pipe.vae
         | 
| 116 | 
            +
                    unet = pipe.unet
         | 
| 117 | 
            +
                    del pipe
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # Diffusers U-Net to original U-Net
         | 
| 120 | 
            +
                    state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
         | 
| 121 | 
            +
                    with init_empty_weights():
         | 
| 122 | 
            +
                        unet = sdxl_original_unet.SdxlUNet2DConditionModel()  # overwrite unet
         | 
| 123 | 
            +
                    sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
         | 
| 124 | 
            +
                    logger.info("U-Net converted to original U-Net")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    logit_scale = None
         | 
| 127 | 
            +
                    ckpt_info = None
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                # VAEを読み込む
         | 
| 130 | 
            +
                if vae_path is not None:
         | 
| 131 | 
            +
                    vae = model_util.load_vae(vae_path, weight_dtype)
         | 
| 132 | 
            +
                    logger.info("additional VAE loaded")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def load_tokenizers(args: argparse.Namespace):
         | 
| 138 | 
            +
                logger.info("prepare tokenizers")
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
         | 
| 141 | 
            +
                tokeniers = []
         | 
| 142 | 
            +
                for i, original_path in enumerate(original_paths):
         | 
| 143 | 
            +
                    tokenizer: CLIPTokenizer = None
         | 
| 144 | 
            +
                    if args.tokenizer_cache_dir:
         | 
| 145 | 
            +
                        local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
         | 
| 146 | 
            +
                        if os.path.exists(local_tokenizer_path):
         | 
| 147 | 
            +
                            logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
         | 
| 148 | 
            +
                            tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if tokenizer is None:
         | 
| 151 | 
            +
                        tokenizer = CLIPTokenizer.from_pretrained(original_path)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
         | 
| 154 | 
            +
                        logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
         | 
| 155 | 
            +
                        tokenizer.save_pretrained(local_tokenizer_path)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    if i == 1:
         | 
| 158 | 
            +
                        tokenizer.pad_token_id = 0  # fix pad token id to make same as open clip tokenizer
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    tokeniers.append(tokenizer)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                if hasattr(args, "max_token_length") and args.max_token_length is not None:
         | 
| 163 | 
            +
                    logger.info(f"update token length: {args.max_token_length}")
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                return tokeniers
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            def match_mixed_precision(args, weight_dtype):
         | 
| 169 | 
            +
                if args.full_fp16:
         | 
| 170 | 
            +
                    assert (
         | 
| 171 | 
            +
                        weight_dtype == torch.float16
         | 
| 172 | 
            +
                    ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
         | 
| 173 | 
            +
                    return weight_dtype
         | 
| 174 | 
            +
                elif args.full_bf16:
         | 
| 175 | 
            +
                    assert (
         | 
| 176 | 
            +
                        weight_dtype == torch.bfloat16
         | 
| 177 | 
            +
                    ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
         | 
| 178 | 
            +
                    return weight_dtype
         | 
| 179 | 
            +
                else:
         | 
| 180 | 
            +
                    return None
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def timestep_embedding(timesteps, dim, max_period=10000):
         | 
| 184 | 
            +
                """
         | 
| 185 | 
            +
                Create sinusoidal timestep embeddings.
         | 
| 186 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 187 | 
            +
                                  These may be fractional.
         | 
| 188 | 
            +
                :param dim: the dimension of the output.
         | 
| 189 | 
            +
                :param max_period: controls the minimum frequency of the embeddings.
         | 
| 190 | 
            +
                :return: an [N x dim] Tensor of positional embeddings.
         | 
| 191 | 
            +
                """
         | 
| 192 | 
            +
                half = dim // 2
         | 
| 193 | 
            +
                freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
         | 
| 194 | 
            +
                    device=timesteps.device
         | 
| 195 | 
            +
                )
         | 
| 196 | 
            +
                args = timesteps[:, None].float() * freqs[None]
         | 
| 197 | 
            +
                embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         | 
| 198 | 
            +
                if dim % 2:
         | 
| 199 | 
            +
                    embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
         | 
| 200 | 
            +
                return embedding
         | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            def get_timestep_embedding(x, outdim):
         | 
| 204 | 
            +
                assert len(x.shape) == 2
         | 
| 205 | 
            +
                b, dims = x.shape[0], x.shape[1]
         | 
| 206 | 
            +
                x = torch.flatten(x)
         | 
| 207 | 
            +
                emb = timestep_embedding(x, outdim)
         | 
| 208 | 
            +
                emb = torch.reshape(emb, (b, dims * outdim))
         | 
| 209 | 
            +
                return emb
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            def get_size_embeddings(orig_size, crop_size, target_size, device):
         | 
| 213 | 
            +
                emb1 = get_timestep_embedding(orig_size, 256)
         | 
| 214 | 
            +
                emb2 = get_timestep_embedding(crop_size, 256)
         | 
| 215 | 
            +
                emb3 = get_timestep_embedding(target_size, 256)
         | 
| 216 | 
            +
                vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
         | 
| 217 | 
            +
                return vector
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            def save_sd_model_on_train_end(
         | 
| 221 | 
            +
                args: argparse.Namespace,
         | 
| 222 | 
            +
                src_path: str,
         | 
| 223 | 
            +
                save_stable_diffusion_format: bool,
         | 
| 224 | 
            +
                use_safetensors: bool,
         | 
| 225 | 
            +
                save_dtype: torch.dtype,
         | 
| 226 | 
            +
                epoch: int,
         | 
| 227 | 
            +
                global_step: int,
         | 
| 228 | 
            +
                text_encoder1,
         | 
| 229 | 
            +
                text_encoder2,
         | 
| 230 | 
            +
                unet,
         | 
| 231 | 
            +
                vae,
         | 
| 232 | 
            +
                logit_scale,
         | 
| 233 | 
            +
                ckpt_info,
         | 
| 234 | 
            +
            ):
         | 
| 235 | 
            +
                def sd_saver(ckpt_file, epoch_no, global_step):
         | 
| 236 | 
            +
                    sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
         | 
| 237 | 
            +
                    sdxl_model_util.save_stable_diffusion_checkpoint(
         | 
| 238 | 
            +
                        ckpt_file,
         | 
| 239 | 
            +
                        text_encoder1,
         | 
| 240 | 
            +
                        text_encoder2,
         | 
| 241 | 
            +
                        unet,
         | 
| 242 | 
            +
                        epoch_no,
         | 
| 243 | 
            +
                        global_step,
         | 
| 244 | 
            +
                        ckpt_info,
         | 
| 245 | 
            +
                        vae,
         | 
| 246 | 
            +
                        logit_scale,
         | 
| 247 | 
            +
                        sai_metadata,
         | 
| 248 | 
            +
                        save_dtype,
         | 
| 249 | 
            +
                    )
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                def diffusers_saver(out_dir):
         | 
| 252 | 
            +
                    sdxl_model_util.save_diffusers_checkpoint(
         | 
| 253 | 
            +
                        out_dir,
         | 
| 254 | 
            +
                        text_encoder1,
         | 
| 255 | 
            +
                        text_encoder2,
         | 
| 256 | 
            +
                        unet,
         | 
| 257 | 
            +
                        src_path,
         | 
| 258 | 
            +
                        vae,
         | 
| 259 | 
            +
                        use_safetensors=use_safetensors,
         | 
| 260 | 
            +
                        save_dtype=save_dtype,
         | 
| 261 | 
            +
                    )
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                train_util.save_sd_model_on_train_end_common(
         | 
| 264 | 
            +
                    args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
         | 
| 265 | 
            +
                )
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
         | 
| 269 | 
            +
            # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
         | 
| 270 | 
            +
            def save_sd_model_on_epoch_end_or_stepwise(
         | 
| 271 | 
            +
                args: argparse.Namespace,
         | 
| 272 | 
            +
                on_epoch_end: bool,
         | 
| 273 | 
            +
                accelerator,
         | 
| 274 | 
            +
                src_path,
         | 
| 275 | 
            +
                save_stable_diffusion_format: bool,
         | 
| 276 | 
            +
                use_safetensors: bool,
         | 
| 277 | 
            +
                save_dtype: torch.dtype,
         | 
| 278 | 
            +
                epoch: int,
         | 
| 279 | 
            +
                num_train_epochs: int,
         | 
| 280 | 
            +
                global_step: int,
         | 
| 281 | 
            +
                text_encoder1,
         | 
| 282 | 
            +
                text_encoder2,
         | 
| 283 | 
            +
                unet,
         | 
| 284 | 
            +
                vae,
         | 
| 285 | 
            +
                logit_scale,
         | 
| 286 | 
            +
                ckpt_info,
         | 
| 287 | 
            +
            ):
         | 
| 288 | 
            +
                def sd_saver(ckpt_file, epoch_no, global_step):
         | 
| 289 | 
            +
                    sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
         | 
| 290 | 
            +
                    sdxl_model_util.save_stable_diffusion_checkpoint(
         | 
| 291 | 
            +
                        ckpt_file,
         | 
| 292 | 
            +
                        text_encoder1,
         | 
| 293 | 
            +
                        text_encoder2,
         | 
| 294 | 
            +
                        unet,
         | 
| 295 | 
            +
                        epoch_no,
         | 
| 296 | 
            +
                        global_step,
         | 
| 297 | 
            +
                        ckpt_info,
         | 
| 298 | 
            +
                        vae,
         | 
| 299 | 
            +
                        logit_scale,
         | 
| 300 | 
            +
                        sai_metadata,
         | 
| 301 | 
            +
                        save_dtype,
         | 
| 302 | 
            +
                    )
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                def diffusers_saver(out_dir):
         | 
| 305 | 
            +
                    sdxl_model_util.save_diffusers_checkpoint(
         | 
| 306 | 
            +
                        out_dir,
         | 
| 307 | 
            +
                        text_encoder1,
         | 
| 308 | 
            +
                        text_encoder2,
         | 
| 309 | 
            +
                        unet,
         | 
| 310 | 
            +
                        src_path,
         | 
| 311 | 
            +
                        vae,
         | 
| 312 | 
            +
                        use_safetensors=use_safetensors,
         | 
| 313 | 
            +
                        save_dtype=save_dtype,
         | 
| 314 | 
            +
                    )
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                train_util.save_sd_model_on_epoch_end_or_stepwise_common(
         | 
| 317 | 
            +
                    args,
         | 
| 318 | 
            +
                    on_epoch_end,
         | 
| 319 | 
            +
                    accelerator,
         | 
| 320 | 
            +
                    save_stable_diffusion_format,
         | 
| 321 | 
            +
                    use_safetensors,
         | 
| 322 | 
            +
                    epoch,
         | 
| 323 | 
            +
                    num_train_epochs,
         | 
| 324 | 
            +
                    global_step,
         | 
| 325 | 
            +
                    sd_saver,
         | 
| 326 | 
            +
                    diffusers_saver,
         | 
| 327 | 
            +
                )
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
         | 
| 331 | 
            +
                parser.add_argument(
         | 
| 332 | 
            +
                    "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
         | 
| 333 | 
            +
                )
         | 
| 334 | 
            +
                parser.add_argument(
         | 
| 335 | 
            +
                    "--cache_text_encoder_outputs_to_disk",
         | 
| 336 | 
            +
                    action="store_true",
         | 
| 337 | 
            +
                    help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
         | 
| 338 | 
            +
                )
         | 
| 339 | 
            +
                parser.add_argument(
         | 
| 340 | 
            +
                    "--disable_mmap_load_safetensors",
         | 
| 341 | 
            +
                    action="store_true",
         | 
| 342 | 
            +
                    help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
         | 
| 343 | 
            +
                )
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
            def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
         | 
| 347 | 
            +
                assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
         | 
| 348 | 
            +
                if args.v_parameterization:
         | 
| 349 | 
            +
                    logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                if args.clip_skip is not None:
         | 
| 352 | 
            +
                    logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                # if args.multires_noise_iterations:
         | 
| 355 | 
            +
                #     logger.info(
         | 
| 356 | 
            +
                #         f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
         | 
| 357 | 
            +
                #     )
         | 
| 358 | 
            +
                # else:
         | 
| 359 | 
            +
                #     if args.noise_offset is None:
         | 
| 360 | 
            +
                #         args.noise_offset = DEFAULT_NOISE_OFFSET
         | 
| 361 | 
            +
                #     elif args.noise_offset != DEFAULT_NOISE_OFFSET:
         | 
| 362 | 
            +
                #         logger.info(
         | 
| 363 | 
            +
                #             f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
         | 
| 364 | 
            +
                #         )
         | 
| 365 | 
            +
                #     logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                assert (
         | 
| 368 | 
            +
                    not hasattr(args, "weighted_captions") or not args.weighted_captions
         | 
| 369 | 
            +
                ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                if supportTextEncoderCaching:
         | 
| 372 | 
            +
                    if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
         | 
| 373 | 
            +
                        args.cache_text_encoder_outputs = True
         | 
| 374 | 
            +
                        logger.warning(
         | 
| 375 | 
            +
                            "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
         | 
| 376 | 
            +
                            + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
         | 
| 377 | 
            +
                        )
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
            def sample_images(*args, **kwargs):
         | 
| 381 | 
            +
                return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
         | 
    	
        library/slicing_vae.py
    ADDED
    
    | @@ -0,0 +1,682 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from Diffusers to reduce VRAM usage
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright 2022 The HuggingFace Team. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            +
            # You may obtain a copy of the License at
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            +
            # limitations under the License.
         | 
| 16 | 
            +
            from dataclasses import dataclass
         | 
| 17 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import numpy as np
         | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            import torch.nn as nn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 25 | 
            +
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 26 | 
            +
            from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
         | 
| 27 | 
            +
            from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
         | 
| 28 | 
            +
            from diffusers.models.autoencoder_kl import AutoencoderKLOutput
         | 
| 29 | 
            +
            from .utils import setup_logging
         | 
| 30 | 
            +
            setup_logging()
         | 
| 31 | 
            +
            import logging
         | 
| 32 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def slice_h(x, num_slices):
         | 
| 35 | 
            +
                # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
         | 
| 36 | 
            +
                # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
         | 
| 37 | 
            +
                # NCHWでもNHWCでもどちらでも動く
         | 
| 38 | 
            +
                size = (x.shape[2] + num_slices - 1) // num_slices
         | 
| 39 | 
            +
                sliced = []
         | 
| 40 | 
            +
                for i in range(num_slices):
         | 
| 41 | 
            +
                    if i == 0:
         | 
| 42 | 
            +
                        sliced.append(x[:, :, : size + 1, :])
         | 
| 43 | 
            +
                    else:
         | 
| 44 | 
            +
                        end = size * (i + 1) + 1
         | 
| 45 | 
            +
                        if x.shape[2] - end < 3:  # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
         | 
| 46 | 
            +
                            end = x.shape[2]
         | 
| 47 | 
            +
                        sliced.append(x[:, :, size * i - 1 : end, :])
         | 
| 48 | 
            +
                        if end >= x.shape[2]:
         | 
| 49 | 
            +
                            break
         | 
| 50 | 
            +
                return sliced
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def cat_h(sliced):
         | 
| 54 | 
            +
                # padding分を除いて結合する
         | 
| 55 | 
            +
                cat = []
         | 
| 56 | 
            +
                for i, x in enumerate(sliced):
         | 
| 57 | 
            +
                    if i == 0:
         | 
| 58 | 
            +
                        cat.append(x[:, :, :-1, :])
         | 
| 59 | 
            +
                    elif i == len(sliced) - 1:
         | 
| 60 | 
            +
                        cat.append(x[:, :, 1:, :])
         | 
| 61 | 
            +
                    else:
         | 
| 62 | 
            +
                        cat.append(x[:, :, 1:-1, :])
         | 
| 63 | 
            +
                    del x
         | 
| 64 | 
            +
                x = torch.cat(cat, dim=2)
         | 
| 65 | 
            +
                return x
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
         | 
| 69 | 
            +
                assert _self.upsample is None and _self.downsample is None
         | 
| 70 | 
            +
                assert _self.norm1.num_groups == _self.norm2.num_groups
         | 
| 71 | 
            +
                assert temb is None
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # make sure norms are on cpu
         | 
| 74 | 
            +
                org_device = input_tensor.device
         | 
| 75 | 
            +
                cpu_device = torch.device("cpu")
         | 
| 76 | 
            +
                _self.norm1.to(cpu_device)
         | 
| 77 | 
            +
                _self.norm2.to(cpu_device)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # GroupNormがCPUでfp16で動かない対策
         | 
| 80 | 
            +
                org_dtype = input_tensor.dtype
         | 
| 81 | 
            +
                if org_dtype == torch.float16:
         | 
| 82 | 
            +
                    _self.norm1.to(torch.float32)
         | 
| 83 | 
            +
                    _self.norm2.to(torch.float32)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # すべてのテンソルをCPUに移動する
         | 
| 86 | 
            +
                input_tensor = input_tensor.to(cpu_device)
         | 
| 87 | 
            +
                hidden_states = input_tensor
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                # どうもこれは結果が異なるようだ……
         | 
| 90 | 
            +
                # def sliced_norm1(norm, x):
         | 
| 91 | 
            +
                #     num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
         | 
| 92 | 
            +
                #     sliced_tensor = torch.chunk(x, num_div, dim=1)
         | 
| 93 | 
            +
                #     sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
         | 
| 94 | 
            +
                #     sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
         | 
| 95 | 
            +
                #     logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
         | 
| 96 | 
            +
                #     normed_tensor = []
         | 
| 97 | 
            +
                #     for i in range(num_div):
         | 
| 98 | 
            +
                #         n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
         | 
| 99 | 
            +
                #         normed_tensor.append(n)
         | 
| 100 | 
            +
                #         del n
         | 
| 101 | 
            +
                #     x = torch.cat(normed_tensor, dim=1)
         | 
| 102 | 
            +
                #     return num_div, x
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
         | 
| 105 | 
            +
                if org_dtype == torch.float16:
         | 
| 106 | 
            +
                    hidden_states = hidden_states.to(torch.float32)
         | 
| 107 | 
            +
                hidden_states = _self.norm1(hidden_states)  # run on cpu
         | 
| 108 | 
            +
                if org_dtype == torch.float16:
         | 
| 109 | 
            +
                    hidden_states = hidden_states.to(torch.float16)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                sliced = slice_h(hidden_states, num_slices)
         | 
| 112 | 
            +
                del hidden_states
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                for i in range(len(sliced)):
         | 
| 115 | 
            +
                    x = sliced[i]
         | 
| 116 | 
            +
                    sliced[i] = None
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    # 計算する部分だけGPUに移動する、以下同様
         | 
| 119 | 
            +
                    x = x.to(org_device)
         | 
| 120 | 
            +
                    x = _self.nonlinearity(x)
         | 
| 121 | 
            +
                    x = _self.conv1(x)
         | 
| 122 | 
            +
                    x = x.to(cpu_device)
         | 
| 123 | 
            +
                    sliced[i] = x
         | 
| 124 | 
            +
                    del x
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                hidden_states = cat_h(sliced)
         | 
| 127 | 
            +
                del sliced
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                if org_dtype == torch.float16:
         | 
| 130 | 
            +
                    hidden_states = hidden_states.to(torch.float32)
         | 
| 131 | 
            +
                hidden_states = _self.norm2(hidden_states)  # run on cpu
         | 
| 132 | 
            +
                if org_dtype == torch.float16:
         | 
| 133 | 
            +
                    hidden_states = hidden_states.to(torch.float16)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                sliced = slice_h(hidden_states, num_slices)
         | 
| 136 | 
            +
                del hidden_states
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                for i in range(len(sliced)):
         | 
| 139 | 
            +
                    x = sliced[i]
         | 
| 140 | 
            +
                    sliced[i] = None
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    x = x.to(org_device)
         | 
| 143 | 
            +
                    x = _self.nonlinearity(x)
         | 
| 144 | 
            +
                    x = _self.dropout(x)
         | 
| 145 | 
            +
                    x = _self.conv2(x)
         | 
| 146 | 
            +
                    x = x.to(cpu_device)
         | 
| 147 | 
            +
                    sliced[i] = x
         | 
| 148 | 
            +
                    del x
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                hidden_states = cat_h(sliced)
         | 
| 151 | 
            +
                del sliced
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                # make shortcut
         | 
| 154 | 
            +
                if _self.conv_shortcut is not None:
         | 
| 155 | 
            +
                    sliced = list(torch.chunk(input_tensor, num_slices, dim=2))  # no padding in conv_shortcut パディングがないので普通にスライスする
         | 
| 156 | 
            +
                    del input_tensor
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    for i in range(len(sliced)):
         | 
| 159 | 
            +
                        x = sliced[i]
         | 
| 160 | 
            +
                        sliced[i] = None
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        x = x.to(org_device)
         | 
| 163 | 
            +
                        x = _self.conv_shortcut(x)
         | 
| 164 | 
            +
                        x = x.to(cpu_device)
         | 
| 165 | 
            +
                        sliced[i] = x
         | 
| 166 | 
            +
                        del x
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    input_tensor = torch.cat(sliced, dim=2)
         | 
| 169 | 
            +
                    del sliced
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                output_tensor = output_tensor.to(org_device)  # 次のレイヤーがGPUで計算する
         | 
| 174 | 
            +
                return output_tensor
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            class SlicingEncoder(nn.Module):
         | 
| 178 | 
            +
                def __init__(
         | 
| 179 | 
            +
                    self,
         | 
| 180 | 
            +
                    in_channels=3,
         | 
| 181 | 
            +
                    out_channels=3,
         | 
| 182 | 
            +
                    down_block_types=("DownEncoderBlock2D",),
         | 
| 183 | 
            +
                    block_out_channels=(64,),
         | 
| 184 | 
            +
                    layers_per_block=2,
         | 
| 185 | 
            +
                    norm_num_groups=32,
         | 
| 186 | 
            +
                    act_fn="silu",
         | 
| 187 | 
            +
                    double_z=True,
         | 
| 188 | 
            +
                    num_slices=2,
         | 
| 189 | 
            +
                ):
         | 
| 190 | 
            +
                    super().__init__()
         | 
| 191 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    self.mid_block = None
         | 
| 196 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    # down
         | 
| 199 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 200 | 
            +
                    for i, down_block_type in enumerate(down_block_types):
         | 
| 201 | 
            +
                        input_channel = output_channel
         | 
| 202 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 203 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                        down_block = get_down_block(
         | 
| 206 | 
            +
                            down_block_type,
         | 
| 207 | 
            +
                            num_layers=self.layers_per_block,
         | 
| 208 | 
            +
                            in_channels=input_channel,
         | 
| 209 | 
            +
                            out_channels=output_channel,
         | 
| 210 | 
            +
                            add_downsample=not is_final_block,
         | 
| 211 | 
            +
                            resnet_eps=1e-6,
         | 
| 212 | 
            +
                            downsample_padding=0,
         | 
| 213 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 214 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 215 | 
            +
                            attention_head_dim=output_channel,
         | 
| 216 | 
            +
                            temb_channels=None,
         | 
| 217 | 
            +
                        )
         | 
| 218 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # mid
         | 
| 221 | 
            +
                    self.mid_block = UNetMidBlock2D(
         | 
| 222 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 223 | 
            +
                        resnet_eps=1e-6,
         | 
| 224 | 
            +
                        resnet_act_fn=act_fn,
         | 
| 225 | 
            +
                        output_scale_factor=1,
         | 
| 226 | 
            +
                        resnet_time_scale_shift="default",
         | 
| 227 | 
            +
                        attention_head_dim=block_out_channels[-1],
         | 
| 228 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 229 | 
            +
                        temb_channels=None,
         | 
| 230 | 
            +
                    )
         | 
| 231 | 
            +
                    self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)  # とりあえずDiffusersのxformersを使う
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # out
         | 
| 234 | 
            +
                    self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
         | 
| 235 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    conv_out_channels = 2 * out_channels if double_z else out_channels
         | 
| 238 | 
            +
                    self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    # replace forward of ResBlocks
         | 
| 241 | 
            +
                    def wrapper(func, module, num_slices):
         | 
| 242 | 
            +
                        def forward(*args, **kwargs):
         | 
| 243 | 
            +
                            return func(module, num_slices, *args, **kwargs)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                        return forward
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    self.num_slices = num_slices
         | 
| 248 | 
            +
                    div = num_slices / (2 ** (len(self.down_blocks) - 1))  # 深い層はそこまで分割しなくていいので適宜減らす
         | 
| 249 | 
            +
                    # logger.info(f"initial divisor: {div}")
         | 
| 250 | 
            +
                    if div >= 2:
         | 
| 251 | 
            +
                        div = int(div)
         | 
| 252 | 
            +
                        for resnet in self.mid_block.resnets:
         | 
| 253 | 
            +
                            resnet.forward = wrapper(resblock_forward, resnet, div)
         | 
| 254 | 
            +
                        # midblock doesn't have downsample
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    for i, down_block in enumerate(self.down_blocks[::-1]):
         | 
| 257 | 
            +
                        if div >= 2:
         | 
| 258 | 
            +
                            div = int(div)
         | 
| 259 | 
            +
                            # logger.info(f"down block: {i} divisor: {div}")
         | 
| 260 | 
            +
                            for resnet in down_block.resnets:
         | 
| 261 | 
            +
                                resnet.forward = wrapper(resblock_forward, resnet, div)
         | 
| 262 | 
            +
                            if down_block.downsamplers is not None:
         | 
| 263 | 
            +
                                # logger.info("has downsample")
         | 
| 264 | 
            +
                                for downsample in down_block.downsamplers:
         | 
| 265 | 
            +
                                    downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
         | 
| 266 | 
            +
                        div *= 2
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def forward(self, x):
         | 
| 269 | 
            +
                    sample = x
         | 
| 270 | 
            +
                    del x
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    org_device = sample.device
         | 
| 273 | 
            +
                    cpu_device = torch.device("cpu")
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    # sample = self.conv_in(sample)
         | 
| 276 | 
            +
                    sample = sample.to(cpu_device)
         | 
| 277 | 
            +
                    sliced = slice_h(sample, self.num_slices)
         | 
| 278 | 
            +
                    del sample
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    for i in range(len(sliced)):
         | 
| 281 | 
            +
                        x = sliced[i]
         | 
| 282 | 
            +
                        sliced[i] = None
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                        x = x.to(org_device)
         | 
| 285 | 
            +
                        x = self.conv_in(x)
         | 
| 286 | 
            +
                        x = x.to(cpu_device)
         | 
| 287 | 
            +
                        sliced[i] = x
         | 
| 288 | 
            +
                        del x
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    sample = cat_h(sliced)
         | 
| 291 | 
            +
                    del sliced
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    sample = sample.to(org_device)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    # down
         | 
| 296 | 
            +
                    for down_block in self.down_blocks:
         | 
| 297 | 
            +
                        sample = down_block(sample)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    # middle
         | 
| 300 | 
            +
                    sample = self.mid_block(sample)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    # post-process
         | 
| 303 | 
            +
                    # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
         | 
| 304 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 305 | 
            +
                    sample = self.conv_act(sample)
         | 
| 306 | 
            +
                    sample = self.conv_out(sample)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    return sample
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                def downsample_forward(self, _self, num_slices, hidden_states):
         | 
| 311 | 
            +
                    assert hidden_states.shape[1] == _self.channels
         | 
| 312 | 
            +
                    assert _self.use_conv and _self.padding == 0
         | 
| 313 | 
            +
                    logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    org_device = hidden_states.device
         | 
| 316 | 
            +
                    cpu_device = torch.device("cpu")
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    hidden_states = hidden_states.to(cpu_device)
         | 
| 319 | 
            +
                    pad = (0, 1, 0, 1)
         | 
| 320 | 
            +
                    hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # slice with even number because of stride 2
         | 
| 323 | 
            +
                    # strideが2なので偶数でスライスする
         | 
| 324 | 
            +
                    # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
         | 
| 325 | 
            +
                    size = (hidden_states.shape[2] + num_slices - 1) // num_slices
         | 
| 326 | 
            +
                    size = size + 1 if size % 2 == 1 else size
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    sliced = []
         | 
| 329 | 
            +
                    for i in range(num_slices):
         | 
| 330 | 
            +
                        if i == 0:
         | 
| 331 | 
            +
                            sliced.append(hidden_states[:, :, : size + 1, :])
         | 
| 332 | 
            +
                        else:
         | 
| 333 | 
            +
                            end = size * (i + 1) + 1
         | 
| 334 | 
            +
                            if hidden_states.shape[2] - end < 4:  # if the last slice is too small, use the rest of the tensor
         | 
| 335 | 
            +
                                end = hidden_states.shape[2]
         | 
| 336 | 
            +
                            sliced.append(hidden_states[:, :, size * i - 1 : end, :])
         | 
| 337 | 
            +
                            if end >= hidden_states.shape[2]:
         | 
| 338 | 
            +
                                break
         | 
| 339 | 
            +
                    del hidden_states
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    for i in range(len(sliced)):
         | 
| 342 | 
            +
                        x = sliced[i]
         | 
| 343 | 
            +
                        sliced[i] = None
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                        x = x.to(org_device)
         | 
| 346 | 
            +
                        x = _self.conv(x)
         | 
| 347 | 
            +
                        x = x.to(cpu_device)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                        # ここだけ雰囲気が違うのはCopilotのせい
         | 
| 350 | 
            +
                        if i == 0:
         | 
| 351 | 
            +
                            hidden_states = x
         | 
| 352 | 
            +
                        else:
         | 
| 353 | 
            +
                            hidden_states = torch.cat([hidden_states, x], dim=2)
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    hidden_states = hidden_states.to(org_device)
         | 
| 356 | 
            +
                    # logger.info(f"downsample forward done {hidden_states.shape}")
         | 
| 357 | 
            +
                    return hidden_states
         | 
| 358 | 
            +
             | 
| 359 | 
            +
             | 
| 360 | 
            +
            class SlicingDecoder(nn.Module):
         | 
| 361 | 
            +
                def __init__(
         | 
| 362 | 
            +
                    self,
         | 
| 363 | 
            +
                    in_channels=3,
         | 
| 364 | 
            +
                    out_channels=3,
         | 
| 365 | 
            +
                    up_block_types=("UpDecoderBlock2D",),
         | 
| 366 | 
            +
                    block_out_channels=(64,),
         | 
| 367 | 
            +
                    layers_per_block=2,
         | 
| 368 | 
            +
                    norm_num_groups=32,
         | 
| 369 | 
            +
                    act_fn="silu",
         | 
| 370 | 
            +
                    num_slices=2,
         | 
| 371 | 
            +
                ):
         | 
| 372 | 
            +
                    super().__init__()
         | 
| 373 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    self.mid_block = None
         | 
| 378 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    # mid
         | 
| 381 | 
            +
                    self.mid_block = UNetMidBlock2D(
         | 
| 382 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 383 | 
            +
                        resnet_eps=1e-6,
         | 
| 384 | 
            +
                        resnet_act_fn=act_fn,
         | 
| 385 | 
            +
                        output_scale_factor=1,
         | 
| 386 | 
            +
                        resnet_time_scale_shift="default",
         | 
| 387 | 
            +
                        attention_head_dim=block_out_channels[-1],
         | 
| 388 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 389 | 
            +
                        temb_channels=None,
         | 
| 390 | 
            +
                    )
         | 
| 391 | 
            +
                    self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)  # とりあえずDiffusersのxformersを使う
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    # up
         | 
| 394 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 395 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 396 | 
            +
                    for i, up_block_type in enumerate(up_block_types):
         | 
| 397 | 
            +
                        prev_output_channel = output_channel
         | 
| 398 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                        up_block = get_up_block(
         | 
| 403 | 
            +
                            up_block_type,
         | 
| 404 | 
            +
                            num_layers=self.layers_per_block + 1,
         | 
| 405 | 
            +
                            in_channels=prev_output_channel,
         | 
| 406 | 
            +
                            out_channels=output_channel,
         | 
| 407 | 
            +
                            prev_output_channel=None,
         | 
| 408 | 
            +
                            add_upsample=not is_final_block,
         | 
| 409 | 
            +
                            resnet_eps=1e-6,
         | 
| 410 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 411 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 412 | 
            +
                            attention_head_dim=output_channel,
         | 
| 413 | 
            +
                            temb_channels=None,
         | 
| 414 | 
            +
                        )
         | 
| 415 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 416 | 
            +
                        prev_output_channel = output_channel
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    # out
         | 
| 419 | 
            +
                    self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
         | 
| 420 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 421 | 
            +
                    self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    # replace forward of ResBlocks
         | 
| 424 | 
            +
                    def wrapper(func, module, num_slices):
         | 
| 425 | 
            +
                        def forward(*args, **kwargs):
         | 
| 426 | 
            +
                            return func(module, num_slices, *args, **kwargs)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                        return forward
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    self.num_slices = num_slices
         | 
| 431 | 
            +
                    div = num_slices / (2 ** (len(self.up_blocks) - 1))
         | 
| 432 | 
            +
                    logger.info(f"initial divisor: {div}")
         | 
| 433 | 
            +
                    if div >= 2:
         | 
| 434 | 
            +
                        div = int(div)
         | 
| 435 | 
            +
                        for resnet in self.mid_block.resnets:
         | 
| 436 | 
            +
                            resnet.forward = wrapper(resblock_forward, resnet, div)
         | 
| 437 | 
            +
                        # midblock doesn't have upsample
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    for i, up_block in enumerate(self.up_blocks):
         | 
| 440 | 
            +
                        if div >= 2:
         | 
| 441 | 
            +
                            div = int(div)
         | 
| 442 | 
            +
                            # logger.info(f"up block: {i} divisor: {div}")
         | 
| 443 | 
            +
                            for resnet in up_block.resnets:
         | 
| 444 | 
            +
                                resnet.forward = wrapper(resblock_forward, resnet, div)
         | 
| 445 | 
            +
                            if up_block.upsamplers is not None:
         | 
| 446 | 
            +
                                # logger.info("has upsample")
         | 
| 447 | 
            +
                                for upsample in up_block.upsamplers:
         | 
| 448 | 
            +
                                    upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
         | 
| 449 | 
            +
                        div *= 2
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                def forward(self, z):
         | 
| 452 | 
            +
                    sample = z
         | 
| 453 | 
            +
                    del z
         | 
| 454 | 
            +
                    sample = self.conv_in(sample)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    # middle
         | 
| 457 | 
            +
                    sample = self.mid_block(sample)
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    # up
         | 
| 460 | 
            +
                    for i, up_block in enumerate(self.up_blocks):
         | 
| 461 | 
            +
                        sample = up_block(sample)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    # post-process
         | 
| 464 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 465 | 
            +
                    sample = self.conv_act(sample)
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    # conv_out with slicing because of VRAM usage
         | 
| 468 | 
            +
                    # conv_outはとてもVRAM使うのでスライスして対応
         | 
| 469 | 
            +
                    org_device = sample.device
         | 
| 470 | 
            +
                    cpu_device = torch.device("cpu")
         | 
| 471 | 
            +
                    sample = sample.to(cpu_device)
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    sliced = slice_h(sample, self.num_slices)
         | 
| 474 | 
            +
                    del sample
         | 
| 475 | 
            +
                    for i in range(len(sliced)):
         | 
| 476 | 
            +
                        x = sliced[i]
         | 
| 477 | 
            +
                        sliced[i] = None
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                        x = x.to(org_device)
         | 
| 480 | 
            +
                        x = self.conv_out(x)
         | 
| 481 | 
            +
                        x = x.to(cpu_device)
         | 
| 482 | 
            +
                        sliced[i] = x
         | 
| 483 | 
            +
                    sample = cat_h(sliced)
         | 
| 484 | 
            +
                    del sliced
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    sample = sample.to(org_device)
         | 
| 487 | 
            +
                    return sample
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
         | 
| 490 | 
            +
                    assert hidden_states.shape[1] == _self.channels
         | 
| 491 | 
            +
                    assert _self.use_conv_transpose == False and _self.use_conv
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                    org_dtype = hidden_states.dtype
         | 
| 494 | 
            +
                    org_device = hidden_states.device
         | 
| 495 | 
            +
                    cpu_device = torch.device("cpu")
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    hidden_states = hidden_states.to(cpu_device)
         | 
| 498 | 
            +
                    sliced = slice_h(hidden_states, num_slices)
         | 
| 499 | 
            +
                    del hidden_states
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    for i in range(len(sliced)):
         | 
| 502 | 
            +
                        x = sliced[i]
         | 
| 503 | 
            +
                        sliced[i] = None
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                        x = x.to(org_device)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
         | 
| 508 | 
            +
                        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
         | 
| 509 | 
            +
                        # https://github.com/pytorch/pytorch/issues/86679
         | 
| 510 | 
            +
                        # PyTorch 2で直らないかね……
         | 
| 511 | 
            +
                        if org_dtype == torch.bfloat16:
         | 
| 512 | 
            +
                            x = x.to(torch.float32)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                        if org_dtype == torch.bfloat16:
         | 
| 517 | 
            +
                            x = x.to(org_dtype)
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                        x = _self.conv(x)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                        # upsampleされてるのでpadは2になる
         | 
| 522 | 
            +
                        if i == 0:
         | 
| 523 | 
            +
                            x = x[:, :, :-2, :]
         | 
| 524 | 
            +
                        elif i == num_slices - 1:
         | 
| 525 | 
            +
                            x = x[:, :, 2:, :]
         | 
| 526 | 
            +
                        else:
         | 
| 527 | 
            +
                            x = x[:, :, 2:-2, :]
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                        x = x.to(cpu_device)
         | 
| 530 | 
            +
                        sliced[i] = x
         | 
| 531 | 
            +
                        del x
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    hidden_states = torch.cat(sliced, dim=2)
         | 
| 534 | 
            +
                    # logger.info(f"us hidden_states {hidden_states.shape}")
         | 
| 535 | 
            +
                    del sliced
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    hidden_states = hidden_states.to(org_device)
         | 
| 538 | 
            +
                    return hidden_states
         | 
| 539 | 
            +
             | 
| 540 | 
            +
             | 
| 541 | 
            +
            class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
         | 
| 542 | 
            +
                r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
         | 
| 543 | 
            +
                and Max Welling.
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
         | 
| 546 | 
            +
                implements for all the model (such as downloading or saving, etc.)
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                Parameters:
         | 
| 549 | 
            +
                    in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
         | 
| 550 | 
            +
                    out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
         | 
| 551 | 
            +
                    down_block_types (`Tuple[str]`, *optional*, defaults to :
         | 
| 552 | 
            +
                        obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
         | 
| 553 | 
            +
                    up_block_types (`Tuple[str]`, *optional*, defaults to :
         | 
| 554 | 
            +
                        obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
         | 
| 555 | 
            +
                    block_out_channels (`Tuple[int]`, *optional*, defaults to :
         | 
| 556 | 
            +
                        obj:`(64,)`): Tuple of block output channels.
         | 
| 557 | 
            +
                    act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
         | 
| 558 | 
            +
                    latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
         | 
| 559 | 
            +
                    sample_size (`int`, *optional*, defaults to `32`): TODO
         | 
| 560 | 
            +
                """
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                @register_to_config
         | 
| 563 | 
            +
                def __init__(
         | 
| 564 | 
            +
                    self,
         | 
| 565 | 
            +
                    in_channels: int = 3,
         | 
| 566 | 
            +
                    out_channels: int = 3,
         | 
| 567 | 
            +
                    down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
         | 
| 568 | 
            +
                    up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
         | 
| 569 | 
            +
                    block_out_channels: Tuple[int] = (64,),
         | 
| 570 | 
            +
                    layers_per_block: int = 1,
         | 
| 571 | 
            +
                    act_fn: str = "silu",
         | 
| 572 | 
            +
                    latent_channels: int = 4,
         | 
| 573 | 
            +
                    norm_num_groups: int = 32,
         | 
| 574 | 
            +
                    sample_size: int = 32,
         | 
| 575 | 
            +
                    num_slices: int = 16,
         | 
| 576 | 
            +
                ):
         | 
| 577 | 
            +
                    super().__init__()
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                    # pass init params to Encoder
         | 
| 580 | 
            +
                    self.encoder = SlicingEncoder(
         | 
| 581 | 
            +
                        in_channels=in_channels,
         | 
| 582 | 
            +
                        out_channels=latent_channels,
         | 
| 583 | 
            +
                        down_block_types=down_block_types,
         | 
| 584 | 
            +
                        block_out_channels=block_out_channels,
         | 
| 585 | 
            +
                        layers_per_block=layers_per_block,
         | 
| 586 | 
            +
                        act_fn=act_fn,
         | 
| 587 | 
            +
                        norm_num_groups=norm_num_groups,
         | 
| 588 | 
            +
                        double_z=True,
         | 
| 589 | 
            +
                        num_slices=num_slices,
         | 
| 590 | 
            +
                    )
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    # pass init params to Decoder
         | 
| 593 | 
            +
                    self.decoder = SlicingDecoder(
         | 
| 594 | 
            +
                        in_channels=latent_channels,
         | 
| 595 | 
            +
                        out_channels=out_channels,
         | 
| 596 | 
            +
                        up_block_types=up_block_types,
         | 
| 597 | 
            +
                        block_out_channels=block_out_channels,
         | 
| 598 | 
            +
                        layers_per_block=layers_per_block,
         | 
| 599 | 
            +
                        norm_num_groups=norm_num_groups,
         | 
| 600 | 
            +
                        act_fn=act_fn,
         | 
| 601 | 
            +
                        num_slices=num_slices,
         | 
| 602 | 
            +
                    )
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                    self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
         | 
| 605 | 
            +
                    self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
         | 
| 606 | 
            +
                    self.use_slicing = False
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
         | 
| 609 | 
            +
                    h = self.encoder(x)
         | 
| 610 | 
            +
                    moments = self.quant_conv(h)
         | 
| 611 | 
            +
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                    if not return_dict:
         | 
| 614 | 
            +
                        return (posterior,)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                    return AutoencoderKLOutput(latent_dist=posterior)
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 619 | 
            +
                    z = self.post_quant_conv(z)
         | 
| 620 | 
            +
                    dec = self.decoder(z)
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    if not return_dict:
         | 
| 623 | 
            +
                        return (dec,)
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    return DecoderOutput(sample=dec)
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                # これはバッチ方向のスライシング 紛らわしい
         | 
| 628 | 
            +
                def enable_slicing(self):
         | 
| 629 | 
            +
                    r"""
         | 
| 630 | 
            +
                    Enable sliced VAE decoding.
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                    When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
         | 
| 633 | 
            +
                    steps. This is useful to save some memory and allow larger batch sizes.
         | 
| 634 | 
            +
                    """
         | 
| 635 | 
            +
                    self.use_slicing = True
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                def disable_slicing(self):
         | 
| 638 | 
            +
                    r"""
         | 
| 639 | 
            +
                    Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
         | 
| 640 | 
            +
                    decoding in one step.
         | 
| 641 | 
            +
                    """
         | 
| 642 | 
            +
                    self.use_slicing = False
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 645 | 
            +
                    if self.use_slicing and z.shape[0] > 1:
         | 
| 646 | 
            +
                        decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
         | 
| 647 | 
            +
                        decoded = torch.cat(decoded_slices)
         | 
| 648 | 
            +
                    else:
         | 
| 649 | 
            +
                        decoded = self._decode(z).sample
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                    if not return_dict:
         | 
| 652 | 
            +
                        return (decoded,)
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                    return DecoderOutput(sample=decoded)
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                def forward(
         | 
| 657 | 
            +
                    self,
         | 
| 658 | 
            +
                    sample: torch.FloatTensor,
         | 
| 659 | 
            +
                    sample_posterior: bool = False,
         | 
| 660 | 
            +
                    return_dict: bool = True,
         | 
| 661 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 662 | 
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 663 | 
            +
                    r"""
         | 
| 664 | 
            +
                    Args:
         | 
| 665 | 
            +
                        sample (`torch.FloatTensor`): Input sample.
         | 
| 666 | 
            +
                        sample_posterior (`bool`, *optional*, defaults to `False`):
         | 
| 667 | 
            +
                            Whether to sample from the posterior.
         | 
| 668 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 669 | 
            +
                            Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
         | 
| 670 | 
            +
                    """
         | 
| 671 | 
            +
                    x = sample
         | 
| 672 | 
            +
                    posterior = self.encode(x).latent_dist
         | 
| 673 | 
            +
                    if sample_posterior:
         | 
| 674 | 
            +
                        z = posterior.sample(generator=generator)
         | 
| 675 | 
            +
                    else:
         | 
| 676 | 
            +
                        z = posterior.mode()
         | 
| 677 | 
            +
                    dec = self.decode(z).sample
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    if not return_dict:
         | 
| 680 | 
            +
                        return (dec,)
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                    return DecoderOutput(sample=dec)
         | 
    	
        library/train_util.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        library/utils.py
    ADDED
    
    | @@ -0,0 +1,287 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import threading
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from torchvision import transforms
         | 
| 6 | 
            +
            from typing import *
         | 
| 7 | 
            +
            from diffusers import EulerAncestralDiscreteScheduler
         | 
| 8 | 
            +
            import diffusers.schedulers.scheduling_euler_ancestral_discrete
         | 
| 9 | 
            +
            from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
         | 
| 10 | 
            +
            import cv2
         | 
| 11 | 
            +
            from PIL import Image
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def fire_in_thread(f, *args, **kwargs):
         | 
| 16 | 
            +
                threading.Thread(target=f, args=args, kwargs=kwargs).start()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def add_logging_arguments(parser):
         | 
| 20 | 
            +
                parser.add_argument(
         | 
| 21 | 
            +
                    "--console_log_level",
         | 
| 22 | 
            +
                    type=str,
         | 
| 23 | 
            +
                    default=None,
         | 
| 24 | 
            +
                    choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
         | 
| 25 | 
            +
                    help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
         | 
| 26 | 
            +
                )
         | 
| 27 | 
            +
                parser.add_argument(
         | 
| 28 | 
            +
                    "--console_log_file",
         | 
| 29 | 
            +
                    type=str,
         | 
| 30 | 
            +
                    default=None,
         | 
| 31 | 
            +
                    help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
                parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def setup_logging(args=None, log_level=None, reset=False):
         | 
| 37 | 
            +
                if logging.root.handlers:
         | 
| 38 | 
            +
                    if reset:
         | 
| 39 | 
            +
                        # remove all handlers
         | 
| 40 | 
            +
                        for handler in logging.root.handlers[:]:
         | 
| 41 | 
            +
                            logging.root.removeHandler(handler)
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        return
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
         | 
| 46 | 
            +
                if log_level is None and args is not None:
         | 
| 47 | 
            +
                    log_level = args.console_log_level
         | 
| 48 | 
            +
                if log_level is None:
         | 
| 49 | 
            +
                    log_level = "INFO"
         | 
| 50 | 
            +
                log_level = getattr(logging, log_level)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                msg_init = None
         | 
| 53 | 
            +
                if args is not None and args.console_log_file:
         | 
| 54 | 
            +
                    handler = logging.FileHandler(args.console_log_file, mode="w")
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    handler = None
         | 
| 57 | 
            +
                    if not args or not args.console_log_simple:
         | 
| 58 | 
            +
                        try:
         | 
| 59 | 
            +
                            from rich.logging import RichHandler
         | 
| 60 | 
            +
                            from rich.console import Console
         | 
| 61 | 
            +
                            from rich.logging import RichHandler
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                            handler = RichHandler(console=Console(stderr=True))
         | 
| 64 | 
            +
                        except ImportError:
         | 
| 65 | 
            +
                            # print("rich is not installed, using basic logging")
         | 
| 66 | 
            +
                            msg_init = "rich is not installed, using basic logging"
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    if handler is None:
         | 
| 69 | 
            +
                        handler = logging.StreamHandler(sys.stdout)  # same as print
         | 
| 70 | 
            +
                        handler.propagate = False
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                formatter = logging.Formatter(
         | 
| 73 | 
            +
                    fmt="%(message)s",
         | 
| 74 | 
            +
                    datefmt="%Y-%m-%d %H:%M:%S",
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
                handler.setFormatter(formatter)
         | 
| 77 | 
            +
                logging.root.setLevel(log_level)
         | 
| 78 | 
            +
                logging.root.addHandler(handler)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                if msg_init is not None:
         | 
| 81 | 
            +
                    logger = logging.getLogger(__name__)
         | 
| 82 | 
            +
                    logger.info(msg_init)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def pil_resize(image, size, interpolation=Image.LANCZOS):
         | 
| 86 | 
            +
                has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if has_alpha:
         | 
| 89 | 
            +
                    pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
         | 
| 90 | 
            +
                else:
         | 
| 91 | 
            +
                    pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                resized_pil = pil_image.resize(size, interpolation)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                # Convert back to cv2 format
         | 
| 96 | 
            +
                if has_alpha:
         | 
| 97 | 
            +
                    resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
         | 
| 98 | 
            +
                else:
         | 
| 99 | 
            +
                    resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                return resized_cv2
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            # TODO make inf_utils.py
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            # region Gradual Latent hires fix
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            class GradualLatent:
         | 
| 111 | 
            +
                def __init__(
         | 
| 112 | 
            +
                    self,
         | 
| 113 | 
            +
                    ratio,
         | 
| 114 | 
            +
                    start_timesteps,
         | 
| 115 | 
            +
                    every_n_steps,
         | 
| 116 | 
            +
                    ratio_step,
         | 
| 117 | 
            +
                    s_noise=1.0,
         | 
| 118 | 
            +
                    gaussian_blur_ksize=None,
         | 
| 119 | 
            +
                    gaussian_blur_sigma=0.5,
         | 
| 120 | 
            +
                    gaussian_blur_strength=0.5,
         | 
| 121 | 
            +
                    unsharp_target_x=True,
         | 
| 122 | 
            +
                ):
         | 
| 123 | 
            +
                    self.ratio = ratio
         | 
| 124 | 
            +
                    self.start_timesteps = start_timesteps
         | 
| 125 | 
            +
                    self.every_n_steps = every_n_steps
         | 
| 126 | 
            +
                    self.ratio_step = ratio_step
         | 
| 127 | 
            +
                    self.s_noise = s_noise
         | 
| 128 | 
            +
                    self.gaussian_blur_ksize = gaussian_blur_ksize
         | 
| 129 | 
            +
                    self.gaussian_blur_sigma = gaussian_blur_sigma
         | 
| 130 | 
            +
                    self.gaussian_blur_strength = gaussian_blur_strength
         | 
| 131 | 
            +
                    self.unsharp_target_x = unsharp_target_x
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def __str__(self) -> str:
         | 
| 134 | 
            +
                    return (
         | 
| 135 | 
            +
                        f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
         | 
| 136 | 
            +
                        + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
         | 
| 137 | 
            +
                        + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
         | 
| 138 | 
            +
                        + f"unsharp_target_x={self.unsharp_target_x})"
         | 
| 139 | 
            +
                    )
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def apply_unshark_mask(self, x: torch.Tensor):
         | 
| 142 | 
            +
                    if self.gaussian_blur_ksize is None:
         | 
| 143 | 
            +
                        return x
         | 
| 144 | 
            +
                    blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
         | 
| 145 | 
            +
                    # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
         | 
| 146 | 
            +
                    mask = (x - blurred) * self.gaussian_blur_strength
         | 
| 147 | 
            +
                    sharpened = x + mask
         | 
| 148 | 
            +
                    return sharpened
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
         | 
| 151 | 
            +
                    org_dtype = x.dtype
         | 
| 152 | 
            +
                    if org_dtype == torch.bfloat16:
         | 
| 153 | 
            +
                        x = x.float()
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # apply unsharp mask / アンシャープマスクを適用する
         | 
| 158 | 
            +
                    if unsharp and self.gaussian_blur_ksize:
         | 
| 159 | 
            +
                        x = self.apply_unshark_mask(x)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return x
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
         | 
| 165 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 166 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 167 | 
            +
                    self.resized_size = None
         | 
| 168 | 
            +
                    self.gradual_latent = None
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
         | 
| 171 | 
            +
                    self.resized_size = size
         | 
| 172 | 
            +
                    self.gradual_latent = gradual_latent
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def step(
         | 
| 175 | 
            +
                    self,
         | 
| 176 | 
            +
                    model_output: torch.FloatTensor,
         | 
| 177 | 
            +
                    timestep: Union[float, torch.FloatTensor],
         | 
| 178 | 
            +
                    sample: torch.FloatTensor,
         | 
| 179 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 180 | 
            +
                    return_dict: bool = True,
         | 
| 181 | 
            +
                ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         | 
| 184 | 
            +
                    process from the learned model outputs (most often the predicted noise).
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    Args:
         | 
| 187 | 
            +
                        model_output (`torch.FloatTensor`):
         | 
| 188 | 
            +
                            The direct output from learned diffusion model.
         | 
| 189 | 
            +
                        timestep (`float`):
         | 
| 190 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 191 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 192 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 193 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 194 | 
            +
                            A random number generator.
         | 
| 195 | 
            +
                        return_dict (`bool`):
         | 
| 196 | 
            +
                            Whether or not to return a
         | 
| 197 | 
            +
                            [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    Returns:
         | 
| 200 | 
            +
                        [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
         | 
| 201 | 
            +
                            If return_dict is `True`,
         | 
| 202 | 
            +
                            [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
         | 
| 203 | 
            +
                            otherwise a tuple is returned where the first element is the sample tensor.
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
         | 
| 208 | 
            +
                        raise ValueError(
         | 
| 209 | 
            +
                            (
         | 
| 210 | 
            +
                                "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
         | 
| 211 | 
            +
                                " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
         | 
| 212 | 
            +
                                " one of the `scheduler.timesteps` as a timestep."
         | 
| 213 | 
            +
                            ),
         | 
| 214 | 
            +
                        )
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    if not self.is_scale_input_called:
         | 
| 217 | 
            +
                        # logger.warning(
         | 
| 218 | 
            +
                        print(
         | 
| 219 | 
            +
                            "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
         | 
| 220 | 
            +
                            "See `StableDiffusionPipeline` for a usage example."
         | 
| 221 | 
            +
                        )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    if self.step_index is None:
         | 
| 224 | 
            +
                        self._init_step_index(timestep)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    sigma = self.sigmas[self.step_index]
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
         | 
| 229 | 
            +
                    if self.config.prediction_type == "epsilon":
         | 
| 230 | 
            +
                        pred_original_sample = sample - sigma * model_output
         | 
| 231 | 
            +
                    elif self.config.prediction_type == "v_prediction":
         | 
| 232 | 
            +
                        # * c_out + input * c_skip
         | 
| 233 | 
            +
                        pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
         | 
| 234 | 
            +
                    elif self.config.prediction_type == "sample":
         | 
| 235 | 
            +
                        raise NotImplementedError("prediction_type not implemented yet: sample")
         | 
| 236 | 
            +
                    else:
         | 
| 237 | 
            +
                        raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    sigma_from = self.sigmas[self.step_index]
         | 
| 240 | 
            +
                    sigma_to = self.sigmas[self.step_index + 1]
         | 
| 241 | 
            +
                    sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
         | 
| 242 | 
            +
                    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    # 2. Convert to an ODE derivative
         | 
| 245 | 
            +
                    derivative = (sample - pred_original_sample) / sigma
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    dt = sigma_down - sigma
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    device = model_output.device
         | 
| 250 | 
            +
                    if self.resized_size is None:
         | 
| 251 | 
            +
                        prev_sample = sample + derivative * dt
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
         | 
| 254 | 
            +
                            model_output.shape, dtype=model_output.dtype, device=device, generator=generator
         | 
| 255 | 
            +
                        )
         | 
| 256 | 
            +
                        s_noise = 1.0
         | 
| 257 | 
            +
                    else:
         | 
| 258 | 
            +
                        print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
         | 
| 259 | 
            +
                        s_noise = self.gradual_latent.s_noise
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                        if self.gradual_latent.unsharp_target_x:
         | 
| 262 | 
            +
                            prev_sample = sample + derivative * dt
         | 
| 263 | 
            +
                            prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
         | 
| 264 | 
            +
                        else:
         | 
| 265 | 
            +
                            sample = self.gradual_latent.interpolate(sample, self.resized_size)
         | 
| 266 | 
            +
                            derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
         | 
| 267 | 
            +
                            prev_sample = sample + derivative * dt
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                        noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
         | 
| 270 | 
            +
                            (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
         | 
| 271 | 
            +
                            dtype=model_output.dtype,
         | 
| 272 | 
            +
                            device=device,
         | 
| 273 | 
            +
                            generator=generator,
         | 
| 274 | 
            +
                        )
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    prev_sample = prev_sample + noise * sigma_up * s_noise
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    # upon completion increase step index by one
         | 
| 279 | 
            +
                    self._step_index += 1
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    if not return_dict:
         | 
| 282 | 
            +
                        return (prev_sample,)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
             | 
| 287 | 
            +
            # endregion
         | 
