duongttr commited on
Commit
3d85088
·
1 Parent(s): efb56b8

Update new app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -20
  2. .gitignore +2 -1
  3. README.md +0 -6
  4. UI.py +0 -81
  5. app.py +47 -213
  6. app_config.py +9 -0
  7. checkpoints/{colornet.pth → epoch_10/colornet.pth} +1 -1
  8. examples/bear/video.mp4 → checkpoints/epoch_10/discriminator.pth +2 -2
  9. checkpoints/{embed_net.pth → epoch_10/embed_net.pth} +0 -0
  10. checkpoints/epoch_10/learning_state.pth +3 -0
  11. checkpoints/{nonlocal_net.pth → epoch_10/nonlocal_net.pth} +1 -1
  12. checkpoints/epoch_12/colornet.pth +3 -0
  13. examples/cows/video.mp4 → checkpoints/epoch_12/discriminator.pth +2 -2
  14. checkpoints/epoch_12/embed_net.pth +3 -0
  15. checkpoints/epoch_12/learning_state.pth +3 -0
  16. checkpoints/epoch_12/nonlocal_net.pth +3 -0
  17. checkpoints/epoch_16/colornet.pth +3 -0
  18. checkpoints/epoch_16/discriminator.pth +3 -0
  19. checkpoints/epoch_16/embed_net.pth +3 -0
  20. checkpoints/epoch_16/learning_state.pth +3 -0
  21. checkpoints/epoch_16/nonlocal_net.pth +3 -0
  22. checkpoints/epoch_20/colornet.pth +3 -0
  23. checkpoints/epoch_20/discriminator.pth +3 -0
  24. checkpoints/epoch_20/embed_net.pth +3 -0
  25. checkpoints/epoch_20/learning_state.pth +3 -0
  26. checkpoints/epoch_20/nonlocal_net.pth +3 -0
  27. cmd.txt +0 -21
  28. cmd_ddp.txt +0 -20
  29. environment.yml +0 -0
  30. examples/bear/ref.jpg +0 -0
  31. examples/boat/ref.jpg +0 -0
  32. examples/boat/video.mp4 +0 -0
  33. examples/cows/ref.jpg +0 -0
  34. examples/flamingo/ref.jpg +0 -0
  35. examples/flamingo/video.mp4 +0 -3
  36. examples/man/ref.jpg +0 -0
  37. examples/man/video.mp4 +0 -3
  38. examples/military/ref.jpg +0 -0
  39. examples/military/video.mp4 +0 -3
  40. gradio_cached_examples/13/log.csv +0 -5
  41. gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 +0 -3
  42. gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4 +0 -0
  43. gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 +0 -3
  44. gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4 +0 -0
  45. inputs/video.mp4/000000000.jpg +0 -0
  46. inputs/video.mp4/000000001.jpg +0 -0
  47. inputs/video.mp4/000000002.jpg +0 -0
  48. inputs/video.mp4/000000003.jpg +0 -0
  49. inputs/video.mp4/000000004.jpg +0 -0
  50. inputs/video.mp4/000000005.jpg +0 -0
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
@@ -33,22 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- EvalDataset/clips/bear/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
37
- EvalDataset/clips/bear/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
38
- EvalDataset/clips/boat/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
39
- EvalDataset/clips/cows/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
40
- EvalDataset/clips/cows/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
41
- EvalDataset/clips/dog/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
42
- EvalDataset/clips/flamingo/output_video_gray.mp4 filter=lfs diff=lfs merge=lfs -text
43
- EvalDataset/ref/goat/0000.jpg filter=lfs diff=lfs merge=lfs -text
44
- EvalDataset/ref/hockey/0000.jpg filter=lfs diff=lfs merge=lfs -text
45
- EvalDataset/ref/horsejump-high/0000.jpg filter=lfs diff=lfs merge=lfs -text
46
- EvalDataset/ref/motorbike/0000.jpg filter=lfs diff=lfs merge=lfs -text
47
- EvalDataset/ref/surf/0000.jpg filter=lfs diff=lfs merge=lfs -text
48
- examples/bear/video.mp4 filter=lfs diff=lfs merge=lfs -text
49
- examples/cows/video.mp4 filter=lfs diff=lfs merge=lfs -text
50
- examples/flamingo/video.mp4 filter=lfs diff=lfs merge=lfs -text
51
- gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
52
- gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 filter=lfs diff=lfs merge=lfs -text
53
- examples/man/video.mp4 filter=lfs diff=lfs merge=lfs -text
54
- examples/military/video.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.arrow filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
- checkpoints/
 
2
  wandb/
3
  .vscode
4
  .DS_Store
 
1
+ flagged/
2
+ sample_output/
3
  wandb/
4
  .vscode
5
  .DS_Store
README.md DELETED
@@ -1,6 +0,0 @@
1
- ---
2
- title: ViTExCo
3
- app_file: app.py
4
- sdk: gradio
5
- sdk_version: 3.40.1
6
- ---
 
 
 
 
 
 
 
UI.py DELETED
@@ -1,81 +0,0 @@
1
- import streamlit as st
2
- from PIL import Image
3
- import torchvision.transforms as transforms
4
- from streamlit_image_comparison import image_comparison
5
- import numpy as np
6
- import torch
7
- import torchvision
8
-
9
- ######################################### Utils ########################################
10
- video_extensions = ["mp4"]
11
- image_extensions = ["png", "jpg"]
12
-
13
-
14
- def check_type(file_name: str):
15
- for image_extension in image_extensions:
16
- if file_name.endswith(image_extension):
17
- return "image"
18
- for video_extension in video_extensions:
19
- if file_name.endswith(video_extension):
20
- return "video"
21
- return None
22
-
23
-
24
- transform = transforms.Compose(
25
- [transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
26
- )
27
-
28
-
29
- ###################################### Load model ######################################
30
- @st.cache_resource
31
- def load_model():
32
- model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
33
- model.eval()
34
- return model
35
-
36
-
37
- model = load_model()
38
- ########################################## UI ##########################################
39
- st.title("Colorization")
40
-
41
- uploaded_file = st.file_uploader("Upload grayscale image or video", type=image_extensions + video_extensions)
42
- if uploaded_file:
43
- # Image
44
- if check_type(file_name=uploaded_file.name) == "image":
45
- image = np.array(Image.open(uploaded_file), dtype=np.float32)
46
-
47
- input_tensor = torchvision.transforms.functional.normalize(
48
- torch.tensor(image).permute(2, 0, 1),
49
- mean=[0.485, 0.456, 0.406],
50
- std=[0.229, 0.224, 0.225],
51
- ).unsqueeze(0)
52
- process_button = st.button("Process")
53
- if process_button:
54
- with st.spinner("Từ từ coi..."):
55
- prediction = model(input_tensor)
56
- segment = prediction["out"][0].permute(1, 2, 0)
57
- segment = segment.detach().numpy()
58
-
59
- st.image(segment)
60
- st.image(image)
61
-
62
- image_comparison(
63
- img1=image,
64
- img2=np.array(segment),
65
- label1="Grayscale",
66
- label2="Colorized",
67
- make_responsive=True,
68
- show_labels=True,
69
- )
70
- # Video
71
- else:
72
- # video = open(uploaded_file.name)
73
- st.video("https://youtu.be/dQw4w9WgXcQ")
74
-
75
- hide_menu_style = """
76
- <style>
77
- #MainMenu {visibility: hidden; }
78
- footer {visibility: hidden;}
79
- </style>
80
- """
81
- st.markdown(hide_menu_style, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,215 +1,49 @@
1
- import numpy as np
2
- import shutil
 
3
  import os
4
- import argparse
5
- import torch
6
- import glob
7
- from tqdm import tqdm
8
  from PIL import Image
9
- from collections import OrderedDict
10
- from src.models.vit.config import load_config
11
- import torchvision.transforms as transforms
12
- import cv2
13
- from skimage import io
14
-
15
- from src.models.CNN.ColorVidNet import GeneralColorVidNet
16
- from src.models.vit.embed import GeneralEmbedModel
17
- from src.models.CNN.NonlocalNet import GeneralWarpNet
18
- from src.models.CNN.FrameColor import frame_colorization
19
- from src.utils import (
20
- RGB2Lab,
21
- ToTensor,
22
- Normalize,
23
- uncenter_l,
24
- tensor_lab2rgb,
25
- SquaredPadding,
26
- UnpaddingSquare
27
- )
28
-
29
- import gradio as gr
30
-
31
- def load_params(ckpt_file):
32
- params = torch.load(ckpt_file, map_location=device)
33
- new_params = []
34
- for key, value in params.items():
35
- new_params.append((key, value))
36
- return OrderedDict(new_params)
37
-
38
- def custom_transform(transforms, img):
39
- for transform in transforms:
40
- if isinstance(transform, SquaredPadding):
41
- img,padding=transform(img, return_paddings=True)
42
- else:
43
- img = transform(img)
44
- return img.to(device), padding
45
-
46
- def save_frames(predicted_rgb, video_name, frame_name):
47
- if predicted_rgb is not None:
48
- predicted_rgb = np.clip(predicted_rgb, 0, 255).astype(np.uint8)
49
- # frame_path_parts = frame_path.split(os.sep)
50
- # if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2])):
51
- # shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2]))
52
- # os.makedirs(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2]), exist_ok=True)
53
- predicted_rgb = np.transpose(predicted_rgb, (1,2,0))
54
- pil_img = Image.fromarray(predicted_rgb)
55
- pil_img.save(os.path.join(OUTPUT_RESULT_PATH, video_name, frame_name))
56
-
57
- def extract_frames_from_video(video_path):
58
- cap = cv2.VideoCapture(video_path)
59
- fps = cap.get(cv2.CAP_PROP_FPS)
60
-
61
- # remove if exists folder
62
- output_frames_path = os.path.join(INPUT_VIDEO_FRAMES_PATH, os.path.basename(video_path))
63
- if os.path.exists(output_frames_path):
64
- shutil.rmtree(output_frames_path)
65
-
66
- # make new folder
67
- os.makedirs(output_frames_path)
68
-
69
- currentframe = 0
70
- frame_path_list = []
71
- while(True):
72
-
73
- # reading from frame
74
- ret,frame = cap.read()
75
-
76
- if ret:
77
- name = os.path.join(output_frames_path, f'{currentframe:09d}.jpg')
78
- frame_path_list.append(name)
79
- cv2.imwrite(name, frame)
80
- currentframe += 1
81
- else:
82
- break
83
-
84
- cap.release()
85
- cv2.destroyAllWindows()
86
-
87
- return frame_path_list, fps
88
-
89
- def combine_frames_from_folder(frames_list_path, fps = 30):
90
- frames_list = glob.glob(f'{frames_list_path}/*.jpg')
91
- frames_list.sort()
92
-
93
- sample_shape = cv2.imread(frames_list[0]).shape
94
-
95
- output_video_path = os.path.join(frames_list_path, 'output_video.mp4')
96
- out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (sample_shape[1], sample_shape[0]))
97
- for filename in frames_list:
98
- img = cv2.imread(filename)
99
- out.write(img)
100
-
101
- out.release()
102
- return output_video_path
103
-
104
-
105
- def upscale_image(I_current_rgb, I_current_ab_predict):
106
- H, W = I_current_rgb.size
107
- high_lab_transforms = [
108
- SquaredPadding(target_size=max(H,W)),
109
- RGB2Lab(),
110
- ToTensor(),
111
- Normalize()
112
- ]
113
- # current_frame_pil_rgb = Image.fromarray(np.clip(I_current_rgb.squeeze(0).permute(1,2,0).cpu().numpy() * 255, 0, 255).astype('uint8'))
114
- high_lab_current, paddings = custom_transform(high_lab_transforms, I_current_rgb)
115
- high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device)
116
- high_l_current = high_lab_current[:, 0:1, :, :]
117
- high_ab_current = high_lab_current[:, 1:3, :, :]
118
- upsampler = torch.nn.Upsample(scale_factor=max(H,W)/224,mode="bilinear")
119
- high_ab_predict = upsampler(I_current_ab_predict)
120
- I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1))
121
- upadded = UnpaddingSquare()
122
- I_predict_rgb = upadded(I_predict_rgb, paddings)
123
- return I_predict_rgb
124
-
125
- def colorize_video(video_path, ref_np):
126
- frames_list, fps = extract_frames_from_video(video_path)
127
-
128
- frame_ref = Image.fromarray(ref_np).convert("RGB")
129
- I_last_lab_predict = None
130
- IB_lab, IB_paddings = custom_transform(transforms, frame_ref)
131
- IB_lab = IB_lab.unsqueeze(0).to(device)
132
- IB_l = IB_lab[:, 0:1, :, :]
133
- IB_ab = IB_lab[:, 1:3, :, :]
134
-
135
- with torch.no_grad():
136
- I_reference_lab = IB_lab
137
- I_reference_l = I_reference_lab[:, 0:1, :, :]
138
- I_reference_ab = I_reference_lab[:, 1:3, :, :]
139
- I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device)
140
- features_B = embed_net(I_reference_rgb)
141
-
142
- video_path_parts = frames_list[0].split(os.sep)
143
-
144
- if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2])):
145
- shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]))
146
- os.makedirs(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), exist_ok=True)
147
-
148
- for frame_path in tqdm(frames_list):
149
- curr_frame = Image.open(frame_path).convert("RGB")
150
- IA_lab, IA_paddings = custom_transform(transforms, curr_frame)
151
- IA_lab = IA_lab.unsqueeze(0).to(device)
152
- IA_l = IA_lab[:, 0:1, :, :]
153
- IA_ab = IA_lab[:, 1:3, :, :]
154
-
155
- if I_last_lab_predict is None:
156
- I_last_lab_predict = torch.zeros_like(IA_lab).to(device)
157
-
158
- with torch.no_grad():
159
- I_current_lab = IA_lab
160
- I_current_ab_predict, _ = frame_colorization(
161
- IA_l,
162
- I_reference_lab,
163
- I_last_lab_predict,
164
- features_B,
165
- embed_net,
166
- nonlocal_net,
167
- colornet,
168
- luminance_noise=0,
169
- temperature=1e-10,
170
- joint_training=False
171
- )
172
- I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
173
-
174
- # IA_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(IA_l), I_current_ab_predict), dim=1))
175
- IA_predict_rgb = upscale_image(curr_frame, I_current_ab_predict)
176
- #IA_predict_rgb = torch.nn.functional.upsample_bilinear(IA_predict_rgb, scale_factor=2)
177
- save_frames(IA_predict_rgb.squeeze(0).cpu().numpy() * 255, video_path_parts[-2], os.path.basename(frame_path))
178
- return combine_frames_from_folder(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), fps)
179
-
180
- if __name__ == '__main__':
181
- # Init global variables
182
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
183
- INPUT_VIDEO_FRAMES_PATH = 'inputs'
184
- OUTPUT_RESULT_PATH = 'outputs'
185
- weight_path = 'checkpoints'
186
-
187
- embed_net=GeneralEmbedModel(pretrained_model="swin-tiny", device=device).to(device)
188
- nonlocal_net = GeneralWarpNet(feature_channel=128).to(device)
189
- colornet=GeneralColorVidNet(7).to(device)
190
-
191
- embed_net.eval()
192
- nonlocal_net.eval()
193
- colornet.eval()
194
-
195
- # Load weights
196
- # embed_net_params = load_params(os.path.join(weight_path, "embed_net.pth"))
197
- nonlocal_net_params = load_params(os.path.join(weight_path, "nonlocal_net.pth"))
198
- colornet_params = load_params(os.path.join(weight_path, "colornet.pth"))
199
-
200
- # embed_net.load_state_dict(embed_net_params, strict=True)
201
- nonlocal_net.load_state_dict(nonlocal_net_params, strict=True)
202
- colornet.load_state_dict(colornet_params, strict=True)
203
-
204
- transforms = [SquaredPadding(target_size=224),
205
- RGB2Lab(),
206
- ToTensor(),
207
- Normalize()]
208
-
209
- #examples = [[vid, ref] for vid, ref in zip(sorted(glob.glob('examples/*/*.mp4')), sorted(glob.glob('examples/*/*.jpg')))]
210
- demo = gr.Interface(colorize_video,
211
- inputs=[gr.Video(), gr.Image()],
212
- outputs="playable_video")#,
213
- #examples=examples,
214
- #cache_examples=True)
215
- demo.launch()
 
1
+ import gradio as gr
2
+ from src.inference import SwinTExCo
3
+ import cv2
4
  import os
 
 
 
 
5
  from PIL import Image
6
+ import time
7
+ import app_config as cfg
8
+
9
+
10
+ model = SwinTExCo(weights_path=cfg.ckpt_path)
11
+
12
+ def video_colorization(video_path, ref_image, progress=gr.Progress()):
13
+ # Initialize video reader
14
+ video_reader = cv2.VideoCapture(video_path)
15
+ fps = video_reader.get(cv2.CAP_PROP_FPS)
16
+ height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
17
+ width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
18
+ num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
19
+
20
+ # Initialize reference image
21
+ ref_image = Image.fromarray(ref_image)
22
+
23
+ # Initialize video writer
24
+ output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4')
25
+ video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
26
+
27
+ # Init progress bar
28
+
29
+ for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
30
+ video_writer.write(colorized_frame)
31
+
32
+ # for i in progress.tqdm(range(1000)):
33
+ # time.sleep(0.5)
34
+
35
+ video_writer.release()
36
+
37
+ return output_path
38
+
39
+ app = gr.Interface(
40
+ fn=video_colorization,
41
+ inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
42
+ gr.Image(sources="upload", label="Reference image (color)")],
43
+ outputs=gr.Video(label="Output video (colorized)"),
44
+ title=cfg.TITLE,
45
+ description=cfg.DESCRIPTION
46
+ ).queue()
47
+
48
+
49
+ app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path = 'checkpoints/epoch_20'
2
+ TITLE = 'Deep Exemplar-based Video Colorization using Vision Transformer'
3
+ DESCRIPTION = '''
4
+ <center>
5
+ This is a demo app of the thesis: <b>Deep Exemplar-based Video Colorization using Vision Transformer</b>.<br/>
6
+ The code is available at: <i>The link will be updated soon</i>.<br/>
7
+ Our previous work was also written into paper and accepted at the <a href="https://ictc.org/program_proceeding">ICTC 2023 conference</a> (Section <i>B1-4</i>).
8
+ </center>
9
+ '''.strip()
checkpoints/{colornet.pth → epoch_10/colornet.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5257ae325e292cd5fb2eff47095e1c4e4815455bd5fb6dc5ed2ee2b923172875
3
  size 131239411
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ecb43b5e02b77bec5342e2e296d336bf8f384a07d3c809d1a548fd5fb1e7365
3
  size 131239411
examples/bear/video.mp4 → checkpoints/epoch_10/discriminator.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb4cec5064873a4616f78bdb653830683a4842b2a5cfd0665b395cff4d120d04
3
- size 1263445
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce8968a9d3d2f99b1bc1e32080507e0d671cee00b66200105c8839be684b84b4
3
+ size 45073068
checkpoints/{embed_net.pth → epoch_10/embed_net.pth} RENAMED
File without changes
checkpoints/epoch_10/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d09b1e96fdf0205930a21928449a44c51cedd965cc0d573068c73971bcb8bd2
3
+ size 748166487
checkpoints/{nonlocal_net.pth → epoch_10/nonlocal_net.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b94c6990f20088bc3cc3fe0b29a6d52e6e746b915c506f0cd349fc6ad6197e72
3
  size 73189765
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86c97d6803d625a0dff8c6c09b70852371906eb5ef77df0277c27875666a68e2
3
  size 73189765
checkpoints/epoch_12/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f4b92cd59f4c88c0c1d7c93652413d54b1b96d729fc4b93e235887b5164f28
3
+ size 131239846
examples/cows/video.mp4 → checkpoints/epoch_12/discriminator.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1ac08603d719cd7a8d71fac76c9318d3e8f1e516e9b3c2a06323a0e4e78f6410
3
- size 2745681
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b54b0bad6ceec33569cc5833cbf03ed8ddbb5f07998aa634badf8298d3cd15f
3
+ size 45073513
checkpoints/epoch_12/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
3
+ size 110352698
checkpoints/epoch_12/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f8bb4dbb3cb8e497a9a2079947f0221823fa8b44695e2d2ad8478be48464fad
3
+ size 748166934
checkpoints/epoch_12/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1f76b53dad7bf15c7d26aa106c95387e75751b8c31fafef2bd73ea7d77160cb
3
+ size 73190208
checkpoints/epoch_16/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81ec9cff0ad5b0d920179fa7a9cc229e1424bfc796b7134604ff66b97d748c49
3
+ size 131239846
checkpoints/epoch_16/discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42262d5ed7596f38e65774085222530eee57da8dfaa7fe1aa223d824ed166f62
3
+ size 45073513
checkpoints/epoch_16/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
3
+ size 110352698
checkpoints/epoch_16/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea4cf81341750ebf517c696a0f6241bfeede0584b0ce75ad208e3ffc8280877f
3
+ size 748166934
checkpoints/epoch_16/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85b63363bc9c79732df78ba50ed19491ed86e961214bbd1f796a871334eba516
3
+ size 73190208
checkpoints/epoch_20/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c524f4e5df5f6ce91db1973a30de55299ebcbbde1edd2009718d3b4cd2631339
3
+ size 131239846
checkpoints/epoch_20/discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcd80950c796fcfe6e4b6bdeeb358776700458d868da94ee31df3d1d37779310
3
+ size 45073513
checkpoints/epoch_20/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
3
+ size 110352698
checkpoints/epoch_20/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b1163b210b246b07d8f1c50eb3766d97c6f03bf409c854d00b7c69edb6d7391
3
+ size 748166934
checkpoints/epoch_20/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:031e5f38cc79eb3c0ed51ca2ad3c8921fdda2fa05946c357f84881259de74e6d
3
+ size 73190208
cmd.txt DELETED
@@ -1,21 +0,0 @@
1
- python train.py --video_data_root_list datasets/images/images \
2
- --flow_data_root_list datasets/flow_fp16/flow_fp16 \
3
- --mask_data_root_list datasets/pgm/pgm \
4
- --data_root_imagenet datasets/imgnet \
5
- --annotation_file_path datasets/final_annot.csv \
6
- --imagenet_pairs_file datasets/pairs.txt \
7
- --gpu_ids 0 \
8
- --workers 12 \
9
- --batch_size 2 \
10
- --real_reference_probability 0.99 \
11
- --weight_contextual 1 \
12
- --weight_perceptual 0.1 \
13
- --weight_smoothness 5 \
14
- --weight_gan 0.9 \
15
- --weight_consistent 0.1 \
16
- --use_wandb True \
17
- --wandb_token "f05d31e6b15339b1cfc5ee1c77fe51f66fc3ea9e" \
18
- --wandb_name "vit_tiny_patch16_384_nofeat" \
19
- --checkpoint_step 500 \
20
- --epoch_train_discriminator 3 \
21
- --epoch 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cmd_ddp.txt DELETED
@@ -1,20 +0,0 @@
1
- !torchrun --nnodes=1 --nproc_per_node=2 train_ddp.py --video_data_root_list $video_data_root_list \
2
- --flow_data_root_list $flow_data_root_list \
3
- --mask_data_root_list $mask_data_root_list \
4
- --data_root_imagenet $data_root_imagenet \
5
- --annotation_file_path $annotation_file_path \
6
- --imagenet_pairs_file $imagenet_pairs_file \
7
- --gpu_ids "0,1" \
8
- --workers 2 \
9
- --batch_size 2 \
10
- --real_reference_probability 0.99 \
11
- --weight_contextual 1 \
12
- --weight_perceptual 0.1 \
13
- --weight_smoothness 5 \
14
- --weight_gan 0.9 \
15
- --weight_consistent 0.1 \
16
- --wandb_token "165e7148081f263b423722115e2ad40fa5339ecf" \
17
- --wandb_name "vit_tiny_patch16_384_nofeat" \
18
- --checkpoint_step 2000 \
19
- --epoch_train_discriminator 2 \
20
- --epoch 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
environment.yml DELETED
File without changes
examples/bear/ref.jpg DELETED
Binary file (30.9 kB)
 
examples/boat/ref.jpg DELETED
Binary file (65.4 kB)
 
examples/boat/video.mp4 DELETED
Binary file (853 kB)
 
examples/cows/ref.jpg DELETED
Binary file (252 kB)
 
examples/flamingo/ref.jpg DELETED
Binary file (539 kB)
 
examples/flamingo/video.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a103fd4991a00e419e5236b885fe9d220704ba0a6ac794c87aaa3f62a4f1561
3
- size 1239570
 
 
 
 
examples/man/ref.jpg DELETED
Binary file (176 kB)
 
examples/man/video.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dff54d74e38285d60e064a0332c66d3ca2860f3c05de814a63693a9c331e94c9
3
- size 1693420
 
 
 
 
examples/military/ref.jpg DELETED
Binary file (111 kB)
 
examples/military/video.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:02ce1717c2f5768af588a0bbeb47c659e54a310a880a52b68a8a7701647e145a
3
- size 1495376
 
 
 
 
gradio_cached_examples/13/log.csv DELETED
@@ -1,5 +0,0 @@
1
- output,flag,username,timestamp
2
- /content/ViTExCo/gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4,,,2023-08-15 09:45:37.897615
3
- /content/ViTExCo/gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4,,,2023-08-15 09:46:01.048997
4
- /content/ViTExCo/gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4,,,2023-08-15 09:46:34.503322
5
- /content/ViTExCo/gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4,,,2023-08-15 09:46:58.088903
 
 
 
 
 
 
gradio_cached_examples/13/output/003c3114319372a78bf2f812ebaf0041afa280fb/output_video.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b5ab666998e14fb00281a90f8801753eca001a432641ae2770007a8336b4c64e
3
- size 1213824
 
 
 
 
gradio_cached_examples/13/output/74c76e483235b7e80665e32d7fcdcc3da2be7644/output_video.mp4 DELETED
Binary file (914 kB)
 
gradio_cached_examples/13/output/7969adca8ae38cb3b38ff8e7bb54688d942c7bc8/output_video.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c367dab34e596f7f0fed34c7e2384525de2ba1824b410d0770bdbd17bc9e72a
3
- size 1793060
 
 
 
 
gradio_cached_examples/13/output/e6d6153dedeb9fec586b3241311cc49dbc17bc85/output_video.mp4 DELETED
Binary file (673 kB)
 
inputs/video.mp4/000000000.jpg DELETED
Binary file (113 kB)
 
inputs/video.mp4/000000001.jpg DELETED
Binary file (146 kB)
 
inputs/video.mp4/000000002.jpg DELETED
Binary file (143 kB)
 
inputs/video.mp4/000000003.jpg DELETED
Binary file (141 kB)
 
inputs/video.mp4/000000004.jpg DELETED
Binary file (142 kB)
 
inputs/video.mp4/000000005.jpg DELETED
Binary file (141 kB)