File size: 5,226 Bytes
a3a23d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
import torch
import pickle
import time
from options.test_options import TestOptions
from data.data_loader_test import CreateDataLoader
from models.networks import ResUnetGenerator, load_checkpoint
from models.afwm import AFWM
import torch.nn as nn
import os
import numpy as np
import torch
import cv2
import torch.nn.functional as F
from torchvision import utils
from util import flow_util
def de_offset(s_grid):
[b,_,h,w] = s_grid.size()
x = torch.arange(w).view(1, -1).expand(h, -1).float()
y = torch.arange(h).view(-1, 1).expand(-1, w).float()
x = 2*x/(w-1)-1
y = 2*y/(h-1)-1
grid = torch.stack([x,y], dim=0).float().cuda()
grid = grid.unsqueeze(0).expand(b, -1, -1, -1)
offset = grid - s_grid
offset_x = offset[:,0,:,:] * (w-1) / 2
offset_y = offset[:,1,:,:] * (h-1) / 2
offset = torch.cat((offset_y,offset_x),0)
return offset
def tryon(person,cloth,edge):
#save images in folders
cv2.imwrite('./data/test_ma_img/000001_0.jpg', person)
cv2.imwrite('./data/test_edge/000001_1.jpg', edge)
cv2.imwrite('./data/test_clothes/000001_1.jpg', cloth)
with open('opt.pkl', 'rb') as handle:
opt = pickle.load(handle)
f2c = flow_util.flow2color()
start_epoch, epoch_iter = 1, 0
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader) #must be 1
print(dataset_size)
warp_model = AFWM(opt, 3)
print(warp_model)
warp_model.eval()
warp_model.cuda()
load_checkpoint(warp_model, opt.warp_checkpoint)
gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
gen_model.eval()
gen_model.cuda()
load_checkpoint(gen_model, opt.gen_checkpoint)
total_steps = (start_epoch-1) * dataset_size + epoch_iter
step = 0
step_per_batch = dataset_size / opt.batchSize
if not os.path.exists('our_t_results'):
os.mkdir('our_t_results')
for epoch in range(1,2):
for i, data in enumerate(dataset, start=epoch_iter):
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
real_image = data['image']
clothes = data['clothes']
##edge is extracted from the clothes image with the built-in function in python
edge = data['edge']
edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int64))
clothes = clothes * edge
flow_out = warp_model(real_image.cuda(), clothes.cuda())
warped_cloth, last_flow, = flow_out
warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
mode='bilinear', padding_mode='zeros')
gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
gen_outputs = gen_model(gen_inputs)
p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
p_rendered = torch.tanh(p_rendered)
m_composite = torch.sigmoid(m_composite)
m_composite = m_composite * warped_edge
p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
path = 'results/' + opt.name
os.makedirs(path, exist_ok=True)
#sub_path = path + '/PFAFN'
#os.makedirs(sub_path,exist_ok=True)
print(data['p_name'])
if step % 1 == 0:
## save try-on image only
utils.save_image(
p_tryon,
os.path.join('./our_t_results', data['p_name'][0]),
nrow=int(1),
normalize=True,
value_range=(-1,1),
)
## save person image, garment, flow, warped garment, and try-on image
#a = real_image.float().cuda()
#b = clothes.cuda()
#flow_offset = de_offset(last_flow)
#flow_color = f2c(flow_offset).cuda()
#c= warped_cloth.cuda()
#d = p_tryon
#combine = torch.cat([a[0],b[0], flow_color, c[0], d[0]], 2).squeeze()
#utils.save_image(
# combine,
# os.path.join('./im_gar_flow_wg', data['p_name'][0]),
# nrow=int(1),
# normalize=True,
# range=(-1,1),
#)
step += 1
if epoch_iter >= dataset_size:
break
result_img = cv2.imread('./our_t_results/000001_0.jpg')
return result_img
demo = gr.Interface(fn=tryon,
inputs=[gr.inputs.Image(label="Person"),gr.inputs.Image(label="Cloth"),gr.inputs.Image(label="Edge")],
outputs="image"
)
# def pp(inp1,inp2):
# return inp1+" hello "+inp2
# demo2 = gr.Interface(fn=pp,
# inputs=[gr.inputs.Textbox(lines=5, label="Input Text"),gr.inputs.Textbox(lines=5, label="Input Text2")],
# outputs=gr.outputs.Textbox(label="Generated Text"),
# )
demo.launch() |