StevenChen16 commited on
Commit
e721a5b
·
verified ·
1 Parent(s): 0447a72

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +241 -0
train.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import tensorflow as tf
6
+ import cv2
7
+ import argparse
8
+ import typing
9
+ import h5py
10
+
11
+ # 解析命令行参数
12
+ def parse_opt(known=False):
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--content_img_path", type=str, default="./images/1.jpg", help="原图路径")
15
+ parser.add_argument("--style_img_path", type=str, default="./images/style.jpg", help="风格图片路径")
16
+ parser.add_argument("--output_path", type=str, default="./output/1", help="生成图片保存路径")
17
+ parser.add_argument("--epochs", type=int, default=20, help="总训练轮数")
18
+ parser.add_argument("--step_per_epoch", type=int, default=100, help="每轮训练次数")
19
+ parser.add_argument("--learning_rate", type=float, default=0.01, help="学习率")
20
+ parser.add_argument("--content_loss_factor", type=float, default=1.0, help="内容损失总加权系数")
21
+ parser.add_argument("--style_loss_factor", type=float, default=100.0, help="风格损失总加权系数")
22
+ parser.add_argument("--img_size", type=int, default=0, help="图片尺寸,0代表不设置使用默认尺寸(450*300),输入1代表使用图片尺寸,其他输入代表使用自定义尺寸")
23
+ parser.add_argument("--img_width", type=int, default=450, help="自定义图片宽度")
24
+ parser.add_argument("--img_height", type=int, default=300, help="自定义图片高度")
25
+ opt = parser.parse_known_args()[0] if known else parser.parse_args()
26
+ return opt
27
+
28
+ def load_images(image_path, width, height):
29
+ """
30
+ 加载并处理图片,返回一个张量
31
+ """
32
+ x = tf.io.read_file(image_path)
33
+ x = tf.image.decode_jpeg(x, channels=3)
34
+ x = tf.image.resize(x, [height, width])
35
+ x = x / 255.0
36
+ x = normalization(x)
37
+ x = tf.reshape(x, [1, height, width, 3])
38
+ return x
39
+
40
+ def load_images_from_list(image_array, width, height):
41
+ """
42
+ 从numpy数组加载并处理图片,返回一个张量
43
+ """
44
+ x = tf.convert_to_tensor(image_array, dtype=tf.float32)
45
+ x = tf.image.resize(x, [height, width])
46
+ x = x / 255.0
47
+ x = normalization(x)
48
+ x = tf.reshape(x, [1, height, width, 3])
49
+ return x
50
+
51
+ def save_image(image, filename):
52
+ """
53
+ 保存图片
54
+ """
55
+ x = tf.reshape(image, image.shape[1:])
56
+ x = x * image_std + image_mean
57
+ x = x * 255.0
58
+ x = tf.cast(x, tf.int32)
59
+ x = tf.clip_by_value(x, 0, 255)
60
+ x = tf.cast(x, tf.uint8)
61
+ x = tf.image.encode_jpeg(x)
62
+ tf.io.write_file(filename, x)
63
+
64
+ def save_image_for_gradio(image):
65
+ """
66
+ 将图片保存为numpy数组
67
+ """
68
+ x = tf.reshape(image, image.shape[1:])
69
+ x = x * image_std + image_mean
70
+ x = x * 255.0
71
+ x = tf.cast(x, tf.int32)
72
+ x = tf.clip_by_value(x, 0, 255)
73
+ x = tf.cast(x, tf.uint8)
74
+ numpy_array = x.numpy() # 将TensorFlow张量转换为numpy数组
75
+ return numpy_array
76
+
77
+ def get_vgg19_model(layers):
78
+ """
79
+ 创建并初始化vgg19模型
80
+ """
81
+ vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet")
82
+ outputs = [vgg.get_layer(layer).output for layer in layers]
83
+ model = tf.keras.Model(vgg.input, outputs)
84
+ model.trainable = False
85
+ return model
86
+
87
+ class NeuralStyleTransferModel(tf.keras.Model):
88
+ def __init__(self, content_layers: typing.Dict[str, float], style_layers: typing.Dict[str, float]):
89
+ super(NeuralStyleTransferModel, self).__init__()
90
+ self.content_layers = content_layers
91
+ self.style_layers = style_layers
92
+ layers = list(self.content_layers.keys()) + list(self.style_layers.keys())
93
+ self.outputs_index_map = dict(zip(layers, range(len(layers))))
94
+ self.vgg = get_vgg19_model(layers)
95
+
96
+ def call(self, inputs, training=None, mask=None):
97
+ outputs = self.vgg(inputs)
98
+ content_outputs = []
99
+ for layer, factor in self.content_layers.items():
100
+ content_outputs.append((outputs[self.outputs_index_map[layer]][0], factor))
101
+ style_outputs = []
102
+ for layer, factor in self.style_layers.items():
103
+ style_outputs.append((outputs[self.outputs_index_map[layer]][0], factor))
104
+ return {"content": content_outputs, "style": style_outputs}
105
+
106
+ def normalization(x):
107
+ """
108
+ 对输入图片进行归一化处理,返回归一化后的值
109
+ """
110
+ return (x - image_mean) / image_std
111
+
112
+ def _compute_content_loss(noise_features, target_features):
113
+ """
114
+ 计算指定层上两个特征之间的内容损失
115
+ """
116
+ content_loss = tf.reduce_sum(tf.square(noise_features - target_features))
117
+ x = 2.0 * M * N
118
+ content_loss = content_loss / x
119
+ return content_loss
120
+
121
+ def compute_content_loss(noise_content_features, target_content_features):
122
+ """
123
+ 计算并返回当前图片的内容损失
124
+ """
125
+ content_losses = []
126
+ for (noise_feature, factor), (target_feature, _) in zip(noise_content_features, target_content_features):
127
+ layer_content_loss = _compute_content_loss(noise_feature, target_feature)
128
+ content_losses.append(layer_content_loss * factor)
129
+ return tf.reduce_sum(content_losses)
130
+
131
+ def gram_matrix(feature):
132
+ """
133
+ 计算给定特征的格拉姆矩阵
134
+ """
135
+ x = tf.transpose(feature, perm=[2, 0, 1])
136
+ x = tf.reshape(x, (x.shape[0], -1))
137
+ return x @ tf.transpose(x)
138
+
139
+ def _compute_style_loss(noise_feature, target_feature):
140
+ """
141
+ 计算指定层上两个特征之间的风格损失
142
+ """
143
+ noise_gram_matrix = gram_matrix(noise_feature)
144
+ style_gram_matrix = gram_matrix(target_feature)
145
+ style_loss = tf.reduce_sum(tf.square(noise_gram_matrix - style_gram_matrix))
146
+ x = 4.0 * (M**2) * (N**2)
147
+ return style_loss / x
148
+
149
+ def compute_style_loss(noise_style_features, target_style_features):
150
+ """
151
+ 计算并返回图片的风格损失
152
+ """
153
+ style_losses = []
154
+ for (noise_feature, factor), (target_feature, _) in zip(noise_style_features, target_style_features):
155
+ layer_style_loss = _compute_style_loss(noise_feature, target_feature)
156
+ style_losses.append(layer_style_loss * factor)
157
+ return tf.reduce_sum(style_losses)
158
+
159
+ def total_loss(noise_features, target_content_features, target_style_features):
160
+ """
161
+ 计算总损失
162
+ """
163
+ content_loss = compute_content_loss(noise_features["content"], target_content_features)
164
+ style_loss = compute_style_loss(noise_features["style"], target_style_features)
165
+ return content_loss * CONTENT_LOSS_FACTOR + style_loss * STYLE_LOSS_FACTOR
166
+
167
+ @tf.function
168
+ def train_one_step(model, noise_image, optimizer, target_content_features, target_style_features):
169
+ """
170
+ 一次迭代过程
171
+ """
172
+ with tf.GradientTape() as tape:
173
+ noise_outputs = model(noise_image)
174
+ loss = total_loss(noise_outputs, target_content_features, target_style_features)
175
+ grad = tape.gradient(loss, noise_image)
176
+ optimizer.apply_gradients([(grad, noise_image)])
177
+ return loss
178
+
179
+ def main(content_img, style_img, epochs, step_per_epoch, learning_rate, content_loss_factor, style_loss_factor, img_size, img_width, img_height):
180
+ global CONTENT_LOSS_FACTOR, STYLE_LOSS_FACTOR, CONTENT_IMAGE_PATH, STYLE_IMAGE_PATH, OUTPUT_DIR, EPOCHS, LEARNING_RATE, STEPS_PER_EPOCH, M, N, image_mean, image_std, IMG_WIDTH, IMG_HEIGHT
181
+
182
+ CONTENT_LOSS_FACTOR = content_loss_factor
183
+ STYLE_LOSS_FACTOR = style_loss_factor
184
+ CONTENT_IMAGE_PATH = content_img
185
+ STYLE_IMAGE_PATH = style_img
186
+ EPOCHS = epochs
187
+ LEARNING_RATE = learning_rate
188
+ STEPS_PER_EPOCH = step_per_epoch
189
+
190
+ # 内容特征层及损失加权系数
191
+ CONTENT_LAYERS = {"block4_conv2": 0.5, "block5_conv2": 0.5}
192
+ # 风格特征层及损失加权系数
193
+ STYLE_LAYERS = {
194
+ "block1_conv1": 0.2,
195
+ "block2_conv1": 0.2,
196
+ "block3_conv1": 0.2,
197
+ "block4_conv1": 0.2,
198
+ "block5_conv1": 0.2,
199
+ }
200
+
201
+ if img_size == "default size":
202
+ IMG_WIDTH = 450
203
+ IMG_HEIGHT = 300
204
+ else:
205
+ IMG_WIDTH = img_width
206
+ IMG_HEIGHT = img_height
207
+
208
+ print("IMG_WIDTH:", IMG_WIDTH)
209
+ print("IMG_HEIGHT:", IMG_HEIGHT)
210
+
211
+ # 我们准备使用经典网络在imagenet数据集上的预训练权重,所以归一化时也要使用imagenet的平均值和标准差
212
+ image_mean = tf.constant([0.485, 0.456, 0.406])
213
+ image_std = tf.constant([0.299, 0.224, 0.225])
214
+
215
+ model = NeuralStyleTransferModel(CONTENT_LAYERS, STYLE_LAYERS)
216
+
217
+ content_image = load_images_from_list(CONTENT_IMAGE_PATH, IMG_WIDTH, IMG_HEIGHT)
218
+ style_image = load_images_from_list(STYLE_IMAGE_PATH, IMG_WIDTH, IMG_HEIGHT)
219
+
220
+ target_content_features = model(content_image)["content"]
221
+ target_style_features = model(style_image)["style"]
222
+
223
+ M = IMG_WIDTH * IMG_HEIGHT
224
+ N = 3
225
+
226
+ optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)
227
+
228
+ noise_image = tf.Variable((content_image[0] + np.random.uniform(-0.2, 0.2, (1, IMG_HEIGHT, IMG_WIDTH, 3))) / 2)
229
+
230
+ for epoch in range(EPOCHS):
231
+ with tqdm(total=STEPS_PER_EPOCH, desc="Epoch {}/{}".format(epoch + 1, EPOCHS)) as pbar:
232
+ for step in range(STEPS_PER_EPOCH):
233
+ _loss = train_one_step(model, noise_image, optimizer, target_content_features, target_style_features)
234
+ pbar.set_postfix({"loss": "%.4f" % float(_loss)})
235
+ pbar.update(1)
236
+
237
+ return save_image_for_gradio(noise_image)
238
+
239
+ if __name__ == "__main__":
240
+ opt = parse_opt()
241
+ main(opt.content_img_path, opt.style_img_path, opt.epochs, opt.step_per_epoch, opt.learning_rate, opt.content_loss_factor, opt.style_loss_factor, opt.img_size, opt.img_width, opt.img_height)