pdiscoformer / utils /data_utils /dataset_utils.py
ananthu-aniraj's picture
add initial files
20239f9
raw
history blame
5.46 kB
from PIL import Image
from torch import Tensor
from typing import List, Optional
import numpy as np
import torchvision
import json
def load_json(path: str):
"""
Load json file from path and return the data
:param path: Path to the json file
:return:
data: Data in the json file
"""
with open(path, 'r') as f:
data = json.load(f)
return data
def save_json(data: dict, path: str):
"""
Save data to a json file
:param data: Data to be saved
:param path: Path to save the data
:return:
"""
with open(path, "w") as f:
json.dump(data, f)
def pil_loader(path):
"""
Load image from path using PIL
:param path: Path to the image
:return:
img: PIL Image
"""
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def get_dimensions(image: Tensor):
"""
Get the dimensions of the image
:param image: Tensor or PIL Image or np.ndarray
:return:
h: Height of the image
w: Width of the image
"""
if isinstance(image, Tensor):
_, h, w = image.shape
elif isinstance(image, np.ndarray):
h, w, _ = image.shape
elif isinstance(image, Image.Image):
w, h = image.size
else:
raise ValueError(f"Invalid image type: {type(image)}")
return h, w
def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None,
boxes: Optional[Tensor] = None, num_keypoints: int = 15):
"""
Calculate the center crop parameters for the bounding boxes and landmarks and update them
:param img: Image
:param output_size: Output size of the cropped image
:param parts: Locations of the landmarks of following format: <part_id> <x> <y> <visible>
:param boxes: Bounding boxes of the landmarks of following format: <image_id> <x> <y> <width> <height>
:param num_keypoints: Number of keypoints
:return:
cropped_img: Center cropped image
parts: Updated locations of the landmarks
boxes: Updated bounding boxes of the landmarks
"""
if isinstance(output_size, int):
output_size = (output_size, output_size)
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
output_size = output_size
else:
raise ValueError(f"Invalid output size: {output_size}")
crop_height, crop_width = output_size
image_height, image_width = get_dimensions(img)
img = torchvision.transforms.functional.center_crop(img, output_size)
crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size)
if parts is not None:
for j in range(num_keypoints):
# Skip if part is invisible
if parts[j][-1] == 0:
continue
parts[j][1] -= crop_left
parts[j][2] -= crop_top
# Skip if part is outside the crop
if parts[j][1] > crop_width or parts[j][2] > crop_height:
parts[j][-1] = 0
if parts[j][1] < 0 or parts[j][2] < 0:
parts[j][-1] = 0
parts[j][1] = min(crop_width, parts[j][1])
parts[j][2] = min(crop_height, parts[j][2])
parts[j][1] = max(0, parts[j][1])
parts[j][2] = max(0, parts[j][2])
if boxes is not None:
boxes[1] -= crop_left
boxes[2] -= crop_top
boxes[1] = max(0, boxes[1])
boxes[2] = max(0, boxes[2])
boxes[1] = min(crop_width, boxes[1])
boxes[2] = min(crop_height, boxes[2])
return img, parts, boxes
def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448):
"""
Get the parameters for center cropping the image
:param image_height: Height of the image
:param image_width: Width of the image
:param output_size: Output size of the cropped image
:return:
crop_top: Top coordinate of the cropped image
crop_left: Left coordinate of the cropped image
"""
if isinstance(output_size, int):
output_size = (output_size, output_size)
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:
output_size = output_size
else:
raise ValueError(f"Invalid output size: {output_size}")
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
crop_top, crop_left = padding_ltrb[1], padding_ltrb[0]
return crop_top, crop_left
if crop_width == image_width and crop_height == image_height:
crop_top = 0
crop_left = 0
return crop_top, crop_left
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop_top, crop_left