File size: 715 Bytes
70d9de4
d197e7f
 
 
96bca50
 
 
d197e7f
 
 
 
 
 
 
96bca50
 
d197e7f
70d9de4
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import wandb


def get_wandb_artifact(
    artifact_name: str, artifact_type: str, get_metadata: bool = False
) -> str:
    if wandb.run:
        artifact = wandb.use_artifact(artifact_name, type=artifact_type)
        artifact_dir = artifact.download()
    else:
        api = wandb.Api()
        artifact = api.artifact(artifact_name)
        artifact_dir = artifact.download()
    if get_metadata:
        return artifact_dir, artifact.metadata
    return artifact_dir


def get_torch_backend():
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        if torch.backends.mps.is_built():
            return "mps"
        return "cpu"
    return "cpu"