naderasadi's picture
Initial commit
5b2ab1c
raw
history blame
5.07 kB
from typing import Any, List, Optional, Tuple, Union
import itertools
from PIL import Image
import numpy as np
import torch
from controlnet_aux import MLSDdetector, PidiNetDetector, HEDdetector
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
MODEL_DICT = {
"mlsd": {
"name": "lllyasviel/Annotators",
"detector": MLSDdetector,
"model": "lllyasviel/control_v11p_sd15_mlsd",
},
"soft_edge": {
"name": "lllyasviel/Annotators",
"detector": PidiNetDetector,
"model": "lllyasviel/control_v11p_sd15_softedge",
},
"hed": {
"name": "lllyasviel/Annotators",
"detector": HEDdetector,
"model": "lllyasviel/sd-controlnet-hed",
},
"scribble": {
"name": "lllyasviel/Annotators",
"detector": HEDdetector,
"model": "lllyasviel/control_v11p_sd15_scribble",
},
}
class StableDiffusionControlNet:
"""ControlNet pipeline for generating images from prompts.
Args:
control_model_name (str):
Name of the controlnet processor.
sd_model_name (str):
Name of the StableDiffusion model.
"""
def __init__(
self,
control_model_name: str,
sd_model_name: Optional[str] = "runwayml/stable-diffusion-v1-5",
) -> None:
self.processor = MODEL_DICT[control_model_name]["detector"].from_pretrained(
MODEL_DICT[control_model_name]["name"]
)
self.pipe = self.create_pipe(
sd_model_name=sd_model_name, control_model_name=control_model_name
)
def _repeat(self, items: List[Any], n: int) -> List[Any]:
"""Repeat items in a list n times.
Args:
items (List[Any]): List of items to be repeated.
n (int): Number of repetitions.
Returns:
List[Any]: List of repeated items.
"""
return list(
itertools.chain.from_iterable(itertools.repeat(item, n) for item in items)
)
def generate_control_images(self, images: List[Image.Image]) -> List[Image.Image]:
"""Generate control images from input images.
Args:
images (List[Image.Image]): Input images.
Returns:
List[Image.Image]: Control images.
"""
return [self.processor(image) for image in images]
def create_pipe(
self, sd_model_name: str, control_model_name: str
) -> StableDiffusionControlNetPipeline:
"""Create a StableDiffusionControlNetPipeline.
Args:
sd_model_name (str): StableDiffusion model name.
control_model_name (str): Name of the ControlNet module.
Returns:
StableDiffusionControlNetPipeline
"""
controlnet = ControlNetModel.from_pretrained(
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
return pipe
def process(
self,
images: List[Image.Image],
prompts: List[str],
negative_prompt: Optional[str] = None,
n_outputs: Optional[int] = 1,
num_inference_steps: Optional[int] = 30,
) -> List[List[Image.Image]]:
"""Generate images from `prompts` using `control_images` and `negative_prompt`.
Args:
images (List[Image.Image]): Input images.
prompts (List[str]): List of prompts.
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.
Returns:
List[List[Image.Image]]
"""
control_images = self.generate_control_images(images)
assert len(prompts) == len(
control_images
), "Number of prompts and input images must be equal."
if n_outputs > 1:
prompts = self._repeat(prompts, n=n_outputs)
control_images = self._repeat(control_images, n=n_outputs)
generator = [
torch.Generator(device="cuda").manual_seed(int(i))
for i in np.random.randint(len(prompts), size=len(prompts))
]
output = self.pipe(
prompts,
image=control_images,
negative_prompt=[negative_prompt] * len(prompts),
num_inference_steps=num_inference_steps,
generator=generator,
)
output_images = [
output.images[idx * n_outputs : (idx + 1) * n_outputs]
for idx in range(len(images))
]
return output_images