import numpy as np | |
import torch | |
# --- | |
# Adapted from | |
# https://github.com/qubvel/segmentation_models.pytorch \ | |
# /tree/master/segmentation_models_pytorch/losses | |
# --- | |
def to_tensor(x, dtype=None) -> torch.Tensor: | |
if isinstance(x, torch.Tensor): | |
if dtype is not None: | |
x = x.type(dtype) | |
return x | |
if isinstance(x, np.ndarray): | |
x = torch.from_numpy(x) | |
if dtype is not None: | |
x = x.type(dtype) | |
return x | |
if isinstance(x, (list, tuple)): | |
x = np.array(x) | |
x = torch.from_numpy(x) | |
if dtype is not None: | |
x = x.type(dtype) | |
return x | |