from diffusers import DiffusionPipeline import torch from torchvision import transforms from tqdm import tqdm class DDPMSketch2ImgPipeline(DiffusionPipeline): # TODO: Move transforms to another class def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) def __call__(self, sketch, num_inference_step=1000, tqdm_leave=True): # sketch : PIL # returl : PIL sketch = transforms.functional.pil_to_tensor(sketch).float() sketch = self.normalize(sketch).to(self.device) sketch = sketch.unsqueeze(0) image = self.sample(sketch, num_inference_step, tqdm_leave) image = image.squeeze(0) image = self.denormalize(image) image = self.denormalized_tensor_to_pil(image) return image def sample(self, transformed_sketch, num_inference_step, tqdm_leave=True): assert ( len(transformed_sketch.shape) == 4 ), f"(bs, c, h, w) but {transformed_sketch.shape}" # Is this the right place to set timesteps? self.scheduler.set_timesteps(num_inference_step, device=self.device) s = transformed_sketch.shape # Assume image's channels == out_channels image = torch.randn((s[0], self.unet.config["out_channels"], s[2], s[3])).to( self.device ) for t in tqdm(self.scheduler.timesteps, leave=tqdm_leave): model_input = torch.concat([image, transformed_sketch], dim=1).to( self.device ) with torch.no_grad(): model_output = self.unet(model_input, t).sample image = self.scheduler.step(model_output, t, image).prev_sample return image def denormalized_tensor_to_pil(self, tensor): assert len(tensor.shape) == 3, f"(c, h, w) but {tensor.shape}" tensor = tensor.cpu().clip(0, 255).to(torch.uint8) pil = transforms.functional.to_pil_image(tensor) return pil def normalize(self, x): assert x.dtype == torch.float # map x to -1 < x < 1 # I'm doing normalization with zero understanding :o x = x / 255.0 x = transforms.Normalize([0.5], [0.5])(x) return x def denormalize(self, x): assert x.dtype == torch.float x = x * 0.5 + 0.5 # map from (-1, 1) back to (0, 1) x = x * 255.0 return x