HumanSD / openclip /training /precision.py
liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
raw
history blame
383 Bytes
import torch
from contextlib import suppress
def get_autocast(precision):
if precision == 'amp':
return torch.cuda.amp.autocast
elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
# amp_bfloat16 is more stable than amp float16 for clip training
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return suppress