|
import numpy as np |
|
|
|
import gradio as gr |
|
import os |
|
|
|
os.system("pip3 install torch") |
|
os.system("/usr/local/bin/python -m pip install --upgrade pip") |
|
os.system("pip3 install collections") |
|
os.system("pip3 install torchvision") |
|
os.system("pip3 install einops") |
|
os.system("pip3 install git+https://github.com/lucasb-eyer/pydensecrf.git") |
|
|
|
import pydensecrf.densecrf as dcrf |
|
from PIL import Image |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from model_video import build_model |
|
import numpy as np |
|
import collections |
|
|
|
def crf_refine(img, annos): |
|
def _sigmoid(x): |
|
return 1 / (1 + np.exp(-x)) |
|
|
|
assert img.dtype == np.uint8 |
|
assert annos.dtype == np.uint8 |
|
assert img.shape[:2] == annos.shape |
|
|
|
|
|
|
|
EPSILON = 1e-8 |
|
|
|
M = 2 |
|
tau = 1.05 |
|
|
|
d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) |
|
|
|
anno_norm = annos / 255. |
|
|
|
n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) |
|
p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) |
|
|
|
U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') |
|
U[0, :] = n_energy.flatten() |
|
U[1, :] = p_energy.flatten() |
|
|
|
d.setUnaryEnergy(U) |
|
|
|
d.addPairwiseGaussian(sxy=3, compat=3) |
|
d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) |
|
|
|
|
|
infer = np.array(d.inference(1)).astype('float32') |
|
res = infer[1, :] |
|
|
|
res = res * 255 |
|
res = res.reshape(img.shape[:2]) |
|
return res.astype('uint8') |
|
|
|
|
|
device='cpu' |
|
net = build_model(device).to(device) |
|
|
|
model_path = 'image_best.pth' |
|
print(model_path) |
|
weight=torch.load(model_path,map_location=torch.device(device)) |
|
|
|
new_dict=collections.OrderedDict() |
|
for k in weight.keys(): |
|
new_dict[k[len('module.'):]]=weight[k] |
|
net.load_state_dict(new_dict) |
|
net.eval() |
|
net = net.to(device) |
|
def test(gpu_id, net, img_list, group_size, img_size): |
|
print('test') |
|
|
|
hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list] |
|
img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) |
|
img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.449], std=[0.226])]) |
|
with torch.no_grad(): |
|
|
|
group_img=torch.rand(5,3,224,224) |
|
for i in range(5): |
|
group_img[i]=img_transform(Image.fromarray(img_list[i])) |
|
_,pred_mask=net(group_img*1) |
|
pred_mask=(pred_mask.detach().squeeze()*255) |
|
pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(wl[i],hl[i]),mode='bilinear').numpy().astype(np.uint8) for i in range(5)] |
|
|
|
pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)] |
|
print(pred_mask[0].shape) |
|
result = [Image.fromarray((torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)).numpy()) for i in range(5)] |
|
|
|
|
|
|
|
print('done') |
|
return result |
|
test('cpu',net,[(torch.rand(224,224,3)*255).numpy().astype(np.uint8) for i in range(5)],5,224) |
|
def sepia(img1,img2,img3,img4,img5): |
|
print('sepia') |
|
'''ans=[] |
|
print(len(input_imgs)) |
|
for input_img in input_imgs: |
|
sepia_filter = np.array( |
|
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]] |
|
) |
|
sepia_img = input_img.dot(sepia_filter.T) |
|
sepia_img /= sepia_img.max() |
|
ans.append(input_img)''' |
|
img_list=[img1,img2,img3,img4,img5] |
|
h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list] |
|
|
|
|
|
result_list=test(device,net,img_list,5,224) |
|
|
|
img1,img2,img3,img4,img5=result_list |
|
return img1,img2,img3,img4,img5 |
|
|
|
|
|
demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"]) |
|
|
|
demo.launch(debug=True) |
|
|