Spaces:
Runtime error
Runtime error
| import torch | |
| from contextlib import suppress | |
| def get_autocast(precision): | |
| if precision == 'amp': | |
| return torch.cuda.amp.autocast | |
| elif precision == 'amp_bfloat16' or precision == 'amp_bf16': | |
| # amp_bfloat16 is more stable than amp float16 for clip training | |
| return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) | |
| else: | |
| return suppress | |