Leo8613 commited on
Commit
ff98d03
·
verified ·
1 Parent(s): b6a5d06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -33
app.py CHANGED
@@ -1,39 +1,41 @@
1
  import torch
2
- import torch.nn as nn
3
  import gradio as gr
4
  import cv2
5
  import numpy as np
6
 
7
- # Define your model architecture
8
- class YourModelArchitecture(nn.Module):
 
 
 
9
  def __init__(self):
10
  super(YourModelArchitecture, self).__init__()
11
- # Define the layers of your model here
12
- # Example: self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
13
-
14
- def forward(self, x):
15
- # Define the forward pass logic
16
- return x
17
 
18
- # Path to the model weights
19
- MODEL_PATH = 'ColorizeVideo_gen.pth'
 
20
 
21
- # Load the model function
22
  def load_model(model_path):
23
- model = YourModelArchitecture() # Initialize the model architecture
24
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load the model weights
25
- model.eval() # Set the model to evaluation mode
 
 
 
 
26
  return model
27
 
28
- # Preprocess the frame before passing it to the model
29
  def preprocess_frame(frame):
30
- # Resize and normalize the image
31
- frame = cv2.resize(frame, (224, 224)) # Resize to model input size
32
- frame = frame / 255.0 # Normalize the pixel values to [0, 1]
33
- input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1) # Convert to tensor and change the dimension order
34
- return input_tensor.unsqueeze(0) # Add batch dimension
35
 
36
- # Process the video, frame by frame
37
  def process_video(model, video_path):
38
  cap = cv2.VideoCapture(video_path)
39
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -45,36 +47,36 @@ def process_video(model, video_path):
45
  if not ret:
46
  break
47
 
48
- # Preprocess the frame
49
  input_tensor = preprocess_frame(frame)
50
 
51
- # Make predictions with the model
52
  with torch.no_grad():
53
  predictions = model(input_tensor)
54
 
55
- # Convert the predictions back to an image format
56
  output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
57
 
58
- # Write the processed frame to the output video
59
  out.write(output_frame)
60
 
61
  cap.release()
62
  out.release()
63
  return output_path
64
 
65
- # Gradio interface function
66
  def colorize_video(video):
67
  model = load_model(MODEL_PATH)
68
- output_video_path = process_video(model, video.name) # Use the video file name to read the video
69
  return output_video_path
70
 
71
- # Configure the Gradio interface
72
  iface = gr.Interface(
73
  fn=colorize_video,
74
- inputs=gr.Video(label="Upload a black and white video"),
75
- outputs=gr.Video(label="Colorized Video"),
76
- title="Video Colorization",
77
- description="Upload a black and white video to colorize it using the model."
78
  )
79
 
80
  if __name__ == '__main__':
 
1
  import torch
 
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
5
 
6
+ # Chemin vers le modèle
7
+ MODEL_PATH = 'ColorizeVideo_gen.pth'
8
+
9
+ # Définir l'architecture de votre modèle ici
10
+ class YourModelArchitecture(torch.nn.Module):
11
  def __init__(self):
12
  super(YourModelArchitecture, self).__init__()
13
+ # Définissez les couches de votre modèle
 
 
 
 
 
14
 
15
+ def forward(self, x):
16
+ # Définissez la logique de passage avant de votre modèle
17
+ return x # Modifiez ceci selon votre modèle
18
 
19
+ # Charger le modèle
20
  def load_model(model_path):
21
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Charger le checkpoint
22
+ model = YourModelArchitecture() # Initialiser l'architecture du modèle
23
+
24
+ # Charger uniquement les poids du modèle à partir du checkpoint
25
+ model.load_state_dict(checkpoint['model'])
26
+
27
+ model.eval() # Met le modèle en mode évaluation
28
  return model
29
 
30
+ # Prétraitement de l'image
31
  def preprocess_frame(frame):
32
+ # Redimensionner et normaliser
33
+ frame = cv2.resize(frame, (224, 224)) # Ajustez la taille si nécessaire
34
+ frame = frame / 255.0 # Normaliser
35
+ input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1) # Convertir en format Tensor
36
+ return input_tensor.unsqueeze(0) # Ajouter une dimension de lot
37
 
38
+ # Traitement de la vidéo
39
  def process_video(model, video_path):
40
  cap = cv2.VideoCapture(video_path)
41
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
47
  if not ret:
48
  break
49
 
50
+ # Prétraiter le cadre
51
  input_tensor = preprocess_frame(frame)
52
 
53
+ # Faire des prédictions
54
  with torch.no_grad():
55
  predictions = model(input_tensor)
56
 
57
+ # Traiter les prédictions et convertir en image
58
  output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
59
 
60
+ # Écrire le cadre traité dans la sortie
61
  out.write(output_frame)
62
 
63
  cap.release()
64
  out.release()
65
  return output_path
66
 
67
+ # Interface Gradio
68
  def colorize_video(video):
69
  model = load_model(MODEL_PATH)
70
+ output_video_path = process_video(model, video.name) # Utiliser le nom pour lire la vidéo
71
  return output_video_path
72
 
73
+ # Configuration de l'interface Gradio
74
  iface = gr.Interface(
75
  fn=colorize_video,
76
+ inputs=gr.Video(label="Téléchargez une vidéo"),
77
+ outputs=gr.Video(label="Vidéo colorisée"),
78
+ title="Colorisation de Vidéos",
79
+ description="Chargez une vidéo en noir et blanc et utilisez le modèle de colorisation pour obtenir une vidéo colorisée."
80
  )
81
 
82
  if __name__ == '__main__':