Spaces:
Sleeping
Sleeping
File size: 5,430 Bytes
3d85088 62ef5f4 3d85088 04f5f0c 3d85088 04f5f0c 3d85088 04f5f0c 3d85088 04f5f0c 3d85088 04f5f0c 3d85088 04f5f0c 3d85088 04f5f0c 3d85088 04f5f0c 3d85088 04f5f0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from src.inference import SwinTExCo
import cv2
import os
from PIL import Image
import time
import app_config as cfg
import threading
model = SwinTExCo(weights_path=cfg.ckpt_path)
stop_thread = False
def stop_process():
global stop_thread
stop_thread = True
def video_colorization(video_path, ref_image, progress=gr.Progress()):
global stop_thread
# Initialize video reader
video_reader = cv2.VideoCapture(video_path)
fps = video_reader.get(cv2.CAP_PROP_FPS)
height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
if not video_reader.isOpened():
gr.Warning("Please upload a valid video.")
if ref_image is None:
gr.Warning("Please upload a valid reference image.")
# Initialize reference image
ref_image = Image.fromarray(ref_image)
# Initialize video writer
output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + f'_colorized_{time.time_ns()}.mp4')
video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
if stop_thread:
stop_thread = False
break
else:
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR)
video_writer.write(colorized_frame)
video_writer.release()
return output_path
def image_colorization(image, ref_image):
if image is None:
gr.Warning("Please upload a valid image.")
if ref_image is None:
gr.Warning("Please upload a valid reference image.")
# Initialize image
image = Image.fromarray(image)
ref_image = Image.fromarray(ref_image)
colorized_image = model.predict_image(image, ref_image)
return colorized_image
# app = gr.Interface(
# fn=video_colorization,
# inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
# gr.Image(sources="upload", label="Reference image (color)")],
# outputs=gr.Video(label="Output video (colorized)"),
# title=cfg.TITLE,
# description=cfg.DESCRIPTION,
# allow_flagging='never'
# )
with gr.Blocks() as app:
# Title
gr.Markdown(cfg.CONTENT)
# Video tab
with gr.Tab("📹 Video"):
with gr.Row():
with gr.Column(scale=1):
input_video_comp = gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True)
ref_image_comp = gr.Image(sources="upload", label="Reference image (color)", height=300)
with gr.Row():
with gr.Column(scale=1):
clear_btn = gr.ClearButton(value="Clear input", variant=['secondary'])
clear_btn.add([input_video_comp, ref_image_comp])
with gr.Column(scale=1):
start_btn = gr.Button(value="Start!", variant=['primary'])
with gr.Column(scale=1):
output_video_comp = gr.Video(label="Output video (colorized)")
with gr.Row():
with gr.Column(scale=1):
clear_output_btn = gr.ClearButton(value="Clear output", variant=['secondary'])
clear_output_btn.add([output_video_comp])
with gr.Column(scale=1):
stop_btn = gr.Button(value="Stop!", variant=['stop'])
start_event = start_btn.click(video_colorization, inputs=[input_video_comp, ref_image_comp], outputs=[output_video_comp])
stop_btn.click(fn=None, cancels=[start_event])
# Image tab
with gr.Tab("🖼️ Image"):
with gr.Row():
with gr.Column(scale=1):
input_image_comp = gr.Image(sources="upload", label="Input image (grayscale)", height=300)
ref_image_comp = gr.Image(sources="upload", label="Reference image (color)", height=300)
with gr.Row():
with gr.Column(scale=1):
clear_input_btn = gr.ClearButton(value="Clear input", variant=['secondary'])
clear_input_btn.add([input_image_comp, ref_image_comp])
with gr.Column(scale=1):
start_btn = gr.Button(value="Start!", variant=['primary'])
with gr.Column(scale=1):
output_image_comp = gr.Image(label="Output image (colorized)", height=300)
with gr.Row():
with gr.Column():
clear_output_btn = gr.ClearButton(value="Clear output", variant=['secondary'])
clear_output_btn.add([output_image_comp])
with gr.Column():
stop_btn = gr.Button(value="Stop!", variant=['stop'])
start_event = start_btn.click(image_colorization, inputs=[input_image_comp, ref_image_comp], outputs=[output_image_comp])
stop_btn.click(fn=None, cancels=[start_event])
gr.Markdown(cfg.APPENDIX)
app.launch(auth=('admin', 'admin')) |