File size: 5,456 Bytes
20239f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from PIL import Image
from torch import Tensor
from typing import List, Optional
import numpy as np
import torchvision
import json


def load_json(path: str):
    """
    Load json file from path and return the data
    :param path: Path to the json file
    :return:
    data: Data in the json file
    """
    with open(path, 'r') as f:
        data = json.load(f)
    return data


def save_json(data: dict, path: str):
    """
    Save data to a json file
    :param data: Data to be saved
    :param path: Path to save the data
    :return:
    """
    with open(path, "w") as f:
        json.dump(data, f)


def pil_loader(path):
    """
    Load image from path using PIL
    :param path: Path to the image
    :return:
    img: PIL Image
    """
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def get_dimensions(image: Tensor):
    """
    Get the dimensions of the image
    :param image: Tensor or PIL Image or np.ndarray
    :return:
    h: Height of the image
    w: Width of the image
    """
    if isinstance(image, Tensor):
        _, h, w = image.shape
    elif isinstance(image, np.ndarray):
        h, w, _ = image.shape
    elif isinstance(image, Image.Image):
        w, h = image.size
    else:
        raise ValueError(f"Invalid image type: {type(image)}")
    return h, w


def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
                          boxes: Optional[Tensor] = None, num_keypoints: int = 15):
    """
    Calculate the center crop parameters for the bounding boxes and landmarks and update them
    :param img: Image
    :param output_size: Output size of the cropped image
    :param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
    :param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
    :param num_keypoints: Number of keypoints
    :return:
    cropped_img: Center cropped image
    parts: Updated locations of the landmarks
    boxes: Updated bounding boxes of the landmarks
    """
    if isinstance(output_size, int):
        output_size = (output_size, output_size)
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
        output_size = output_size
    else:
        raise ValueError(f"Invalid output size: {output_size}")

    crop_height, crop_width = output_size
    image_height, image_width = get_dimensions(img)
    img = torchvision.transforms.functional.center_crop(img, output_size)

    crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)

    if parts is not None:
        for j in range(num_keypoints):
            # Skip if part is invisible
            if parts[j][-1] == 0:
                continue
            parts[j][1] -= crop_left
            parts[j][2] -= crop_top

            # Skip if part is outside the crop
            if parts[j][1] > crop_width or parts[j][2] > crop_height:
                parts[j][-1] = 0
            if parts[j][1] < 0 or parts[j][2] < 0:
                parts[j][-1] = 0

            parts[j][1] = min(crop_width, parts[j][1])
            parts[j][2] = min(crop_height, parts[j][2])
            parts[j][1] = max(0, parts[j][1])
            parts[j][2] = max(0, parts[j][2])

    if boxes is not None:
        boxes[1] -= crop_left
        boxes[2] -= crop_top
        boxes[1] = max(0, boxes[1])
        boxes[2] = max(0, boxes[2])
        boxes[1] = min(crop_width, boxes[1])
        boxes[2] = min(crop_height, boxes[2])

    return img, parts, boxes


def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
    """
    Get the parameters for center cropping the image
    :param image_height: Height of the image
    :param image_width: Width of the image
    :param output_size: Output size of the cropped image
    :return:
    crop_top: Top coordinate of the cropped image
    crop_left: Left coordinate of the cropped image
    """
    if isinstance(output_size, int):
        output_size = (output_size, output_size)
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
        output_size = (output_size[0], output_size[0])
    elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
        output_size = output_size
    else:
        raise ValueError(f"Invalid output size: {output_size}")

    crop_height, crop_width = output_size

    if crop_width > image_width or crop_height > image_height:
        padding_ltrb = [
            (crop_width - image_width) // 2 if crop_width > image_width else 0,
            (crop_height - image_height) // 2 if crop_height > image_height else 0,
            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
        ]
        crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
        return crop_top, crop_left

    if crop_width == image_width and crop_height == image_height:
        crop_top = 0
        crop_left = 0
        return crop_top, crop_left

    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))

    return crop_top, crop_left