File size: 1,803 Bytes
b55d767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import annotations

import numpy as np
import torch


class XYMasking:
    def __init__(
        self,
        num_masks_x: int | tuple[int, int],
        num_masks_y: int | tuple[int, int],
        mask_x_length: int | tuple[int, int],
        mask_y_length: int | tuple[int, int],
        fill_value: int,
        p: float = 1.0,
    ):
        self.num_masks_x = num_masks_x
        self.num_masks_y = num_masks_y
        self.mask_x_length = mask_x_length
        self.mask_y_length = mask_y_length
        self.fill_value = fill_value
        self.p = p

    def __call__(self, img: torch.tensor) -> torch.tensor:
        if np.random.rand() < self.p:
            return img
        _, width, height = img.shape
        num_masks_x = (
            np.random.randint(*self.num_masks_x)
            if isinstance(self.num_masks_x, tuple)
            else self.num_masks_x
        )
        for _ in range(num_masks_x):
            mask_x_length = (
                np.random.randint(*self.mask_x_length)
                if isinstance(self.mask_x_length, tuple)
                else self.mask_x_length
            )
            x = np.random.randint(0, width - mask_x_length)
            img[:, :, x : x + mask_x_length] = self.fill_value

        num_masks_y = (
            np.random.randint(*self.num_masks_y)
            if isinstance(self.num_masks_y, tuple)
            else self.num_masks_y
        )
        for _ in range(num_masks_y):
            mask_y_length = (
                np.random.randint(*self.mask_y_length)
                if isinstance(self.mask_y_length, tuple)
                else self.mask_y_length
            )
            y = np.random.randint(0, height - mask_y_length)
            img[:, y : y + mask_y_length, :] = self.fill_value

        return img