File size: 400 Bytes
6af7294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch


def get_device(device = None):
    if device is None:
        # get cuda -> mps -> cpu
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            if torch.backends.mps.is_built():
                device = "mps"
            else:
                device = "cpu"
        else:
            device = "cpu"
    return device