not-lain commited on
Commit
dc79526
·
verified ·
1 Parent(s): a5944b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -0
app.py CHANGED
@@ -10,6 +10,13 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
  import numpy as np
11
  from simple_lama_inpainting import SimpleLama
12
 
 
 
 
 
 
 
 
13
 
14
  pipe = FluxFillPipeline.from_pretrained(
15
  "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
 
10
  import numpy as np
11
  from simple_lama_inpainting import SimpleLama
12
 
13
+ if device.type == "cuda":
14
+ # use bfloat16 for the entire notebook
15
+ torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
16
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
17
+ if torch.cuda.get_device_properties(0).major >= 8:
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+ torch.backends.cudnn.allow_tf32 = True
20
 
21
  pipe = FluxFillPipeline.from_pretrained(
22
  "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16