Leo8613 commited on
Commit
20654d0
·
verified ·
1 Parent(s): 8eaf927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -35
app.py CHANGED
@@ -1,48 +1,66 @@
1
  import os
2
  import gradio as gr
3
  import torch
 
4
 
5
- # Define your model architecture
6
  class YourModelArchitecture(torch.nn.Module):
 
7
  def __init__(self):
8
  super(YourModelArchitecture, self).__init__()
9
- # Initialize your model layers here
10
- # Example: self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
11
 
12
  def forward(self, x):
13
- # Define forward pass here
14
- return x # Change this to return the output of your model
15
 
16
- # Load model function
17
  def load_model(model_path):
18
- checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Load checkpoint
19
- model = YourModelArchitecture() # Initialize your model architecture
20
-
21
- # Load only model weights from checkpoint, ignoring unexpected keys
22
- model.load_state_dict(checkpoint['model'], strict=False) # Use strict=False
23
-
24
- model.eval() # Set model to evaluation mode
25
  return model
26
 
27
- # Colorize video function
28
- def colorize_video(video):
29
- model = load_model(MODEL_PATH) # Load the model
30
- # Add your video processing logic here
31
- return "Processed video output" # Replace with actual output
32
-
33
- # Gradio interface setup
34
- def create_interface():
35
- interface = gr.Interface(
36
- fn=colorize_video,
37
- inputs=gr.Video(label="Upload Video"),
38
- outputs=gr.Video(label="Colorized Video"),
39
- title="Video Colorizer",
40
- description="Upload a video to colorize it using a trained model.",
41
- )
42
- return interface
43
-
44
- if __name__ == "__main__":
45
- MODEL_PATH = "ColorizeVideo_gen.pth" # Define your model path
46
-
47
- interface = create_interface()
48
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ import cv2
5
 
 
6
  class YourModelArchitecture(torch.nn.Module):
7
+ # Remplacez ceci par votre architecture réelle
8
  def __init__(self):
9
  super(YourModelArchitecture, self).__init__()
10
+ # Définissez votre architecture ici
 
11
 
12
  def forward(self, x):
13
+ # Implémentez la méthode forward
14
+ return x
15
 
 
16
  def load_model(model_path):
17
+ model = YourModelArchitecture()
18
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
19
+ model.load_state_dict(checkpoint['model'])
20
+ model.eval()
 
 
 
21
  return model
22
 
23
+ def colorize_frame(frame, model):
24
+ # Effectuer la colorisation sur une frame ici
25
+ with torch.no_grad():
26
+ # Transformez l'image et passez-la par le modèle
27
+ colorized_frame = model(frame) # Assurez-vous que `frame` est correctement transformé
28
+ return colorized_frame
29
+
30
+ def colorize_video(video_path, model):
31
+ cap = cv2.VideoCapture(video_path)
32
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
33
+ output_path = "output_video.mp4"
34
+ out = cv2.VideoWriter(output_path, fourcc, 30, (int(cap.get(3)), int(cap.get(4))))
35
+
36
+ while cap.isOpened():
37
+ ret, frame = cap.read()
38
+ if not ret:
39
+ break
40
+
41
+ # Convertir l'image au format approprié
42
+ input_tensor = preprocess_frame(frame) # Implémentez cette fonction pour le prétraitement
43
+ colorized_frame = colorize_frame(input_tensor, model)
44
+
45
+ # Enregistrez chaque frame colorisée
46
+ out.write(colorized_frame.numpy()) # Assurez-vous que `colorized_frame` est converti en numpy array
47
+
48
+ cap.release()
49
+ out.release()
50
+ return output_path
51
+
52
+ def preprocess_frame(frame):
53
+ # Convertir l'image de BGR à RGB et normaliser
54
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55
+ frame = frame / 255.0 # Normaliser
56
+ # Convertir en tensor PyTorch
57
+ tensor_frame = torch.tensor(frame).permute(2, 0, 1).unsqueeze(0) # Ajouter la dimension du batch
58
+ return tensor_frame
59
+
60
+ def main(video_path):
61
+ model = load_model("model.pth")
62
+ output_video_path = colorize_video(video_path, model)
63
+ return output_video_path
64
+
65
+ iface = gr.Interface(fn=main, inputs="video", outputs="video")
66
+ iface.launch()