import logging import os from multiprocessing import Process from pathlib import Path import torch from s3prl.util.download import _urls_to_filepaths logger = logging.getLogger(__name__) URL = "https://dl.fbaipublicfiles.com/librilight/CPC_checkpoints/60k_epoch4-d0f474de.pt" def _download_with_timeout(timeout: float, num_process: int): processes = [] for _ in range(num_process): process = Process( target=_urls_to_filepaths, args=(URL,), kwargs=dict(refresh=True) ) process.start() processes.append(process) exitcodes = [] for process in processes: process.join(timeout=timeout) exitcodes.append(process.exitcode) assert len(set(exitcodes)) == 1 exitcode = exitcodes[0] if exitcode != 0: for process in processes: process.terminate() def test_download(): filepath = Path(_urls_to_filepaths(URL, download=False)) if filepath.is_file(): os.remove(filepath) logger.info("This should timeout") _download_with_timeout(0.1, 2) assert not filepath.is_file(), ( "The download should failed due to the too short timeout second: 0.1 sec, " "and hence there should not be any corrupted (incomplete) file" ) logger.info("This should success") _download_with_timeout(None, 2) torch.load(filepath, map_location="cpu")