schirrmacher commited on
Commit
0383b74
·
verified ·
1 Parent(s): 2c218d6

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +18 -27
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
@@ -5,18 +6,11 @@ import gradio as gr
5
  from ormbg import ORMBG
6
  from PIL import Image
7
 
8
-
9
  model_path = "ormbg.pth"
10
 
 
11
  net = ORMBG()
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- net.to(device)
14
-
15
- if torch.cuda.is_available():
16
- net.load_state_dict(torch.load(model_path))
17
- net = net.cuda()
18
- else:
19
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
20
  net.eval()
21
 
22
 
@@ -27,9 +21,14 @@ def resize_image(image):
27
  return image
28
 
29
 
 
 
30
  def inference(image):
 
 
 
31
 
32
- # prepare input
33
  orig_image = Image.fromarray(image)
34
  w, h = orig_image.size
35
  image = resize_image(orig_image)
@@ -37,50 +36,42 @@ def inference(image):
37
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
38
  im_tensor = torch.unsqueeze(im_tensor, 0)
39
  im_tensor = torch.divide(im_tensor, 255.0)
 
40
  if torch.cuda.is_available():
41
- im_tensor = im_tensor.cuda()
42
 
43
- # inference
44
  result = net(im_tensor)
45
- # post process
46
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
47
  ma = torch.max(result)
48
  mi = torch.min(result)
49
  result = (result - mi) / (ma - mi)
50
- # image to pil
51
  im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
  pil_im = Image.fromarray(np.squeeze(im_array))
53
- # paste the mask on the original image
54
  new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
55
  new_im.paste(orig_image, mask=pil_im)
56
 
57
  return new_im
58
 
59
 
60
- gr.Markdown("## Open Remove Background Model (ormbg)")
61
- gr.HTML(
62
- """
63
- <p style="margin-bottom: 10px; font-size: 94%">
64
- This is a demo for Open Remove Background Model (ormbg) that using
65
- <a href="https://huggingface.co/schirrmacher/ormbg" target="_blank">Open Remove Background Model (ormbg) model</a> as backbone.
66
- </p>
67
- """
68
- )
69
  title = "Open Remove Background Model (ormbg)"
70
  description = r"""
71
  This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
72
-
73
  It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
74
  The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).
75
 
76
  This is the first iteration of the model, so there will be improvements!
77
- If you identify cases were the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
78
 
79
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
80
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
81
  - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
82
-
83
  """
 
84
  examples = ["./example1.png", "./example2.png", "./example3.png"]
85
 
86
  demo = gr.Interface(
 
1
+ import spaces
2
  import numpy as np
3
  import torch
4
  import torch.nn.functional as F
 
6
  from ormbg import ORMBG
7
  from PIL import Image
8
 
 
9
  model_path = "ormbg.pth"
10
 
11
+ # Load the model globally but don't send to device yet
12
  net = ORMBG()
13
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
 
 
 
 
 
 
 
14
  net.eval()
15
 
16
 
 
21
  return image
22
 
23
 
24
+ @spaces.GPU
25
+ @torch.inference_mode()
26
  def inference(image):
27
+ # Check for CUDA and set the device inside inference
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ net.to(device)
30
 
31
+ # Prepare input
32
  orig_image = Image.fromarray(image)
33
  w, h = orig_image.size
34
  image = resize_image(orig_image)
 
36
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
37
  im_tensor = torch.unsqueeze(im_tensor, 0)
38
  im_tensor = torch.divide(im_tensor, 255.0)
39
+
40
  if torch.cuda.is_available():
41
+ im_tensor = im_tensor.to(device)
42
 
43
+ # Inference
44
  result = net(im_tensor)
45
+ # Post process
46
  result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
47
  ma = torch.max(result)
48
  mi = torch.min(result)
49
  result = (result - mi) / (ma - mi)
50
+ # Image to PIL
51
  im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
52
  pil_im = Image.fromarray(np.squeeze(im_array))
53
+ # Paste the mask on the original image
54
  new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
55
  new_im.paste(orig_image, mask=pil_im)
56
 
57
  return new_im
58
 
59
 
60
+ # Gradio interface setup
 
 
 
 
 
 
 
 
61
  title = "Open Remove Background Model (ormbg)"
62
  description = r"""
63
  This model is a <strong>fully open-source background remover</strong> optimized for images with humans.
 
64
  It is based on [Highly Accurate Dichotomous Image Segmentation research](https://github.com/xuebinqin/DIS).
65
  The model was trained with the synthetic [Human Segmentation Dataset](https://huggingface.co/datasets/schirrmacher/humans).
66
 
67
  This is the first iteration of the model, so there will be improvements!
68
+ If you identify cases where the model fails, <a href='https://huggingface.co/schirrmacher/ormbg/discussions' target='_blank'>upload your examples</a>!
69
 
70
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Model card</a>: find inference code, training information, tutorials
71
  - <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>Dataset</a>: see training images, segmentation data, backgrounds
72
  - <a href='https://huggingface.co/schirrmacher/ormbg\#research' target='_blank'>Research</a>: see current approach for improvements
 
73
  """
74
+
75
  examples = ["./example1.png", "./example2.png", "./example3.png"]
76
 
77
  demo = gr.Interface(