not-lain commited on
Commit
3aa0053
·
1 Parent(s): fe5d262

introduce context manager

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -9,12 +9,35 @@ from PIL import Image, ImageOps
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
  import numpy as np
11
  from simple_lama_inpainting import SimpleLama
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  pipe = FluxFillPipeline.from_pretrained(
15
  "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
16
  ).to("cuda")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def prepare_image_and_mask(
20
  image,
@@ -98,27 +121,15 @@ def rmbg(image=None, url=None):
98
  image = url
99
  image = load_img(image).convert("RGB")
100
  image_size = image.size
101
- torch.set_float32_matmul_precision(["high", "highest"][0])
102
- birefnet = AutoModelForImageSegmentation.from_pretrained(
103
- "ZhengPeng7/BiRefNet", trust_remote_code=True
104
- )
105
- birefnet.to("cuda")
106
- transform_image = transforms.Compose(
107
- [
108
- transforms.Resize((1024, 1024)),
109
- transforms.ToTensor(),
110
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
111
- ]
112
- )
113
  input_images = transform_image(image).unsqueeze(0).to("cuda")
114
- # Prediction
115
- with torch.no_grad():
116
- preds = birefnet(input_images)[-1].sigmoid().cpu()
 
117
  pred = preds[0].squeeze()
118
  pred_pil = transforms.ToPILImage()(pred)
119
  mask = pred_pil.resize(image_size)
120
  image.putalpha(mask)
121
- torch.set_float32_matmul_precision(["high", "highest"][1])
122
  return image
123
 
124
 
 
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
  import numpy as np
11
  from simple_lama_inpainting import SimpleLama
12
+ from contextlib import contextmanager
13
+
14
+
15
+ @contextmanager
16
+ def float32_high_matmul_precision():
17
+ torch.set_float32_matmul_precision("high")
18
+ try:
19
+ yield
20
+ finally:
21
+ torch.set_float32_matmul_precision("highest")
22
 
23
 
24
  pipe = FluxFillPipeline.from_pretrained(
25
  "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
26
  ).to("cuda")
27
 
28
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
29
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
30
+ )
31
+ birefnet.to("cuda")
32
+
33
+ transform_image = transforms.Compose(
34
+ [
35
+ transforms.Resize((1024, 1024)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
38
+ ]
39
+ )
40
+
41
 
42
  def prepare_image_and_mask(
43
  image,
 
121
  image = url
122
  image = load_img(image).convert("RGB")
123
  image_size = image.size
 
 
 
 
 
 
 
 
 
 
 
 
124
  input_images = transform_image(image).unsqueeze(0).to("cuda")
125
+ with float32_high_matmul_precision():
126
+ # Prediction
127
+ with torch.no_grad():
128
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
129
  pred = preds[0].squeeze()
130
  pred_pil = transforms.ToPILImage()(pred)
131
  mask = pred_pil.resize(image_size)
132
  image.putalpha(mask)
 
133
  return image
134
 
135