Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import glob | |
| import os.path | |
| import torch | |
| try: | |
| from torch.utils.cpp_extension import load as load_ext | |
| from torch.utils.cpp_extension import CUDA_HOME | |
| except ImportError: | |
| raise ImportError("The cpp layer extensions requires PyTorch 0.4 or higher") | |
| def _load_C_extensions(): | |
| this_dir = os.path.dirname(os.path.abspath(__file__)) | |
| this_dir = os.path.dirname(this_dir) | |
| this_dir = os.path.join(this_dir, "csrc") | |
| main_file = glob.glob(os.path.join(this_dir, "*.cpp")) | |
| source_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) | |
| source_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) | |
| source = main_file + source_cpu | |
| extra_cflags = [] | |
| if torch.cuda.is_available() and CUDA_HOME is not None: | |
| source.extend(source_cuda) | |
| extra_cflags = ["-DWITH_CUDA"] | |
| source = [os.path.join(this_dir, s) for s in source] | |
| extra_include_paths = [this_dir] | |
| return load_ext( | |
| "torchvision", | |
| source, | |
| extra_cflags=extra_cflags, | |
| extra_include_paths=extra_include_paths, | |
| ) | |
| _C = _load_C_extensions() | |