Leo8613 commited on
Commit
13e1420
·
verified ·
1 Parent(s): 3cb3977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -14
app.py CHANGED
@@ -1,20 +1,55 @@
1
- from deoldify.visualize import get_video_colorizer
2
  import torch
 
 
3
 
4
- # Load the video colorization model
5
- def load_video_model():
6
- model_path = 'ColorizeVideo_gen.pth' # Path to the model in the root directory
7
- video_colorizer = get_video_colorizer()
8
 
9
- # Load the model's state from the .pth file
10
- state = torch.load(model_path, map_location=torch.device('cpu')) # Adjust if using GPU
11
- video_colorizer.learn.model.load_state_dict(state)
 
 
12
 
13
- return video_colorizer
 
 
 
 
 
 
14
 
15
- # Example usage of the video colorizer
16
- video_colorizer = load_video_model()
 
 
 
17
 
18
- # You can now use the colorizer to colorize a video
19
- video_path = 'your_video.mp4'
20
- colorized_video = video_colorizer.colorize_from_file_name(video_path, render_factor=35)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import cv2
3
+ import numpy as np
4
 
5
+ # Chemins vers le modèle et la vidéo à traiter
6
+ MODEL_PATH = 'ColorizeVideo_gen.pth'
7
+ VIDEO_PATH = '119195-716970703_small.mp4' # Nom de la vidéo exemple
8
+ OUTPUT_VIDEO_PATH = 'output_video.mp4'
9
 
10
+ # Charger le modèle
11
+ def load_model(model_path):
12
+ model = torch.load(model_path, map_location=torch.device('cpu')) # Charger sur le CPU
13
+ model.eval() # Met le modèle en mode évaluation
14
+ return model
15
 
16
+ # Prétraitement de l'image
17
+ def preprocess_frame(frame):
18
+ # Redimensionner et normaliser
19
+ frame = cv2.resize(frame, (224, 224)) # Ajustez la taille si nécessaire
20
+ frame = frame / 255.0 # Normaliser
21
+ input_tensor = torch.from_numpy(frame.astype(np.float32)).permute(2, 0, 1) # Convertir en format Tensor
22
+ return input_tensor.unsqueeze(0) # Ajouter une dimension de lot
23
 
24
+ # Traitement de la vidéo
25
+ def process_video(model, video_path, output_path):
26
+ cap = cv2.VideoCapture(video_path)
27
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
28
+ out = cv2.VideoWriter(output_path, fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))
29
 
30
+ while cap.isOpened():
31
+ ret, frame = cap.read()
32
+ if not ret:
33
+ break
34
+
35
+ # Prétraiter le cadre
36
+ input_tensor = preprocess_frame(frame)
37
+
38
+ # Faire des prédictions
39
+ with torch.no_grad():
40
+ predictions = model(input_tensor)
41
+
42
+ # Traiter les prédictions et convertir en image
43
+ output_frame = (predictions.squeeze().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
44
+
45
+ # Écrire le cadre traité dans la sortie
46
+ out.write(output_frame)
47
+
48
+ cap.release()
49
+ out.release()
50
+
51
+ # Fonction principale
52
+ if __name__ == '__main__':
53
+ model = load_model(MODEL_PATH)
54
+ process_video(model, VIDEO_PATH, OUTPUT_VIDEO_PATH)
55
+ print(f"Traitement de la vidéo terminé. Résultats enregistrés dans {OUTPUT_VIDEO_PATH}.")