# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from ..version_utils import digit_version from .parrots_wrapper import TORCH_VERSION _torch_version_meshgrid_indexing = ( 'parrots' not in TORCH_VERSION and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) def torch_meshgrid(*tensors): """A wrapper of torch.meshgrid to compat different PyTorch versions. Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. So we implement a wrapper here to avoid warning when using high-version PyTorch and avoid compatibility issues when using previous versions of PyTorch. Args: tensors (List[Tensor]): List of scalars or 1 dimensional tensors. Returns: Sequence[Tensor]: Sequence of meshgrid tensors. """ if _torch_version_meshgrid_indexing: return torch.meshgrid(*tensors, indexing='ij') else: return torch.meshgrid(*tensors) # Uses indexing='ij' by default