Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
from mmpt.utils import recursive_config | |
class BaseJob(object): | |
def __init__(self, yaml_file, dryrun=False): | |
self.yaml_file = yaml_file | |
self.config = recursive_config(yaml_file) | |
self.dryrun = dryrun | |
def submit(self, **kwargs): | |
raise NotImplementedError | |
def _normalize_cmd(self, cmd_list): | |
cmd_list = list(cmd_list) | |
yaml_index = cmd_list.index("[yaml]") | |
cmd_list[yaml_index] = self.yaml_file | |
return cmd_list | |
class LocalJob(BaseJob): | |
CMD_CONFIG = { | |
"local_single": [ | |
"fairseq-train", "[yaml]", "--user-dir", "mmpt", | |
"--task", "mmtask", "--arch", "mmarch", | |
"--criterion", "mmloss", | |
], | |
"local_small": [ | |
"fairseq-train", "[yaml]", "--user-dir", "mmpt", | |
"--task", "mmtask", "--arch", "mmarch", | |
"--criterion", "mmloss", | |
"--distributed-world-size", "2" | |
], | |
"local_big": [ | |
"fairseq-train", "[yaml]", "--user-dir", "mmpt", | |
"--task", "mmtask", "--arch", "mmarch", | |
"--criterion", "mmloss", | |
"--distributed-world-size", "8" | |
], | |
"local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"], | |
} | |
def __init__(self, yaml_file, job_type=None, dryrun=False): | |
super().__init__(yaml_file, dryrun) | |
if job_type is None: | |
self.job_type = "local_single" | |
if self.config.task_type is not None: | |
self.job_type = self.config.task_type | |
else: | |
self.job_type = job_type | |
if self.job_type in ["local_single", "local_small"]: | |
if self.config.fairseq.dataset.batch_size > 32: | |
print("decreasing batch_size to 32 for local testing?") | |
def submit(self): | |
cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type]) | |
if "predict" not in self.job_type: | |
# append fairseq args. | |
from mmpt.utils import load_config | |
config = load_config(config_file=self.yaml_file) | |
for field in config.fairseq: | |
for key in config.fairseq[field]: | |
if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]: # a list of binary flag. | |
param = ["--" + key.replace("_", "-")] | |
else: | |
if key == "lr": | |
value = str(config.fairseq[field][key][0]) | |
elif key == "adam_betas": | |
value = "'"+str(config.fairseq[field][key])+"'" | |
else: | |
value = str(config.fairseq[field][key]) | |
param = [ | |
"--" + key.replace("_", "-"), | |
value | |
] | |
cmd_list.extend(param) | |
print("launching", " ".join(cmd_list)) | |
if not self.dryrun: | |
os.system(" ".join(cmd_list)) | |
return JobStatus("12345678") | |
class JobStatus(object): | |
def __init__(self, job_id): | |
self.job_id = job_id | |
def __repr__(self): | |
return self.job_id | |
def __str__(self): | |
return self.job_id | |
def done(self): | |
return False | |
def running(self): | |
return False | |
def result(self): | |
if self.done(): | |
return "{} is done.".format(self.job_id) | |
else: | |
return "{} is running.".format(self.job_id) | |
def stderr(self): | |
return self.result() | |
def stdout(self): | |
return self.result() | |