feat: use bfloat16
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ from src.plot_utils import export_mask
|
|
16 |
@spaces.GPU()
|
17 |
def predict(model_choice, annotations: Dict[str, Any]):
|
18 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
19 |
sam2_model = load_model(
|
20 |
variant=model_choice,
|
21 |
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
|
|
|
16 |
@spaces.GPU()
|
17 |
def predict(model_choice, annotations: Dict[str, Any]):
|
18 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
20 |
sam2_model = load_model(
|
21 |
variant=model_choice,
|
22 |
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
|