Commit
·
bdff8f5
1
Parent(s):
a2fe5a2
Align lib_name as birefnet and add inference endpoint option.
Browse files- README.md +1 -1
- birefnet.py +2 -1
- handler.py +7 -1
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
library_name:
|
3 |
tags:
|
4 |
- background-removal
|
5 |
- mask-generation
|
|
|
1 |
---
|
2 |
+
library_name: birefnet
|
3 |
tags:
|
4 |
- background-removal
|
5 |
- mask-generation
|
birefnet.py
CHANGED
@@ -1995,7 +1995,8 @@ class BiRefNet(
|
|
1995 |
):
|
1996 |
config_class = BiRefNetConfig
|
1997 |
def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
|
1998 |
-
super(BiRefNet, self).__init__()
|
|
|
1999 |
self.config = Config()
|
2000 |
self.epoch = 1
|
2001 |
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
|
|
|
1995 |
):
|
1996 |
config_class = BiRefNetConfig
|
1997 |
def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
|
1998 |
+
super(BiRefNet, self).__init__(config)
|
1999 |
+
bb_pretrained = config.bb_pretrained
|
2000 |
self.config = Config()
|
2001 |
self.epoch = 1
|
2002 |
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
|
handler.py
CHANGED
@@ -62,6 +62,7 @@ class ImagePreprocessor():
|
|
62 |
|
63 |
usage_to_weights_file = {
|
64 |
'General': 'BiRefNet',
|
|
|
65 |
'General-Lite': 'BiRefNet_lite',
|
66 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
67 |
'General-reso_512': 'BiRefNet-reso_512',
|
@@ -82,9 +83,12 @@ if usage in ['General-Lite-2K']:
|
|
82 |
resolution = (2560, 1440)
|
83 |
elif usage in ['General-reso_512']:
|
84 |
resolution = (512, 512)
|
|
|
|
|
85 |
else:
|
86 |
resolution = (1024, 1024)
|
87 |
|
|
|
88 |
|
89 |
class EndpointHandler():
|
90 |
def __init__(self, path=''):
|
@@ -93,6 +97,8 @@ class EndpointHandler():
|
|
93 |
)
|
94 |
self.birefnet.to(device)
|
95 |
self.birefnet.eval()
|
|
|
|
|
96 |
|
97 |
def __call__(self, data: Dict[str, Any]):
|
98 |
"""
|
@@ -122,7 +128,7 @@ class EndpointHandler():
|
|
122 |
|
123 |
# Prediction
|
124 |
with torch.no_grad():
|
125 |
-
preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
126 |
pred = preds[0].squeeze()
|
127 |
|
128 |
# Show Results
|
|
|
62 |
|
63 |
usage_to_weights_file = {
|
64 |
'General': 'BiRefNet',
|
65 |
+
'General-HR': 'BiRefNet_HR',
|
66 |
'General-Lite': 'BiRefNet_lite',
|
67 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
68 |
'General-reso_512': 'BiRefNet-reso_512',
|
|
|
83 |
resolution = (2560, 1440)
|
84 |
elif usage in ['General-reso_512']:
|
85 |
resolution = (512, 512)
|
86 |
+
elif usage in ['General-HR']:
|
87 |
+
resolution = (2048, 2048)
|
88 |
else:
|
89 |
resolution = (1024, 1024)
|
90 |
|
91 |
+
half_precision = True
|
92 |
|
93 |
class EndpointHandler():
|
94 |
def __init__(self, path=''):
|
|
|
97 |
)
|
98 |
self.birefnet.to(device)
|
99 |
self.birefnet.eval()
|
100 |
+
if half_precision:
|
101 |
+
self.birefnet.half()
|
102 |
|
103 |
def __call__(self, data: Dict[str, Any]):
|
104 |
"""
|
|
|
128 |
|
129 |
# Prediction
|
130 |
with torch.no_grad():
|
131 |
+
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
|
132 |
pred = preds[0].squeeze()
|
133 |
|
134 |
# Show Results
|