| """ | |
| Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
| Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
| License: Apache License 2.0 | |
| """ | |
| from PIL import Image | |
| from carvekit.trimap.cv_gen import CV2TrimapGenerator | |
| from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion | |
| class TrimapGenerator(CV2TrimapGenerator): | |
| def __init__( | |
| self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5 | |
| ): | |
| """ | |
| Initialize a TrimapGenerator instance | |
| Args: | |
| prob_threshold: Probability threshold at which the | |
| prob_filter and prob_as_unknown_area operations will be applied | |
| kernel_size: The size of the offset from the object mask | |
| in pixels when an unknown area is detected in the trimap | |
| erosion_iters: The number of iterations of erosion that | |
| the object's mask will be subjected to before forming an unknown area | |
| """ | |
| super().__init__(kernel_size, erosion_iters=0) | |
| self.prob_threshold = prob_threshold | |
| self.__erosion_iters = erosion_iters | |
| def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image: | |
| """ | |
| Generates trimap based on predicted object mask to refine object mask borders. | |
| Based on cv2 erosion algorithm and additional prob. filters. | |
| Args: | |
| original_image: Original image | |
| mask: Predicted object mask | |
| Returns: | |
| Generated trimap for image. | |
| """ | |
| filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold) | |
| trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask) | |
| new_trimap = prob_as_unknown_area( | |
| trimap=trimap, mask=mask, prob_threshold=self.prob_threshold | |
| ) | |
| new_trimap = post_erosion(new_trimap, self.__erosion_iters) | |
| return new_trimap | |