Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from typing import List | |
logger = logging.getLogger(__name__) | |
def uniform(dataset_sizes: List[int]): | |
return [1.0] * len(dataset_sizes) | |
def temperature_sampling(dataset_sizes, temp): | |
total_size = sum(dataset_sizes) | |
return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] | |
def make_temperature_sampling(temp=1.0): | |
def sampling_func(dataset_sizes): | |
return temperature_sampling(dataset_sizes, temp) | |
return sampling_func | |
def make_ratio_sampling(ratios): | |
def sampling_func(dataset_sizes): | |
return ratios | |
return sampling_func | |
class SamplingMethod: | |
def add_arguments(parser): | |
parser.add_argument( | |
"--sampling-method", | |
choices=[ | |
"uniform", | |
"temperature", | |
"concat", | |
"RoundRobin", | |
], | |
type=str, | |
default="concat", | |
help="The method to sample data per language pairs", | |
) | |
parser.add_argument( | |
"--sampling-temperature", | |
default=1.5, | |
type=float, | |
help="only work with --sampling-method temperature", | |
) | |
def build_sampler(args, task): | |
return SamplingMethod(args, task) | |
def __init__(self, args, task): | |
self.args = args | |
self.task = task | |
def is_adaptive(self): | |
return False | |
def sampling_method_selector(self): | |
args = self.args | |
logger.info(f"selected sampler: {args.sampling_method}") | |
if args.sampling_method == "uniform": | |
return uniform | |
elif args.sampling_method == "temperature" or self.is_adaptive(): | |
return make_temperature_sampling(float(args.sampling_temperature)) | |
else: | |
# default to concating all data set together | |
return None | |