Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ os.system("pip3 install git+https://github.com/lucasb-eyer/pydensecrf.git")
|
|
13 |
import pydensecrf.densecrf as dcrf
|
14 |
from PIL import Image
|
15 |
import torch
|
|
|
16 |
from torchvision import transforms
|
17 |
from model_video import build_model
|
18 |
import numpy as np
|
@@ -71,7 +72,7 @@ for k in weight.keys():
|
|
71 |
net.load_state_dict(new_dict)
|
72 |
net.eval()
|
73 |
net = net.to(device)
|
74 |
-
def test(gpu_id, net, img_list, group_size, img_size):
|
75 |
print('test')
|
76 |
#device=device
|
77 |
|
@@ -85,8 +86,10 @@ def test(gpu_id, net, img_list, group_size, img_size):
|
|
85 |
for i in range(5):
|
86 |
group_img[i]=img_transform(Image.fromarray(img_list[i]))
|
87 |
_,pred_mask=net(group_img*1)
|
88 |
-
pred_mask=(pred_mask.detach().squeeze()*255)
|
89 |
-
pred_mask=[
|
|
|
|
|
90 |
print(pred_mask[0].shape)
|
91 |
result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)]
|
92 |
#w, h = 224,224#Image.open(image_list[i][j]).size
|
@@ -110,7 +113,7 @@ def sepia(img1,img2,img3,img4,img5):
|
|
110 |
h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
|
111 |
#print(type(img1))
|
112 |
#print(img1.shape)
|
113 |
-
result_list=test(device,net,img_list,5,224)
|
114 |
#result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
|
115 |
img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
|
116 |
return img1,img2,img3,img4,img5
|
|
|
13 |
import pydensecrf.densecrf as dcrf
|
14 |
from PIL import Image
|
15 |
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
from torchvision import transforms
|
18 |
from model_video import build_model
|
19 |
import numpy as np
|
|
|
72 |
net.load_state_dict(new_dict)
|
73 |
net.eval()
|
74 |
net = net.to(device)
|
75 |
+
def test(gpu_id, net, img_list, group_size, img_size,wl,hl):
|
76 |
print('test')
|
77 |
#device=device
|
78 |
|
|
|
86 |
for i in range(5):
|
87 |
group_img[i]=img_transform(Image.fromarray(img_list[i]))
|
88 |
_,pred_mask=net(group_img*1)
|
89 |
+
pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
|
90 |
+
pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(wl[i],hl[i]),mode='bilinear').numpy().astype(np.uint8) for i in range(5)]
|
91 |
+
#pred_mask=[crf_refine(((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8),pred_mask[i]) for i in range(5)]
|
92 |
+
pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
|
93 |
print(pred_mask[0].shape)
|
94 |
result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)]
|
95 |
#w, h = 224,224#Image.open(image_list[i][j]).size
|
|
|
113 |
h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
|
114 |
#print(type(img1))
|
115 |
#print(img1.shape)
|
116 |
+
result_list=test(device,net,img_list,5,224,w_list,h_list)
|
117 |
#result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
|
118 |
img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
|
119 |
return img1,img2,img3,img4,img5
|