Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,576 Bytes
fb6a167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
from typing import Optional, Union, List, Tuple
from diffusers.pipelines import FluxPipeline
from PIL import Image, ImageFilter
import numpy as np
import cv2
from .pipeline_tools import encode_images
condition_dict = {
"depth": 0,
"canny": 1,
"subject": 4,
"coloring": 6,
"deblurring": 7,
"depth_pred": 8,
"fill": 9,
"sr": 10,
"cartoon": 11,
}
class Condition(object):
def __init__(
self,
condition_type: str,
raw_img: Union[Image.Image, torch.Tensor] = None,
condition: Union[Image.Image, torch.Tensor] = None,
mask=None,
position_delta=None,
position_scale=1.0,
) -> None:
self.condition_type = condition_type
assert raw_img is not None or condition is not None
if raw_img is not None:
self.condition = self.get_condition(condition_type, raw_img)
else:
self.condition = condition
self.position_delta = position_delta
self.position_scale = position_scale
# TODO: Add mask support
assert mask is None, "Mask not supported yet"
def get_condition(
self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
) -> Union[Image.Image, torch.Tensor]:
"""
Returns the condition image.
"""
if condition_type == "depth":
from transformers import pipeline
depth_pipe = pipeline(
task="depth-estimation",
model="LiheYoung/depth-anything-small-hf",
device="cuda",
)
source_image = raw_img.convert("RGB")
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
return condition_img
elif condition_type == "canny":
img = np.array(raw_img)
edges = cv2.Canny(img, 100, 200)
edges = Image.fromarray(edges).convert("RGB")
return edges
elif condition_type == "subject":
return raw_img
elif condition_type == "coloring":
return raw_img.convert("L").convert("RGB")
elif condition_type == "deblurring":
condition_image = (
raw_img.convert("RGB")
.filter(ImageFilter.GaussianBlur(10))
.convert("RGB")
)
return condition_image
elif condition_type == "fill":
return raw_img.convert("RGB")
elif condition_type == "cartoon":
return raw_img.convert("RGB")
return self.condition
@property
def type_id(self) -> int:
"""
Returns the type id of the condition.
"""
return condition_dict[self.condition_type]
@classmethod
def get_type_id(cls, condition_type: str) -> int:
"""
Returns the type id of the condition.
"""
return condition_dict[condition_type]
def encode(
self, pipe: FluxPipeline, empty: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Encodes the condition into tokens, ids and type_id.
"""
if self.condition_type in [
"depth",
"canny",
"subject",
"coloring",
"deblurring",
"depth_pred",
"fill",
"sr",
"cartoon",
]:
if empty:
# make the condition black
e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
e_condition = e_condition.convert("RGB")
tokens, ids = encode_images(pipe, e_condition)
else:
tokens, ids = encode_images(pipe, self.condition)
tokens, ids = encode_images(pipe, self.condition)
else:
raise NotImplementedError(
f"Condition type {self.condition_type} not implemented"
)
if self.position_delta is None and self.condition_type == "subject":
self.position_delta = [0, -self.condition.size[0] // 16]
if self.position_delta is not None:
ids[:, 1] += self.position_delta[0]
ids[:, 2] += self.position_delta[1]
if self.position_scale != 1.0:
scale_bias = (self.position_scale - 1.0) / 2
ids[:, 1] *= self.position_scale
ids[:, 2] *= self.position_scale
ids[:, 1] += scale_bias
ids[:, 2] += scale_bias
type_id = torch.ones_like(ids[:, :1]) * self.type_id
return tokens, ids, type_id
|