Mateusz Mróz
Refactor Florence2Processor and utils.py for improved readability and maintainability; add colab setup script with necessary patches and dependencies for model execution.
bfabb11
| # coding=utf-8 | |
| # Copyright 2024 Microsoft and The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Processor class for Florence-2. | |
| """ | |
| import re | |
| import logging | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| import PIL | |
| from transformers.feature_extraction_utils import BatchFeature | |
| from transformers.image_utils import ImageInput | |
| from transformers.processing_utils import ProcessorMixin | |
| from transformers.tokenization_utils_base import ( | |
| PaddingStrategy, | |
| TextInput, | |
| TruncationStrategy, | |
| ) | |
| from transformers.utils import TensorType | |
| import re | |
| logger = logging.getLogger(__name__) | |
| class Florence2Processor(ProcessorMixin): | |
| attributes = ["image_processor", "tokenizer"] | |
| image_processor_class = "CLIPImageProcessor" | |
| tokenizer_class = ("BartTokenizer", "BartTokenizerFast") | |
| def __init__( | |
| self, | |
| image_processor=None, | |
| tokenizer=None, | |
| ): | |
| if image_processor is None: | |
| raise ValueError("You need to specify an `image_processor`.") | |
| if tokenizer is None: | |
| raise ValueError("You need to specify a `tokenizer`.") | |
| if not hasattr(image_processor, "image_seq_length"): | |
| raise ValueError( | |
| "Image processor is missing an `image_seq_length` attribute.") | |
| self.image_seq_length = image_processor.image_seq_length | |
| tokens_to_add = { | |
| 'additional_special_tokens': | |
| tokenizer.additional_special_tokens + | |
| ['<od>', '</od>', '<ocr>', '</ocr>'] + | |
| [f'<loc_{x}>' for x in range(1000)] + | |
| ['<cap>', '</cap>', '<ncap>', '</ncap>', '<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>'] + | |
| ['<panel>', '<text>', '<character>', '<tail>'] | |
| } | |
| tokenizer.add_special_tokens(tokens_to_add) | |
| self.decoder_start_token_id = 2 | |
| self.box_quantizer = BoxQuantizer( | |
| mode='floor', | |
| bins=(1000, 1000), | |
| ) | |
| super().__init__(image_processor, tokenizer) | |
| def __call__( | |
| self, | |
| batch_input_text: List[TextInput] = None, | |
| batch_input_list_of_list_of_bboxes: List[List[List[List[float]]]] = None, | |
| batch_output_text: List[TextInput] = None, | |
| batch_output_list_of_list_of_bboxes: List[List[List[List[float]]]] = None, | |
| batch_images: ImageInput = None, | |
| batch_character_cluster_labels=None, | |
| batch_text_character_association_labels=None, | |
| batch_text_tail_association_labels=None, | |
| batch_is_essential_text_labels=None, | |
| batch_tail_character_association_labels=None, | |
| padding: Union[bool, str, PaddingStrategy] = None, | |
| truncation: Union[bool, str, TruncationStrategy] = None, | |
| max_input_length_including_image_tokens=None, | |
| max_output_length=None, | |
| return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, | |
| do_resize: bool = None, | |
| do_normalize: bool = None, | |
| image_mean: Optional[Union[float, List[float]]] = None, | |
| image_std: Optional[Union[float, List[float]]] = None, | |
| data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 | |
| input_data_format: Optional[ | |
| Union[str, "ChannelDimension"] # noqa: F821 | |
| ] = None, | |
| resample: "PILImageResampling" = None, # noqa: F821 | |
| do_convert_rgb: bool = None, | |
| dtype: torch.dtype = None, | |
| device: torch.device = None, | |
| ) -> BatchFeature: | |
| assert batch_images is not None, "`batch_images` are expected as arguments to a `Florence2Processor` instance." | |
| assert batch_input_text is not None, "`batch_input_text` are expected as arguments to a `Florence2Processor` instance." | |
| if batch_input_list_of_list_of_bboxes is None: | |
| batch_input_list_of_list_of_bboxes = [ | |
| [] for _ in range(len(batch_input_text))] | |
| assert len(batch_input_text) == len(batch_input_list_of_list_of_bboxes) == len( | |
| batch_images), "`batch_input_text`, `batch_input_list_of_list_of_bboxes` and `batch_images` have different lengths." | |
| if batch_output_text is None: | |
| assert batch_output_list_of_list_of_bboxes is None, "`batch_output_text` and `batch_output_list_of_list_of_bboxes` should be provided together." | |
| else: | |
| if batch_output_list_of_list_of_bboxes is None: | |
| batch_output_list_of_list_of_bboxes = [ | |
| [] for _ in range(len(batch_output_text))] | |
| assert len(batch_output_text) == len(batch_output_list_of_list_of_bboxes) == len( | |
| batch_images), "`batch_output_text`, `batch_output_list_of_list_of_bboxes` and `batch_images` have different lengths." | |
| max_input_length = max_input_length_including_image_tokens - \ | |
| self.image_seq_length if max_input_length_including_image_tokens is not None else None | |
| batch_input_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip( | |
| batch_input_text, batch_input_list_of_list_of_bboxes, batch_images)] | |
| inputs = self.tokenizer( | |
| batch_input_texts, | |
| return_tensors=return_tensors, | |
| padding=padding, | |
| truncation=False, | |
| ) | |
| # Truncating manually because I don't want </s> token at the end of truncated sequences, which is the default behavior | |
| if inputs["input_ids"].shape[1] > max_input_length: | |
| inputs["input_ids"] = inputs["input_ids"][:, :max_input_length] | |
| inputs["attention_mask"] = inputs["attention_mask"][:, :max_input_length] | |
| if batch_output_text is not None: | |
| batch_output_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip( | |
| batch_output_text, batch_output_list_of_list_of_bboxes, batch_images)] | |
| decoder_inputs = self.tokenizer( | |
| batch_output_texts, | |
| return_tensors=return_tensors, | |
| padding=padding, | |
| truncation=False, | |
| ) | |
| # Truncating manually because I don't want </s> token at the end of truncated sequences, which is the default behavior | |
| if decoder_inputs["input_ids"].shape[1] > max_output_length: | |
| decoder_inputs["input_ids"] = decoder_inputs["input_ids"][:, | |
| :max_output_length] | |
| decoder_inputs["attention_mask"] = decoder_inputs["attention_mask"][:, | |
| :max_output_length] | |
| pixel_values = self.image_processor( | |
| batch_images, | |
| do_resize=do_resize, | |
| do_normalize=do_normalize, | |
| return_tensors=return_tensors, | |
| image_mean=image_mean, | |
| image_std=image_std, | |
| input_data_format=input_data_format, | |
| data_format=data_format, | |
| resample=resample, | |
| do_convert_rgb=do_convert_rgb, | |
| )["pixel_values"] | |
| if dtype is not None: | |
| pixel_values = pixel_values.to(dtype) | |
| return_data = {**inputs, "pixel_values": pixel_values} | |
| if batch_output_text is not None: | |
| labels = decoder_inputs["input_ids"] | |
| decoder_input_ids = labels.new_zeros(labels.shape) | |
| decoder_input_ids[:, 1:] = labels[:, :-1].clone() | |
| decoder_input_ids[:, 0] = self.decoder_start_token_id | |
| decoder_attention_mask = decoder_inputs["attention_mask"].new_ones( | |
| decoder_input_ids.shape) | |
| decoder_attention_mask[:, | |
| 1:] = decoder_inputs["attention_mask"][:, :-1].clone() | |
| # Mask fill labels to replace pad token ID with -100 | |
| labels.masked_fill_(labels == self.tokenizer.pad_token_id, -100) | |
| return_data.update({ | |
| "labels": labels, | |
| "decoder_input_ids": decoder_input_ids, | |
| "decoder_attention_mask": decoder_attention_mask, | |
| }) | |
| if device is not None: | |
| for key, value in return_data.items(): | |
| if isinstance(value, torch.Tensor): | |
| return_data[key] = value.to(device) | |
| if batch_character_cluster_labels is not None: | |
| return_data["character_cluster_labels"] = batch_character_cluster_labels | |
| if batch_text_character_association_labels is not None: | |
| return_data["text_character_association_labels"] = batch_text_character_association_labels | |
| if batch_text_tail_association_labels is not None: | |
| return_data["text_tail_association_labels"] = batch_text_tail_association_labels | |
| if batch_is_essential_text_labels is not None: | |
| return_data["is_essential_text_labels"] = batch_is_essential_text_labels | |
| if batch_tail_character_association_labels is not None: | |
| return_data["tail_character_association_labels"] = batch_tail_character_association_labels | |
| return_data["tokenizer"] = self.tokenizer | |
| return BatchFeature(data=return_data) | |
| def cleanup_generated_text(self, generated_text): | |
| return generated_text.replace("<s>", "").replace("</s>", "").replace("<pad>", "") | |
| def postprocess_output(self, generated_ids, images): | |
| # only for some testing purposes | |
| generated_ids.masked_fill_( | |
| generated_ids == -100, self.tokenizer.pad_token_id) | |
| batch_decoded_texts = self.batch_decode( | |
| generated_ids, skip_special_tokens=False) | |
| batch_decoded_texts = [self.cleanup_generated_text( | |
| text) for text in batch_decoded_texts] | |
| batch_list_of_list_of_bboxes = [] | |
| batch_indices_of_bboxes_in_new_string = [] | |
| batch_new_texts = [] | |
| for text, image in zip(batch_decoded_texts, images): | |
| size_wh = self._get_image_size_wh(image) | |
| parsed_text, list_of_stringified_bboxes, start_end_in_new_string = self._parse_text_with_bboxes( | |
| text) | |
| list_of_list_of_bboxes = [self.box_quantizer.dequantize_from_stringified_bboxes( | |
| stringified_bbox, size_wh) for stringified_bbox in list_of_stringified_bboxes] | |
| batch_list_of_list_of_bboxes.append(list_of_list_of_bboxes) | |
| batch_indices_of_bboxes_in_new_string.append( | |
| start_end_in_new_string) | |
| batch_new_texts.append(parsed_text) | |
| return batch_new_texts, batch_list_of_list_of_bboxes, batch_indices_of_bboxes_in_new_string | |
| def _parse_text_with_bboxes(self, text): | |
| loc_pattern = r'((?:<loc_\d+>){4}(?:,(?:<loc_\d+>){4})*)' | |
| grounding_pattern = r'<grounding>(.*?)</grounding>' + loc_pattern | |
| list_of_stringified_bboxes = [] | |
| start_end_in_new_string = [] | |
| new_text = "" | |
| original_pos = 0 | |
| new_pos = 0 | |
| for match in re.finditer(grounding_pattern + '|' + loc_pattern, text): | |
| # Add text before the match | |
| new_text += text[original_pos:match.start()] | |
| new_pos += match.start() - original_pos | |
| if match.group(0).startswith('<grounding>'): | |
| # Handle grounding pattern | |
| grounding_text = match.group(1) | |
| locs = match.group(2) | |
| new_text += grounding_text | |
| list_of_stringified_bboxes.append(locs) | |
| start_end_in_new_string.append( | |
| (new_pos, new_pos + len(grounding_text))) | |
| new_pos += len(grounding_text) | |
| else: | |
| # Handle loc pattern | |
| locs = match.group(0) | |
| replacement = "" | |
| new_text += replacement | |
| list_of_stringified_bboxes.append(locs) | |
| start_end_in_new_string.append( | |
| (new_pos, new_pos + len(replacement))) | |
| new_pos += len(replacement) | |
| original_pos = match.end() | |
| # Add any remaining text | |
| new_text += text[original_pos:] | |
| return new_text, list_of_stringified_bboxes, start_end_in_new_string | |
| def _format_text_with_bboxes(self, text, list_of_list_of_bboxes, image): | |
| size_wh = self._get_image_size_wh(image) | |
| quantized_bbox_lists = [] | |
| for list_of_bboxes in list_of_list_of_bboxes: | |
| quantized_bboxes = self.box_quantizer.quantize( | |
| list_of_bboxes, size_wh=size_wh) | |
| stringified_bboxes = [ | |
| f"<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>" for x1, y1, x2, y2 in quantized_bboxes] | |
| stringified_bboxes = ",".join(stringified_bboxes) | |
| quantized_bbox_lists.append(stringified_bboxes) | |
| return text.format(*quantized_bbox_lists) | |
| def _get_image_size_wh(self, image): | |
| # Get size_wh from image based on its type | |
| if isinstance(image, torch.Tensor): | |
| # For PyTorch tensor | |
| if image.dim() == 3: | |
| size_wh = (image.shape[2], image.shape[1]) # (width, height) | |
| elif image.dim() == 4: | |
| size_wh = (image.shape[3], image.shape[2]) # (width, height) | |
| else: | |
| raise ValueError("Unsupported tensor dimensions") | |
| elif isinstance(image, np.ndarray): | |
| # For NumPy array | |
| if image.ndim == 2: | |
| size_wh = (image.shape[1], image.shape[0]) # (width, height) | |
| elif image.ndim == 3: | |
| size_wh = (image.shape[1], image.shape[0]) # (width, height) | |
| else: | |
| raise ValueError("Unsupported array dimensions") | |
| elif isinstance(image, PIL.Image.Image): | |
| # For PIL Image | |
| size_wh = image.size # Already in (width, height) format | |
| else: | |
| raise TypeError("Unsupported image type") | |
| return size_wh | |
| # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2 | |
| def batch_decode(self, *args, **kwargs): | |
| """ | |
| This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please | |
| refer to the docstring of this method for more information. | |
| """ | |
| return self.tokenizer.batch_decode(*args, **kwargs) | |
| # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2 | |
| def decode(self, *args, **kwargs): | |
| """ | |
| This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to | |
| the docstring of this method for more information. | |
| """ | |
| return self.tokenizer.decode(*args, **kwargs) | |
| # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2 | |
| def model_input_names(self): | |
| tokenizer_input_names = self.tokenizer.model_input_names | |
| image_processor_input_names = self.image_processor.model_input_names | |
| return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) | |
| class BoxQuantizer(object): | |
| def __init__(self, mode, bins): | |
| self.mode = mode | |
| self.bins = bins | |
| def quantize(self, boxes, size_wh): | |
| if not isinstance(boxes, torch.Tensor): | |
| boxes = torch.tensor(boxes) | |
| bins_w, bins_h = self.bins # Quantization bins. | |
| size_w, size_h = size_wh # Original image size. | |
| size_per_bin_w = size_w / bins_w | |
| size_per_bin_h = size_h / bins_h | |
| xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1]. | |
| if self.mode == 'floor': | |
| quantized_xmin = ( | |
| xmin / size_per_bin_w).floor().clamp(0, bins_w - 1) | |
| quantized_ymin = ( | |
| ymin / size_per_bin_h).floor().clamp(0, bins_h - 1) | |
| quantized_xmax = ( | |
| xmax / size_per_bin_w).floor().clamp(0, bins_w - 1) | |
| quantized_ymax = ( | |
| ymax / size_per_bin_h).floor().clamp(0, bins_h - 1) | |
| elif self.mode == 'round': | |
| raise NotImplementedError() | |
| else: | |
| raise ValueError('Incorrect quantization type.') | |
| quantized_boxes = torch.cat( | |
| (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1 | |
| ).int() | |
| return quantized_boxes.tolist() | |
| def dequantize_from_stringified_bboxes(self, stringified_bboxes, size_wh): | |
| bboxes = stringified_bboxes.split(',') | |
| def parse_bbox(bbox_string): | |
| pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>' | |
| match = re.match(pattern, bbox_string) | |
| if match: | |
| return [int(match.group(i)) for i in range(1, 5)] | |
| else: | |
| raise ValueError(f"Invalid bbox string format: {bbox_string}") | |
| parsed_bboxes = [parse_bbox(bbox) for bbox in bboxes] | |
| return self.dequantize(parsed_bboxes, size_wh).tolist() | |
| def dequantize(self, boxes: torch.Tensor, size): | |
| if not isinstance(boxes, torch.Tensor): | |
| boxes = torch.tensor(boxes) | |
| bins_w, bins_h = self.bins # Quantization bins. | |
| size_w, size_h = size # Original image size. | |
| size_per_bin_w = size_w / bins_w | |
| size_per_bin_h = size_h / bins_h | |
| xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1]. | |
| if self.mode == 'floor': | |
| # Add 0.5 to use the center position of the bin as the coordinate. | |
| dequantized_xmin = (xmin + 0.5) * size_per_bin_w | |
| dequantized_ymin = (ymin + 0.5) * size_per_bin_h | |
| dequantized_xmax = (xmax + 0.5) * size_per_bin_w | |
| dequantized_ymax = (ymax + 0.5) * size_per_bin_h | |
| elif self.mode == 'round': | |
| raise NotImplementedError() | |
| else: | |
| raise ValueError('Incorrect quantization type.') | |
| dequantized_boxes = torch.cat( | |
| (dequantized_xmin, dequantized_ymin, | |
| dequantized_xmax, dequantized_ymax), dim=-1 | |
| ) | |
| return dequantized_boxes | |