Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Union | |
from PIL import Image | |
import torch | |
from transformers.feature_extraction_utils import BatchFeature | |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array | |
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order | |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput | |
from transformers.utils import logging | |
from transformers import Qwen2VLImageProcessor | |
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize | |
logger = logging.get_logger(__name__) | |
class TarsierProcessorKwargs(ProcessingKwargs, total=False): | |
_defaults = { | |
"text_kwargs": {}, | |
"images_kwargs": {}, | |
} | |
class TarsierProcessor(ProcessorMixin): | |
attributes = ["image_processor", "tokenizer"] | |
valid_kwargs = ["chat_template", "image_token", "patch_size", "merge_size", "temporal_patch_size", "max_seq_len"] | |
image_processor_class = "AutoImageProcessor" | |
tokenizer_class = "AutoTokenizer" | |
def __init__( | |
self, | |
image_processor=None, | |
tokenizer=None, | |
chat_template=None, | |
image_token="<image>", | |
patch_size=None, | |
merge_size=1, | |
temporal_patch_size=1, | |
max_seq_len=8192, | |
**kwargs, | |
) -> None: | |
self.image_token = image_token | |
self.patch_size = patch_size | |
self.merge_size = merge_size | |
self.temporal_patch_size = temporal_patch_size | |
self.max_seq_len = max_seq_len | |
self.max_pixels_per_sample = 128 * 384 * 384 | |
super().__init__(image_processor, tokenizer, chat_template=chat_template) | |
def __call__( | |
self, | |
messages, | |
image_processing_config=None, | |
is_training=True, | |
) -> torch.Tensor: | |
output_kwargs = self._merge_kwargs( | |
TarsierProcessorKwargs, | |
tokenizer_init_kwargs=self.tokenizer.init_kwargs, | |
) | |
# γεΎηε€ηγ | |
pixel_values, image_grid_thw = [], [] | |
num_images = 0 | |
for msg in messages: | |
for content in msg['content']: | |
if content['type'] == 'image': | |
num_images += self.temporal_patch_size | |
elif content['type'] == 'video': | |
num_images += len(content['video']) | |
if num_images > 0 and self.max_pixels_per_sample // num_images < image_processing_config['max_pixels']: | |
image_processing_config['max_pixels'] = self.max_pixels_per_sample // num_images | |
image_processing_config['min_pixels'] = min(image_processing_config['min_pixels'], image_processing_config['max_pixels']) | |
for msg in messages: | |
for content in msg['content']: | |
if content['type'] == 'image': | |
content['image'] = self.preprocess_image(content['image'], image_processing_config) | |
content['image'] = self.image_processor(images = content['image'], **output_kwargs["images_kwargs"], return_tensors="pt") | |
content['num_vision_tokens'] = self.get_num_vision_tokens(content) | |
pixel_values.append(content['image']['pixel_values']) | |
if 'image_grid_thw' in content['image']: | |
image_grid_thw.extend(content['image']['image_grid_thw']) | |
elif content['type'] == 'video': | |
content['video'] = self.preprocess_image(content['video'], image_processing_config) | |
if isinstance(self.image_processor, Qwen2VLImageProcessor): | |
content['video'] = self.image_processor(images = None, videos = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt") | |
pixel_values.append(content['video']['pixel_values_videos']) | |
else: | |
content['video'] = self.image_processor(images = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt") | |
pixel_values.append(content['video']['pixel_values']) | |
if 'video_grid_thw' in content['video']: | |
image_grid_thw.extend(content['video']['video_grid_thw']) | |
content['num_vision_tokens'] = self.get_num_vision_tokens(content) | |
#γζζ¬ε€ηγ | |
add_generation_prompt = (not is_training and messages[-1]['role'] != 'assistant') | |
strip_final_eos = (not is_training and messages[-1]['role'] == 'assistant') | |
text_inputs = self.tokenizer.apply_chat_template( | |
messages, | |
chat_template = self.chat_template, | |
tokenize=True, | |
tokenizer_kwargs = output_kwargs["text_kwargs"], | |
return_assistant_tokens_mask=True, | |
return_dict=True, | |
add_generation_prompt=add_generation_prompt, | |
strip_final_eos=strip_final_eos, | |
) | |
labels = [-100 if j == 0 else i for i, j in zip(text_inputs['input_ids'], text_inputs['assistant_masks'])] | |
labels = labels[:self.max_seq_len] | |
input_ids = text_inputs['input_ids'][:self.max_seq_len] | |
image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) | |
if image_token_id in text_inputs['input_ids'][self.max_seq_len:]: | |
raise ValueError(f'Too long sequence! {len(text_inputs["input_ids"])}') | |
outputs = { | |
'input_ids': input_ids, | |
'labels': labels, | |
'num_images': num_images, | |
} | |
if len(pixel_values) > 0: | |
outputs['pixel_values'] = torch.cat(pixel_values, dim=0) | |
if len(image_grid_thw) > 0: | |
outputs['image_grid_thw'] = torch.stack(image_grid_thw) | |
return outputs | |
def preprocess_image(self, pil_img: Union[Image.Image, List[Image.Image]], image_processing_config): | |
if image_processing_config is None: | |
return pil_img | |
images = pil_img | |
if isinstance(pil_img, Image.Image): | |
images = [images] | |
if image_processing_config['do_crop']: | |
images = [self.centralcrop(img, rate=[4, 3]) for img in images] | |
if image_processing_config['do_padding']: | |
images = [self.expand2square( | |
img, | |
# tuple(int(x * 255) for x in self.processor.image_processor.image_mean) | |
tuple(int(x * 255) for x in [0, 0, 0]) | |
) for img in images] | |
if image_processing_config['do_resize']: | |
images = [self.resize2square(img) for img in images] | |
if image_processing_config.get('max_pixels'): | |
images = [self.resize2pixels( | |
img, | |
int(image_processing_config['max_pixels']), | |
int(image_processing_config['min_pixels']) | |
) for img in images] | |
if isinstance(pil_img, Image.Image): | |
images = images[0] | |
return images | |
def expand2square(self, pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
def resize2square(self, pil_img: Image.Image): | |
width, height = pil_img.size | |
pil_img = pil_img.resize((max(width, height), max(width, height))) | |
return pil_img | |
def centralcrop(self, pil_img: Image.Image, rate=[4, 3]): | |
width, height = pil_img.size | |
size = (width, height) | |
min_len = min(size) | |
longer_side = 0 if width >= height else 1 | |
center = (width/2, height/2) | |
box = [0, 0, size[0], size[1]] | |
# if longer_side == 0: | |
# box[0] = max(0, center[0] - 1/2*min_len/rate[1]*rate[0]) | |
# box[2] = min(width, center[0] + 1/2*min_len/rate[1]*rate[0]) | |
# else: | |
# box[1] = max(0, center[1] - 1/2*min_len/rate[1]*rate[0]) | |
# box[3] = min(height, center[1] + 1/2*min_len/rate[1]*rate[0]) | |
box[longer_side] = max(0, center[longer_side] - 1/2*min_len/rate[1]*rate[0]) | |
box[2 + longer_side] = min(size[longer_side], center[longer_side] + 1/2*min_len/rate[1]*rate[0]) | |
# box = (width/2-min_len/2, height/2-min_len/2, width/2+min_len/2, height/2+min_len/2) | |
pil_img = pil_img.crop(box) | |
return pil_img | |
def resize2pixels(self, pil_img: Image.Image, max_pixels=None, min_pixels=None): | |
width, height = pil_img.size | |
new_height, new_width = smart_resize(height, width, factor=1, max_pixels=max_pixels, min_pixels=min_pixels) | |
pil_img = pil_img.resize((new_width, new_height)) | |
return pil_img | |
def get_num_vision_tokens(self, content): | |
if isinstance(self.image_processor, Qwen2VLImageProcessor): | |
merge_length = self.image_processor.merge_size**2 | |
if content['type'] == 'image': | |
num_image_tokens = content['image']['image_grid_thw'].prod() // merge_length | |
else: | |
num_image_tokens = content['video']['video_grid_thw'].prod() // merge_length | |
return num_image_tokens | |
else: | |
# ε Άδ»ζ¨‘εοΌimage tokens (-> 2x2 compressed) -> add image_newline and image_new | |
k = 'image'if content['type'] == 'image' else 'video' | |
pixel_values = content[k]['pixel_values'][0] | |
n_frames = len(content[k]['pixel_values']) | |
height, width = get_image_size(to_numpy_array(pixel_values)) | |
num_image_tokens = (height // (self.patch_size * self.merge_size)) * (width // (self.patch_size * self.merge_size) + 1) + 1 | |
return num_image_tokens * n_frames | |
def batch_decode(self, *args, **kwargs): | |
""" | |
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please | |
refer to the docstring of this method for more information. | |
""" | |
return self.tokenizer.batch_decode(*args, **kwargs) | |
def decode(self, *args, **kwargs): | |
""" | |
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to | |
the docstring of this method for more information. | |
""" | |
return self.tokenizer.decode(*args, **kwargs) | |
def model_input_names(self): | |
tokenizer_input_names = self.tokenizer.model_input_names | |
image_processor_input_names = self.image_processor.model_input_names | |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) | |