schirrmacher commited on
Commit
d7921b8
·
verified ·
1 Parent(s): 85f55d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -31
app.py CHANGED
@@ -5,31 +5,25 @@ import gradio as gr
5
  from ormbg import ORMBG
6
  from PIL import Image
7
 
8
-
9
  model_path = "ormbg.pth"
10
 
 
11
  net = ORMBG()
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- net.to(device)
14
-
15
- if torch.cuda.is_available():
16
- net.load_state_dict(torch.load(model_path))
17
- net = net.cuda()
18
- else:
19
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
20
  net.eval()
21
 
22
-
23
  def resize_image(image):
24
  image = image.convert("RGB")
25
  model_input_size = (1024, 1024)
26
  image = image.resize(model_input_size, Image.BILINEAR)
27
  return image
28
 
29
-
30
  def inference(image):
 
 
 
31
 
32
- # prepare input
33
  orig_image = Image.fromarray(image)
34
  w, h = orig_image.size
35
  image = resize_image(orig_image)
@@ -37,50 +31,41 @@ def inference(image):
37
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
38
  im_tensor = torch.unsqueeze(im_tensor, 0)
39
  im_tensor = torch.divide(im_tensor, 255.0)
 
40
  if torch.cuda.is_available():
41
- im_tensor = im_tensor.cuda()
42
 
43
- # inference
44
  result = net(im_tensor)
45
- # post process
46
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
47
  ma = torch.max(result)
48
  mi = torch.min(result)
49
  result = (result - mi) / (ma - mi)
50
- # image to pil
51
  im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
  pil_im = Image.fromarray(np.squeeze(im_array))
53
- # paste the mask on the original image
54
  new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
55
  new_im.paste(orig_image, mask=pil_im)
56
 
57
  return new_im
58
 
59
-
60
- gr.Markdown("## Open Remove Background Model (ormbg)")
61
- gr.HTML(
62
- """
63
- <p style="margin-bottom: 10px; font-size: 94%">
64
- This is a demo for Open Remove Background Model (ormbg) that using
65
- <a href="https://huggingface.co/schirrmacher/ormbg" target="_blank">Open Remove Background Model (ormbg) model</a> as backbone.
66
- </p>
67
- """
68
- )
69
  title = "Open Remove Background Model (ormbg)"
70
  description = r"""
71
  This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
72
-
73
  It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
74
  The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).
75
 
76
  This is the first iteration of the model, so there will be improvements!
77
- If you identify cases were the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
78
 
79
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
80
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
81
  - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
82
-
83
  """
 
84
  examples = ["./example1.png", "./example2.png", "./example3.png"]
85
 
86
  demo = gr.Interface(
@@ -89,7 +74,7 @@ demo = gr.Interface(
89
  outputs="image",
90
  examples=examples,
91
  title=title,
92
- description=description,
93
  )
94
 
95
  if __name__ == "__main__":
 
5
  from ormbg import ORMBG
6
  from PIL import Image
7
 
 
8
  model_path = "ormbg.pth"
9
 
10
+ # Load the model globally but don't send to device yet
11
  net = ORMBG()
12
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
 
 
 
 
 
 
 
13
  net.eval()
14
 
 
15
  def resize_image(image):
16
  image = image.convert("RGB")
17
  model_input_size = (1024, 1024)
18
  image = image.resize(model_input_size, Image.BILINEAR)
19
  return image
20
 
 
21
  def inference(image):
22
+ # Check for CUDA and set the device inside inference
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ net.to(device)
25
 
26
+ # Prepare input
27
  orig_image = Image.fromarray(image)
28
  w, h = orig_image.size
29
  image = resize_image(orig_image)
 
31
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
32
  im_tensor = torch.unsqueeze(im_tensor, 0)
33
  im_tensor = torch.divide(im_tensor, 255.0)
34
+
35
  if torch.cuda.is_available():
36
+ im_tensor = im_tensor.to(device)
37
 
38
+ # Inference
39
  result = net(im_tensor)
40
+ # Post process
41
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
42
  ma = torch.max(result)
43
  mi = torch.min(result)
44
  result = (result - mi) / (ma - mi)
45
+ # Image to PIL
46
  im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
47
  pil_im = Image.fromarray(np.squeeze(im_array))
48
+ # Paste the mask on the original image
49
  new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
50
  new_im.paste(orig_image, mask=pil_im)
51
 
52
  return new_im
53
 
54
+ # Gradio interface setup
 
 
 
 
 
 
 
 
 
55
  title = "Open Remove Background Model (ormbg)"
56
  description = r"""
57
  This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
 
58
  It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
59
  The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).
60
 
61
  This is the first iteration of the model, so there will be improvements!
62
+ If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
63
 
64
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
65
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
66
  - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
 
67
  """
68
+
69
  examples = ["./example1.png", "./example2.png", "./example3.png"]
70
 
71
  demo = gr.Interface(
 
74
  outputs="image",
75
  examples=examples,
76
  title=title,
77
+ description=description
78
  )
79
 
80
  if __name__ == "__main__":