Spaces:
Sleeping
Sleeping
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from ctypes import c_float, sizeof | |
from enum import Enum | |
from typing import Any, Dict, Iterable | |
class ParameterFormat(Enum): | |
Float = c_float | |
def size(self) -> int: | |
""" | |
Number of byte required for this data type | |
Returns: | |
Integer > 0 | |
""" | |
return sizeof(self.value) | |
def compute_effective_axis_dimension(dimension: int, fixed_dimension: int, num_token_to_add: int = 0) -> int: | |
""" | |
Args: | |
dimension: | |
fixed_dimension: | |
num_token_to_add: | |
Returns: | |
""" | |
# < 0 is possible if using a dynamic axis | |
if dimension <= 0: | |
dimension = fixed_dimension | |
dimension -= num_token_to_add | |
return dimension | |
def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterFormat) -> int: | |
""" | |
Compute the size taken by all the parameters in the given the storage format when serializing the model | |
Args: | |
num_parameters: Number of parameters to be saved | |
dtype: The data format each parameter will be saved | |
Returns: | |
Size (in byte) taken to save all the parameters | |
""" | |
return num_parameters * dtype.size | |
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: | |
""" | |
Flatten any potential nested structure expanding the name of the field with the index of the element within the | |
structure. | |
Args: | |
name: The name of the nested structure | |
field: The structure to, potentially, be flattened | |
Returns: | |
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure. | |
""" | |
from itertools import chain | |
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))} | |