Leo8613 commited on
Commit
7c05905
·
verified ·
1 Parent(s): ff98d03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -69
app.py CHANGED
@@ -1,83 +1,47 @@
1
- import torch
2
  import gradio as gr
3
- import cv2
4
- import numpy as np
5
-
6
- # Chemin vers le modèle
7
- MODEL_PATH = 'ColorizeVideo_gen.pth'
8
 
9
- # Définir l'architecture de votre modèle ici
10
  class YourModelArchitecture(torch.nn.Module):
11
  def __init__(self):
12
  super(YourModelArchitecture, self).__init__()
13
- # Définissez les couches de votre modèle
 
14
 
15
  def forward(self, x):
16
- # Définissez la logique de passage avant de votre modèle
17
- return x # Modifiez ceci selon votre modèle
18
 
19
- # Charger le modèle
20
  def load_model(model_path):
21
- checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Charger le checkpoint
22
- model = YourModelArchitecture() # Initialiser l'architecture du modèle
23
 
24
- # Charger uniquement les poids du modèle à partir du checkpoint
25
- model.load_state_dict(checkpoint['model'])
26
 
27
- model.eval() # Met le modèle en mode évaluation
28
  return model
29
 
30
- # Prétraitement de l'image
31
- def preprocess_frame(frame):
32
- # Redimensionner et normaliser
33
- frame = cv2.resize(frame, (224, 224)) # Ajustez la taille si nécessaire
34
- frame = frame / 255.0 # Normaliser
35
- input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1) # Convertir en format Tensor
36
- return input_tensor.unsqueeze(0) # Ajouter une dimension de lot
37
-
38
- # Traitement de la vidéo
39
- def process_video(model, video_path):
40
- cap = cv2.VideoCapture(video_path)
41
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
42
- output_path = "output_video.mp4"
43
- out = cv2.VideoWriter(output_path, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))
44
-
45
- while cap.isOpened():
46
- ret, frame = cap.read()
47
- if not ret:
48
- break
49
-
50
- # Prétraiter le cadre
51
- input_tensor = preprocess_frame(frame)
52
-
53
- # Faire des prédictions
54
- with torch.no_grad():
55
- predictions = model(input_tensor)
56
-
57
- # Traiter les prédictions et convertir en image
58
- output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
59
-
60
- # Écrire le cadre traité dans la sortie
61
- out.write(output_frame)
62
-
63
- cap.release()
64
- out.release()
65
- return output_path
66
-
67
- # Interface Gradio
68
  def colorize_video(video):
69
- model = load_model(MODEL_PATH)
70
- output_video_path = process_video(model, video.name) # Utiliser le nom pour lire la vidéo
71
- return output_video_path
72
-
73
- # Configuration de l'interface Gradio
74
- iface = gr.Interface(
75
- fn=colorize_video,
76
- inputs=gr.Video(label="Téléchargez une vidéo"),
77
- outputs=gr.Video(label="Vidéo colorisée"),
78
- title="Colorisation de Vidéos",
79
- description="Chargez une vidéo en noir et blanc et utilisez le modèle de colorisation pour obtenir une vidéo colorisée."
80
- )
81
-
82
- if __name__ == '__main__':
83
- iface.launch()
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ import torch
 
 
 
 
4
 
5
+ # Define your model architecture
6
  class YourModelArchitecture(torch.nn.Module):
7
  def __init__(self):
8
  super(YourModelArchitecture, self).__init__()
9
+ # Initialize your model layers here
10
+ # Example: self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
11
 
12
  def forward(self, x):
13
+ # Define forward pass here
14
+ return x # Change this to return the output of your model
15
 
16
+ # Load model function
17
  def load_model(model_path):
18
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Load checkpoint
19
+ model = YourModelArchitecture() # Initialize your model architecture
20
 
21
+ # Load only model weights from checkpoint, ignoring unexpected keys
22
+ model.load_state_dict(checkpoint['model'], strict=False) # Use strict=False
23
 
24
+ model.eval() # Set model to evaluation mode
25
  return model
26
 
27
+ # Colorize video function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def colorize_video(video):
29
+ model = load_model(MODEL_PATH) # Load the model
30
+ # Add your video processing logic here
31
+ return "Processed video output" # Replace with actual output
32
+
33
+ # Gradio interface setup
34
+ def create_interface():
35
+ interface = gr.Interface(
36
+ fn=colorize_video,
37
+ inputs=gr.Video(label="Upload Video"),
38
+ outputs=gr.Video(label="Colorized Video"),
39
+ title="Video Colorizer",
40
+ description="Upload a video to colorize it using a trained model.",
41
+ )
42
+ return interface
43
+
44
+ if __name__ == "__main__":
45
+ MODEL_PATH = "path/to/your/model.pth" # Define your model path
46
+ interface = create_interface()
47
+ interface.launch()