Spaces:
Runtime error
Runtime error
import json | |
import logging | |
import os | |
import subprocess | |
from argparse import ArgumentParser | |
logger = logging.getLogger(__name__) | |
def parse_args(): | |
parser = ArgumentParser() | |
parsed, unknown = parser.parse_known_args() | |
for arg in unknown: | |
if arg.startswith(("-", "--")): | |
parser.add_argument(arg.split("=")[0]) | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
port = 8888 | |
num_gpus = int(os.environ["SM_NUM_GPUS"]) | |
hosts = json.loads(os.environ["SM_HOSTS"]) | |
num_nodes = len(hosts) | |
current_host = os.environ["SM_CURRENT_HOST"] | |
rank = hosts.index(current_host) | |
os.environ["NCCL_DEBUG"] = "INFO" | |
if num_nodes > 1: | |
cmd = f"""python -m torch.distributed.launch \ | |
--nnodes={num_nodes} \ | |
--node_rank={rank} \ | |
--nproc_per_node={num_gpus} \ | |
--master_addr={hosts[0]} \ | |
--master_port={port} \ | |
./run_glue.py \ | |
{"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" | |
else: | |
cmd = f"""python -m torch.distributed.launch \ | |
--nproc_per_node={num_gpus} \ | |
./run_glue.py \ | |
{"".join([f" --{parameter} {value}" for parameter,value in args.__dict__.items()])}""" | |
try: | |
subprocess.run(cmd, shell=True) | |
except Exception as e: | |
logger.info(e) | |
if __name__ == "__main__": | |
main() | |