Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from typing import Optional, Union, List, Tuple | |
| from diffusers.pipelines import FluxPipeline | |
| from PIL import Image, ImageFilter | |
| import numpy as np | |
| import cv2 | |
| from .pipeline_tools import encode_vae_images | |
| condition_dict = { | |
| "depth": 0, | |
| "canny": 1, | |
| "subject": 4, | |
| "coloring": 6, | |
| "deblurring": 7, | |
| "depth_pred": 8, | |
| "fill": 9, | |
| "sr": 10, | |
| } | |
| class Condition(object): | |
| def __init__( | |
| self, | |
| condition_type: str, | |
| raw_img: Union[Image.Image, torch.Tensor] = None, | |
| condition: Union[Image.Image, torch.Tensor] = None, | |
| mask=None, | |
| position_delta=None, | |
| ) -> None: | |
| self.condition_type = condition_type | |
| assert raw_img is not None or condition is not None | |
| if raw_img is not None: | |
| self.condition = self.get_condition(condition_type, raw_img) | |
| else: | |
| self.condition = condition | |
| self.position_delta = position_delta | |
| # TODO: Add mask support | |
| assert mask is None, "Mask not supported yet" | |
| def get_condition( | |
| self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] | |
| ) -> Union[Image.Image, torch.Tensor]: | |
| """ | |
| Returns the condition image. | |
| """ | |
| if condition_type == "depth": | |
| from transformers import pipeline | |
| depth_pipe = pipeline( | |
| task="depth-estimation", | |
| model="LiheYoung/depth-anything-small-hf", | |
| device="cuda", | |
| ) | |
| source_image = raw_img.convert("RGB") | |
| condition_img = depth_pipe(source_image)["depth"].convert("RGB") | |
| return condition_img | |
| elif condition_type == "canny": | |
| img = np.array(raw_img) | |
| edges = cv2.Canny(img, 100, 200) | |
| edges = Image.fromarray(edges).convert("RGB") | |
| return edges | |
| elif condition_type == "subject": | |
| return raw_img | |
| elif condition_type == "coloring": | |
| return raw_img.convert("L").convert("RGB") | |
| elif condition_type == "deblurring": | |
| condition_image = ( | |
| raw_img.convert("RGB") | |
| .filter(ImageFilter.GaussianBlur(10)) | |
| .convert("RGB") | |
| ) | |
| return condition_image | |
| elif condition_type == "fill": | |
| return raw_img.convert("RGB") | |
| return self.condition | |
| def type_id(self) -> int: | |
| """ | |
| Returns the type id of the condition. | |
| """ | |
| return condition_dict[self.condition_type] | |
| def get_type_id(cls, condition_type: str) -> int: | |
| """ | |
| Returns the type id of the condition. | |
| """ | |
| return condition_dict[condition_type] | |
| def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
| """ | |
| Encodes the condition into tokens, ids and type_id. | |
| """ | |
| if self.condition_type in [ | |
| "depth", | |
| "canny", | |
| "subject", | |
| "coloring", | |
| "deblurring", | |
| "depth_pred", | |
| "fill", | |
| "sr", | |
| ]: | |
| tokens, ids = encode_vae_images(pipe, self.condition) | |
| else: | |
| raise NotImplementedError( | |
| f"Condition type {self.condition_type} not implemented" | |
| ) | |
| if self.position_delta is None and self.condition_type == "subject": | |
| self.position_delta = [0, -self.condition.size[0] // 16] | |
| if self.position_delta is not None: | |
| ids[:, 1] += self.position_delta[0] | |
| ids[:, 2] += self.position_delta[1] | |
| print(f"[Condition.encode] position_delta={self.position_delta}") | |
| type_id = torch.ones_like(ids[:, :1]) * self.type_id | |
| return tokens, ids, type_id | |