Leo8613 commited on
Commit
ce67fb9
·
verified ·
1 Parent(s): f45ee40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -2,8 +2,19 @@ import tensorflow as tf
2
  import numpy as np
3
  import cv2
4
  import gradio as gr
 
5
 
6
- # Fonction pour charger le modèle sans 'batch_shape'
 
 
 
 
 
 
 
 
 
 
7
  def load_model_safe(model_path):
8
  try:
9
  # Charger le modèle sans compilation pour éviter des erreurs liées au batch_shape
@@ -13,14 +24,16 @@ def load_model_safe(model_path):
13
  print(f"Erreur lors du chargement du modèle: {e}")
14
  return None
15
 
16
- # Charger le modèle en toute sécurité
17
- generator = load_model_safe('generator.h5') # Assurez-vous que le fichier 'generator.h5' est dans le même répertoire
18
 
19
- # Vérifier que le modèle a bien été chargé
20
- if generator is None:
21
- print("Le modèle n'a pas pu être chargé. Vérifiez le fichier 'generator.h5'.")
 
22
  else:
23
- print("Modèle chargé avec succès.")
 
24
 
25
  # Fonction pour générer une vidéo à partir du générateur
26
  def generate_video():
@@ -31,11 +44,22 @@ def generate_video():
31
  noise = np.random.normal(0, 1, (1, 16, 64, 64, 3)) # Exemple de bruit pour 16 frames de 64x64x3
32
  generated_video = generator.predict(noise) # Générer la vidéo
33
 
 
 
 
34
  # Normaliser les données générées
35
  video = (generated_video[0] * 255).astype(np.uint8) # Convertir en entier 8 bits
36
- filename = "/content/generated_video.mp4"
 
 
 
 
 
 
 
 
37
 
38
- # Sauvegarder la vidéo
39
  fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec vidéo
40
  height, width, _ = video[0].shape
41
  out = cv2.VideoWriter(filename, fourcc, 15, (width, height)) # 15 FPS
@@ -45,6 +69,7 @@ def generate_video():
45
  out.write(frame)
46
  out.release()
47
 
 
48
  return filename
49
 
50
  # Interface Gradio
@@ -57,7 +82,7 @@ def interface_function():
57
  gr.Interface(
58
  fn=interface_function, # Fonction à exécuter
59
  inputs=[], # Pas d'entrée utilisateur
60
- outputs=gr.Video(label="Vidéo générée"), # Vidéo en sortie
61
  title="Générateur de Vidéos avec IA",
62
  description="Cliquez sur le bouton ci-dessous pour générer une vidéo aléatoire avec l'IA."
63
  ).launch(server_name="0.0.0.0", server_port=7860)
 
2
  import numpy as np
3
  import cv2
4
  import gradio as gr
5
+ import os
6
 
7
+ # Fonction pour chercher le modèle dans tout le système de fichiers
8
+ def find_model_file(filename="generator.h5"):
9
+ # Parcourir tous les répertoires du serveur pour trouver le fichier
10
+ for root, dirs, files in os.walk("/"):
11
+ if filename in files:
12
+ model_path = os.path.join(root, filename)
13
+ print(f"Modèle trouvé à : {model_path}")
14
+ return model_path
15
+ return None
16
+
17
+ # Fonction pour charger le modèle en toute sécurité
18
  def load_model_safe(model_path):
19
  try:
20
  # Charger le modèle sans compilation pour éviter des erreurs liées au batch_shape
 
24
  print(f"Erreur lors du chargement du modèle: {e}")
25
  return None
26
 
27
+ # Chercher le modèle 'generator.h5' dans le système
28
+ model_path = find_model_file("generator.h5")
29
 
30
+ # Vérifier si le modèle est trouvé et charger
31
+ if model_path is not None:
32
+ generator = load_model_safe(model_path)
33
+ print(f"Modèle chargé depuis : {model_path}")
34
  else:
35
+ print("Le modèle 'generator.h5' n'a pas été trouvé sur le serveur.")
36
+ generator = None
37
 
38
  # Fonction pour générer une vidéo à partir du générateur
39
  def generate_video():
 
44
  noise = np.random.normal(0, 1, (1, 16, 64, 64, 3)) # Exemple de bruit pour 16 frames de 64x64x3
45
  generated_video = generator.predict(noise) # Générer la vidéo
46
 
47
+ # Vérifier la forme des données générées
48
+ print(f"Shape of generated video: {generated_video.shape}")
49
+
50
  # Normaliser les données générées
51
  video = (generated_video[0] * 255).astype(np.uint8) # Convertir en entier 8 bits
52
+
53
+ # Vérifier si la vidéo générée a la bonne forme
54
+ if len(video.shape) != 4 or video.shape[0] != 16:
55
+ return "Erreur dans les dimensions de la vidéo générée."
56
+
57
+ # Créer le répertoire pour la vidéo
58
+ output_dir = "/content/generated_videos"
59
+ os.makedirs(output_dir, exist_ok=True)
60
+ filename = os.path.join(output_dir, "generated_video.mp4")
61
 
62
+ # Enregistrer la vidéo avec OpenCV
63
  fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec vidéo
64
  height, width, _ = video[0].shape
65
  out = cv2.VideoWriter(filename, fourcc, 15, (width, height)) # 15 FPS
 
69
  out.write(frame)
70
  out.release()
71
 
72
+ # Retourner le chemin de la vidéo générée
73
  return filename
74
 
75
  # Interface Gradio
 
82
  gr.Interface(
83
  fn=interface_function, # Fonction à exécuter
84
  inputs=[], # Pas d'entrée utilisateur
85
+ outputs=gr.Video(label="Vidéo générée", type="file"), # Vidéo en sortie avec type 'file' pour le téléchargement
86
  title="Générateur de Vidéos avec IA",
87
  description="Cliquez sur le bouton ci-dessous pour générer une vidéo aléatoire avec l'IA."
88
  ).launch(server_name="0.0.0.0", server_port=7860)