Caleb Spradlin
initial commit
ab687e7
raw
history blame
642 Bytes
import numpy as np
import torch
# ---
# Adapted from
# https://github.com/qubvel/segmentation_models.pytorch \
# /tree/master/segmentation_models_pytorch/losses
# ---
def to_tensor(x, dtype=None) -> torch.Tensor:
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, (list, tuple)):
x = np.array(x)
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x