Bingsu commited on
Commit
1b21c84
·
1 Parent(s): 9740b1e

Delete files

Browse files
Files changed (7) hide show
  1. asdff/__init__.py +0 -10
  2. asdff/__version__.py +0 -1
  3. asdff/base.py +0 -152
  4. asdff/sd.py +0 -51
  5. asdff/utils.py +0 -70
  6. asdff/yolo.py +0 -80
  7. pipeline.py +0 -1
asdff/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- from .__version__ import __version__
2
- from .sd import AdCnPipeline, AdPipeline
3
- from .yolo import yolo_detector
4
-
5
- __all__ = [
6
- "AdPipeline",
7
- "AdCnPipeline",
8
- "yolo_detector",
9
- "__version__",
10
- ]
 
 
 
 
 
 
 
 
 
 
 
asdff/__version__.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "0.2.0"
 
 
asdff/base.py DELETED
@@ -1,152 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import inspect
4
- from abc import ABC, abstractmethod
5
- from typing import Any, Callable, Iterable, List, Mapping, Optional
6
-
7
- from diffusers.utils import logging
8
- from PIL import Image
9
-
10
- from asdff.utils import (
11
- ADOutput,
12
- bbox_padding,
13
- composite,
14
- mask_dilate,
15
- mask_gaussian_blur,
16
- )
17
- from asdff.yolo import yolo_detector
18
-
19
- logger = logging.get_logger("diffusers")
20
-
21
-
22
- DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]]
23
-
24
-
25
- def ordinal(n: int) -> str:
26
- d = {1: "st", 2: "nd", 3: "rd"}
27
- return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th"))
28
-
29
-
30
- class AdPipelineBase(ABC):
31
- @property
32
- @abstractmethod
33
- def inpaint_pipeline(self) -> Callable:
34
- raise NotImplementedError
35
-
36
- @property
37
- @abstractmethod
38
- def txt2img_class(self) -> type:
39
- raise NotImplementedError
40
-
41
- def __call__( # noqa: C901
42
- self,
43
- common: Mapping[str, Any] | None = None,
44
- txt2img_only: Mapping[str, Any] | None = None,
45
- inpaint_only: Mapping[str, Any] | None = None,
46
- images: Image.Image | Iterable[Image.Image] | None = None,
47
- detectors: DetectorType | Iterable[DetectorType] | None = None,
48
- mask_dilation: int = 4,
49
- mask_blur: int = 4,
50
- mask_padding: int = 32,
51
- ):
52
- if common is None:
53
- common = {}
54
- if txt2img_only is None:
55
- txt2img_only = {}
56
- if inpaint_only is None:
57
- inpaint_only = {}
58
- if "strength" not in inpaint_only:
59
- inpaint_only = {**inpaint_only, "strength": 0.4}
60
-
61
- if detectors is None:
62
- detectors = [self.default_detector]
63
- elif not isinstance(detectors, Iterable):
64
- detectors = [detectors]
65
-
66
- if images and txt2img_only:
67
- logger.warning(
68
- "Both `images` and `txt2img_only` are specified. if `images` is specified, `txt2img_only` is ignored."
69
- )
70
-
71
- if images is None:
72
- txt2img_args = self._get_txt2img_args(common, txt2img_only)
73
- txt2img_output = self.txt2img_class.__call__(self, **txt2img_args)
74
- txt2img_images: list[Image.Image] = txt2img_output[0]
75
- else:
76
- if not isinstance(images, Iterable):
77
- txt2img_images = [images]
78
- else:
79
- txt2img_images = images
80
-
81
- init_images = []
82
- final_images = []
83
-
84
- for i, init_image in enumerate(txt2img_images):
85
- init_images.append(init_image.copy())
86
- final_image = None
87
-
88
- for j, detector in enumerate(detectors):
89
- masks = detector(init_image)
90
- if masks is None:
91
- logger.info(
92
- f"No object detected on {ordinal(i + 1)} image with {ordinal(j + 1)} detector."
93
- )
94
- continue
95
-
96
- for k, mask in enumerate(masks):
97
- mask = mask.convert("L")
98
- mask = mask_dilate(mask, mask_dilation)
99
- bbox = mask.getbbox()
100
- if bbox is None:
101
- logger.info(f"No object in {ordinal(k + 1)} mask.")
102
- continue
103
- mask = mask_gaussian_blur(mask, mask_blur)
104
- bbox_padded = bbox_padding(bbox, init_image.size, mask_padding)
105
-
106
- crop_image = init_image.crop(bbox_padded)
107
- crop_mask = mask.crop(bbox_padded)
108
-
109
- inpaint_args = self._get_inpaint_args(common, inpaint_only)
110
- inpaint_args["image"] = crop_image
111
- inpaint_args["mask_image"] = crop_mask
112
- inpaint_output = self.inpaint_pipeline(**inpaint_args)
113
- inpaint_image: Image.Image = inpaint_output[0][0]
114
- final_image = composite(
115
- init=init_image,
116
- mask=mask,
117
- gen=inpaint_image,
118
- bbox_padded=bbox_padded,
119
- )
120
- init_image = final_image
121
-
122
- if final_image is not None:
123
- final_images.append(final_image)
124
-
125
- return ADOutput(images=final_images, init_images=init_images)
126
-
127
- @property
128
- def default_detector(self) -> Callable[..., list[Image.Image] | None]:
129
- return yolo_detector
130
-
131
- def _get_txt2img_args(
132
- self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any]
133
- ):
134
- return {**common, **txt2img_only, "output_type": "pil"}
135
-
136
- def _get_inpaint_args(
137
- self, common: Mapping[str, Any], inpaint_only: Mapping[str, Any]
138
- ):
139
- common = dict(common)
140
- sig = inspect.signature(self.inpaint_pipeline)
141
- if (
142
- "control_image" in sig.parameters
143
- and "control_image" not in common
144
- and "image" in common
145
- ):
146
- common["control_image"] = common.pop("image")
147
- return {
148
- **common,
149
- **inpaint_only,
150
- "num_images_per_prompt": 1,
151
- "output_type": "pil",
152
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asdff/sd.py DELETED
@@ -1,51 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from functools import cached_property
4
-
5
- from diffusers import (
6
- StableDiffusionControlNetInpaintPipeline,
7
- StableDiffusionControlNetPipeline,
8
- StableDiffusionInpaintPipeline,
9
- StableDiffusionPipeline,
10
- )
11
-
12
- from asdff.base import AdPipelineBase
13
-
14
-
15
- class AdPipeline(AdPipelineBase, StableDiffusionPipeline):
16
- @cached_property
17
- def inpaint_pipeline(self):
18
- return StableDiffusionInpaintPipeline(
19
- vae=self.vae,
20
- text_encoder=self.text_encoder,
21
- tokenizer=self.tokenizer,
22
- unet=self.unet,
23
- scheduler=self.scheduler,
24
- safety_checker=self.safety_checker,
25
- feature_extractor=self.feature_extractor,
26
- requires_safety_checker=self.config.requires_safety_checker,
27
- )
28
-
29
- @property
30
- def txt2img_class(self):
31
- return StableDiffusionPipeline
32
-
33
-
34
- class AdCnPipeline(AdPipelineBase, StableDiffusionControlNetPipeline):
35
- @cached_property
36
- def inpaint_pipeline(self):
37
- return StableDiffusionControlNetInpaintPipeline(
38
- vae=self.vae,
39
- text_encoder=self.text_encoder,
40
- tokenizer=self.tokenizer,
41
- unet=self.unet,
42
- controlnet=self.controlnet,
43
- scheduler=self.scheduler,
44
- safety_checker=self.safety_checker,
45
- feature_extractor=self.feature_extractor,
46
- requires_safety_checker=self.config.requires_safety_checker,
47
- )
48
-
49
- @property
50
- def txt2img_class(self):
51
- return StableDiffusionControlNetPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asdff/utils.py DELETED
@@ -1,70 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
-
5
- import cv2
6
- import numpy as np
7
- from diffusers.utils import BaseOutput
8
- from PIL import Image, ImageFilter, ImageOps
9
-
10
-
11
- @dataclass
12
- class ADOutput(BaseOutput):
13
- images: list[Image.Image]
14
- init_images: list[Image.Image]
15
-
16
-
17
- def mask_dilate(image: Image.Image, value: int = 4) -> Image.Image:
18
- if value <= 0:
19
- return image
20
-
21
- arr = np.array(image)
22
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
23
- dilated = cv2.dilate(arr, kernel, iterations=1)
24
- return Image.fromarray(dilated)
25
-
26
-
27
- def mask_gaussian_blur(image: Image.Image, value: int = 4) -> Image.Image:
28
- if value <= 0:
29
- return image
30
-
31
- blur = ImageFilter.GaussianBlur(value)
32
- return image.filter(blur)
33
-
34
-
35
- def bbox_padding(
36
- bbox: tuple[int, int, int, int], image_size: tuple[int, int], value: int = 32
37
- ) -> tuple[int, int, int, int]:
38
- if value <= 0:
39
- return bbox
40
-
41
- arr = np.array(bbox).reshape(2, 2)
42
- arr[0] -= value
43
- arr[1] += value
44
- arr = np.clip(arr, (0, 0), image_size)
45
- return tuple(arr.flatten())
46
-
47
-
48
- def composite(
49
- init: Image.Image,
50
- mask: Image.Image,
51
- gen: Image.Image,
52
- bbox_padded: tuple[int, int, int, int],
53
- ) -> Image.Image:
54
- img_masked = Image.new("RGBa", init.size)
55
- img_masked.paste(
56
- init.convert("RGBA").convert("RGBa"),
57
- mask=ImageOps.invert(mask),
58
- )
59
- img_masked = img_masked.convert("RGBA")
60
-
61
- size = (
62
- bbox_padded[2] - bbox_padded[0],
63
- bbox_padded[3] - bbox_padded[1],
64
- )
65
- resized = gen.resize(size)
66
-
67
- output = Image.new("RGBA", init.size)
68
- output.paste(resized, bbox_padded)
69
- output.alpha_composite(img_masked)
70
- return output.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asdff/yolo.py DELETED
@@ -1,80 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
-
5
- import numpy as np
6
- import torch
7
- from huggingface_hub import hf_hub_download
8
- from PIL import Image, ImageDraw
9
- from torchvision.transforms.functional import to_pil_image
10
-
11
- try:
12
- from ultralytics import YOLO
13
- except ModuleNotFoundError:
14
- print("Please install ultralytics using `pip install ultralytics`")
15
- raise
16
-
17
-
18
- def create_mask_from_bbox(
19
- bboxes: np.ndarray, shape: tuple[int, int]
20
- ) -> list[Image.Image]:
21
- """
22
- Parameters
23
- ----------
24
- bboxes: list[list[float]]
25
- list of [x1, y1, x2, y2]
26
- bounding boxes
27
- shape: tuple[int, int]
28
- shape of the image (width, height)
29
-
30
- Returns
31
- -------
32
- masks: list[Image.Image]
33
- A list of masks
34
-
35
- """
36
- masks = []
37
- for bbox in bboxes:
38
- mask = Image.new("L", shape, "black")
39
- mask_draw = ImageDraw.Draw(mask)
40
- mask_draw.rectangle(bbox, fill="white")
41
- masks.append(mask)
42
- return masks
43
-
44
-
45
- def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
46
- """
47
- Parameters
48
- ----------
49
- masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
50
- The device can be CUDA, but `to_pil_image` takes care of that.
51
-
52
- shape: tuple[int, int]
53
- (width, height) of the original image
54
-
55
- Returns
56
- -------
57
- images: list[Image.Image]
58
- """
59
- n = masks.shape[0]
60
- return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
61
-
62
-
63
- def yolo_detector(
64
- image: Image.Image, model_path: str | Path | None = None, confidence: float = 0.3
65
- ) -> list[Image.Image] | None:
66
- if not model_path:
67
- model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt")
68
- model = YOLO(model_path)
69
- pred = model(image, conf=confidence)
70
-
71
- bboxes = pred[0].boxes.xyxy.cpu().numpy()
72
- if bboxes.size == 0:
73
- return None
74
-
75
- if pred[0].masks is None:
76
- masks = create_mask_from_bbox(bboxes, image.size)
77
- else:
78
- masks = mask_to_pil(pred[0].masks.data, image.size)
79
-
80
- return masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline.py DELETED
@@ -1 +0,0 @@
1
- from asdff import AdCnPipeline # noqa: F401