File size: 2,759 Bytes
f55b4f4
b6a5d06
195a17a
13e1420
 
56d08ee
b6a5d06
 
 
 
 
 
 
 
 
 
 
 
13e1420
56d08ee
b6a5d06
13e1420
b6a5d06
 
 
13e1420
3cdf272
b6a5d06
13e1420
b6a5d06
 
 
 
 
3cdf272
b6a5d06
195a17a
13e1420
 
195a17a
b6a5d06
091614b
13e1420
 
 
 
 
b6a5d06
13e1420
 
b6a5d06
13e1420
 
 
b6a5d06
13e1420
 
b6a5d06
13e1420
 
 
 
195a17a
13e1420
b6a5d06
195a17a
13e1420
b6a5d06
195a17a
 
b6a5d06
195a17a
 
b6a5d06
 
 
 
195a17a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import gradio as gr
import cv2
import numpy as np

# Define your model architecture
class YourModelArchitecture(nn.Module):
    def __init__(self):
        super(YourModelArchitecture, self).__init__()
        # Define the layers of your model here
        # Example: self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        # Define the forward pass logic
        return x

# Path to the model weights
MODEL_PATH = 'ColorizeVideo_gen.pth'

# Load the model function
def load_model(model_path):
    model = YourModelArchitecture()  # Initialize the model architecture
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))  # Load the model weights
    model.eval()  # Set the model to evaluation mode
    return model

# Preprocess the frame before passing it to the model
def preprocess_frame(frame):
    # Resize and normalize the image
    frame = cv2.resize(frame, (224, 224))  # Resize to model input size
    frame = frame / 255.0  # Normalize the pixel values to [0, 1]
    input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1)  # Convert to tensor and change the dimension order
    return input_tensor.unsqueeze(0)  # Add batch dimension

# Process the video, frame by frame
def process_video(model, video_path):
    cap = cv2.VideoCapture(video_path)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_path = "output_video.mp4"
    out = cv2.VideoWriter(output_path, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Preprocess the frame
        input_tensor = preprocess_frame(frame)

        # Make predictions with the model
        with torch.no_grad():
            predictions = model(input_tensor)
        
        # Convert the predictions back to an image format
        output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        
        # Write the processed frame to the output video
        out.write(output_frame)

    cap.release()
    out.release()
    return output_path

# Gradio interface function
def colorize_video(video):
    model = load_model(MODEL_PATH)
    output_video_path = process_video(model, video.name)  # Use the video file name to read the video
    return output_video_path

# Configure the Gradio interface
iface = gr.Interface(
    fn=colorize_video,
    inputs=gr.Video(label="Upload a black and white video"),
    outputs=gr.Video(label="Colorized Video"),
    title="Video Colorization",
    description="Upload a black and white video to colorize it using the model."
)

if __name__ == '__main__':
    iface.launch()