# Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ctypes import c_float, sizeof from enum import Enum from typing import Any, Dict, Iterable class ParameterFormat(Enum): Float = c_float @property def size(self) -> int: """ Number of byte required for this data type Returns: Integer > 0 """ return sizeof(self.value) def compute_effective_axis_dimension(dimension: int, fixed_dimension: int, num_token_to_add: int = 0) -> int: """ Args: dimension: fixed_dimension: num_token_to_add: Returns: """ # < 0 is possible if using a dynamic axis if dimension <= 0: dimension = fixed_dimension dimension -= num_token_to_add return dimension def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterFormat) -> int: """ Compute the size taken by all the parameters in the given the storage format when serializing the model Args: num_parameters: Number of parameters to be saved dtype: The data format each parameter will be saved Returns: Size (in byte) taken to save all the parameters """ return num_parameters * dtype.size def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: """ Flatten any potential nested structure expanding the name of the field with the index of the element within the structure. Args: name: The name of the nested structure field: The structure to, potentially, be flattened Returns: (Dict[str, Any]): Outputs with flattened structure and key mapping this new structure. """ from itertools import chain return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}