Spaces:
Runtime error
Runtime error
# Copyright 2024 MIT Han Lab | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
from typing import Any, Optional, Union | |
import numpy as np | |
import torch | |
__all__ = [ | |
"torch_randint", | |
"torch_random", | |
"torch_shuffle", | |
"torch_uniform", | |
"torch_random_choices", | |
] | |
def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int: | |
"""uniform: [low, high)""" | |
if low == high: | |
return low | |
else: | |
assert low < high | |
return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) | |
def torch_random(generator: Optional[torch.Generator] = None) -> float: | |
"""uniform distribution on the interval [0, 1)""" | |
return float(torch.rand(1, generator=generator)) | |
def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]: | |
rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() | |
return [src_list[i] for i in rand_indexes] | |
def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float: | |
"""uniform distribution on the interval [low, high)""" | |
rand_val = torch_random(generator) | |
return (high - low) * rand_val + low | |
def torch_random_choices( | |
src_list: list[Any], | |
generator: Optional[torch.Generator] = None, | |
k=1, | |
weight_list: Optional[list[float]] = None, | |
) -> Union[Any, list]: | |
if weight_list is None: | |
rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,)) | |
out_list = [src_list[i] for i in rand_idx] | |
else: | |
assert len(weight_list) == len(src_list) | |
accumulate_weight_list = np.cumsum(weight_list) | |
out_list = [] | |
for _ in range(k): | |
val = torch_uniform(0, accumulate_weight_list[-1], generator) | |
active_id = 0 | |
for i, weight_val in enumerate(accumulate_weight_list): | |
active_id = i | |
if weight_val > val: | |
break | |
out_list.append(src_list[active_id]) | |
return out_list[0] if k == 1 else out_list | |