duongttr commited on
Commit
04f5f0c
·
1 Parent(s): da5e78e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -17
app.py CHANGED
@@ -5,11 +5,20 @@ import os
5
  from PIL import Image
6
  import time
7
  import app_config as cfg
 
8
 
9
 
10
  model = SwinTExCo(weights_path=cfg.ckpt_path)
11
 
 
 
 
 
 
 
12
  def video_colorization(video_path, ref_image, progress=gr.Progress()):
 
 
13
  # Initialize video reader
14
  video_reader = cv2.VideoCapture(video_path)
15
  fps = video_reader.get(cv2.CAP_PROP_FPS)
@@ -17,34 +26,107 @@ def video_colorization(video_path, ref_image, progress=gr.Progress()):
17
  width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
18
  num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
19
 
 
 
 
 
 
20
  # Initialize reference image
21
  ref_image = Image.fromarray(ref_image)
22
 
23
  # Initialize video writer
24
- output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4')
25
  video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
26
 
27
- # Init progress bar
28
-
29
  for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
30
- colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR)
31
- video_writer.write(colorized_frame)
32
-
33
- # for i in progress.tqdm(range(1000)):
34
- # time.sleep(0.5)
 
35
 
36
  video_writer.release()
37
 
38
  return output_path
39
 
40
- app = gr.Interface(
41
- fn=video_colorization,
42
- inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
43
- gr.Image(sources="upload", label="Reference image (color)")],
44
- outputs=gr.Video(label="Output video (colorized)"),
45
- title=cfg.TITLE,
46
- description=cfg.DESCRIPTION
47
- ).queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
49
 
50
- app.launch()
 
5
  from PIL import Image
6
  import time
7
  import app_config as cfg
8
+ import threading
9
 
10
 
11
  model = SwinTExCo(weights_path=cfg.ckpt_path)
12
 
13
+ stop_thread = False
14
+
15
+ def stop_process():
16
+ global stop_thread
17
+ stop_thread = True
18
+
19
  def video_colorization(video_path, ref_image, progress=gr.Progress()):
20
+ global stop_thread
21
+
22
  # Initialize video reader
23
  video_reader = cv2.VideoCapture(video_path)
24
  fps = video_reader.get(cv2.CAP_PROP_FPS)
 
26
  width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
27
  num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
28
 
29
+ if not video_reader.isOpened():
30
+ gr.Warning("Please upload a valid video.")
31
+ if ref_image is None:
32
+ gr.Warning("Please upload a valid reference image.")
33
+
34
  # Initialize reference image
35
  ref_image = Image.fromarray(ref_image)
36
 
37
  # Initialize video writer
38
+ output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + f'_colorized_{time.time_ns()}.mp4')
39
  video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
40
 
 
 
41
  for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
42
+ if stop_thread:
43
+ stop_thread = False
44
+ break
45
+ else:
46
+ colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR)
47
+ video_writer.write(colorized_frame)
48
 
49
  video_writer.release()
50
 
51
  return output_path
52
 
53
+ def image_colorization(image, ref_image):
54
+ if image is None:
55
+ gr.Warning("Please upload a valid image.")
56
+ if ref_image is None:
57
+ gr.Warning("Please upload a valid reference image.")
58
+
59
+ # Initialize image
60
+ image = Image.fromarray(image)
61
+ ref_image = Image.fromarray(ref_image)
62
+
63
+ colorized_image = model.predict_image(image, ref_image)
64
+
65
+ return colorized_image
66
+
67
+ # app = gr.Interface(
68
+ # fn=video_colorization,
69
+ # inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
70
+ # gr.Image(sources="upload", label="Reference image (color)")],
71
+ # outputs=gr.Video(label="Output video (colorized)"),
72
+ # title=cfg.TITLE,
73
+ # description=cfg.DESCRIPTION,
74
+ # allow_flagging='never'
75
+ # )
76
+
77
+ with gr.Blocks() as app:
78
+ # Title
79
+ gr.Markdown(cfg.CONTENT)
80
+
81
+ # Video tab
82
+ with gr.Tab("📹 Video"):
83
+ with gr.Row():
84
+ with gr.Column(scale=1):
85
+ input_video_comp = gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True)
86
+ ref_image_comp = gr.Image(sources="upload", label="Reference image (color)", height=300)
87
+ with gr.Row():
88
+ with gr.Column(scale=1):
89
+ clear_btn = gr.ClearButton(value="Clear input", variant=['secondary'])
90
+ clear_btn.add([input_video_comp, ref_image_comp])
91
+ with gr.Column(scale=1):
92
+ start_btn = gr.Button(value="Start!", variant=['primary'])
93
+ with gr.Column(scale=1):
94
+ output_video_comp = gr.Video(label="Output video (colorized)")
95
+ with gr.Row():
96
+ with gr.Column(scale=1):
97
+ clear_output_btn = gr.ClearButton(value="Clear output", variant=['secondary'])
98
+ clear_output_btn.add([output_video_comp])
99
+ with gr.Column(scale=1):
100
+ stop_btn = gr.Button(value="Stop!", variant=['stop'])
101
+
102
+ start_event = start_btn.click(video_colorization, inputs=[input_video_comp, ref_image_comp], outputs=[output_video_comp])
103
+ stop_btn.click(fn=None, cancels=[start_event])
104
+
105
+ # Image tab
106
+ with gr.Tab("🖼️ Image"):
107
+ with gr.Row():
108
+ with gr.Column(scale=1):
109
+ input_image_comp = gr.Image(sources="upload", label="Input image (grayscale)", height=300)
110
+ ref_image_comp = gr.Image(sources="upload", label="Reference image (color)", height=300)
111
+ with gr.Row():
112
+ with gr.Column(scale=1):
113
+ clear_input_btn = gr.ClearButton(value="Clear input", variant=['secondary'])
114
+ clear_input_btn.add([input_image_comp, ref_image_comp])
115
+ with gr.Column(scale=1):
116
+ start_btn = gr.Button(value="Start!", variant=['primary'])
117
+ with gr.Column(scale=1):
118
+ output_image_comp = gr.Image(label="Output image (colorized)", height=300)
119
+
120
+ with gr.Row():
121
+ with gr.Column():
122
+ clear_output_btn = gr.ClearButton(value="Clear output", variant=['secondary'])
123
+ clear_output_btn.add([output_image_comp])
124
+ with gr.Column():
125
+ stop_btn = gr.Button(value="Stop!", variant=['stop'])
126
+
127
+ start_event = start_btn.click(image_colorization, inputs=[input_image_comp, ref_image_comp], outputs=[output_image_comp])
128
+ stop_btn.click(fn=None, cancels=[start_event])
129
 
130
+ gr.Markdown(cfg.APPENDIX)
131
 
132
+ app.launch(auth=('admin', 'admin'))