# -*- coding: utf-8 -*- # @Organization : Alibaba XR-Lab # @Author : Lingteng Qiu # @Email : 220019047@link.cuhk.edu.cn # @Time : 2024-08-30 20:50:27 # @Function : The class defines bbox, base-seg module import copy import cv2 import numpy as np import torch class BaseModel(object): """ Simple BaseModel """ def cuda(self): self.model.cuda() return self def cpu(self): self.model.cpu() return self def float(self): self.model.float() return self def to(self, device): self.model.to(device) return self def eval(self): self.model.eval() return self def train(self): self.model.train() return self def __call__(self, x): raise NotImplementedError def __repr__(self): return f"model: \n{self.model}" def get_dtype_string(arr): if arr.dtype == np.uint8: return "uint8" elif arr.dtype == np.float32: return "float32" elif arr.dtype == np.float64: return "float" else: return "unknow" class BaseSeg(BaseModel): def __init__(self): pass class Bbox: def __init__(self, box, mode="whwh"): assert len(box) == 4 assert mode in ["whwh", "xywh"] self.box = box self.mode = mode def to_xywh(self): if self.mode == "whwh": l, t, r, b = self.box center_x = (l + r) / 2 center_y = (t + b) / 2 width = r - l height = b - t return Bbox([center_x, center_y, width, height], mode="xywh") else: return self def to_whwh(self): if self.mode == "whwh": return self else: cx, cy, w, h = self.box l = cx - w // 2 t = cy - h // 2 r = cx + w - (w // 2) b = cy + h - (h // 2) return Bbox([l, t, r, b], mode="whwh") def area(self): box = self.to_xywh() _, __, w, h = box.box return w * h def get_box(self): return list(map(int, self.box)) def scale(self, scale, width, height): new_box = self.to_xywh() cx, cy, w, h = new_box.get_box() w = w * scale h = h * scale l = cx - w // 2 t = cy - h // 2 r = cx + w - (w // 2) b = cy + h - (h // 2) l = int(max(l, 0)) t = int(max(t, 0)) r = int(min(r, width)) b = int(min(b, height)) return Bbox([l, t, r, b], mode="whwh") def __repr__(self): box = self.to_whwh() l, t, r, b = box.box return f"BBox(left={l}, top={t}, right={r}, bottom={b})" class Image: """TODO need to debug""" TYPE_ORDER = ["uint8", "float32", "float"] ORDER = ["RGB", "BGR"] MODE = ["numpy"] def __init__(self, input, order="RGB", type_mode="uint8"): """Only support 3 Channel Image""" if isinstance(input, str): self.data = self.read_image(input, type_mode, order) else: self.data = self.get_image(input, type_mode, order) self.order = order self.type_mode = type_mode def get_image(self, input, type_mode, order): if isinstance(input, Image): return input.to_numpy(type_mode, order) elif isinstance(input, np.ndarray): self.data = input self.order = "RGB" # default self.type_mode = get_dtype_string(input) return self.to_numpy(type_mode, order) else: raise NotImplementedError def to_numpy(self, type_mode="uint8", order="RGB"): data = copy.deepcopy(self.data) if not order == self.order: return data[..., ::-1] # only support RGB -> BGR or BGR -> RGB if self.type_mode == type_mode: return data else: if self.type_mode == "float32": return (self.data / 255.0).astype(np.float32) elif self.type_mode == "float": return (self.data / 255.0).astype(np.float64) def to_tensor(self, order): data = self.to_numpy(type_mode="float32", order=order) return torch.from_numpy(data) def read_image( self, path, mode, order, ): """read an image file into various formats and color mode. Args: path (str): path to the image file. mode (Literal["float", "uint8", "pil", "torch", "tensor"], optional): returned image format. Defaults to "float". float: float32 numpy array, range [0, 1]; uint8: uint8 numpy array, range [0, 255]; pil: PIL image; torch/tensor: float32 torch tensor, range [0, 1]; order (Literal["RGB", "RGBA", "BGR", "BGRA"], optional): channel order. Defaults to "RGB". Note: By default this function will convert RGBA image to white-background RGB image. Use ``order="RGBA"`` to keep the alpha channel. Returns: Union[np.ndarray, PIL.Image, torch.Tensor]: the image array. """ if mode == "pil": return Image.open(path).convert(order) img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cvtColor if len(img.shape) == 3: # ignore if gray scale if order in ["RGB", "RGBA"]: if img.shape[-1] == 4: img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) elif img.shape[-1] == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # mix background if img.shape[-1] == 4 and "A" not in order: img = img.astype(np.float32) / 255 img = img[..., :3] * img[..., 3:] + (1 - img[..., 3:]) # mode if mode == "uint8": if img.dtype != np.uint8: img = (img * 255).astype(np.uint8) elif mode == "float": if img.dtype == np.uint8: img = img.astype(np.float32) / 255 else: raise ValueError(f"Unknown read_image mode {mode}") return img