Spaces:
Running
Running
gavinyuan
commited on
Commit
·
f057d66
1
Parent(s):
a104d3f
udpate: app.py import FSGenerator
Browse files- inference/tricks.py +6 -2
inference/tricks.py
CHANGED
@@ -138,10 +138,14 @@ class SoftErosion(nn.Module):
|
|
138 |
return x, mask
|
139 |
|
140 |
|
|
|
|
|
|
|
|
|
141 |
vgg_mean = torch.tensor([[[0.485]], [[0.456]], [[0.406]]],
|
142 |
-
requires_grad=False, device=
|
143 |
vgg_std = torch.tensor([[[0.229]], [[0.224]], [[0.225]]],
|
144 |
-
requires_grad=False, device=
|
145 |
def load_bisenet():
|
146 |
bisenet_model = BiSeNet(n_classes=19)
|
147 |
bisenet_model.load_state_dict(
|
|
|
138 |
return x, mask
|
139 |
|
140 |
|
141 |
+
if torch.cuda.is_available():
|
142 |
+
device = torch.device(0)
|
143 |
+
else:
|
144 |
+
device = torch.device('cpu')
|
145 |
vgg_mean = torch.tensor([[[0.485]], [[0.456]], [[0.406]]],
|
146 |
+
requires_grad=False, device=device)
|
147 |
vgg_std = torch.tensor([[[0.229]], [[0.224]], [[0.225]]],
|
148 |
+
requires_grad=False, device=device)
|
149 |
def load_bisenet():
|
150 |
bisenet_model = BiSeNet(n_classes=19)
|
151 |
bisenet_model.load_state_dict(
|