Colorize_video / app.py
Leo8613's picture
Update app.py
cbe01ed verified
import os
import gradio as gr
import torch
import cv2
class YourModelArchitecture(torch.nn.Module):
# Remplacez ceci par votre architecture réelle
def __init__(self):
super(YourModelArchitecture, self).__init__()
# Définissez votre architecture ici
def forward(self, x):
# Implémentez la méthode forward
return x
def load_model(model_path):
model = YourModelArchitecture()
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])
model.eval()
return model
def colorize_frame(frame, model):
# Effectuer la colorisation sur une frame ici
with torch.no_grad():
# Transformez l'image et passez-la par le modèle
colorized_frame = model(frame) # Assurez-vous que `frame` est correctement transformé
return colorized_frame
def colorize_video(video_path, model):
cap = cv2.VideoCapture(video_path)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output_path = "output_video.mp4"
out = cv2.VideoWriter(output_path, fourcc, 30, (int(cap.get(3)), int(cap.get(4))))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convertir l'image au format approprié
input_tensor = preprocess_frame(frame) # Implémentez cette fonction pour le prétraitement
colorized_frame = colorize_frame(input_tensor, model)
# Enregistrez chaque frame colorisée
out.write(colorized_frame.numpy()) # Assurez-vous que `colorized_frame` est converti en numpy array
cap.release()
out.release()
return output_path
def preprocess_frame(frame):
# Convertir l'image de BGR à RGB et normaliser
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = frame / 255.0 # Normaliser
# Convertir en tensor PyTorch
tensor_frame = torch.tensor(frame).permute(2, 0, 1).unsqueeze(0) # Ajouter la dimension du batch
return tensor_frame
def main(video_path):
model = load_model("ColorizeVideo_gen.pth")
output_video_path = colorize_video(video_path, model)
return output_video_path
iface = gr.Interface(fn=main, inputs="video", outputs="video")
iface.launch()