# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union import cv2 import numpy as np import torch if TYPE_CHECKING: from matplotlib.backends.backend_agg import FigureCanvasAgg def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: """If the type of value is torch.Tensor, convert the value to np.ndarray. Args: value (np.ndarray, torch.Tensor): value. Returns: Any: value. """ if isinstance(value, torch.Tensor): value = value.detach().cpu().numpy() return value def value2list(value: Any, valid_type: Union[Type, Tuple[Type, ...]], expand_dim: int) -> List[Any]: """If the type of ``value`` is ``valid_type``, convert the value to list and expand to ``expand_dim``. Args: value (Any): value. valid_type (Union[Type, Tuple[Type, ...]): valid type. expand_dim (int): expand dim. Returns: List[Any]: value. """ if isinstance(value, valid_type): value = [value] * expand_dim return value def check_type(name: str, value: Any, valid_type: Union[Type, Tuple[Type, ...]]) -> None: """Check whether the type of value is in ``valid_type``. Args: name (str): value name. value (Any): value. valid_type (Type, Tuple[Type, ...]): expected type. """ if not isinstance(value, valid_type): raise TypeError(f'`{name}` should be {valid_type} ' f' but got {type(value)}') def check_length(name: str, value: Any, valid_length: int) -> None: """If type of the ``value`` is list, check whether its length is equal with or greater than ``valid_length``. Args: name (str): value name. value (Any): value. valid_length (int): expected length. """ if isinstance(value, list): if len(value) < valid_length: raise AssertionError( f'The length of {name} must equal with or ' f'greater than {valid_length}, but got {len(value)}') def check_type_and_length(name: str, value: Any, valid_type: Union[Type, Tuple[Type, ...]], valid_length: int) -> None: """Check whether the type of value is in ``valid_type``. If type of the ``value`` is list, check whether its length is equal with or greater than ``valid_length``. Args: value (Any): value. legal_type (Type, Tuple[Type, ...]): legal type. valid_length (int): expected length. Returns: List[Any]: value. """ check_type(name, value, valid_type) check_length(name, value, valid_length) def color_val_matplotlib( colors: Union[str, tuple, List[Union[str, tuple]]] ) -> Union[str, tuple, List[Union[str, tuple]]]: """Convert various input in RGB order to normalized RGB matplotlib color tuples, Args: colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs Returns: Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized floats indicating RGB channels. """ if isinstance(colors, str): return colors elif isinstance(colors, tuple): assert len(colors) == 3 for channel in colors: assert 0 <= channel <= 255 colors = [channel / 255 for channel in colors] return tuple(colors) elif isinstance(colors, list): colors = [ color_val_matplotlib(color) # type:ignore for color in colors ] return colors else: raise TypeError(f'Invalid type for color: {type(colors)}') def color_str2rgb(color: str) -> tuple: """Convert Matplotlib str color to an RGB color which range is 0 to 255, silently dropping the alpha channel. Args: color (str): Matplotlib color. Returns: tuple: RGB color. """ import matplotlib rgb_color: tuple = matplotlib.colors.to_rgb(color) rgb_color = tuple(int(c * 255) for c in rgb_color) return rgb_color def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], img: Optional[np.ndarray] = None, alpha: float = 0.5) -> np.ndarray: """Convert feat_map to heatmap and overlay on image, if image is not None. Args: feat_map (np.ndarray, torch.Tensor): The feat_map to convert with of shape (H, W), where H is the image height and W is the image width. img (np.ndarray, optional): The origin image. The format should be RGB. Defaults to None. alpha (float): The transparency of featmap. Defaults to 0.5. Returns: np.ndarray: heatmap """ assert feat_map.ndim == 2 or (feat_map.ndim == 3 and feat_map.shape[0] in [1, 3]) if isinstance(feat_map, torch.Tensor): feat_map = feat_map.detach().cpu().numpy() if feat_map.ndim == 3: feat_map = feat_map.transpose(1, 2, 0) norm_img = np.zeros(feat_map.shape) norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) norm_img = np.asarray(norm_img, dtype=np.uint8) heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) if img is not None: heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) return heat_img def wait_continue(figure, timeout: float = 0, continue_key: str = ' ') -> int: """Show the image and wait for the user's input. This implementation refers to https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py Args: timeout (float): If positive, continue after ``timeout`` seconds. Defaults to 0. continue_key (str): The key for users to continue. Defaults to the space key. Returns: int: If zero, means time out or the user pressed ``continue_key``, and if one, means the user closed the show figure. """ # noqa: E501 import matplotlib.pyplot as plt from matplotlib.backend_bases import CloseEvent is_inline = 'inline' in plt.get_backend() if is_inline: # If use inline backend, interactive input and timeout is no use. return 0 if figure.canvas.manager: # type: ignore # Ensure that the figure is shown figure.show() # type: ignore while True: # Connect the events to the handler function call. event = None def handler(ev): # Set external event variable nonlocal event # Qt backend may fire two events at the same time, # use a condition to avoid missing close event. event = ev if not isinstance(event, CloseEvent) else event figure.canvas.stop_event_loop() cids = [ figure.canvas.mpl_connect(name, handler) # type: ignore for name in ('key_press_event', 'close_event') ] try: figure.canvas.start_event_loop(timeout) # type: ignore finally: # Run even on exception like ctrl-c. # Disconnect the callbacks. for cid in cids: figure.canvas.mpl_disconnect(cid) # type: ignore if isinstance(event, CloseEvent): return 1 # Quit for close. elif event is None or event.key == continue_key: return 0 # Quit for continue. def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray: """Get RGB image from ``FigureCanvasAgg``. Args: canvas (FigureCanvasAgg): The canvas to get image. Returns: np.ndarray: the output of image in RGB. """ # noqa: E501 s, (width, height) = canvas.print_to_buffer() buffer = np.frombuffer(s, dtype='uint8') img_rgba = buffer.reshape(height, width, 4) rgb, alpha = np.split(img_rgba, [3], axis=2) return rgb.astype('uint8')