Spaces:
Runtime error
Runtime error
using cude app.py
Browse files
app.py
CHANGED
@@ -106,8 +106,8 @@ def inference(ic_image, ic_mask, image1, image2):
|
|
106 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
107 |
|
108 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
109 |
-
|
110 |
-
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
111 |
predictor = SamPredictor(sam)
|
112 |
|
113 |
# Image features encoding
|
@@ -206,8 +206,8 @@ def inference_scribble(image, image1, image2):
|
|
206 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
207 |
|
208 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
209 |
-
|
210 |
-
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
211 |
predictor = SamPredictor(sam)
|
212 |
|
213 |
# Image features encoding
|
@@ -304,12 +304,12 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
|
|
304 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
305 |
|
306 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
307 |
-
|
308 |
-
gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
|
309 |
|
310 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
311 |
-
|
312 |
-
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
313 |
for name, param in sam.named_parameters():
|
314 |
param.requires_grad = False
|
315 |
predictor = SamPredictor(sam)
|
@@ -347,8 +347,8 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
|
|
347 |
|
348 |
print('======> Start Training')
|
349 |
# Learnable mask weights
|
350 |
-
|
351 |
-
mask_weights = Mask_Weights()
|
352 |
mask_weights.train()
|
353 |
train_epoch = 1000
|
354 |
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)
|
|
|
106 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
107 |
|
108 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
109 |
+
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
110 |
+
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
111 |
predictor = SamPredictor(sam)
|
112 |
|
113 |
# Image features encoding
|
|
|
206 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
207 |
|
208 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
209 |
+
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
210 |
+
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
211 |
predictor = SamPredictor(sam)
|
212 |
|
213 |
# Image features encoding
|
|
|
304 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
305 |
|
306 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
307 |
+
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
|
308 |
+
# gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
|
309 |
|
310 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
311 |
+
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
312 |
+
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
313 |
for name, param in sam.named_parameters():
|
314 |
param.requires_grad = False
|
315 |
predictor = SamPredictor(sam)
|
|
|
347 |
|
348 |
print('======> Start Training')
|
349 |
# Learnable mask weights
|
350 |
+
mask_weights = Mask_Weights().cuda()
|
351 |
+
# mask_weights = Mask_Weights()
|
352 |
mask_weights.train()
|
353 |
train_epoch = 1000
|
354 |
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)
|