Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
# Define your model architecture | |
class YourModelArchitecture(nn.Module): | |
def __init__(self): | |
super(YourModelArchitecture, self).__init__() | |
# Define the layers of your model here | |
# Example: self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) | |
def forward(self, x): | |
# Define the forward pass logic | |
return x | |
# Path to the model weights | |
MODEL_PATH = 'ColorizeVideo_gen.pth' | |
# Load the model function | |
def load_model(model_path): | |
model = YourModelArchitecture() # Initialize the model architecture | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load the model weights | |
model.eval() # Set the model to evaluation mode | |
return model | |
# Preprocess the frame before passing it to the model | |
def preprocess_frame(frame): | |
# Resize and normalize the image | |
frame = cv2.resize(frame, (224, 224)) # Resize to model input size | |
frame = frame / 255.0 # Normalize the pixel values to [0, 1] | |
input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1) # Convert to tensor and change the dimension order | |
return input_tensor.unsqueeze(0) # Add batch dimension | |
# Process the video, frame by frame | |
def process_video(model, video_path): | |
cap = cv2.VideoCapture(video_path) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
output_path = "output_video.mp4" | |
out = cv2.VideoWriter(output_path, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4)))) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Preprocess the frame | |
input_tensor = preprocess_frame(frame) | |
# Make predictions with the model | |
with torch.no_grad(): | |
predictions = model(input_tensor) | |
# Convert the predictions back to an image format | |
output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
# Write the processed frame to the output video | |
out.write(output_frame) | |
cap.release() | |
out.release() | |
return output_path | |
# Gradio interface function | |
def colorize_video(video): | |
model = load_model(MODEL_PATH) | |
output_video_path = process_video(model, video.name) # Use the video file name to read the video | |
return output_video_path | |
# Configure the Gradio interface | |
iface = gr.Interface( | |
fn=colorize_video, | |
inputs=gr.Video(label="Upload a black and white video"), | |
outputs=gr.Video(label="Colorized Video"), | |
title="Video Colorization", | |
description="Upload a black and white video to colorize it using the model." | |
) | |
if __name__ == '__main__': | |
iface.launch() | |