zestyoreo commited on
Commit
a3a23d3
·
1 Parent(s): 90e5fa0
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pickle
4
+ import time
5
+ from options.test_options import TestOptions
6
+ from data.data_loader_test import CreateDataLoader
7
+ from models.networks import ResUnetGenerator, load_checkpoint
8
+ from models.afwm import AFWM
9
+ import torch.nn as nn
10
+ import os
11
+ import numpy as np
12
+ import torch
13
+ import cv2
14
+ import torch.nn.functional as F
15
+ from torchvision import utils
16
+ from util import flow_util
17
+
18
+ def de_offset(s_grid):
19
+ [b,_,h,w] = s_grid.size()
20
+
21
+
22
+ x = torch.arange(w).view(1, -1).expand(h, -1).float()
23
+ y = torch.arange(h).view(-1, 1).expand(-1, w).float()
24
+ x = 2*x/(w-1)-1
25
+ y = 2*y/(h-1)-1
26
+ grid = torch.stack([x,y], dim=0).float().cuda()
27
+ grid = grid.unsqueeze(0).expand(b, -1, -1, -1)
28
+
29
+ offset = grid - s_grid
30
+
31
+ offset_x = offset[:,0,:,:] * (w-1) / 2
32
+ offset_y = offset[:,1,:,:] * (h-1) / 2
33
+
34
+ offset = torch.cat((offset_y,offset_x),0)
35
+
36
+ return offset
37
+
38
+ def tryon(person,cloth,edge):
39
+
40
+ #save images in folders
41
+ cv2.imwrite('./data/test_ma_img/000001_0.jpg', person)
42
+ cv2.imwrite('./data/test_edge/000001_1.jpg', edge)
43
+ cv2.imwrite('./data/test_clothes/000001_1.jpg', cloth)
44
+
45
+ with open('opt.pkl', 'rb') as handle:
46
+ opt = pickle.load(handle)
47
+
48
+ f2c = flow_util.flow2color()
49
+ start_epoch, epoch_iter = 1, 0
50
+
51
+ data_loader = CreateDataLoader(opt)
52
+ dataset = data_loader.load_data()
53
+ dataset_size = len(data_loader) #must be 1
54
+ print(dataset_size)
55
+
56
+ warp_model = AFWM(opt, 3)
57
+ print(warp_model)
58
+ warp_model.eval()
59
+ warp_model.cuda()
60
+ load_checkpoint(warp_model, opt.warp_checkpoint)
61
+
62
+ gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
63
+ gen_model.eval()
64
+ gen_model.cuda()
65
+ load_checkpoint(gen_model, opt.gen_checkpoint)
66
+
67
+ total_steps = (start_epoch-1) * dataset_size + epoch_iter
68
+ step = 0
69
+ step_per_batch = dataset_size / opt.batchSize
70
+
71
+ if not os.path.exists('our_t_results'):
72
+ os.mkdir('our_t_results')
73
+
74
+ for epoch in range(1,2):
75
+
76
+ for i, data in enumerate(dataset, start=epoch_iter):
77
+ iter_start_time = time.time()
78
+ total_steps += opt.batchSize
79
+ epoch_iter += opt.batchSize
80
+
81
+ real_image = data['image']
82
+ clothes = data['clothes']
83
+ ##edge is extracted from the clothes image with the built-in function in python
84
+ edge = data['edge']
85
+ edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int64))
86
+ clothes = clothes * edge
87
+
88
+ flow_out = warp_model(real_image.cuda(), clothes.cuda())
89
+ warped_cloth, last_flow, = flow_out
90
+ warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
91
+ mode='bilinear', padding_mode='zeros')
92
+
93
+ gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
94
+ gen_outputs = gen_model(gen_inputs)
95
+ p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
96
+ p_rendered = torch.tanh(p_rendered)
97
+ m_composite = torch.sigmoid(m_composite)
98
+ m_composite = m_composite * warped_edge
99
+ p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
100
+
101
+ path = 'results/' + opt.name
102
+ os.makedirs(path, exist_ok=True)
103
+ #sub_path = path + '/PFAFN'
104
+ #os.makedirs(sub_path,exist_ok=True)
105
+ print(data['p_name'])
106
+
107
+ if step % 1 == 0:
108
+
109
+ ## save try-on image only
110
+
111
+ utils.save_image(
112
+ p_tryon,
113
+ os.path.join('./our_t_results', data['p_name'][0]),
114
+ nrow=int(1),
115
+ normalize=True,
116
+ value_range=(-1,1),
117
+ )
118
+
119
+ ## save person image, garment, flow, warped garment, and try-on image
120
+
121
+ #a = real_image.float().cuda()
122
+ #b = clothes.cuda()
123
+ #flow_offset = de_offset(last_flow)
124
+ #flow_color = f2c(flow_offset).cuda()
125
+ #c= warped_cloth.cuda()
126
+ #d = p_tryon
127
+ #combine = torch.cat([a[0],b[0], flow_color, c[0], d[0]], 2).squeeze()
128
+ #utils.save_image(
129
+ # combine,
130
+ # os.path.join('./im_gar_flow_wg', data['p_name'][0]),
131
+ # nrow=int(1),
132
+ # normalize=True,
133
+ # range=(-1,1),
134
+ #)
135
+
136
+
137
+ step += 1
138
+ if epoch_iter >= dataset_size:
139
+ break
140
+
141
+ result_img = cv2.imread('./our_t_results/000001_0.jpg')
142
+ return result_img
143
+
144
+ demo = gr.Interface(fn=tryon,
145
+ inputs=[gr.inputs.Image(label="Person"),gr.inputs.Image(label="Cloth"),gr.inputs.Image(label="Edge")],
146
+ outputs="image"
147
+ )
148
+
149
+ # def pp(inp1,inp2):
150
+ # return inp1+" hello "+inp2
151
+
152
+ # demo2 = gr.Interface(fn=pp,
153
+ # inputs=[gr.inputs.Textbox(lines=5, label="Input Text"),gr.inputs.Textbox(lines=5, label="Input Text2")],
154
+ # outputs=gr.outputs.Textbox(label="Generated Text"),
155
+ # )
156
+
157
+ demo.launch()