ZhengPeng7 commited on
Commit
bdff8f5
·
1 Parent(s): a2fe5a2

Align lib_name as birefnet and add inference endpoint option.

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. birefnet.py +2 -1
  3. handler.py +7 -1
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- library_name: BiRefNet-legacy
3
  tags:
4
  - background-removal
5
  - mask-generation
 
1
  ---
2
+ library_name: birefnet
3
  tags:
4
  - background-removal
5
  - mask-generation
birefnet.py CHANGED
@@ -1995,7 +1995,8 @@ class BiRefNet(
1995
  ):
1996
  config_class = BiRefNetConfig
1997
  def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
1998
- super(BiRefNet, self).__init__()
 
1999
  self.config = Config()
2000
  self.epoch = 1
2001
  self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
 
1995
  ):
1996
  config_class = BiRefNetConfig
1997
  def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
1998
+ super(BiRefNet, self).__init__(config)
1999
+ bb_pretrained = config.bb_pretrained
2000
  self.config = Config()
2001
  self.epoch = 1
2002
  self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
handler.py CHANGED
@@ -62,6 +62,7 @@ class ImagePreprocessor():
62
 
63
  usage_to_weights_file = {
64
  'General': 'BiRefNet',
 
65
  'General-Lite': 'BiRefNet_lite',
66
  'General-Lite-2K': 'BiRefNet_lite-2K',
67
  'General-reso_512': 'BiRefNet-reso_512',
@@ -82,9 +83,12 @@ if usage in ['General-Lite-2K']:
82
  resolution = (2560, 1440)
83
  elif usage in ['General-reso_512']:
84
  resolution = (512, 512)
 
 
85
  else:
86
  resolution = (1024, 1024)
87
 
 
88
 
89
  class EndpointHandler():
90
  def __init__(self, path=''):
@@ -93,6 +97,8 @@ class EndpointHandler():
93
  )
94
  self.birefnet.to(device)
95
  self.birefnet.eval()
 
 
96
 
97
  def __call__(self, data: Dict[str, Any]):
98
  """
@@ -122,7 +128,7 @@ class EndpointHandler():
122
 
123
  # Prediction
124
  with torch.no_grad():
125
- preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
126
  pred = preds[0].squeeze()
127
 
128
  # Show Results
 
62
 
63
  usage_to_weights_file = {
64
  'General': 'BiRefNet',
65
+ 'General-HR': 'BiRefNet_HR',
66
  'General-Lite': 'BiRefNet_lite',
67
  'General-Lite-2K': 'BiRefNet_lite-2K',
68
  'General-reso_512': 'BiRefNet-reso_512',
 
83
  resolution = (2560, 1440)
84
  elif usage in ['General-reso_512']:
85
  resolution = (512, 512)
86
+ elif usage in ['General-HR']:
87
+ resolution = (2048, 2048)
88
  else:
89
  resolution = (1024, 1024)
90
 
91
+ half_precision = True
92
 
93
  class EndpointHandler():
94
  def __init__(self, path=''):
 
97
  )
98
  self.birefnet.to(device)
99
  self.birefnet.eval()
100
+ if half_precision:
101
+ self.birefnet.half()
102
 
103
  def __call__(self, data: Dict[str, Any]):
104
  """
 
128
 
129
  # Prediction
130
  with torch.no_grad():
131
+ preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
132
  pred = preds[0].squeeze()
133
 
134
  # Show Results