akshit-g's picture
add : files
d3cd5c1
from typing import List, Tuple
import numpy as np
from PIL import Image
def im_resize(
image: Image.Image,
size: Tuple[int, int],
resample: int = Image.Resampling.BICUBIC,
) -> Image.Image:
return image.resize(size, resample=resample)
def normalize(
image: np.ndarray,
mean: List[float] = [0.5, 0.5, 0.5],
std: List[float] = [0.5, 0.5, 0.5],
) -> np.ndarray:
"""
Normalize an image array.
"""
return (image - np.array(mean)) / np.array(std)
def create_patches(image: Image.Image, image_patch_size=378) -> np.ndarray:
"""
Split the given image into a variable number of patches depending upon its
resolution.
"""
# Start off with the global patch.
patches = [im_resize(image, (image_patch_size, image_patch_size))]
# Find the closest resolution template.
res_templates = [(1, 2), (2, 1), (2, 2)]
im_width, im_height = image.size
max_dim = max(im_width, im_height)
if max_dim < image_patch_size * 1.4:
# If the image is already small, we just do a single patch that is a
# duplicate of the global patch. This creates a small amount of
# redundant computation now, but it is simpler and future-proofs us
# if/when we condition the vision encoder on the patch type.
patches.append(patches[0])
else:
aspect_ratio = im_width / im_height
res_template = min(
res_templates, key=lambda size: abs((size[1] / size[0]) - aspect_ratio)
)
# TODO: Actually implement patching... just going to put in the global
# patch for now to make progress on other aspects.
patches.append(patches[0])
return np.stack(
[
normalize(
(np.array(patch_img) / 255.0),
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
).transpose(2, 0, 1)
for patch_img in patches
],
dtype=np.float16,
)