Spaces:
Runtime error
Runtime error
from transformers import CLIPImageProcessor | |
from transformers.image_processing_utils import BatchFeature, get_size_dict | |
from transformers.image_transforms import get_resize_output_image_size | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
class VideoFramesProcessor(CLIPImageProcessor): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def preprocess(self, images, **kwargs): | |
if not isinstance(images, np.ndarray): | |
return super().preprocess(images=images, **kwargs) | |
do_resize = kwargs.get('do_resize', self.do_resize) | |
size = kwargs.get('size', self.size) | |
size = get_size_dict(size, param_name="size", default_to_square=False) | |
do_center_crop = kwargs.get('do_center_crop', self.do_center_crop) | |
crop_size = kwargs.get('crop_size', self.crop_size) | |
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) | |
do_rescale = kwargs.get('do_rescale', self.do_rescale) | |
rescale_factor = kwargs.get('rescale_factor', self.rescale_factor) | |
do_normalize = kwargs.get('do_normalize', self.do_normalize) | |
image_mean = kwargs.get('image_mean', self.image_mean) | |
image_std = kwargs.get('image_std', self.image_std) | |
return_tensors = kwargs.get('return_tensors', None) | |
def resize(images, output_size): | |
images = images.permute((0, 3, 1, 2)) | |
images = F.interpolate(images, size=output_size, mode='bicubic') | |
images = images.permute((0, 2, 3, 1)) | |
return images | |
def center_crop(images, crop_size): | |
crop_width, crop_height = crop_size["width"], crop_size["height"] | |
img_width, img_height = images.shape[1:3] | |
x = (img_width - crop_width) // 2 | |
y = (img_height - crop_height) // 2 | |
images = images[:, x:x+crop_width, y:y+crop_height] | |
return images | |
def rescale(images, rescale_factor): | |
images = images * rescale_factor | |
return images | |
def normalize(images, mean, std): | |
mean = torch.tensor(mean) | |
std = torch.tensor(std) | |
images = (images - mean) / std | |
return images | |
images = torch.from_numpy(images).float() | |
if do_resize: | |
output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False) | |
images = resize(images, output_size) | |
if do_center_crop: | |
images = center_crop(images, crop_size) | |
if do_rescale: | |
images = rescale(images, rescale_factor) | |
if do_normalize: | |
images = normalize(images, image_mean, image_std) | |
images = images.permute((0, 3, 1, 2)) | |
data = {"pixel_values": images} | |
return BatchFeature(data=data, tensor_type=return_tensors) | |