Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Copyright (c) https://github.com/pytorch/pytorch | |
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_distributed.py # noqa: E501 | |
import faulthandler | |
import logging | |
import multiprocessing | |
import sys | |
import tempfile | |
import threading | |
import time | |
import traceback | |
import types | |
import unittest | |
from enum import Enum | |
from functools import wraps | |
from typing import NamedTuple | |
from unittest import TestCase | |
import torch | |
from torch.multiprocessing import active_children | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TestSkip(NamedTuple): | |
exit_code: int | |
message: str | |
TEST_SKIPS = { | |
'backend_unavailable': | |
TestSkip(10, 'Skipped because distributed backend is not available.'), | |
'no_cuda': | |
TestSkip(11, 'CUDA is not available.'), | |
'multi-gpu-2': | |
TestSkip(12, 'Need at least 2 CUDA device'), | |
'generic': | |
TestSkip( | |
13, 'Test skipped at subprocess level, look at subprocess log for ' | |
'skip reason'), | |
} | |
# [How does MultiProcessTestCase work?] | |
# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by | |
# default `world_size()` returns 2. Let's take `test_rpc_spawn.py` as an | |
# example which inherits from this class. Its `Setup()` methods calls into | |
# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()` | |
# subprocesses. During the spawn, the main process passes the test name to | |
# subprocesses, and the name is acquired from self.id(). The subprocesses | |
# then use the provided test function name to retrieve the function attribute | |
# from the test instance and run it. The main process simply waits for all | |
# subprocesses to join. | |
class MultiProcessTestCase(TestCase): | |
MAIN_PROCESS_RANK = -1 | |
# This exit code is used to indicate that the test code had an error and | |
# exited abnormally. There are certain tests that might use sys.exit() to | |
# simulate failures and in those cases, we can't have an exit code of 0, | |
# but we still want to ensure we didn't run into any other errors. | |
TEST_ERROR_EXIT_CODE = 10 | |
# do not early terminate for distributed tests. | |
def _should_stop_test_suite(self) -> bool: | |
return False | |
def prepare_subprocess(self): | |
pass | |
def world_size(self) -> int: | |
return 2 | |
def timeout(self) -> int: | |
return 1000 | |
def join_or_run(self, fn): | |
def wrapper(self): | |
if self.rank == self.MAIN_PROCESS_RANK: | |
self._join_processes(fn) | |
else: | |
fn() | |
return types.MethodType(wrapper, self) | |
# The main process spawns N subprocesses that run the test. | |
# Constructor patches current instance test method to | |
# assume the role of the main process and join its subprocesses, | |
# or run the underlying test function. | |
def __init__(self, method_name: str = 'runTest') -> None: | |
super().__init__(method_name) | |
fn = getattr(self, method_name) | |
setattr(self, method_name, self.join_or_run(fn)) | |
def setUp(self) -> None: | |
super().setUp() | |
self.skip_return_code_checks = [] # type: ignore[var-annotated] | |
self.processes = [] # type: ignore[var-annotated] | |
self.rank = self.MAIN_PROCESS_RANK | |
self.file_name = tempfile.NamedTemporaryFile(delete=False).name | |
# pid to pipe consisting of error message from process. | |
self.pid_to_pipe = {} # type: ignore[var-annotated] | |
def tearDown(self) -> None: | |
super().tearDown() | |
for p in self.processes: | |
p.terminate() | |
# Each Process instance holds a few open file descriptors. The unittest | |
# runner creates a new TestCase instance for each test method and keeps | |
# it alive until the end of the entire suite. We must thus reset the | |
# processes to prevent an effective file descriptor leak. | |
self.processes = [] | |
def _current_test_name(self) -> str: | |
# self.id() | |
# e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' | |
return self.id().split('.')[-1] | |
def _start_processes(self, proc) -> None: | |
self.processes = [] | |
for rank in range(int(self.world_size)): | |
parent_conn, child_conn = torch.multiprocessing.Pipe() | |
process = proc( | |
target=self.__class__._run, | |
name='process ' + str(rank), | |
args=(rank, self._current_test_name(), self.file_name, | |
child_conn), | |
) | |
process.start() | |
self.pid_to_pipe[process.pid] = parent_conn | |
self.processes.append(process) | |
def _spawn_processes(self) -> None: | |
proc = torch.multiprocessing.get_context('spawn').Process | |
self._start_processes(proc) | |
class Event(Enum): | |
GET_TRACEBACK = 1 | |
def _event_listener(parent_pipe, signal_pipe, rank: int): | |
while True: | |
ready_pipes = multiprocessing.connection.wait( | |
[parent_pipe, signal_pipe]) | |
if parent_pipe in ready_pipes: | |
if parent_pipe.closed: | |
return | |
event = parent_pipe.recv() | |
if event == MultiProcessTestCase.Event.GET_TRACEBACK: | |
# Return traceback to the parent process. | |
with tempfile.NamedTemporaryFile(mode='r+') as tmp_file: | |
faulthandler.dump_traceback(tmp_file) | |
# Flush buffers and seek to read from the beginning | |
tmp_file.flush() | |
tmp_file.seek(0) | |
parent_pipe.send(tmp_file.read()) | |
if signal_pipe in ready_pipes: | |
return | |
def _run(cls, rank: int, test_name: str, file_name: str, | |
parent_pipe) -> None: | |
self = cls(test_name) | |
try: | |
self.prepare_subprocess() | |
except Exception: | |
raise sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) | |
self.rank = rank | |
self.file_name = file_name | |
self.run_test(test_name, parent_pipe) | |
def run_test(self, test_name: str, parent_pipe) -> None: | |
# Start event listener thread. | |
signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe( | |
duplex=False) | |
event_listener_thread = threading.Thread( | |
target=MultiProcessTestCase._event_listener, | |
args=(parent_pipe, signal_recv_pipe, self.rank), | |
daemon=True, | |
) | |
event_listener_thread.start() | |
# self.id() == e.g. '__main__.TestDistributed.test_get_rank' | |
# We're retrieving a corresponding test and executing it. | |
try: | |
getattr(self, test_name)() | |
except unittest.SkipTest as se: | |
logger.info(f'Process {self.rank} skipping test {test_name} for ' | |
f'following reason: {str(se)}') | |
sys.exit(TEST_SKIPS['generic'].exit_code) | |
except Exception: | |
logger.error( | |
f'Caught exception: \n{traceback.format_exc()} exiting ' | |
f'process {self.rank} with exit code: ' | |
f'{MultiProcessTestCase.TEST_ERROR_EXIT_CODE}') | |
# Send error to parent process. | |
parent_pipe.send(traceback.format_exc()) | |
sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) | |
finally: | |
if signal_send_pipe is not None: | |
signal_send_pipe.send(None) | |
assert event_listener_thread is not None | |
event_listener_thread.join() | |
# Close pipe after done with test. | |
parent_pipe.close() | |
def _get_timedout_process_traceback(self) -> None: | |
pipes = [] | |
for i, process in enumerate(self.processes): | |
if process.exitcode is None: | |
pipe = self.pid_to_pipe[process.pid] | |
try: | |
pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) | |
pipes.append((i, pipe)) | |
except ConnectionError as e: | |
logger.error( | |
'Encountered error while trying to get traceback ' | |
f'for process {i}: {e}') | |
# Wait for results. | |
for rank, pipe in pipes: | |
try: | |
# Wait for traceback | |
if pipe.poll(5): | |
if pipe.closed: | |
logger.info( | |
f'Pipe closed for process {rank}, cannot retrieve ' | |
'traceback') | |
continue | |
traceback = pipe.recv() | |
logger.error(f'Process {rank} timed out with traceback: ' | |
f'\n\n{traceback}') | |
else: | |
logger.error('Could not retrieve traceback for timed out ' | |
f'process: {rank}') | |
except ConnectionError as e: | |
logger.error( | |
'Encountered error while trying to get traceback for ' | |
f'process {rank}: {e}') | |
def _join_processes(self, fn) -> None: | |
start_time = time.time() | |
subprocess_error = False | |
try: | |
while True: | |
# check to see if any subprocess exited with an error early. | |
for (i, p) in enumerate(self.processes): | |
# This is the exit code processes exit with if they | |
# encountered an exception. | |
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: | |
print( | |
f'Process {i} terminated with exit code ' | |
f'{p.exitcode}, terminating remaining processes.') | |
_active_children = active_children() | |
for ac in _active_children: | |
ac.terminate() | |
subprocess_error = True | |
break | |
if subprocess_error: | |
break | |
# All processes have joined cleanly if they all a valid | |
# exitcode | |
if all([p.exitcode is not None for p in self.processes]): | |
break | |
# Check if we should time out the test. If so, we terminate | |
# each process. | |
elapsed = time.time() - start_time | |
if elapsed > self.timeout: | |
self._get_timedout_process_traceback() | |
print(f'Timing out after {self.timeout} seconds and ' | |
'killing subprocesses.') | |
for p in self.processes: | |
p.terminate() | |
break | |
# Sleep to avoid excessive busy polling. | |
time.sleep(0.1) | |
elapsed_time = time.time() - start_time | |
if fn in self.skip_return_code_checks: | |
self._check_no_test_errors(elapsed_time) | |
else: | |
self._check_return_codes(elapsed_time) | |
finally: | |
# Close all pipes | |
for pid, pipe in self.pid_to_pipe.items(): | |
pipe.close() | |
def _check_no_test_errors(self, elapsed_time) -> None: | |
"""Checks that we didn't have any errors thrown in the child | |
processes.""" | |
for i, p in enumerate(self.processes): | |
if p.exitcode is None: | |
raise RuntimeError( | |
'Process {} timed out after {} seconds'.format( | |
i, elapsed_time)) | |
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) | |
def _check_return_codes(self, elapsed_time) -> None: | |
"""Checks that the return codes of all spawned processes match, and | |
skips tests if they returned a return code indicating a skipping | |
condition.""" | |
first_process = self.processes[0] | |
# first, we check if there are errors in actual processes | |
# (via TEST_ERROR_EXIT CODE), and raise an exception for those. | |
# the reason we do this is to attempt to raise a more helpful error | |
# message than "Process x terminated/timed out" | |
# TODO: we should pipe the exception of the failed subprocess here. | |
# Currently, the actual exception is displayed as a logging output. | |
errored_processes = [ | |
(i, p) for i, p in enumerate(self.processes) | |
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE | |
] | |
if errored_processes: | |
error = '' | |
for i, process in errored_processes: | |
# Get error from pipe. | |
error_message = self.pid_to_pipe[process.pid].recv() | |
error += ( | |
'Process {} exited with error code {} and exception:\n{}\n' | |
.format(i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, | |
error_message)) | |
raise RuntimeError(error) | |
# If no process exited uncleanly, we check for timeouts, and then | |
# ensure each process exited cleanly. | |
for i, p in enumerate(self.processes): | |
if p.exitcode is None: | |
raise RuntimeError( | |
f'Process {i} terminated or timed out after ' | |
'{elapsed_time} seconds') | |
for skip in TEST_SKIPS.values(): | |
if first_process.exitcode == skip.exit_code: | |
raise unittest.SkipTest(skip.message) | |
# Skip the unittest since the raised error maybe not caused by | |
# the tested function. For example, in CI environment, the tested | |
# method could be terminated by system signal for the limited | |
# resources. | |
self.skipTest(f'Skip test {self._testMethodName} due to ' | |
'the program abort') | |
def is_master(self) -> bool: | |
return self.rank == 0 | |