VITA-Audio / vita_audio /data /processor /image_processor.py
shenyunhang's picture
-a
82f2cfa
import math
import os
import numpy as np
import torch
from PIL import Image
import decord
import natsort
from vita_audio.constants import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
)
class ImageProcessor:
def __init__(
self,
process_type,
image_size=448,
normalize_type="imagenet",
min_patch_grid=1,
max_patch_grid=6,
):
self.process_type = process_type
self.image_size = image_size
if normalize_type == "imagenet":
MEAN, STD = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
elif normalize_type == "clip":
MEAN, STD = OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
elif normalize_type == "siglip":
MEAN, STD = IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
else:
raise NotImplementedError
self.mean = MEAN
self.std = STD
self.patch_size = image_size
self.min_patch_grid = min_patch_grid
self.max_patch_grid = max_patch_grid
if self.process_type == "anyres":
self.grid_pinpoints = [
(i, j)
for i in range(min_patch_grid, max_patch_grid + 1)
for j in range(min_patch_grid, max_patch_grid + 1)
]
self.possible_resolutions = [
[dim * self.patch_size for dim in pair] for pair in self.grid_pinpoints
]
print(f"grid_pinpoints {self.grid_pinpoints}")
print(f"possible_resolutions {self.possible_resolutions}")
if self.process_type == "dynamic":
max_num = self.max_patch_grid
min_num = self.min_patch_grid
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
self.target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
self.possible_resolutions = [
[dim * self.patch_size for dim in pair] for pair in self.target_ratios
]
print(f"target_ratios {self.target_ratios}")
print(f"possible_resolutions {self.possible_resolutions}")
def get_frame_paths(self, frame_root, num_frames=8):
os.makedirs(frame_root, exist_ok=True)
self.frame_tmpl = "frame-{}-of-{}.jpg"
return [
os.path.join(frame_root, self.frame_tmpl.format(i, num_frames))
for i in range(1, num_frames + 1)
]
def save_video_frames(self, vid_path, max_fps=1, num_frames=8):
vid = decord.VideoReader(vid_path, num_threads=1)
step_size = len(vid) / (num_frames + 1)
# step_size = max(1, step_size)
fps = vid.get_avg_fps()
step_size = max(fps / max_fps, step_size)
# indices = [int(i * step_size) for i in range(1, num_frames + 1)]
indices = [int(i * step_size) for i in range(0, num_frames)]
indices = [i for i in indices if i < len(vid)]
num_frames = len(indices)
frame_paths = self.get_frame_paths(vid_path + ".saved_frames", num_frames)
flag = np.all([os.path.exists(p) for p in frame_paths])
if flag:
return frame_paths
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
for im, pth in zip(images, frame_paths):
# if not os.path.exists(pth):
# im.save(pth)
im.save(pth)
# print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}")
return frame_paths
def get_video_frames(self, vid_path, max_fps=1, num_frames=8):
vid = decord.VideoReader(vid_path, num_threads=1)
step_size = len(vid) / (num_frames + 1)
# step_size = max(1, step_size)
fps = vid.get_avg_fps()
step_size = max(fps / max_fps, step_size)
# indices = [int(i * step_size) for i in range(1, num_frames + 1)]
indices = [int(i * step_size) for i in range(0, num_frames)]
indices = [i for i in indices if i < len(vid)]
images = [vid[i].asnumpy() for i in indices]
images = [Image.fromarray(arr) for arr in images]
# print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}")
return images
def process_video(self, video_file_or_dir, max_num_frame=8, max_fps=1):
if os.path.isdir(video_file_or_dir):
all_filepath = []
for root, dirs, files in os.walk(video_file_or_dir):
for filename in files:
if (
filename.endswith("png")
or filename.endswith("jpeg")
or filename.endswith("jpg")
):
filepath = os.path.join(root, filename)
all_filepath.append(filepath)
if len(all_filepath) == 0:
return None
# all_filepath.sort()
all_filepath = natsort.natsorted(all_filepath)
total_frame = len(all_filepath)
if "ShareGPTVideo" in video_file_or_dir:
fps = 2
else:
fps = 1
target_frame = int(min(total_frame / fps * max_fps, max_num_frame))
index = [int(1.0 * total_frame / target_frame) * x for x in range(target_frame)]
selected_filepath = [all_filepath[x] for x in index]
img_or_path_list = selected_filepath
# print(f"process_video {img_or_path_list}")
elif os.path.isfile(video_file_or_dir):
# frame_paths = self.save_video_frames(
# video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps
# )
# img_or_path_list = frame_paths
img_or_path_list = self.get_video_frames(
video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps
)
else:
# print(f"FileNotFoundError {video_file_or_dir}")
raise NotImplementedError
return self.process_images(img_or_path_list), img_or_path_list
def process_images(self, img_or_path_list):
if isinstance(img_or_path_list[0], str):
images = [Image.open(x).convert("RGB") for x in img_or_path_list]
elif isinstance(img_or_path_list[0], Image.Image):
images = [x.convert("RGB") for x in img_or_path_list]
else:
images = img_or_path_list
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image_tensor = torch.ones([len(images), 3, self.image_size, self.image_size])
for i, image in enumerate(images):
image = expand2square(image, tuple(int(x * 255) for x in self.mean))
image = image.resize(
(self.image_size, self.image_size), resample=Image.Resampling.BICUBIC
)
image = np.array(image, dtype=np.float32)
image = image * 1.0 / 255.0
mean = np.array(self.mean, dtype=image.dtype)
std = np.array(self.std, dtype=image.dtype)
image = (image - mean) / std
image = torch.tensor(image, dtype=torch.float32)
image = image.permute(2, 0, 1)
image_tensor[i] = image
return image_tensor
def process_images_with_subpatch(self, img_or_path):
if self.process_type == "anyres":
return self.process_anyres(img_or_path)
if self.process_type == "dynamic":
return self.process_dynamic(img_or_path)
if isinstance(img_or_path, str):
image = Image.open(img_or_path).convert("RGB")
elif isinstance(img_or_path, Image.Image):
image = img_or_path.convert("RGB")
else:
image = img_or_path
return self.process_images([images])
def process_anyres(self, img_or_path):
if isinstance(img_or_path, str):
image = Image.open(img_or_path).convert("RGB")
elif isinstance(img_or_path, Image.Image):
image = img_or_path.convert("RGB")
else:
image = img_or_path
best_resolution = select_best_resolution(image.size, self.possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, self.patch_size)
if best_resolution == (self.patch_size, self.patch_size):
image_patches = [image]
else:
image_patches = [image] + patches
image_patches = self.process_images(image_patches)
# print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}")
return image_patches, best_resolution
def process_dynamic(self, img_or_path):
if isinstance(img_or_path, str):
image = Image.open(img_or_path).convert("RGB")
elif isinstance(img_or_path, Image.Image):
image = img_or_path.convert("RGB")
else:
image = img_or_path
image_patches, best_resolution = dynamic_preprocess(
image,
min_num=self.min_patch_grid,
max_num=self.max_patch_grid,
image_size=self.patch_size,
use_thumbnail=True,
)
image_patches = self.process_images(image_patches)
# print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}")
return image_patches, best_resolution
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
# Calculate effective and wasted resolutions
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
# Determine which dimension (width or height) to fill
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
# Width will be filled completely
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
# Height will be filled completely
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new image with the target size and paste the resized image onto it
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
# processed_images.append(thumbnail_img)
processed_images = [
thumbnail_img,
] + processed_images
return processed_images, (target_width, target_height)