Leo8613 commited on
Commit
b6a5d06
·
verified ·
1 Parent(s): c6969cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -27
app.py CHANGED
@@ -1,71 +1,80 @@
1
  import torch
 
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
5
 
6
- # Chemin vers le modèle
 
 
 
 
 
 
 
 
 
 
 
7
  MODEL_PATH = 'ColorizeVideo_gen.pth'
8
 
9
- # Charger le modèle
10
  def load_model(model_path):
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model = torch.load(model_path, map_location=device)
13
- model.eval()
14
  return model
15
 
16
- # Prétraitement de l'image
17
  def preprocess_frame(frame):
18
- # Redimensionner et normaliser
19
- frame = cv2.resize(frame, (224, 224)) # Ajustez la taille si nécessaire
20
- frame = frame / 255.0 # Normaliser
21
- input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1)
22
- return input_tensor.unsqueeze(0)
23
 
24
- # Traitement de la vidéo
25
  def process_video(model, video_path):
26
  cap = cv2.VideoCapture(video_path)
27
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
28
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
29
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
30
  output_path = "output_video.mp4"
31
- out = cv2.VideoWriter(output_path, fourcc, 30.0, (width, height))
32
 
33
  while cap.isOpened():
34
  ret, frame = cap.read()
35
  if not ret:
36
  break
37
 
38
- # Prétraiter le cadre
39
  input_tensor = preprocess_frame(frame)
40
 
41
- # Faire des prédictions
42
  with torch.no_grad():
43
  predictions = model(input_tensor)
44
 
45
- # Convertir en image
46
  output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
47
- output_frame = cv2.resize(output_frame, (frame.shape[1], frame.shape[0])) # Rétablir la taille originale
48
 
49
- # Écrire le cadre traité dans la sortie
50
  out.write(output_frame)
51
 
52
  cap.release()
53
  out.release()
54
  return output_path
55
 
56
- # Interface Gradio
57
  def colorize_video(video):
58
  model = load_model(MODEL_PATH)
59
- output_video_path = process_video(model, video.name)
60
  return output_video_path
61
 
62
- # Configuration de l'interface Gradio
63
  iface = gr.Interface(
64
  fn=colorize_video,
65
- inputs=gr.Video(label="Téléchargez une vidéo"),
66
- outputs=gr.Video(label="Vidéo colorisée"),
67
- title="Colorisation de Vidéos",
68
- description="Chargez une vidéo en noir et blanc et utilisez le modèle de colorisation pour obtenir une vidéo colorisée."
69
  )
70
 
71
  if __name__ == '__main__':
 
1
  import torch
2
+ import torch.nn as nn
3
  import gradio as gr
4
  import cv2
5
  import numpy as np
6
 
7
+ # Define your model architecture
8
+ class YourModelArchitecture(nn.Module):
9
+ def __init__(self):
10
+ super(YourModelArchitecture, self).__init__()
11
+ # Define the layers of your model here
12
+ # Example: self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
13
+
14
+ def forward(self, x):
15
+ # Define the forward pass logic
16
+ return x
17
+
18
+ # Path to the model weights
19
  MODEL_PATH = 'ColorizeVideo_gen.pth'
20
 
21
+ # Load the model function
22
  def load_model(model_path):
23
+ model = YourModelArchitecture() # Initialize the model architecture
24
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load the model weights
25
+ model.eval() # Set the model to evaluation mode
26
  return model
27
 
28
+ # Preprocess the frame before passing it to the model
29
  def preprocess_frame(frame):
30
+ # Resize and normalize the image
31
+ frame = cv2.resize(frame, (224, 224)) # Resize to model input size
32
+ frame = frame / 255.0 # Normalize the pixel values to [0, 1]
33
+ input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1) # Convert to tensor and change the dimension order
34
+ return input_tensor.unsqueeze(0) # Add batch dimension
35
 
36
+ # Process the video, frame by frame
37
  def process_video(model, video_path):
38
  cap = cv2.VideoCapture(video_path)
 
 
39
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
40
  output_path = "output_video.mp4"
41
+ out = cv2.VideoWriter(output_path, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))
42
 
43
  while cap.isOpened():
44
  ret, frame = cap.read()
45
  if not ret:
46
  break
47
 
48
+ # Preprocess the frame
49
  input_tensor = preprocess_frame(frame)
50
 
51
+ # Make predictions with the model
52
  with torch.no_grad():
53
  predictions = model(input_tensor)
54
 
55
+ # Convert the predictions back to an image format
56
  output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
 
57
 
58
+ # Write the processed frame to the output video
59
  out.write(output_frame)
60
 
61
  cap.release()
62
  out.release()
63
  return output_path
64
 
65
+ # Gradio interface function
66
  def colorize_video(video):
67
  model = load_model(MODEL_PATH)
68
+ output_video_path = process_video(model, video.name) # Use the video file name to read the video
69
  return output_video_path
70
 
71
+ # Configure the Gradio interface
72
  iface = gr.Interface(
73
  fn=colorize_video,
74
+ inputs=gr.Video(label="Upload a black and white video"),
75
+ outputs=gr.Video(label="Colorized Video"),
76
+ title="Video Colorization",
77
+ description="Upload a black and white video to colorize it using the model."
78
  )
79
 
80
  if __name__ == '__main__':