Ahmadkhan12 commited on
Commit
4226ecb
·
verified ·
1 Parent(s): 5446064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -5
app.py CHANGED
@@ -10,8 +10,8 @@ import logging
10
  from scipy.io.wavfile import write as write_wav
11
  from scipy import signal
12
  from moviepy.editor import VideoFileClip, AudioFileClip
13
- from transformers import AutoProcessor, AutoModelForAudioGeneration
14
- import requests # Add this line
15
 
16
  # Set up logging for better debug tracking
17
  logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -48,9 +48,9 @@ with open("categories_places365.txt", "r") as f:
48
  try:
49
  logging.info("Loading AudioGen Medium and MusicGen Medium models...")
50
  audiogen_processor = AutoProcessor.from_pretrained("facebook/audiogen-medium")
51
- audiogen_model = AutoModelForAudioGeneration.from_pretrained("facebook/audiogen-medium")
52
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
53
- musicgen_model = AutoModelForAudioGeneration.from_pretrained("facebook/musicgen-medium")
54
 
55
  # Move models to GPU if available
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -61,4 +61,159 @@ except Exception as e:
61
  logging.error(f"Error loading AudioGen/MusicGen models: {e}")
62
  raise
63
 
64
- # Rest of the code remains the same...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from scipy.io.wavfile import write as write_wav
11
  from scipy import signal
12
  from moviepy.editor import VideoFileClip, AudioFileClip
13
+ from transformers import AutoProcessor, AutoModelForCausalLM
14
+ import requests
15
 
16
  # Set up logging for better debug tracking
17
  logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
 
48
  try:
49
  logging.info("Loading AudioGen Medium and MusicGen Medium models...")
50
  audiogen_processor = AutoProcessor.from_pretrained("facebook/audiogen-medium")
51
+ audiogen_model = AutoModelForCausalLM.from_pretrained("facebook/audiogen-medium")
52
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
53
+ musicgen_model = AutoModelForCausalLM.from_pretrained("facebook/musicgen-medium")
54
 
55
  # Move models to GPU if available
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
61
  logging.error(f"Error loading AudioGen/MusicGen models: {e}")
62
  raise
63
 
64
+ # Function to classify a frame using Places365
65
+ def classify_frame(frame):
66
+ try:
67
+ preprocess = transforms.Compose([
68
+ transforms.Resize(128), # Smaller resolution
69
+ transforms.CenterCrop(128),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
72
+ ])
73
+ img = Image.fromarray(frame)
74
+ img = preprocess(img).unsqueeze(0)
75
+ with torch.no_grad():
76
+ output = places365(img.to("cpu")) # Ensure inference on CPU
77
+ probabilities = F.softmax(output, dim=1)
78
+ _, predicted = torch.max(probabilities, 1)
79
+ predicted_index = predicted.item()
80
+
81
+ # Ensure the predicted index is within the range of SCENE_CLASSES
82
+ if predicted_index >= len(SCENE_CLASSES) or predicted_index < 0:
83
+ logging.warning(f"Predicted class index {predicted_index} is out of range. Defaulting to 'nature'.")
84
+ return "nature" # Default scene type
85
+
86
+ scene_type = SCENE_CLASSES[predicted_index]
87
+ logging.info(f"Predicted scene: {scene_type}")
88
+ return scene_type
89
+ except Exception as e:
90
+ logging.error(f"Error classifying frame: {e}")
91
+ raise
92
+
93
+ # Function to analyze video content and return the scene type using Places365
94
+ def analyze_video(video_path):
95
+ try:
96
+ logging.info(f"Analyzing video: {video_path}")
97
+ clip = VideoFileClip(video_path)
98
+ frame = clip.get_frame(0) # Get the first frame
99
+ frame = Image.fromarray(frame) # Convert to PIL image
100
+ frame = np.array(frame.resize((128, 128))) # Resize to reduce memory usage
101
+
102
+ # Classify the frame using Places365
103
+ scene_type = classify_frame(frame)
104
+ logging.info(f"Scene type detected: {scene_type}")
105
+ return scene_type
106
+ except Exception as e:
107
+ logging.error(f"Error analyzing video: {e}")
108
+ raise
109
+
110
+ # Function to generate audio using AudioGen Medium
111
+ def generate_audio_audiogen(scene, duration=10):
112
+ try:
113
+ logging.info(f"Generating audio for scene: {scene} using AudioGen Medium...")
114
+ inputs = audiogen_processor(
115
+ text=[f"Ambient sounds of {scene}"],
116
+ padding=True,
117
+ return_tensors="pt",
118
+ ).to(audiogen_model.device) # Move inputs to the same device as the model
119
+ with torch.no_grad():
120
+ audio = audiogen_model.generate(**inputs, max_new_tokens=duration * 50) # Adjust tokens for duration
121
+ audio = audio.cpu().numpy().squeeze()
122
+ audio_path = "generated_audio_audiogen.wav"
123
+ write_wav(audio_path, 16000, audio) # Save as WAV file
124
+ logging.info(f"Audio generated and saved to: {audio_path}")
125
+ return audio_path
126
+ except Exception as e:
127
+ logging.error(f"Error generating audio with AudioGen Medium: {e}")
128
+ raise
129
+
130
+ # Function to generate music using MusicGen Medium
131
+ def generate_music_musicgen(scene, duration=10):
132
+ try:
133
+ logging.info(f"Generating music for scene: {scene} using MusicGen Medium...")
134
+ inputs = musicgen_processor(
135
+ text=[f"Calm music for {scene}"],
136
+ padding=True,
137
+ return_tensors="pt",
138
+ ).to(musicgen_model.device) # Move inputs to the same device as the model
139
+ with torch.no_grad():
140
+ music = musicgen_model.generate(**inputs, max_new_tokens=duration * 50) # Adjust tokens for duration
141
+ music = music.cpu().numpy().squeeze()
142
+ music_path = "generated_music_musicgen.wav"
143
+ write_wav(music_path, 16000, music) # Save as WAV file
144
+ logging.info(f"Music generated and saved to: {music_path}")
145
+ return music_path
146
+ except Exception as e:
147
+ logging.error(f"Error generating music with MusicGen Medium: {e}")
148
+ raise
149
+
150
+ # Function to merge audio and video into a final video file using moviepy
151
+ def merge_audio_video(video_path, audio_path, output_path="output.mp4"):
152
+ try:
153
+ logging.info("Merging audio and video using moviepy...")
154
+ video_clip = VideoFileClip(video_path)
155
+ audio_clip = AudioFileClip(audio_path)
156
+ final_clip = video_clip.set_audio(audio_clip)
157
+ final_clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
158
+ logging.info(f"Final video saved to: {output_path}")
159
+ return output_path
160
+ except Exception as e:
161
+ logging.error(f"Error merging audio and video: {e}")
162
+ return None
163
+
164
+ # Main processing function to handle video upload, scene analysis, and video output
165
+ def process_video(video_path, progress=gr.Progress()):
166
+ try:
167
+ progress(0.1, desc="Starting video processing...")
168
+ logging.info("Starting video processing...")
169
+
170
+ # Analyze the video to determine the scene type
171
+ progress(0.3, desc="Analyzing video...")
172
+ scene_type = analyze_video(video_path)
173
+
174
+ # Generate audio using AudioGen Medium
175
+ progress(0.5, desc="Generating audio...")
176
+ audio_path = generate_audio_audiogen(scene_type, duration=10)
177
+
178
+ # Generate music using MusicGen Medium
179
+ progress(0.7, desc="Generating music...")
180
+ music_path = generate_music_musicgen(scene_type, duration=10)
181
+
182
+ # Merge the generated audio with the video and output the final video
183
+ progress(0.9, desc="Merging audio and video...")
184
+ output_path = merge_audio_video(video_path, music_path)
185
+ if not output_path:
186
+ return "Error: Failed to merge audio and video.", "Logs: Merge failed."
187
+
188
+ logging.info("Video processing completed successfully.")
189
+ return output_path, "Logs: Processing completed."
190
+ except Exception as e:
191
+ logging.error(f"Error in process_video: {e}")
192
+ return f"An error occurred during processing: {e}", f"Logs: {e}"
193
+
194
+ # Gradio UI for video upload
195
+ def gradio_interface(video_file, progress=gr.Progress()):
196
+ try:
197
+ progress(0.1, desc="Starting video processing...")
198
+ logging.info("Gradio interface triggered.")
199
+ output_video, logs = process_video(video_file, progress)
200
+ return output_video, logs
201
+ except Exception as e:
202
+ logging.error(f"Error in Gradio interface: {e}")
203
+ return f"An error occurred: {e}", f"Logs: {e}"
204
+
205
+ # Launch Gradio app
206
+ try:
207
+ logging.info("Launching Gradio app...")
208
+ interface = gr.Interface(
209
+ fn=gradio_interface,
210
+ inputs=[gr.Video(label="Upload Video")],
211
+ outputs=[gr.Video(label="Output Video with Generated Audio"), gr.Textbox(label="Logs", lines=10)],
212
+ title="Video to Video with Generated Audio and Music",
213
+ description="Upload a video, and this app will analyze it and generate matching audio and music using AudioGen Medium and MusicGen Medium."
214
+ )
215
+ interface.queue() # Enable queue for long-running tasks
216
+ interface.launch(share=True) # Launch the app
217
+ except Exception as e:
218
+ logging.error(f"Error launching Gradio app: {e}")
219
+ raise