File size: 4,148 Bytes
b2a4f5f
6056c5f
 
 
 
d3fa5af
b2a4f5f
 
 
3b1f1d6
9d42a5c
b2a4f5f
9d42a5c
6056c5f
 
 
 
 
 
9d42a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2a4f5f
d3fa5af
d86b569
9d42a5c
472e711
6056c5f
d86b569
6056c5f
 
 
 
 
 
d86b569
6056c5f
 
d86b569
6056c5f
 
 
 
 
 
 
 
 
 
 
9d42a5c
 
 
 
6056c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d86b569
6056c5f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

import numpy as np

import gradio as gr
import os
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
os.system("pip3 install torch")
os.system("pip3 install collections")
os.system("pip3 install torchvision")
os.system("pip3 install einops")
os.system("pip3 install pydensecrf")
#os.system("pip3 install argparse")
import pydensecrf.densecrf as dcrf
from PIL import Image
import torch
from torchvision import transforms
from model_video import build_model
import numpy as np
import collections

def crf_refine(img, annos):
    def _sigmoid(x):
        return 1 / (1 + np.exp(-x))

    assert img.dtype == np.uint8
    assert annos.dtype == np.uint8
    assert img.shape[:2] == annos.shape

    # img and annos should be np array with data type uint8

    EPSILON = 1e-8

    M = 2  # salient or not
    tau = 1.05
    # Setup the CRF model
    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)

    anno_norm = annos / 255.

    n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
    p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))

    U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32')
    U[0, :] = n_energy.flatten()
    U[1, :] = p_energy.flatten()

    d.setUnaryEnergy(U)

    d.addPairwiseGaussian(sxy=3, compat=3)
    d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)

    # Do the inference
    infer = np.array(d.inference(1)).astype('float32')
    res = infer[1, :]

    res = res * 255
    res = res.reshape(img.shape[:2])
    return res.astype('uint8')

#import argparse
device='cpu'
net = build_model(device).to(device)
#net=torch.nn.DataParallel(net)
model_path = 'image_best.pth'
print(model_path)
weight=torch.load(model_path,map_location=torch.device(device))
#print(type(weight))
new_dict=collections.OrderedDict()
for k in weight.keys():
  new_dict[k[len('module.'):]]=weight[k]
net.load_state_dict(new_dict)
net.eval()
net = net.to(device)
def test(gpu_id, net, img_list, group_size, img_size):
    print('test')
    #device=device
    
    img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.449], std=[0.226])])
    with torch.no_grad():
        
        group_img=torch.rand(5,3,224,224)
        for i in range(5):
          group_img[i]=img_transform(Image.fromarray(img_list[i]))
        _,pred_mask=net(group_img)
        pred_mask=(pred_mask.detach().squeeze()*255).numpy().astype(np.uint8)
        pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
        print(pred_mask[0].shape)
        result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)]
        #w, h = 224,224#Image.open(image_list[i][j]).size
        #result = result.resize((w, h), Image.BILINEAR)
        #result.convert('L').save('0.png')
        print('done')
        return result
def sepia(img1,img2,img3,img4,img5):
  print('sepia')
  '''ans=[]
  print(len(input_imgs))
  for input_img in input_imgs:
    sepia_filter = np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    )
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    ans.append(input_img)'''
  img_list=[img1,img2,img3,img4,img5]
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
  #print(type(img1))
  #print(img1.shape)
  result_list=test(device,net,img_list,5,224)
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
  return img1,img2,img3,img4,img5

#gr.Image(shape=(224, 2))
demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")

demo.launch(debug=True)