Spaces:
Sleeping
Sleeping
Update src/inference.py
Browse files- src/inference.py +3 -7
src/inference.py
CHANGED
|
@@ -17,7 +17,6 @@ from src.utils import (
|
|
| 17 |
tensor_lab2rgb
|
| 18 |
)
|
| 19 |
import numpy as np
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
|
| 22 |
class SwinTExCo:
|
| 23 |
def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
|
|
@@ -62,13 +61,13 @@ class SwinTExCo:
|
|
| 62 |
size=(H,W),
|
| 63 |
mode="bilinear",
|
| 64 |
align_corners=False)
|
| 65 |
-
large_IA_l = torch.cat((large_IA_l, large_current_ab_predict
|
| 66 |
large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
|
| 67 |
-
return large_current_rgb_predict
|
| 68 |
|
| 69 |
def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
|
| 70 |
large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
|
| 71 |
-
large_IA_l = large_IA_lab[:, 0:1, :, :]
|
| 72 |
|
| 73 |
IA_lab = self.processor(curr_frame)
|
| 74 |
IA_lab = IA_lab.unsqueeze(0).to(self.device)
|
|
@@ -113,9 +112,7 @@ class SwinTExCo:
|
|
| 113 |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
|
| 114 |
features_B = self.embed_net(I_reference_rgb)
|
| 115 |
|
| 116 |
-
#PBAR = tqdm(total=int(video.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Colorizing video", unit="frame")
|
| 117 |
while video.isOpened():
|
| 118 |
-
#PBAR.update(1)
|
| 119 |
ret, curr_frame = video.read()
|
| 120 |
|
| 121 |
if not ret:
|
|
@@ -130,7 +127,6 @@ class SwinTExCo:
|
|
| 130 |
|
| 131 |
yield IA_predict_rgb
|
| 132 |
|
| 133 |
-
#PBAR.close()
|
| 134 |
video.release()
|
| 135 |
|
| 136 |
def predict_image(self, image, ref_image):
|
|
|
|
| 17 |
tensor_lab2rgb
|
| 18 |
)
|
| 19 |
import numpy as np
|
|
|
|
| 20 |
|
| 21 |
class SwinTExCo:
|
| 22 |
def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
|
|
|
|
| 61 |
size=(H,W),
|
| 62 |
mode="bilinear",
|
| 63 |
align_corners=False)
|
| 64 |
+
large_IA_l = torch.cat((large_IA_l, large_current_ab_predict), dim=1)
|
| 65 |
large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
|
| 66 |
+
return large_current_rgb_predict.cpu()
|
| 67 |
|
| 68 |
def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
|
| 69 |
large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
|
| 70 |
+
large_IA_l = large_IA_lab[:, 0:1, :, :].to(self.device)
|
| 71 |
|
| 72 |
IA_lab = self.processor(curr_frame)
|
| 73 |
IA_lab = IA_lab.unsqueeze(0).to(self.device)
|
|
|
|
| 112 |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
|
| 113 |
features_B = self.embed_net(I_reference_rgb)
|
| 114 |
|
|
|
|
| 115 |
while video.isOpened():
|
|
|
|
| 116 |
ret, curr_frame = video.read()
|
| 117 |
|
| 118 |
if not ret:
|
|
|
|
| 127 |
|
| 128 |
yield IA_predict_rgb
|
| 129 |
|
|
|
|
| 130 |
video.release()
|
| 131 |
|
| 132 |
def predict_image(self, image, ref_image):
|