|
import logging
|
|
import sys
|
|
import threading
|
|
import torch
|
|
from torchvision import transforms
|
|
from typing import *
|
|
from diffusers import EulerAncestralDiscreteScheduler
|
|
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
|
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
|
import cv2
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
|
|
def fire_in_thread(f, *args, **kwargs):
|
|
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
|
|
|
|
|
def add_logging_arguments(parser):
|
|
parser.add_argument(
|
|
"--console_log_level",
|
|
type=str,
|
|
default=None,
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
|
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
|
|
)
|
|
parser.add_argument(
|
|
"--console_log_file",
|
|
type=str,
|
|
default=None,
|
|
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
|
|
)
|
|
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
|
|
|
|
|
|
def setup_logging(args=None, log_level=None, reset=False):
|
|
if logging.root.handlers:
|
|
if reset:
|
|
|
|
for handler in logging.root.handlers[:]:
|
|
logging.root.removeHandler(handler)
|
|
else:
|
|
return
|
|
|
|
|
|
if log_level is None and args is not None:
|
|
log_level = args.console_log_level
|
|
if log_level is None:
|
|
log_level = "INFO"
|
|
log_level = getattr(logging, log_level)
|
|
|
|
msg_init = None
|
|
if args is not None and args.console_log_file:
|
|
handler = logging.FileHandler(args.console_log_file, mode="w")
|
|
else:
|
|
handler = None
|
|
if not args or not args.console_log_simple:
|
|
try:
|
|
from rich.logging import RichHandler
|
|
from rich.console import Console
|
|
from rich.logging import RichHandler
|
|
|
|
handler = RichHandler(console=Console(stderr=True))
|
|
except ImportError:
|
|
|
|
msg_init = "rich is not installed, using basic logging"
|
|
|
|
if handler is None:
|
|
handler = logging.StreamHandler(sys.stdout)
|
|
handler.propagate = False
|
|
|
|
formatter = logging.Formatter(
|
|
fmt="%(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
handler.setFormatter(formatter)
|
|
logging.root.setLevel(log_level)
|
|
logging.root.addHandler(handler)
|
|
|
|
if msg_init is not None:
|
|
logger = logging.getLogger(__name__)
|
|
logger.info(msg_init)
|
|
|
|
|
|
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
|
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
|
|
|
if has_alpha:
|
|
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
|
else:
|
|
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
|
|
|
resized_pil = pil_image.resize(size, interpolation)
|
|
|
|
|
|
if has_alpha:
|
|
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
|
else:
|
|
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
|
|
|
return resized_cv2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradualLatent:
|
|
def __init__(
|
|
self,
|
|
ratio,
|
|
start_timesteps,
|
|
every_n_steps,
|
|
ratio_step,
|
|
s_noise=1.0,
|
|
gaussian_blur_ksize=None,
|
|
gaussian_blur_sigma=0.5,
|
|
gaussian_blur_strength=0.5,
|
|
unsharp_target_x=True,
|
|
):
|
|
self.ratio = ratio
|
|
self.start_timesteps = start_timesteps
|
|
self.every_n_steps = every_n_steps
|
|
self.ratio_step = ratio_step
|
|
self.s_noise = s_noise
|
|
self.gaussian_blur_ksize = gaussian_blur_ksize
|
|
self.gaussian_blur_sigma = gaussian_blur_sigma
|
|
self.gaussian_blur_strength = gaussian_blur_strength
|
|
self.unsharp_target_x = unsharp_target_x
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
|
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
|
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
|
+ f"unsharp_target_x={self.unsharp_target_x})"
|
|
)
|
|
|
|
def apply_unshark_mask(self, x: torch.Tensor):
|
|
if self.gaussian_blur_ksize is None:
|
|
return x
|
|
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
|
|
|
mask = (x - blurred) * self.gaussian_blur_strength
|
|
sharpened = x + mask
|
|
return sharpened
|
|
|
|
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
|
org_dtype = x.dtype
|
|
if org_dtype == torch.bfloat16:
|
|
x = x.float()
|
|
|
|
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
|
|
|
|
|
if unsharp and self.gaussian_blur_ksize:
|
|
x = self.apply_unshark_mask(x)
|
|
|
|
return x
|
|
|
|
|
|
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.resized_size = None
|
|
self.gradual_latent = None
|
|
|
|
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
|
self.resized_size = size
|
|
self.gradual_latent = gradual_latent
|
|
|
|
def step(
|
|
self,
|
|
model_output: torch.FloatTensor,
|
|
timestep: Union[float, torch.FloatTensor],
|
|
sample: torch.FloatTensor,
|
|
generator: Optional[torch.Generator] = None,
|
|
return_dict: bool = True,
|
|
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
|
"""
|
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
|
process from the learned model outputs (most often the predicted noise).
|
|
|
|
Args:
|
|
model_output (`torch.FloatTensor`):
|
|
The direct output from learned diffusion model.
|
|
timestep (`float`):
|
|
The current discrete timestep in the diffusion chain.
|
|
sample (`torch.FloatTensor`):
|
|
A current instance of a sample created by the diffusion process.
|
|
generator (`torch.Generator`, *optional*):
|
|
A random number generator.
|
|
return_dict (`bool`):
|
|
Whether or not to return a
|
|
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
|
|
|
Returns:
|
|
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
|
If return_dict is `True`,
|
|
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
|
otherwise a tuple is returned where the first element is the sample tensor.
|
|
|
|
"""
|
|
|
|
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
|
raise ValueError(
|
|
(
|
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
|
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
|
" one of the `scheduler.timesteps` as a timestep."
|
|
),
|
|
)
|
|
|
|
if not self.is_scale_input_called:
|
|
|
|
print(
|
|
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
|
"See `StableDiffusionPipeline` for a usage example."
|
|
)
|
|
|
|
if self.step_index is None:
|
|
self._init_step_index(timestep)
|
|
|
|
sigma = self.sigmas[self.step_index]
|
|
|
|
|
|
if self.config.prediction_type == "epsilon":
|
|
pred_original_sample = sample - sigma * model_output
|
|
elif self.config.prediction_type == "v_prediction":
|
|
|
|
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
|
elif self.config.prediction_type == "sample":
|
|
raise NotImplementedError("prediction_type not implemented yet: sample")
|
|
else:
|
|
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
|
|
|
sigma_from = self.sigmas[self.step_index]
|
|
sigma_to = self.sigmas[self.step_index + 1]
|
|
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
|
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
|
|
|
|
|
derivative = (sample - pred_original_sample) / sigma
|
|
|
|
dt = sigma_down - sigma
|
|
|
|
device = model_output.device
|
|
if self.resized_size is None:
|
|
prev_sample = sample + derivative * dt
|
|
|
|
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
|
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
|
)
|
|
s_noise = 1.0
|
|
else:
|
|
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
|
s_noise = self.gradual_latent.s_noise
|
|
|
|
if self.gradual_latent.unsharp_target_x:
|
|
prev_sample = sample + derivative * dt
|
|
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
|
else:
|
|
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
|
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
|
prev_sample = sample + derivative * dt
|
|
|
|
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
|
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
|
dtype=model_output.dtype,
|
|
device=device,
|
|
generator=generator,
|
|
)
|
|
|
|
prev_sample = prev_sample + noise * sigma_up * s_noise
|
|
|
|
|
|
self._step_index += 1
|
|
|
|
if not return_dict:
|
|
return (prev_sample,)
|
|
|
|
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
|
|
|
|
|
|
|
|