SauravMaheshkar commited on
Commit
090066d
·
unverified ·
1 Parent(s): 3dcca3c

feat: use bfloat16

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -16,6 +16,7 @@ from src.plot_utils import export_mask
16
  @spaces.GPU()
17
  def predict(model_choice, annotations: Dict[str, Any]):
18
  # device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
  sam2_model = load_model(
20
  variant=model_choice,
21
  ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
 
16
  @spaces.GPU()
17
  def predict(model_choice, annotations: Dict[str, Any]):
18
  # device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
20
  sam2_model = load_model(
21
  variant=model_choice,
22
  ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",