Spaces:
Sleeping
Sleeping
Create train.py
Browse files
train.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
import tensorflow as tf
|
6 |
+
import cv2
|
7 |
+
import argparse
|
8 |
+
import typing
|
9 |
+
import h5py
|
10 |
+
|
11 |
+
# 解析命令行参数
|
12 |
+
def parse_opt(known=False):
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--content_img_path", type=str, default="./images/1.jpg", help="原图路径")
|
15 |
+
parser.add_argument("--style_img_path", type=str, default="./images/style.jpg", help="风格图片路径")
|
16 |
+
parser.add_argument("--output_path", type=str, default="./output/1", help="生成图片保存路径")
|
17 |
+
parser.add_argument("--epochs", type=int, default=20, help="总训练轮数")
|
18 |
+
parser.add_argument("--step_per_epoch", type=int, default=100, help="每轮训练次数")
|
19 |
+
parser.add_argument("--learning_rate", type=float, default=0.01, help="学习率")
|
20 |
+
parser.add_argument("--content_loss_factor", type=float, default=1.0, help="内容损失总加权系数")
|
21 |
+
parser.add_argument("--style_loss_factor", type=float, default=100.0, help="风格损失总加权系数")
|
22 |
+
parser.add_argument("--img_size", type=int, default=0, help="图片尺寸,0代表不设置使用默认尺寸(450*300),输入1代表使用图片尺寸,其他输入代表使用自定义尺寸")
|
23 |
+
parser.add_argument("--img_width", type=int, default=450, help="自定义图片宽度")
|
24 |
+
parser.add_argument("--img_height", type=int, default=300, help="自定义图片高度")
|
25 |
+
opt = parser.parse_known_args()[0] if known else parser.parse_args()
|
26 |
+
return opt
|
27 |
+
|
28 |
+
def load_images(image_path, width, height):
|
29 |
+
"""
|
30 |
+
加载并处理图片,返回一个张量
|
31 |
+
"""
|
32 |
+
x = tf.io.read_file(image_path)
|
33 |
+
x = tf.image.decode_jpeg(x, channels=3)
|
34 |
+
x = tf.image.resize(x, [height, width])
|
35 |
+
x = x / 255.0
|
36 |
+
x = normalization(x)
|
37 |
+
x = tf.reshape(x, [1, height, width, 3])
|
38 |
+
return x
|
39 |
+
|
40 |
+
def load_images_from_list(image_array, width, height):
|
41 |
+
"""
|
42 |
+
从numpy数组加载并处理图片,返回一个张量
|
43 |
+
"""
|
44 |
+
x = tf.convert_to_tensor(image_array, dtype=tf.float32)
|
45 |
+
x = tf.image.resize(x, [height, width])
|
46 |
+
x = x / 255.0
|
47 |
+
x = normalization(x)
|
48 |
+
x = tf.reshape(x, [1, height, width, 3])
|
49 |
+
return x
|
50 |
+
|
51 |
+
def save_image(image, filename):
|
52 |
+
"""
|
53 |
+
保存图片
|
54 |
+
"""
|
55 |
+
x = tf.reshape(image, image.shape[1:])
|
56 |
+
x = x * image_std + image_mean
|
57 |
+
x = x * 255.0
|
58 |
+
x = tf.cast(x, tf.int32)
|
59 |
+
x = tf.clip_by_value(x, 0, 255)
|
60 |
+
x = tf.cast(x, tf.uint8)
|
61 |
+
x = tf.image.encode_jpeg(x)
|
62 |
+
tf.io.write_file(filename, x)
|
63 |
+
|
64 |
+
def save_image_for_gradio(image):
|
65 |
+
"""
|
66 |
+
将图片保存为numpy数组
|
67 |
+
"""
|
68 |
+
x = tf.reshape(image, image.shape[1:])
|
69 |
+
x = x * image_std + image_mean
|
70 |
+
x = x * 255.0
|
71 |
+
x = tf.cast(x, tf.int32)
|
72 |
+
x = tf.clip_by_value(x, 0, 255)
|
73 |
+
x = tf.cast(x, tf.uint8)
|
74 |
+
numpy_array = x.numpy() # 将TensorFlow张量转换为numpy数组
|
75 |
+
return numpy_array
|
76 |
+
|
77 |
+
def get_vgg19_model(layers):
|
78 |
+
"""
|
79 |
+
创建并初始化vgg19模型
|
80 |
+
"""
|
81 |
+
vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet")
|
82 |
+
outputs = [vgg.get_layer(layer).output for layer in layers]
|
83 |
+
model = tf.keras.Model(vgg.input, outputs)
|
84 |
+
model.trainable = False
|
85 |
+
return model
|
86 |
+
|
87 |
+
class NeuralStyleTransferModel(tf.keras.Model):
|
88 |
+
def __init__(self, content_layers: typing.Dict[str, float], style_layers: typing.Dict[str, float]):
|
89 |
+
super(NeuralStyleTransferModel, self).__init__()
|
90 |
+
self.content_layers = content_layers
|
91 |
+
self.style_layers = style_layers
|
92 |
+
layers = list(self.content_layers.keys()) + list(self.style_layers.keys())
|
93 |
+
self.outputs_index_map = dict(zip(layers, range(len(layers))))
|
94 |
+
self.vgg = get_vgg19_model(layers)
|
95 |
+
|
96 |
+
def call(self, inputs, training=None, mask=None):
|
97 |
+
outputs = self.vgg(inputs)
|
98 |
+
content_outputs = []
|
99 |
+
for layer, factor in self.content_layers.items():
|
100 |
+
content_outputs.append((outputs[self.outputs_index_map[layer]][0], factor))
|
101 |
+
style_outputs = []
|
102 |
+
for layer, factor in self.style_layers.items():
|
103 |
+
style_outputs.append((outputs[self.outputs_index_map[layer]][0], factor))
|
104 |
+
return {"content": content_outputs, "style": style_outputs}
|
105 |
+
|
106 |
+
def normalization(x):
|
107 |
+
"""
|
108 |
+
对输入图片进行归一化处理,返回归一化后的值
|
109 |
+
"""
|
110 |
+
return (x - image_mean) / image_std
|
111 |
+
|
112 |
+
def _compute_content_loss(noise_features, target_features):
|
113 |
+
"""
|
114 |
+
计算指定层上两个特征之间的内容损失
|
115 |
+
"""
|
116 |
+
content_loss = tf.reduce_sum(tf.square(noise_features - target_features))
|
117 |
+
x = 2.0 * M * N
|
118 |
+
content_loss = content_loss / x
|
119 |
+
return content_loss
|
120 |
+
|
121 |
+
def compute_content_loss(noise_content_features, target_content_features):
|
122 |
+
"""
|
123 |
+
计算并返回当前图片的内容损失
|
124 |
+
"""
|
125 |
+
content_losses = []
|
126 |
+
for (noise_feature, factor), (target_feature, _) in zip(noise_content_features, target_content_features):
|
127 |
+
layer_content_loss = _compute_content_loss(noise_feature, target_feature)
|
128 |
+
content_losses.append(layer_content_loss * factor)
|
129 |
+
return tf.reduce_sum(content_losses)
|
130 |
+
|
131 |
+
def gram_matrix(feature):
|
132 |
+
"""
|
133 |
+
计算给定特征的格拉姆矩阵
|
134 |
+
"""
|
135 |
+
x = tf.transpose(feature, perm=[2, 0, 1])
|
136 |
+
x = tf.reshape(x, (x.shape[0], -1))
|
137 |
+
return x @ tf.transpose(x)
|
138 |
+
|
139 |
+
def _compute_style_loss(noise_feature, target_feature):
|
140 |
+
"""
|
141 |
+
计算指定层上两个特征之间的风格损失
|
142 |
+
"""
|
143 |
+
noise_gram_matrix = gram_matrix(noise_feature)
|
144 |
+
style_gram_matrix = gram_matrix(target_feature)
|
145 |
+
style_loss = tf.reduce_sum(tf.square(noise_gram_matrix - style_gram_matrix))
|
146 |
+
x = 4.0 * (M**2) * (N**2)
|
147 |
+
return style_loss / x
|
148 |
+
|
149 |
+
def compute_style_loss(noise_style_features, target_style_features):
|
150 |
+
"""
|
151 |
+
计算并返回图片的风格损失
|
152 |
+
"""
|
153 |
+
style_losses = []
|
154 |
+
for (noise_feature, factor), (target_feature, _) in zip(noise_style_features, target_style_features):
|
155 |
+
layer_style_loss = _compute_style_loss(noise_feature, target_feature)
|
156 |
+
style_losses.append(layer_style_loss * factor)
|
157 |
+
return tf.reduce_sum(style_losses)
|
158 |
+
|
159 |
+
def total_loss(noise_features, target_content_features, target_style_features):
|
160 |
+
"""
|
161 |
+
计算总损失
|
162 |
+
"""
|
163 |
+
content_loss = compute_content_loss(noise_features["content"], target_content_features)
|
164 |
+
style_loss = compute_style_loss(noise_features["style"], target_style_features)
|
165 |
+
return content_loss * CONTENT_LOSS_FACTOR + style_loss * STYLE_LOSS_FACTOR
|
166 |
+
|
167 |
+
@tf.function
|
168 |
+
def train_one_step(model, noise_image, optimizer, target_content_features, target_style_features):
|
169 |
+
"""
|
170 |
+
一次迭代过程
|
171 |
+
"""
|
172 |
+
with tf.GradientTape() as tape:
|
173 |
+
noise_outputs = model(noise_image)
|
174 |
+
loss = total_loss(noise_outputs, target_content_features, target_style_features)
|
175 |
+
grad = tape.gradient(loss, noise_image)
|
176 |
+
optimizer.apply_gradients([(grad, noise_image)])
|
177 |
+
return loss
|
178 |
+
|
179 |
+
def main(content_img, style_img, epochs, step_per_epoch, learning_rate, content_loss_factor, style_loss_factor, img_size, img_width, img_height):
|
180 |
+
global CONTENT_LOSS_FACTOR, STYLE_LOSS_FACTOR, CONTENT_IMAGE_PATH, STYLE_IMAGE_PATH, OUTPUT_DIR, EPOCHS, LEARNING_RATE, STEPS_PER_EPOCH, M, N, image_mean, image_std, IMG_WIDTH, IMG_HEIGHT
|
181 |
+
|
182 |
+
CONTENT_LOSS_FACTOR = content_loss_factor
|
183 |
+
STYLE_LOSS_FACTOR = style_loss_factor
|
184 |
+
CONTENT_IMAGE_PATH = content_img
|
185 |
+
STYLE_IMAGE_PATH = style_img
|
186 |
+
EPOCHS = epochs
|
187 |
+
LEARNING_RATE = learning_rate
|
188 |
+
STEPS_PER_EPOCH = step_per_epoch
|
189 |
+
|
190 |
+
# 内容特征层及损失加权系数
|
191 |
+
CONTENT_LAYERS = {"block4_conv2": 0.5, "block5_conv2": 0.5}
|
192 |
+
# 风格特征层及损失加权系数
|
193 |
+
STYLE_LAYERS = {
|
194 |
+
"block1_conv1": 0.2,
|
195 |
+
"block2_conv1": 0.2,
|
196 |
+
"block3_conv1": 0.2,
|
197 |
+
"block4_conv1": 0.2,
|
198 |
+
"block5_conv1": 0.2,
|
199 |
+
}
|
200 |
+
|
201 |
+
if img_size == "default size":
|
202 |
+
IMG_WIDTH = 450
|
203 |
+
IMG_HEIGHT = 300
|
204 |
+
else:
|
205 |
+
IMG_WIDTH = img_width
|
206 |
+
IMG_HEIGHT = img_height
|
207 |
+
|
208 |
+
print("IMG_WIDTH:", IMG_WIDTH)
|
209 |
+
print("IMG_HEIGHT:", IMG_HEIGHT)
|
210 |
+
|
211 |
+
# 我们准备使用经典网络在imagenet数据集上的预训练权重,所以归一化时也要使用imagenet的平均值和标准差
|
212 |
+
image_mean = tf.constant([0.485, 0.456, 0.406])
|
213 |
+
image_std = tf.constant([0.299, 0.224, 0.225])
|
214 |
+
|
215 |
+
model = NeuralStyleTransferModel(CONTENT_LAYERS, STYLE_LAYERS)
|
216 |
+
|
217 |
+
content_image = load_images_from_list(CONTENT_IMAGE_PATH, IMG_WIDTH, IMG_HEIGHT)
|
218 |
+
style_image = load_images_from_list(STYLE_IMAGE_PATH, IMG_WIDTH, IMG_HEIGHT)
|
219 |
+
|
220 |
+
target_content_features = model(content_image)["content"]
|
221 |
+
target_style_features = model(style_image)["style"]
|
222 |
+
|
223 |
+
M = IMG_WIDTH * IMG_HEIGHT
|
224 |
+
N = 3
|
225 |
+
|
226 |
+
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)
|
227 |
+
|
228 |
+
noise_image = tf.Variable((content_image[0] + np.random.uniform(-0.2, 0.2, (1, IMG_HEIGHT, IMG_WIDTH, 3))) / 2)
|
229 |
+
|
230 |
+
for epoch in range(EPOCHS):
|
231 |
+
with tqdm(total=STEPS_PER_EPOCH, desc="Epoch {}/{}".format(epoch + 1, EPOCHS)) as pbar:
|
232 |
+
for step in range(STEPS_PER_EPOCH):
|
233 |
+
_loss = train_one_step(model, noise_image, optimizer, target_content_features, target_style_features)
|
234 |
+
pbar.set_postfix({"loss": "%.4f" % float(_loss)})
|
235 |
+
pbar.update(1)
|
236 |
+
|
237 |
+
return save_image_for_gradio(noise_image)
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
opt = parse_opt()
|
241 |
+
main(opt.content_img_path, opt.style_img_path, opt.epochs, opt.step_per_epoch, opt.learning_rate, opt.content_loss_factor, opt.style_loss_factor, opt.img_size, opt.img_width, opt.img_height)
|