djl234 commited on
Commit
dd76ffb
·
1 Parent(s): 2828013

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -13,6 +13,7 @@ os.system("pip3 install git+https://github.com/lucasb-eyer/pydensecrf.git")
13
  import pydensecrf.densecrf as dcrf
14
  from PIL import Image
15
  import torch
 
16
  from torchvision import transforms
17
  from model_video import build_model
18
  import numpy as np
@@ -71,7 +72,7 @@ for k in weight.keys():
71
  net.load_state_dict(new_dict)
72
  net.eval()
73
  net = net.to(device)
74
- def test(gpu_id, net, img_list, group_size, img_size):
75
  print('test')
76
  #device=device
77
 
@@ -85,8 +86,10 @@ def test(gpu_id, net, img_list, group_size, img_size):
85
  for i in range(5):
86
  group_img[i]=img_transform(Image.fromarray(img_list[i]))
87
  _,pred_mask=net(group_img*1)
88
- pred_mask=(pred_mask.detach().squeeze()*255).numpy().astype(np.uint8)
89
- pred_mask=[crf_refine(((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8),pred_mask[i]) for i in range(5)]
 
 
90
  print(pred_mask[0].shape)
91
  result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)]
92
  #w, h = 224,224#Image.open(image_list[i][j]).size
@@ -110,7 +113,7 @@ def sepia(img1,img2,img3,img4,img5):
110
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
111
  #print(type(img1))
112
  #print(img1.shape)
113
- result_list=test(device,net,img_list,5,224)
114
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
115
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
116
  return img1,img2,img3,img4,img5
 
13
  import pydensecrf.densecrf as dcrf
14
  from PIL import Image
15
  import torch
16
+ import torch.nn.functional as F
17
  from torchvision import transforms
18
  from model_video import build_model
19
  import numpy as np
 
72
  net.load_state_dict(new_dict)
73
  net.eval()
74
  net = net.to(device)
75
+ def test(gpu_id, net, img_list, group_size, img_size,wl,hl):
76
  print('test')
77
  #device=device
78
 
 
86
  for i in range(5):
87
  group_img[i]=img_transform(Image.fromarray(img_list[i]))
88
  _,pred_mask=net(group_img*1)
89
+ pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
90
+ pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(wl[i],hl[i]),mode='bilinear').numpy().astype(np.uint8) for i in range(5)]
91
+ #pred_mask=[crf_refine(((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8),pred_mask[i]) for i in range(5)]
92
+ pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
93
  print(pred_mask[0].shape)
94
  result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)]
95
  #w, h = 224,224#Image.open(image_list[i][j]).size
 
113
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
114
  #print(type(img1))
115
  #print(img1.shape)
116
+ result_list=test(device,net,img_list,5,224,w_list,h_list)
117
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
118
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
119
  return img1,img2,img3,img4,img5