Spaces:
Sleeping
Sleeping
Update new app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -20
- .gitignore +2 -1
- README.md +0 -6
- UI.py +0 -81
- app.py +47 -213
- app_config.py +9 -0
- checkpoints/{colornet.pth → epoch_10/colornet.pth} +1 -1
- examples/bear/video.mp4 → checkpoints/epoch_10/discriminator.pth +2 -2
- checkpoints/{embed_net.pth → epoch_10/embed_net.pth} +0 -0
- checkpoints/epoch_10/learning_state.pth +3 -0
- checkpoints/{nonlocal_net.pth → epoch_10/nonlocal_net.pth} +1 -1
- checkpoints/epoch_12/colornet.pth +3 -0
- examples/cows/video.mp4 → checkpoints/epoch_12/discriminator.pth +2 -2
- checkpoints/epoch_12/embed_net.pth +3 -0
- checkpoints/epoch_12/learning_state.pth +3 -0
- checkpoints/epoch_12/nonlocal_net.pth +3 -0
- checkpoints/epoch_16/colornet.pth +3 -0
- checkpoints/epoch_16/discriminator.pth +3 -0
- checkpoints/epoch_16/embed_net.pth +3 -0
- checkpoints/epoch_16/learning_state.pth +3 -0
- checkpoints/epoch_16/nonlocal_net.pth +3 -0
- checkpoints/epoch_20/colornet.pth +3 -0
- checkpoints/epoch_20/discriminator.pth +3 -0
- checkpoints/epoch_20/embed_net.pth +3 -0
- checkpoints/epoch_20/learning_state.pth +3 -0
- checkpoints/epoch_20/nonlocal_net.pth +3 -0
- cmd.txt +0 -21
- cmd_ddp.txt +0 -20
- environment.yml +0 -0
- examples/bear/ref.jpg +0 -0
- examples/boat/ref.jpg +0 -0
- examples/boat/video.mp4 +0 -0
- examples/cows/ref.jpg +0 -0
- examples/flamingo/ref.jpg +0 -0
- examples/flamingo/video.mp4 +0 -3
- examples/man/ref.jpg +0 -0
- examples/man/video.mp4 +0 -3
- examples/military/ref.jpg +0 -0
- examples/military/video.mp4 +0 -3
- gradio_cached_examples/13/log.csv +0 -5
- gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 +0 -3
- gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4 +0 -0
- gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 +0 -3
- gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4 +0 -0
- inputs/video.mp4/000000000.jpg +0 -0
- inputs/video.mp4/000000001.jpg +0 -0
- inputs/video.mp4/000000002.jpg +0 -0
- inputs/video.mp4/000000003.jpg +0 -0
- inputs/video.mp4/000000004.jpg +0 -0
- inputs/video.mp4/000000005.jpg +0 -0
.gitattributes
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
@@ -33,22 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
|
37 |
-
EvalDataset/clips/bear/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
-
EvalDataset/clips/boat/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
-
EvalDataset/clips/cows/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
-
EvalDataset/clips/cows/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
-
EvalDataset/clips/dog/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
42 |
-
EvalDataset/clips/flamingo/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
-
EvalDataset/ref/goat/0000.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
-
EvalDataset/ref/hockey/0000.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
-
EvalDataset/ref/horsejump-high/0000.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
-
EvalDataset/ref/motorbike/0000.jpg filter=lfs diff=lfs merge=lfs -text
|
47 |
-
EvalDataset/ref/surf/0000.jpg filter=lfs diff=lfs merge=lfs -text
|
48 |
-
examples/bear/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
49 |
-
examples/cows/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
50 |
-
examples/flamingo/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
51 |
-
gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
52 |
-
gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
|
53 |
-
examples/man/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
54 |
-
examples/military/video.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
1 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
|
|
|
2 |
wandb/
|
3 |
.vscode
|
4 |
.DS_Store
|
|
|
1 |
+
flagged/
|
2 |
+
sample_output/
|
3 |
wandb/
|
4 |
.vscode
|
5 |
.DS_Store
|
README.md
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: ViTExCo
|
3 |
-
app_file: app.py
|
4 |
-
sdk: gradio
|
5 |
-
sdk_version: 3.40.1
|
6 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UI.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from PIL import Image
|
3 |
-
import torchvision.transforms as transforms
|
4 |
-
from streamlit_image_comparison import image_comparison
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import torchvision
|
8 |
-
|
9 |
-
######################################### Utils ########################################
|
10 |
-
video_extensions = ["mp4"]
|
11 |
-
image_extensions = ["png", "jpg"]
|
12 |
-
|
13 |
-
|
14 |
-
def check_type(file_name: str):
|
15 |
-
for image_extension in image_extensions:
|
16 |
-
if file_name.endswith(image_extension):
|
17 |
-
return "image"
|
18 |
-
for video_extension in video_extensions:
|
19 |
-
if file_name.endswith(video_extension):
|
20 |
-
return "video"
|
21 |
-
return None
|
22 |
-
|
23 |
-
|
24 |
-
transform = transforms.Compose(
|
25 |
-
[transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
###################################### Load model ######################################
|
30 |
-
@st.cache_resource
|
31 |
-
def load_model():
|
32 |
-
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
|
33 |
-
model.eval()
|
34 |
-
return model
|
35 |
-
|
36 |
-
|
37 |
-
model = load_model()
|
38 |
-
########################################## UI ##########################################
|
39 |
-
st.title("Colorization")
|
40 |
-
|
41 |
-
uploaded_file = st.file_uploader("Upload grayscale image or video", type=image_extensions + video_extensions)
|
42 |
-
if uploaded_file:
|
43 |
-
# Image
|
44 |
-
if check_type(file_name=uploaded_file.name) == "image":
|
45 |
-
image = np.array(Image.open(uploaded_file), dtype=np.float32)
|
46 |
-
|
47 |
-
input_tensor = torchvision.transforms.functional.normalize(
|
48 |
-
torch.tensor(image).permute(2, 0, 1),
|
49 |
-
mean=[0.485, 0.456, 0.406],
|
50 |
-
std=[0.229, 0.224, 0.225],
|
51 |
-
).unsqueeze(0)
|
52 |
-
process_button = st.button("Process")
|
53 |
-
if process_button:
|
54 |
-
with st.spinner("Từ từ coi..."):
|
55 |
-
prediction = model(input_tensor)
|
56 |
-
segment = prediction["out"][0].permute(1, 2, 0)
|
57 |
-
segment = segment.detach().numpy()
|
58 |
-
|
59 |
-
st.image(segment)
|
60 |
-
st.image(image)
|
61 |
-
|
62 |
-
image_comparison(
|
63 |
-
img1=image,
|
64 |
-
img2=np.array(segment),
|
65 |
-
label1="Grayscale",
|
66 |
-
label2="Colorized",
|
67 |
-
make_responsive=True,
|
68 |
-
show_labels=True,
|
69 |
-
)
|
70 |
-
# Video
|
71 |
-
else:
|
72 |
-
# video = open(uploaded_file.name)
|
73 |
-
st.video("https://youtu.be/dQw4w9WgXcQ")
|
74 |
-
|
75 |
-
hide_menu_style = """
|
76 |
-
<style>
|
77 |
-
#MainMenu {visibility: hidden; }
|
78 |
-
footer {visibility: hidden;}
|
79 |
-
</style>
|
80 |
-
"""
|
81 |
-
st.markdown(hide_menu_style, unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,215 +1,49 @@
|
|
1 |
-
import
|
2 |
-
import
|
|
|
3 |
import os
|
4 |
-
import argparse
|
5 |
-
import torch
|
6 |
-
import glob
|
7 |
-
from tqdm import tqdm
|
8 |
from PIL import Image
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
)
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
predicted_rgb = np.transpose(predicted_rgb, (1,2,0))
|
54 |
-
pil_img = Image.fromarray(predicted_rgb)
|
55 |
-
pil_img.save(os.path.join(OUTPUT_RESULT_PATH, video_name, frame_name))
|
56 |
-
|
57 |
-
def extract_frames_from_video(video_path):
|
58 |
-
cap = cv2.VideoCapture(video_path)
|
59 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
60 |
-
|
61 |
-
# remove if exists folder
|
62 |
-
output_frames_path = os.path.join(INPUT_VIDEO_FRAMES_PATH, os.path.basename(video_path))
|
63 |
-
if os.path.exists(output_frames_path):
|
64 |
-
shutil.rmtree(output_frames_path)
|
65 |
-
|
66 |
-
# make new folder
|
67 |
-
os.makedirs(output_frames_path)
|
68 |
-
|
69 |
-
currentframe = 0
|
70 |
-
frame_path_list = []
|
71 |
-
while(True):
|
72 |
-
|
73 |
-
# reading from frame
|
74 |
-
ret,frame = cap.read()
|
75 |
-
|
76 |
-
if ret:
|
77 |
-
name = os.path.join(output_frames_path, f'{currentframe:09d}.jpg')
|
78 |
-
frame_path_list.append(name)
|
79 |
-
cv2.imwrite(name, frame)
|
80 |
-
currentframe += 1
|
81 |
-
else:
|
82 |
-
break
|
83 |
-
|
84 |
-
cap.release()
|
85 |
-
cv2.destroyAllWindows()
|
86 |
-
|
87 |
-
return frame_path_list, fps
|
88 |
-
|
89 |
-
def combine_frames_from_folder(frames_list_path, fps = 30):
|
90 |
-
frames_list = glob.glob(f'{frames_list_path}/*.jpg')
|
91 |
-
frames_list.sort()
|
92 |
-
|
93 |
-
sample_shape = cv2.imread(frames_list[0]).shape
|
94 |
-
|
95 |
-
output_video_path = os.path.join(frames_list_path, 'output_video.mp4')
|
96 |
-
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (sample_shape[1], sample_shape[0]))
|
97 |
-
for filename in frames_list:
|
98 |
-
img = cv2.imread(filename)
|
99 |
-
out.write(img)
|
100 |
-
|
101 |
-
out.release()
|
102 |
-
return output_video_path
|
103 |
-
|
104 |
-
|
105 |
-
def upscale_image(I_current_rgb, I_current_ab_predict):
|
106 |
-
H, W = I_current_rgb.size
|
107 |
-
high_lab_transforms = [
|
108 |
-
SquaredPadding(target_size=max(H,W)),
|
109 |
-
RGB2Lab(),
|
110 |
-
ToTensor(),
|
111 |
-
Normalize()
|
112 |
-
]
|
113 |
-
# current_frame_pil_rgb = Image.fromarray(np.clip(I_current_rgb.squeeze(0).permute(1,2,0).cpu().numpy() * 255, 0, 255).astype('uint8'))
|
114 |
-
high_lab_current, paddings = custom_transform(high_lab_transforms, I_current_rgb)
|
115 |
-
high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device)
|
116 |
-
high_l_current = high_lab_current[:, 0:1, :, :]
|
117 |
-
high_ab_current = high_lab_current[:, 1:3, :, :]
|
118 |
-
upsampler = torch.nn.Upsample(scale_factor=max(H,W)/224,mode="bilinear")
|
119 |
-
high_ab_predict = upsampler(I_current_ab_predict)
|
120 |
-
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1))
|
121 |
-
upadded = UnpaddingSquare()
|
122 |
-
I_predict_rgb = upadded(I_predict_rgb, paddings)
|
123 |
-
return I_predict_rgb
|
124 |
-
|
125 |
-
def colorize_video(video_path, ref_np):
|
126 |
-
frames_list, fps = extract_frames_from_video(video_path)
|
127 |
-
|
128 |
-
frame_ref = Image.fromarray(ref_np).convert("RGB")
|
129 |
-
I_last_lab_predict = None
|
130 |
-
IB_lab, IB_paddings = custom_transform(transforms, frame_ref)
|
131 |
-
IB_lab = IB_lab.unsqueeze(0).to(device)
|
132 |
-
IB_l = IB_lab[:, 0:1, :, :]
|
133 |
-
IB_ab = IB_lab[:, 1:3, :, :]
|
134 |
-
|
135 |
-
with torch.no_grad():
|
136 |
-
I_reference_lab = IB_lab
|
137 |
-
I_reference_l = I_reference_lab[:, 0:1, :, :]
|
138 |
-
I_reference_ab = I_reference_lab[:, 1:3, :, :]
|
139 |
-
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device)
|
140 |
-
features_B = embed_net(I_reference_rgb)
|
141 |
-
|
142 |
-
video_path_parts = frames_list[0].split(os.sep)
|
143 |
-
|
144 |
-
if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2])):
|
145 |
-
shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]))
|
146 |
-
os.makedirs(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), exist_ok=True)
|
147 |
-
|
148 |
-
for frame_path in tqdm(frames_list):
|
149 |
-
curr_frame = Image.open(frame_path).convert("RGB")
|
150 |
-
IA_lab, IA_paddings = custom_transform(transforms, curr_frame)
|
151 |
-
IA_lab = IA_lab.unsqueeze(0).to(device)
|
152 |
-
IA_l = IA_lab[:, 0:1, :, :]
|
153 |
-
IA_ab = IA_lab[:, 1:3, :, :]
|
154 |
-
|
155 |
-
if I_last_lab_predict is None:
|
156 |
-
I_last_lab_predict = torch.zeros_like(IA_lab).to(device)
|
157 |
-
|
158 |
-
with torch.no_grad():
|
159 |
-
I_current_lab = IA_lab
|
160 |
-
I_current_ab_predict, _ = frame_colorization(
|
161 |
-
IA_l,
|
162 |
-
I_reference_lab,
|
163 |
-
I_last_lab_predict,
|
164 |
-
features_B,
|
165 |
-
embed_net,
|
166 |
-
nonlocal_net,
|
167 |
-
colornet,
|
168 |
-
luminance_noise=0,
|
169 |
-
temperature=1e-10,
|
170 |
-
joint_training=False
|
171 |
-
)
|
172 |
-
I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
|
173 |
-
|
174 |
-
# IA_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(IA_l), I_current_ab_predict), dim=1))
|
175 |
-
IA_predict_rgb = upscale_image(curr_frame, I_current_ab_predict)
|
176 |
-
#IA_predict_rgb = torch.nn.functional.upsample_bilinear(IA_predict_rgb, scale_factor=2)
|
177 |
-
save_frames(IA_predict_rgb.squeeze(0).cpu().numpy() * 255, video_path_parts[-2], os.path.basename(frame_path))
|
178 |
-
return combine_frames_from_folder(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), fps)
|
179 |
-
|
180 |
-
if __name__ == '__main__':
|
181 |
-
# Init global variables
|
182 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
183 |
-
INPUT_VIDEO_FRAMES_PATH = 'inputs'
|
184 |
-
OUTPUT_RESULT_PATH = 'outputs'
|
185 |
-
weight_path = 'checkpoints'
|
186 |
-
|
187 |
-
embed_net=GeneralEmbedModel(pretrained_model="swin-tiny", device=device).to(device)
|
188 |
-
nonlocal_net = GeneralWarpNet(feature_channel=128).to(device)
|
189 |
-
colornet=GeneralColorVidNet(7).to(device)
|
190 |
-
|
191 |
-
embed_net.eval()
|
192 |
-
nonlocal_net.eval()
|
193 |
-
colornet.eval()
|
194 |
-
|
195 |
-
# Load weights
|
196 |
-
# embed_net_params = load_params(os.path.join(weight_path, "embed_net.pth"))
|
197 |
-
nonlocal_net_params = load_params(os.path.join(weight_path, "nonlocal_net.pth"))
|
198 |
-
colornet_params = load_params(os.path.join(weight_path, "colornet.pth"))
|
199 |
-
|
200 |
-
# embed_net.load_state_dict(embed_net_params, strict=True)
|
201 |
-
nonlocal_net.load_state_dict(nonlocal_net_params, strict=True)
|
202 |
-
colornet.load_state_dict(colornet_params, strict=True)
|
203 |
-
|
204 |
-
transforms = [SquaredPadding(target_size=224),
|
205 |
-
RGB2Lab(),
|
206 |
-
ToTensor(),
|
207 |
-
Normalize()]
|
208 |
-
|
209 |
-
#examples = [[vid, ref] for vid, ref in zip(sorted(glob.glob('examples/*/*.mp4')), sorted(glob.glob('examples/*/*.jpg')))]
|
210 |
-
demo = gr.Interface(colorize_video,
|
211 |
-
inputs=[gr.Video(), gr.Image()],
|
212 |
-
outputs="playable_video")#,
|
213 |
-
#examples=examples,
|
214 |
-
#cache_examples=True)
|
215 |
-
demo.launch()
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.inference import SwinTExCo
|
3 |
+
import cv2
|
4 |
import os
|
|
|
|
|
|
|
|
|
5 |
from PIL import Image
|
6 |
+
import time
|
7 |
+
import app_config as cfg
|
8 |
+
|
9 |
+
|
10 |
+
model = SwinTExCo(weights_path=cfg.ckpt_path)
|
11 |
+
|
12 |
+
def video_colorization(video_path, ref_image, progress=gr.Progress()):
|
13 |
+
# Initialize video reader
|
14 |
+
video_reader = cv2.VideoCapture(video_path)
|
15 |
+
fps = video_reader.get(cv2.CAP_PROP_FPS)
|
16 |
+
height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
17 |
+
width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
|
18 |
+
num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
|
19 |
+
|
20 |
+
# Initialize reference image
|
21 |
+
ref_image = Image.fromarray(ref_image)
|
22 |
+
|
23 |
+
# Initialize video writer
|
24 |
+
output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4')
|
25 |
+
video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
26 |
+
|
27 |
+
# Init progress bar
|
28 |
+
|
29 |
+
for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
|
30 |
+
video_writer.write(colorized_frame)
|
31 |
+
|
32 |
+
# for i in progress.tqdm(range(1000)):
|
33 |
+
# time.sleep(0.5)
|
34 |
+
|
35 |
+
video_writer.release()
|
36 |
+
|
37 |
+
return output_path
|
38 |
+
|
39 |
+
app = gr.Interface(
|
40 |
+
fn=video_colorization,
|
41 |
+
inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
|
42 |
+
gr.Image(sources="upload", label="Reference image (color)")],
|
43 |
+
outputs=gr.Video(label="Output video (colorized)"),
|
44 |
+
title=cfg.TITLE,
|
45 |
+
description=cfg.DESCRIPTION
|
46 |
+
).queue()
|
47 |
+
|
48 |
+
|
49 |
+
app.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ckpt_path = 'checkpoints/epoch_20'
|
2 |
+
TITLE = 'Deep Exemplar-based Video Colorization using Vision Transformer'
|
3 |
+
DESCRIPTION = '''
|
4 |
+
<center>
|
5 |
+
This is a demo app of the thesis: <b>Deep Exemplar-based Video Colorization using Vision Transformer</b>.<br/>
|
6 |
+
The code is available at: <i>The link will be updated soon</i>.<br/>
|
7 |
+
Our previous work was also written into paper and accepted at the <a href="https://ictc.org/program_proceeding">ICTC 2023 conference</a> (Section <i>B1-4</i>).
|
8 |
+
</center>
|
9 |
+
'''.strip()
|
checkpoints/{colornet.pth → epoch_10/colornet.pth}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 131239411
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ecb43b5e02b77bec5342e2e296d336bf8f384a07d3c809d1a548fd5fb1e7365
|
3 |
size 131239411
|
examples/bear/video.mp4 → checkpoints/epoch_10/discriminator.pth
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce8968a9d3d2f99b1bc1e32080507e0d671cee00b66200105c8839be684b84b4
|
3 |
+
size 45073068
|
checkpoints/{embed_net.pth → epoch_10/embed_net.pth}
RENAMED
File without changes
|
checkpoints/epoch_10/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d09b1e96fdf0205930a21928449a44c51cedd965cc0d573068c73971bcb8bd2
|
3 |
+
size 748166487
|
checkpoints/{nonlocal_net.pth → epoch_10/nonlocal_net.pth}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 73189765
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86c97d6803d625a0dff8c6c09b70852371906eb5ef77df0277c27875666a68e2
|
3 |
size 73189765
|
checkpoints/epoch_12/colornet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50f4b92cd59f4c88c0c1d7c93652413d54b1b96d729fc4b93e235887b5164f28
|
3 |
+
size 131239846
|
examples/cows/video.mp4 → checkpoints/epoch_12/discriminator.pth
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b54b0bad6ceec33569cc5833cbf03ed8ddbb5f07998aa634badf8298d3cd15f
|
3 |
+
size 45073513
|
checkpoints/epoch_12/embed_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
|
3 |
+
size 110352698
|
checkpoints/epoch_12/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f8bb4dbb3cb8e497a9a2079947f0221823fa8b44695e2d2ad8478be48464fad
|
3 |
+
size 748166934
|
checkpoints/epoch_12/nonlocal_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1f76b53dad7bf15c7d26aa106c95387e75751b8c31fafef2bd73ea7d77160cb
|
3 |
+
size 73190208
|
checkpoints/epoch_16/colornet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81ec9cff0ad5b0d920179fa7a9cc229e1424bfc796b7134604ff66b97d748c49
|
3 |
+
size 131239846
|
checkpoints/epoch_16/discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42262d5ed7596f38e65774085222530eee57da8dfaa7fe1aa223d824ed166f62
|
3 |
+
size 45073513
|
checkpoints/epoch_16/embed_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
|
3 |
+
size 110352698
|
checkpoints/epoch_16/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea4cf81341750ebf517c696a0f6241bfeede0584b0ce75ad208e3ffc8280877f
|
3 |
+
size 748166934
|
checkpoints/epoch_16/nonlocal_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85b63363bc9c79732df78ba50ed19491ed86e961214bbd1f796a871334eba516
|
3 |
+
size 73190208
|
checkpoints/epoch_20/colornet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c524f4e5df5f6ce91db1973a30de55299ebcbbde1edd2009718d3b4cd2631339
|
3 |
+
size 131239846
|
checkpoints/epoch_20/discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fcd80950c796fcfe6e4b6bdeeb358776700458d868da94ee31df3d1d37779310
|
3 |
+
size 45073513
|
checkpoints/epoch_20/embed_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
|
3 |
+
size 110352698
|
checkpoints/epoch_20/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b1163b210b246b07d8f1c50eb3766d97c6f03bf409c854d00b7c69edb6d7391
|
3 |
+
size 748166934
|
checkpoints/epoch_20/nonlocal_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:031e5f38cc79eb3c0ed51ca2ad3c8921fdda2fa05946c357f84881259de74e6d
|
3 |
+
size 73190208
|
cmd.txt
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
python train.py --video_data_root_list datasets/images/images \
|
2 |
-
--flow_data_root_list datasets/flow_fp16/flow_fp16 \
|
3 |
-
--mask_data_root_list datasets/pgm/pgm \
|
4 |
-
--data_root_imagenet datasets/imgnet \
|
5 |
-
--annotation_file_path datasets/final_annot.csv \
|
6 |
-
--imagenet_pairs_file datasets/pairs.txt \
|
7 |
-
--gpu_ids 0 \
|
8 |
-
--workers 12 \
|
9 |
-
--batch_size 2 \
|
10 |
-
--real_reference_probability 0.99 \
|
11 |
-
--weight_contextual 1 \
|
12 |
-
--weight_perceptual 0.1 \
|
13 |
-
--weight_smoothness 5 \
|
14 |
-
--weight_gan 0.9 \
|
15 |
-
--weight_consistent 0.1 \
|
16 |
-
--use_wandb True \
|
17 |
-
--wandb_token "f05d31e6b15339b1cfc5ee1c77fe51f66fc3ea9e" \
|
18 |
-
--wandb_name "vit_tiny_patch16_384_nofeat" \
|
19 |
-
--checkpoint_step 500 \
|
20 |
-
--epoch_train_discriminator 3 \
|
21 |
-
--epoch 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmd_ddp.txt
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
!torchrun --nnodes=1 --nproc_per_node=2 train_ddp.py --video_data_root_list $video_data_root_list \
|
2 |
-
--flow_data_root_list $flow_data_root_list \
|
3 |
-
--mask_data_root_list $mask_data_root_list \
|
4 |
-
--data_root_imagenet $data_root_imagenet \
|
5 |
-
--annotation_file_path $annotation_file_path \
|
6 |
-
--imagenet_pairs_file $imagenet_pairs_file \
|
7 |
-
--gpu_ids "0,1" \
|
8 |
-
--workers 2 \
|
9 |
-
--batch_size 2 \
|
10 |
-
--real_reference_probability 0.99 \
|
11 |
-
--weight_contextual 1 \
|
12 |
-
--weight_perceptual 0.1 \
|
13 |
-
--weight_smoothness 5 \
|
14 |
-
--weight_gan 0.9 \
|
15 |
-
--weight_consistent 0.1 \
|
16 |
-
--wandb_token "165e7148081f263b423722115e2ad40fa5339ecf" \
|
17 |
-
--wandb_name "vit_tiny_patch16_384_nofeat" \
|
18 |
-
--checkpoint_step 2000 \
|
19 |
-
--epoch_train_discriminator 2 \
|
20 |
-
--epoch 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
environment.yml
DELETED
File without changes
|
examples/bear/ref.jpg
DELETED
Binary file (30.9 kB)
|
|
examples/boat/ref.jpg
DELETED
Binary file (65.4 kB)
|
|
examples/boat/video.mp4
DELETED
Binary file (853 kB)
|
|
examples/cows/ref.jpg
DELETED
Binary file (252 kB)
|
|
examples/flamingo/ref.jpg
DELETED
Binary file (539 kB)
|
|
examples/flamingo/video.mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:5a103fd4991a00e419e5236b885fe9d220704ba0a6ac794c87aaa3f62a4f1561
|
3 |
-
size 1239570
|
|
|
|
|
|
|
|
examples/man/ref.jpg
DELETED
Binary file (176 kB)
|
|
examples/man/video.mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:dff54d74e38285d60e064a0332c66d3ca2860f3c05de814a63693a9c331e94c9
|
3 |
-
size 1693420
|
|
|
|
|
|
|
|
examples/military/ref.jpg
DELETED
Binary file (111 kB)
|
|
examples/military/video.mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:02ce1717c2f5768af588a0bbeb47c659e54a310a880a52b68a8a7701647e145a
|
3 |
-
size 1495376
|
|
|
|
|
|
|
|
gradio_cached_examples/13/log.csv
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
output,flag,username,timestamp
|
2 |
-
/content/ViTExCo/gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4,,,2023-08-15 09:45:37.897615
|
3 |
-
/content/ViTExCo/gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4,,,2023-08-15 09:46:01.048997
|
4 |
-
/content/ViTExCo/gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4,,,2023-08-15 09:46:34.503322
|
5 |
-
/content/ViTExCo/gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4,,,2023-08-15 09:46:58.088903
|
|
|
|
|
|
|
|
|
|
|
|
gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b5ab666998e14fb00281a90f8801753eca001a432641ae2770007a8336b4c64e
|
3 |
-
size 1213824
|
|
|
|
|
|
|
|
gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4
DELETED
Binary file (914 kB)
|
|
gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7c367dab34e596f7f0fed34c7e2384525de2ba1824b410d0770bdbd17bc9e72a
|
3 |
-
size 1793060
|
|
|
|
|
|
|
|
gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4
DELETED
Binary file (673 kB)
|
|
inputs/video.mp4/000000000.jpg
DELETED
Binary file (113 kB)
|
|
inputs/video.mp4/000000001.jpg
DELETED
Binary file (146 kB)
|
|
inputs/video.mp4/000000002.jpg
DELETED
Binary file (143 kB)
|
|
inputs/video.mp4/000000003.jpg
DELETED
Binary file (141 kB)
|
|
inputs/video.mp4/000000004.jpg
DELETED
Binary file (142 kB)
|
|
inputs/video.mp4/000000005.jpg
DELETED
Binary file (141 kB)
|
|