StevenChen16 commited on
Commit
2340c45
·
verified ·
1 Parent(s): e721a5b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from train import main
3
+
4
+ # Function to calculate the aspect ratio
5
+ def calculate_aspect_ratio(width, height):
6
+ return width / height
7
+
8
+ # Function to update image dimensions based on aspect ratio
9
+ def update_dimensions(value, aspect_ratio, dim):
10
+ if dim == "width":
11
+ new_height = int(value / aspect_ratio)
12
+ return gr.update(value=value, visible=True), gr.update(value=new_height, visible=True)
13
+ elif dim == "height":
14
+ new_width = int(value * aspect_ratio)
15
+ return gr.update(value=new_width, visible=True), gr.update(value=value, visible=True)
16
+
17
+ # Function to handle image size selection
18
+ def handle_image_size_selection(img_size, content_img):
19
+ if img_size == "custom size":
20
+ height, width, _ = content_img.shape
21
+ aspect_ratio = calculate_aspect_ratio(width, height)
22
+ return (gr.update(value=width, visible=True),
23
+ gr.update(value=height, visible=True),
24
+ gr.update(visible=True),
25
+ gr.update(visible=True),
26
+ aspect_ratio)
27
+ else:
28
+ return (gr.update(value=450, visible=False),
29
+ gr.update(value=300, visible=False),
30
+ gr.update(visible=False),
31
+ gr.update(visible=False),
32
+ None)
33
+
34
+ # Define the function to process images
35
+ def process_images(content_img, style_img, epochs, steps_per_epoch, learning_rate, content_loss_factor, style_loss_factor, img_size, img_width, img_height):
36
+ print("Start processing")
37
+ output_img = main(content_img, style_img, epochs, steps_per_epoch, learning_rate, content_loss_factor, style_loss_factor, img_size, img_width, img_height)
38
+ return output_img
39
+
40
+ with gr.Blocks() as demo:
41
+ aspect_ratio = gr.State(None)
42
+ with gr.Row():
43
+ with gr.Column(scale=1):
44
+ content_img = gr.Image(type="numpy", label="Content Image")
45
+ style_img = gr.Image(type="numpy", label="Style Image")
46
+ process_button = gr.Button("Process")
47
+ with gr.Accordion("Parameters", open=False):
48
+ epochs = gr.Slider(minimum=1, maximum=100, step=1, label="Epochs", value=20)
49
+ steps_per_epoch = gr.Slider(minimum=1, maximum=1000, step=1, label="Steps per Epoch", value=100)
50
+ learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, step=0.0001, label="Learning Rate", value=0.01)
51
+ content_loss_factor = gr.Slider(minimum=0.1, maximum=10, step=0.1, label="Content Loss Factor", value=1.0)
52
+ style_loss_factor = gr.Slider(minimum=0.1, maximum=1000, step=0.1, label="Style Loss Factor", value=100.0)
53
+ img_size = gr.Dropdown(choices=["default size", "custom size"], label="Image Size", value="default size")
54
+ img_width = gr.Number(label="Image Width", value=450, visible=False)
55
+ img_height = gr.Number(label="Image Height", value=300, visible=False)
56
+ with gr.Column(scale=1):
57
+ output_img = gr.Image(label="Output Image")
58
+
59
+ img_size.change(
60
+ fn=handle_image_size_selection,
61
+ inputs=[img_size, content_img],
62
+ outputs=[img_width, img_height, img_width, img_height, aspect_ratio]
63
+ )
64
+
65
+ img_width.change(
66
+ fn=lambda w, ar: update_dimensions(w, ar, "width"),
67
+ inputs=[img_width, aspect_ratio],
68
+ outputs=[img_width, img_height]
69
+ )
70
+
71
+ img_height.change(
72
+ fn=lambda h, ar: update_dimensions(h, ar, "height"),
73
+ inputs=[img_height, aspect_ratio],
74
+ outputs=[img_width, img_height]
75
+ )
76
+
77
+ process_button.click(
78
+ process_images,
79
+ inputs=[content_img, style_img, epochs, steps_per_epoch, learning_rate, content_loss_factor, style_loss_factor, img_size, img_width, img_height],
80
+ outputs=[output_img]
81
+ )
82
+
83
+ # Launch the app
84
+ demo.launch()