UFO / app.py
djl234's picture
Update app.py
472e711
raw
history blame
2.82 kB
import numpy as np
import gradio as gr
import os
os.system("pip3 install torch")
os.system("pip3 install collections")
os.system("pip3 install torchvision")
os.system("pip3 install einops")
#os.system("pip3 install argparse")
from PIL import Image
import torch
from torchvision import transforms
from model_video import build_model
import numpy as np
import collections
#import argparse
net = build_model('cpu').to('cpu')
#net=torch.nn.DataParallel(net)
model_path = 'image_best.pth'
print(model_path)
weight=torch.load(model_path,map_location=torch.device('cpu'))
#print(type(weight))
new_dict=collections.OrderedDict()
for k in weight.keys():
new_dict[k[len('module.'):]]=weight[k]
net.load_state_dict(new_dict)
net.eval()
net = net.to('cpu')
def test(gpu_id, net, img_list, group_size, img_size):
print('test')
device='cpu'
img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
transforms.Normalize(mean=[0.449], std=[0.226])])
with torch.no_grad():
group_img=torch.rand(5,3,224,224)
for i in range(5):
group_img[i]=img_transform(Image.fromarray(img_list[i]))
_,pred_mask=net(group_img)
print(pred_mask.shape)
result = [Image.fromarray((pred_mask[i].detach().squeeze().unsqueeze(2).repeat(1,1,3) * 255).numpy().astype(np.uint8)) for i in range(5)]
#w, h = 224,224#Image.open(image_list[i][j]).size
#result = result.resize((w, h), Image.BILINEAR)
#result.convert('L').save('0.png')
print('done')
return result
def sepia(img1,img2,img3,img4,img5):
print('sepia')
'''ans=[]
print(len(input_imgs))
for input_img in input_imgs:
sepia_filter = np.array(
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
)
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
ans.append(input_img)'''
img_list=[img1,img2,img3,img4,img5]
h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
#print(type(img1))
#print(img1.shape)
result_list=test('cpu',net,img_list,5,224)
#result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
return img1,img2,img3,img4,img5
#gr.Image(shape=(224, 2))
demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")
demo.launch(debug=True)