Spaces:
Running
on
A10G
Running
on
A10G
| """This file contains methods for inference and image generation.""" | |
| import logging | |
| from typing import List, Tuple, Dict | |
| import streamlit as st | |
| import torch | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| from time import perf_counter | |
| from contextlib import contextmanager | |
| from scipy.signal import fftconvolve | |
| from PIL import ImageFilter | |
| from transformers import AutoImageProcessor, UperNetForSemanticSegmentation | |
| from diffusers import ControlNetModel, UniPCMultistepScheduler | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from compel import Compel | |
| from config import WIDTH, HEIGHT | |
| from palette import ade_palette | |
| from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline | |
| LOGGING = logging.getLogger(__name__) | |
| class ControlNetPipeline: | |
| def __init__(self): | |
| self.in_use = False | |
| self.controlnet = ControlNetModel.from_pretrained( | |
| "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16) | |
| self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| controlnet=self.controlnet, | |
| safety_checker=None, | |
| torch_dtype=torch.float16 | |
| ) | |
| self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) | |
| self.pipe.enable_xformers_memory_efficient_attention() | |
| self.pipe = self.pipe.to("cuda") | |
| self.waiting_queue = [] | |
| self.count = 0 | |
| def __call__(self, **kwargs): | |
| self.count += 1 | |
| number = self.count | |
| self.waiting_queue.append(number) | |
| # wait until the next number in the queue is the current number | |
| while self.waiting_queue[0] != number: | |
| print(f"Wait for your turn {number} in queue {self.waiting_queue}") | |
| time.sleep(0.5) | |
| pass | |
| # it's your turn, so remove the number from the queue | |
| # and call the function | |
| print("It's the turn of", self.count) | |
| return self.pipe(**kwargs) | |
| self.waiting_queue.pop(0) | |
| def catchtime(message: str) -> float: | |
| """Context manager to measure time | |
| Args: | |
| message (str): message to log | |
| Returns: | |
| float: time in seconds | |
| Yields: | |
| Iterator[float]: time in seconds | |
| """ | |
| start = perf_counter() | |
| yield lambda: perf_counter() - start | |
| LOGGING.info('%s: %.3f seconds', message, perf_counter() - start) | |
| def convolution(mask: Image.Image, size=9) -> Image: | |
| """Method to blur the mask | |
| Args: | |
| mask (Image): masking image | |
| size (int, optional): size of the blur. Defaults to 9. | |
| Returns: | |
| Image: blurred mask | |
| """ | |
| mask = np.array(mask.convert("L")) | |
| conv = np.ones((size, size)) / size**2 | |
| mask_blended = fftconvolve(mask, conv, 'same') | |
| mask_blended = mask_blended.astype(np.uint8).copy() | |
| border = size | |
| # replace borders with original values | |
| mask_blended[:border, :] = mask[:border, :] | |
| mask_blended[-border:, :] = mask[-border:, :] | |
| mask_blended[:, :border] = mask[:, :border] | |
| mask_blended[:, -border:] = mask[:, -border:] | |
| return Image.fromarray(mask_blended).convert("L") | |
| def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image: | |
| """Method to postprocess the inpainted image | |
| Args: | |
| inpainted (Image): inpainted image | |
| image (Image): original image | |
| mask (Image): mask | |
| Returns: | |
| Image: inpainted image | |
| """ | |
| final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask) | |
| return final_inpainted.convert("RGB") | |
| def get_controlnet() -> ControlNetModel: | |
| """Method to load the controlnet model | |
| Returns: | |
| ControlNetModel: controlnet model | |
| """ | |
| pipe = ControlNetPipeline() | |
| return pipe | |
| def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: | |
| """Method to load the segmentation pipeline | |
| Returns: | |
| Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline | |
| """ | |
| image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") | |
| image_segmentor = UperNetForSemanticSegmentation.from_pretrained( | |
| "openmmlab/upernet-convnext-small") | |
| return image_processor, image_segmentor | |
| def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline: | |
| """Method to load the inpainting pipeline | |
| Returns: | |
| StableDiffusionInpaintPipeline: inpainting pipeline | |
| """ | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", | |
| torch_dtype=torch.float16, | |
| safety_checker=None, | |
| ) | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe = pipe.to("cuda") | |
| return pipe | |
| def make_grid_parameters(grid_search: Dict, params: Dict) -> List[Dict]: | |
| """Method to make grid parameters | |
| Args: | |
| grid_search (Dict): grid search parameters | |
| params (Dict): fixed parameters | |
| Returns: | |
| List[Dict]: grid parameters | |
| """ | |
| options = [] | |
| for k in range(len(grid_search['generator'])): | |
| for i in range(len(grid_search['strength'])): | |
| for j in range(len(grid_search['guidance_scale'])): | |
| options.append({'strength': grid_search['strength'][i], | |
| 'guidance_scale': grid_search['guidance_scale'][j], | |
| 'generator': grid_search['generator'][k], | |
| **params | |
| }) | |
| return options | |
| def make_captions(options: List[Dict]) -> List[str]: | |
| """Method to make captions | |
| Args: | |
| options (List[Dict]): grid parameters | |
| Returns: | |
| List[str]: captions | |
| """ | |
| captions = [] | |
| for option in options: | |
| captions.append( | |
| f"strength {option['strength']}, guidance {option['guidance_scale']}, steps {option['num_inference_steps']}") | |
| return captions | |
| def make_image_controlnet(image: np.ndarray, | |
| mask_image: np.ndarray, | |
| controlnet_conditioning_image: np.ndarray, | |
| positive_prompt: str, negative_prompt: str, | |
| seed: int = 2356132) -> List[Image.Image]: | |
| """Method to make image using controlnet | |
| Args: | |
| image (np.ndarray): input image | |
| mask_image (np.ndarray): mask image | |
| controlnet_conditioning_image (np.ndarray): conditioning image | |
| positive_prompt (str): positive prompt string | |
| negative_prompt (str): negative prompt string | |
| seed (int, optional): seed. Defaults to 2356132. | |
| Returns: | |
| List[Image.Image]: list of generated images | |
| """ | |
| with catchtime("get controlnet"): | |
| pipe = get_controlnet() | |
| torch.cuda.empty_cache() | |
| images = [] | |
| common_parameters = {'prompt': positive_prompt, | |
| 'negative_prompt': negative_prompt, | |
| 'num_inference_steps': 30, | |
| 'controlnet_conditioning_scale': 1.1, | |
| 'controlnet_conditioning_scale_decay': 0.96, | |
| 'controlnet_steps': 28, | |
| } | |
| grid_search = {'strength': [1.00, ], | |
| 'guidance_scale': [7.0], | |
| 'generator': [[torch.Generator(device="cuda").manual_seed(seed+i)] for i in range(1)], | |
| } | |
| prompt_settings = make_grid_parameters(grid_search, common_parameters) | |
| mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB") | |
| image = Image.fromarray(image).convert("RGB") | |
| controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB").filter(ImageFilter.GaussianBlur(radius = 9)) | |
| mask_image_postproc = convolution(mask_image) | |
| with catchtime("Controlnet generation total"): | |
| for _, setting in enumerate(prompt_settings): | |
| with catchtime("Controlnet generation"): | |
| generated_image = pipe( | |
| **setting, | |
| image=image, | |
| mask_image=mask_image, | |
| controlnet_conditioning_image=controlnet_conditioning_image, | |
| ).images[0] | |
| generated_image = postprocess_image_masking( | |
| generated_image, image, mask_image_postproc) | |
| images.append(generated_image) | |
| return images | |
| def make_inpainting(positive_prompt: str, | |
| image: Image, | |
| mask_image: np.ndarray, | |
| negative_prompt: str = "") -> List[Image.Image]: | |
| """Method to make inpainting | |
| Args: | |
| positive_prompt (str): positive prompt string | |
| image (Image): input image | |
| mask_image (np.ndarray): mask image | |
| negative_prompt (str, optional): negative prompt string. Defaults to "". | |
| Returns: | |
| List[Image.Image]: list of generated images | |
| """ | |
| with catchtime("Get inpainting pipeline"): | |
| pipe = get_inpainting_pipeline() | |
| common_parameters = {'prompt': positive_prompt, | |
| 'negative_prompt': negative_prompt, | |
| 'num_inference_steps': 20, | |
| } | |
| torch.cuda.empty_cache() | |
| images = [] | |
| for _ in range(1): | |
| with catchtime("Inpainting generation"): | |
| image_ = pipe(image=image, | |
| mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)), | |
| height=HEIGHT, | |
| width=WIDTH, | |
| **common_parameters | |
| ).images[0] | |
| images.append(image_) | |
| return images | |
| def segment_image(image: Image) -> Image: | |
| """Method to segment image | |
| Args: | |
| image (Image): input image | |
| Returns: | |
| Image: segmented image | |
| """ | |
| image_processor, image_segmentor = get_segmentation_pipeline() | |
| pixel_values = image_processor(image, return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| outputs = image_segmentor(pixel_values) | |
| seg = image_processor.post_process_semantic_segmentation( | |
| outputs, target_sizes=[image.size[::-1]]) | |
| seg = seg[0] | |
| color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 | |
| palette = np.array(ade_palette()) | |
| for label, color in enumerate(palette): | |
| color_seg[seg == label, :] = color | |
| color_seg = color_seg.astype(np.uint8) | |
| seg_image = Image.fromarray(color_seg).convert('RGB') | |
| return seg_image |