File size: 220 Bytes
86b1a7e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
from enum import Enum

class AccelerationType(Enum):
    CPU = "cpu"
    GPU = "gpu"
    TPU = "tpu"
    MPS = "mps"

def execute_graph() -> None:
    if _acceleration_type == AccelerationType.TPU:
        xm.mark_step()