File size: 4,454 Bytes
6056c5f
 
 
 
d3fa5af
d7d8a78
fb0b0a0
b2a4f5f
 
3b1f1d6
d7d8a78
62db64f
9d42a5c
6056c5f
 
 
 
 
 
9d42a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2a4f5f
d3fa5af
d86b569
9d42a5c
472e711
6056c5f
d86b569
6056c5f
 
 
 
 
 
d86b569
6056c5f
 
d86b569
6056c5f
 
 
 
 
 
 
 
 
 
b02f90e
9d42a5c
fafb3c5
9d42a5c
 
6056c5f
 
 
 
 
7d308bc
6056c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d86b569
6056c5f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np

import gradio as gr
import os
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
os.system("pip3 install torch")
os.system("/usr/local/bin/python -m pip install --upgrade pip")
os.system("pip3 install collections")
os.system("pip3 install torchvision")
os.system("pip3 install einops")
os.system("pip3 install git+https://github.com/lucasb-eyer/pydensecrf.git")
#os.system("pip install argparse")
import pydensecrf.densecrf as dcrf
from PIL import Image
import torch
from torchvision import transforms
from model_video import build_model
import numpy as np
import collections

def crf_refine(img, annos):
    def _sigmoid(x):
        return 1 / (1 + np.exp(-x))

    assert img.dtype == np.uint8
    assert annos.dtype == np.uint8
    assert img.shape[:2] == annos.shape

    # img and annos should be np array with data type uint8

    EPSILON = 1e-8

    M = 2  # salient or not
    tau = 1.05
    # Setup the CRF model
    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)

    anno_norm = annos / 255.

    n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
    p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))

    U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32')
    U[0, :] = n_energy.flatten()
    U[1, :] = p_energy.flatten()

    d.setUnaryEnergy(U)

    d.addPairwiseGaussian(sxy=3, compat=3)
    d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)

    # Do the inference
    infer = np.array(d.inference(1)).astype('float32')
    res = infer[1, :]

    res = res * 255
    res = res.reshape(img.shape[:2])
    return res.astype('uint8')

#import argparse
device='cpu'
net = build_model(device).to(device)
#net=torch.nn.DataParallel(net)
model_path = 'image_best.pth'
print(model_path)
weight=torch.load(model_path,map_location=torch.device(device))
#print(type(weight))
new_dict=collections.OrderedDict()
for k in weight.keys():
  new_dict[k[len('module.'):]]=weight[k]
net.load_state_dict(new_dict)
net.eval()
net = net.to(device)
def test(gpu_id, net, img_list, group_size, img_size):
    print('test')
    #device=device
    
    img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.449], std=[0.226])])
    with torch.no_grad():
        
        group_img=torch.rand(5,3,224,224)
        for i in range(5):
          group_img[i]=img_transform(Image.fromarray(img_list[i]))
        _,pred_mask=net(group_img*1)
        pred_mask=(pred_mask.detach().squeeze()*255).numpy().astype(np.uint8)
        pred_mask=[crf_refine(((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).numpy().astype(np.uint8),pred_mask[i]) for i in range(5)]
        print(pred_mask[0].shape)
        result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)]
        #w, h = 224,224#Image.open(image_list[i][j]).size
        #result = result.resize((w, h), Image.BILINEAR)
        #result.convert('L').save('0.png')
        print('done')
        return result
test('cpu',net,[(torch.rand(224,224,3)*255).numpy().astype(np.uint8) for i in range(5)],5,224)
def sepia(img1,img2,img3,img4,img5):
  print('sepia')
  '''ans=[]
  print(len(input_imgs))
  for input_img in input_imgs:
    sepia_filter = np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    )
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    ans.append(input_img)'''
  img_list=[img1,img2,img3,img4,img5]
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
  #print(type(img1))
  #print(img1.shape)
  result_list=test(device,net,img_list,5,224)
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
  return img1,img2,img3,img4,img5

#gr.Image(shape=(224, 2))
demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")

demo.launch(debug=True)