Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
import cv2 | |
class YourModelArchitecture(torch.nn.Module): | |
# Remplacez ceci par votre architecture réelle | |
def __init__(self): | |
super(YourModelArchitecture, self).__init__() | |
# Définissez votre architecture ici | |
def forward(self, x): | |
# Implémentez la méthode forward | |
return x | |
def load_model(model_path): | |
model = YourModelArchitecture() | |
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint['model']) | |
model.eval() | |
return model | |
def colorize_frame(frame, model): | |
# Effectuer la colorisation sur une frame ici | |
with torch.no_grad(): | |
# Transformez l'image et passez-la par le modèle | |
colorized_frame = model(frame) # Assurez-vous que `frame` est correctement transformé | |
return colorized_frame | |
def colorize_video(video_path, model): | |
cap = cv2.VideoCapture(video_path) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
output_path = "output_video.mp4" | |
out = cv2.VideoWriter(output_path, fourcc, 30, (int(cap.get(3)), int(cap.get(4)))) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convertir l'image au format approprié | |
input_tensor = preprocess_frame(frame) # Implémentez cette fonction pour le prétraitement | |
colorized_frame = colorize_frame(input_tensor, model) | |
# Enregistrez chaque frame colorisée | |
out.write(colorized_frame.numpy()) # Assurez-vous que `colorized_frame` est converti en numpy array | |
cap.release() | |
out.release() | |
return output_path | |
def preprocess_frame(frame): | |
# Convertir l'image de BGR à RGB et normaliser | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = frame / 255.0 # Normaliser | |
# Convertir en tensor PyTorch | |
tensor_frame = torch.tensor(frame).permute(2, 0, 1).unsqueeze(0) # Ajouter la dimension du batch | |
return tensor_frame | |
def main(video_path): | |
model = load_model("ColorizeVideo_gen.pth") | |
output_video_path = colorize_video(video_path, model) | |
return output_video_path | |
iface = gr.Interface(fn=main, inputs="video", outputs="video") | |
iface.launch() | |