English
music
emotion
kjysmu commited on
Commit
b036e59
·
verified ·
1 Parent(s): 2dfd92b

Upload 3 files

Browse files
Files changed (3) hide show
  1. demo.ipynb +72 -0
  2. music2emo.py +511 -0
  3. requirements.txt +0 -1
demo.ipynb ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "I music2emo 02-10 03:58:55.459 music2emo.py:280] audio file loaded and feature computation success : inference/input/test.mp3\n"
13
+ ]
14
+ },
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "\n",
20
+ "🎵 **Music Emotion Recognition Results** 🎵\n",
21
+ "--------------------------------------------------\n",
22
+ "🎭 **Predicted Mood Tags:** ballad, calm, film, hopeful, inspiring, love, meditative, melancholic, relaxing, romantic, sad, soft\n",
23
+ "💖 **Valence:** 5.42 (Scale: 1-9)\n",
24
+ "⚡ **Arousal:** 4.16 (Scale: 1-9)\n",
25
+ "--------------------------------------------------\n"
26
+ ]
27
+ }
28
+ ],
29
+ "source": [
30
+ "from music2emo import Music2emo\n",
31
+ "\n",
32
+ "input_audio = \"inference/input/test.mp3\"\n",
33
+ "\n",
34
+ "music2emo = Music2emo()\n",
35
+ "output_dic = music2emo.predict(input_audio)\n",
36
+ "\n",
37
+ "valence = output_dic[\"valence\"]\n",
38
+ "arousal = output_dic[\"arousal\"]\n",
39
+ "predicted_moods =output_dic[\"predicted_moods\"]\n",
40
+ "\n",
41
+ "print(\"\\n🎵 **Music Emotion Recognition Results** 🎵\")\n",
42
+ "print(\"-\" * 50)\n",
43
+ "print(f\"🎭 **Predicted Mood Tags:** {', '.join(predicted_moods) if predicted_moods else 'None'}\")\n",
44
+ "print(f\"💖 **Valence:** {valence:.2f} (Scale: 1-9)\")\n",
45
+ "print(f\"⚡ **Arousal:** {arousal:.2f} (Scale: 1-9)\")\n",
46
+ "print(\"-\" * 50)\n",
47
+ "\n"
48
+ ]
49
+ }
50
+ ],
51
+ "metadata": {
52
+ "kernelspec": {
53
+ "display_name": "music2emo",
54
+ "language": "python",
55
+ "name": "python3"
56
+ },
57
+ "language_info": {
58
+ "codemirror_mode": {
59
+ "name": "ipython",
60
+ "version": 3
61
+ },
62
+ "file_extension": ".py",
63
+ "mimetype": "text/x-python",
64
+ "name": "python",
65
+ "nbconvert_exporter": "python",
66
+ "pygments_lexer": "ipython3",
67
+ "version": "3.10.14"
68
+ }
69
+ },
70
+ "nbformat": 4,
71
+ "nbformat_minor": 2
72
+ }
music2emo.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import mir_eval
3
+ import pretty_midi as pm
4
+ from utils import logger
5
+ from utils.btc_model import BTC_model
6
+ # from preprocess.BTC.btc_model import *
7
+
8
+ from utils.transformer_modules import *
9
+ from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask
10
+ from utils.hparams import HParams
11
+
12
+
13
+ from utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths, get_lab_paths
14
+ import argparse
15
+ import warnings
16
+ from music21 import converter
17
+ import os
18
+ from tqdm import tqdm
19
+ import json
20
+ import torch
21
+ import torchaudio
22
+ import torchaudio.transforms as T
23
+ import numpy as np
24
+ from omegaconf import DictConfig
25
+ import hydra
26
+ from hydra.utils import to_absolute_path
27
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
28
+ from utils.mert import FeatureExtractorMERT
29
+ from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK
30
+ from pathlib import Path
31
+ import gradio as gr
32
+
33
+ import shutil
34
+ import warnings
35
+
36
+ import logging
37
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
38
+
39
+
40
+
41
+
42
+ # from gradio import Markdown
43
+
44
+ PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
45
+
46
+ pitch_num_dic = {
47
+ 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
48
+ 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
49
+ }
50
+
51
+ minor_major_dic = {
52
+ 'D-':'C#', 'E-':'D#', 'G-':'F#', 'A-':'G#', 'B-':'A#'
53
+ }
54
+ minor_major_dic2 = {
55
+ 'Db':'C#', 'Eb':'D#', 'Gb':'F#', 'Ab':'G#', 'Bb':'A#'
56
+ }
57
+
58
+ shift_major_dic = {
59
+ 'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5,
60
+ 'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11
61
+ }
62
+
63
+ shift_minor_dic = {
64
+ 'A': 0, 'A#': 1, 'B': 2, 'C': 3, 'C#': 4, 'D': 5,
65
+ 'D#': 6, 'E': 7, 'F': 8, 'F#': 9, 'G': 10, 'G#': 11,
66
+ }
67
+
68
+ flat_to_sharp_mapping = {
69
+ "Cb": "B",
70
+ "Db": "C#",
71
+ "Eb": "D#",
72
+ "Fb": "E",
73
+ "Gb": "F#",
74
+ "Ab": "G#",
75
+ "Bb": "A#"
76
+ }
77
+
78
+ segment_duration = 30
79
+ resample_rate = 24000
80
+ is_split = True
81
+
82
+ def normalize_chord(file_path, key, key_type='major'):
83
+ with open(file_path, 'r') as f:
84
+ lines = f.readlines()
85
+
86
+ if key == "None":
87
+ new_key = "C major"
88
+ shift = 0
89
+ else:
90
+ #print ("asdas",key)
91
+ if len(key) == 1:
92
+ key = key[0].upper()
93
+ else:
94
+ key = key[0].upper() + key[1:]
95
+
96
+ if key in minor_major_dic2:
97
+ key = minor_major_dic2[key]
98
+
99
+ shift = 0
100
+
101
+ if key_type == "major":
102
+ new_key = "C major"
103
+
104
+ shift = shift_major_dic[key]
105
+ else:
106
+ new_key = "A minor"
107
+ shift = shift_minor_dic[key]
108
+
109
+ converted_lines = []
110
+ for line in lines:
111
+ if line.strip(): # Skip empty lines
112
+ parts = line.split()
113
+ start_time = parts[0]
114
+ end_time = parts[1]
115
+ chord = parts[2] # The chord is in the 3rd column
116
+ if chord == "N":
117
+ newchordnorm = "N"
118
+ elif chord == "X":
119
+ newchordnorm = "X"
120
+ elif ":" in chord:
121
+ pitch = chord.split(":")[0]
122
+ attr = chord.split(":")[1]
123
+ pnum = pitch_num_dic [pitch]
124
+ new_idx = (pnum - shift)%12
125
+ newchord = PITCH_CLASS[new_idx]
126
+ newchordnorm = newchord + ":" + attr
127
+ else:
128
+ pitch = chord
129
+ pnum = pitch_num_dic [pitch]
130
+ new_idx = (pnum - shift)%12
131
+ newchord = PITCH_CLASS[new_idx]
132
+ newchordnorm = newchord
133
+
134
+ converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n")
135
+
136
+ return converted_lines
137
+
138
+ def sanitize_key_signature(key):
139
+ return key.replace('-', 'b')
140
+
141
+ def resample_waveform(waveform, original_sample_rate, target_sample_rate):
142
+ if original_sample_rate != target_sample_rate:
143
+ resampler = T.Resample(original_sample_rate, target_sample_rate)
144
+ return resampler(waveform), target_sample_rate
145
+ return waveform, original_sample_rate
146
+
147
+ def split_audio(waveform, sample_rate):
148
+ segment_samples = segment_duration * sample_rate
149
+ total_samples = waveform.size(0)
150
+
151
+ segments = []
152
+ for start in range(0, total_samples, segment_samples):
153
+ end = start + segment_samples
154
+ if end <= total_samples:
155
+ segment = waveform[start:end]
156
+ segments.append(segment)
157
+
158
+ # In case audio length is shorter than segment length.
159
+ if len(segments) == 0:
160
+ segment = waveform
161
+ segments.append(segment)
162
+
163
+ return segments
164
+
165
+
166
+ class Music2emo:
167
+ def __init__(
168
+ self,
169
+ model_weights = "saved_models/J_all.ckpt"
170
+ ):
171
+ use_cuda = torch.cuda.is_available()
172
+ self.device = torch.device("cuda" if use_cuda else "cpu")
173
+
174
+ self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate)
175
+ self.model_weights = model_weights
176
+
177
+ self.music2emo_model = FeedforwardModelMTAttnCK(
178
+ input_size= 768 * 2,
179
+ output_size_classification=56,
180
+ output_size_regression=2
181
+ )
182
+
183
+ checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False)
184
+ state_dict = checkpoint["state_dict"]
185
+
186
+ # Adjust the keys in the state_dict
187
+ state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()}
188
+
189
+ # Filter state_dict to match model's keys
190
+ model_keys = set(self.music2emo_model.state_dict().keys())
191
+ filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys}
192
+
193
+ # Load the filtered state_dict and set the model to evaluation mode
194
+ self.music2emo_model.load_state_dict(filtered_state_dict)
195
+
196
+ self.music2emo_model.to(self.device)
197
+ self.music2emo_model.eval()
198
+
199
+ def predict(self, audio, threshold = 0.5):
200
+
201
+ feature_dir = Path("./temp_out")
202
+ output_dir = Path("./output")
203
+ current_dir = Path("./")
204
+
205
+ if feature_dir.exists():
206
+ shutil.rmtree(str(feature_dir))
207
+ if output_dir.exists():
208
+ shutil.rmtree(str(output_dir))
209
+
210
+ feature_dir.mkdir(parents=True)
211
+ output_dir.mkdir(parents=True)
212
+
213
+ warnings.filterwarnings('ignore')
214
+ logger.logging_verbosity(1)
215
+
216
+ # use_cuda = torch.cuda.is_available()
217
+ # device = torch.device("cuda" if use_cuda else "cpu")
218
+
219
+ mert_dir = feature_dir / "mert"
220
+ mert_dir.mkdir(parents=True)
221
+
222
+
223
+ # args = parser.parse_args()
224
+
225
+ # --- MERT feature extract ---
226
+
227
+ waveform, sample_rate = torchaudio.load(audio)
228
+ if waveform.shape[0] > 1:
229
+ waveform = waveform.mean(dim=0).unsqueeze(0)
230
+ waveform = waveform.squeeze()
231
+ waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
232
+
233
+ if is_split:
234
+ segments = split_audio(waveform, sample_rate)
235
+ for i, segment in enumerate(segments):
236
+ segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
237
+ self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
238
+ else:
239
+ segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
240
+ self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
241
+
242
+ embeddings = []
243
+ layers_to_extract = [5,6]
244
+ segment_embeddings = []
245
+ for filename in sorted(os.listdir(mert_dir)): # Sort files to ensure sequential order
246
+ file_path = os.path.join(mert_dir, filename)
247
+ if os.path.isfile(file_path) and filename.endswith('.npy'):
248
+ segment = np.load(file_path)
249
+ concatenated_features = np.concatenate(
250
+ [segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1
251
+ )
252
+ concatenated_features = np.squeeze(concatenated_features) # Shape: 768 * 2 = 1536
253
+ segment_embeddings.append(concatenated_features)
254
+
255
+ segment_embeddings = np.array(segment_embeddings)
256
+ if len(segment_embeddings) > 0:
257
+ final_embedding_mert = np.mean(segment_embeddings, axis=0)
258
+ else:
259
+ final_embedding_mert = np.zeros((1536,))
260
+
261
+ final_embedding_mert = torch.from_numpy(final_embedding_mert)
262
+ final_embedding_mert.to(self.device)
263
+
264
+ # --- Chord feature extract ---
265
+ config = HParams.load("./inference/data/run_config.yaml")
266
+ config.feature['large_voca'] = True
267
+ config.model['num_chords'] = 170
268
+ model_file = './inference/data/btc_model_large_voca.pt'
269
+ idx_to_chord = idx2voca_chord()
270
+ model = BTC_model(config=config.model).to(self.device)
271
+
272
+ if os.path.isfile(model_file):
273
+ checkpoint = torch.load(model_file)
274
+ mean = checkpoint['mean']
275
+ std = checkpoint['std']
276
+ model.load_state_dict(checkpoint['model'])
277
+
278
+ audio_path = audio
279
+ audio_id = audio_path.split("/")[-1][:-4]
280
+ try:
281
+ feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
282
+ except:
283
+ logger.info("audio file failed to load : %s" % audio_path)
284
+ assert(False)
285
+
286
+ logger.info("audio file loaded and feature computation success : %s" % audio_path)
287
+
288
+ feature = feature.T
289
+ feature = (feature - mean) / std
290
+ time_unit = feature_per_second
291
+ n_timestep = config.model['timestep']
292
+
293
+ num_pad = n_timestep - (feature.shape[0] % n_timestep)
294
+ feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
295
+ num_instance = feature.shape[0] // n_timestep
296
+
297
+ start_time = 0.0
298
+ lines = []
299
+ with torch.no_grad():
300
+ model.eval()
301
+ feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device)
302
+ for t in range(num_instance):
303
+ self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
304
+ prediction, _ = model.output_layer(self_attn_output)
305
+ prediction = prediction.squeeze()
306
+ for i in range(n_timestep):
307
+ if t == 0 and i == 0:
308
+ prev_chord = prediction[i].item()
309
+ continue
310
+ if prediction[i].item() != prev_chord:
311
+ lines.append(
312
+ '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
313
+ start_time = time_unit * (n_timestep * t + i)
314
+ prev_chord = prediction[i].item()
315
+ if t == num_instance - 1 and i + num_pad == n_timestep:
316
+ if start_time != time_unit * (n_timestep * t + i):
317
+ lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
318
+ break
319
+
320
+ save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
321
+ with open(save_path, 'w') as f:
322
+ for line in lines:
323
+ f.write(line)
324
+
325
+ # logger.info("label file saved : %s" % save_path)
326
+
327
+ # lab file to midi file
328
+ starts, ends, pitchs = list(), list(), list()
329
+
330
+ intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
331
+ for p in range(12):
332
+ for i, (interval, chord) in enumerate(zip(intervals, chords)):
333
+ root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
334
+ tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
335
+ if i == 0:
336
+ start_time = interval[0]
337
+ label = tmp_label
338
+ continue
339
+ if tmp_label != label:
340
+ if label == 1.0:
341
+ starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
342
+ start_time = interval[0]
343
+ label = tmp_label
344
+ if i == (len(intervals) - 1):
345
+ if label == 1.0:
346
+ starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
347
+
348
+ midi = pm.PrettyMIDI()
349
+ instrument = pm.Instrument(program=0)
350
+
351
+ for start, end, pitch in zip(starts, ends, pitchs):
352
+ pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
353
+ instrument.notes.append(pm_note)
354
+
355
+ midi.instruments.append(instrument)
356
+ midi.write(save_path.replace('.lab', '.midi'))
357
+
358
+ tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
359
+ mode_signatures = ["major", "minor"] # Major and minor modes
360
+
361
+ tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)}
362
+ mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)}
363
+ idx_to_tonic = {idx: tonic for tonic, idx in tonic_to_idx.items()}
364
+ idx_to_mode = {idx: mode for mode, idx in mode_to_idx.items()}
365
+
366
+ with open('inference/data/chord.json', 'r') as f:
367
+ chord_to_idx = json.load(f)
368
+ with open('inference/data/chord_inv.json', 'r') as f:
369
+ idx_to_chord = json.load(f)
370
+ idx_to_chord = {int(k): v for k, v in idx_to_chord.items()} # Ensure keys are ints
371
+ with open('inference/data/chord_root.json') as json_file:
372
+ chordRootDic = json.load(json_file)
373
+ with open('inference/data/chord_attr.json') as json_file:
374
+ chordAttrDic = json.load(json_file)
375
+
376
+ try:
377
+ midi_file = converter.parse(save_path.replace('.lab', '.midi'))
378
+ key_signature = str(midi_file.analyze('key'))
379
+ except Exception as e:
380
+ key_signature = "None"
381
+
382
+ key_parts = key_signature.split()
383
+ key_signature = sanitize_key_signature(key_parts[0]) # Sanitize key signature
384
+ key_type = key_parts[1] if len(key_parts) > 1 else 'major'
385
+
386
+ # --- Key feature (Tonic and Mode separation) ---
387
+ if key_signature == "None":
388
+ mode = "major"
389
+ else:
390
+ mode = key_signature.split()[-1]
391
+
392
+ encoded_mode = mode_to_idx.get(mode, 0)
393
+ mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device)
394
+
395
+ converted_lines = normalize_chord(save_path, key_signature, key_type)
396
+
397
+ lab_norm_path = save_path[:-4] + "_norm.lab"
398
+
399
+ # Write the converted lines to the new file
400
+ with open(lab_norm_path, 'w') as f:
401
+ f.writelines(converted_lines)
402
+
403
+ chords = []
404
+
405
+ if not os.path.exists(lab_norm_path):
406
+ chords.append((float(0), float(0), "N"))
407
+ else:
408
+ with open(lab_norm_path, 'r') as file:
409
+ for line in file:
410
+ start, end, chord = line.strip().split()
411
+ chords.append((float(start), float(end), chord))
412
+
413
+ encoded = []
414
+ encoded_root= []
415
+ encoded_attr=[]
416
+ durations = []
417
+
418
+ for start, end, chord in chords:
419
+ chord_arr = chord.split(":")
420
+ if len(chord_arr) == 1:
421
+ chordRootID = chordRootDic[chord_arr[0]]
422
+ if chord_arr[0] == "N" or chord_arr[0] == "X":
423
+ chordAttrID = 0
424
+ else:
425
+ chordAttrID = 1
426
+ elif len(chord_arr) == 2:
427
+ chordRootID = chordRootDic[chord_arr[0]]
428
+ chordAttrID = chordAttrDic[chord_arr[1]]
429
+ encoded_root.append(chordRootID)
430
+ encoded_attr.append(chordAttrID)
431
+
432
+ if chord in chord_to_idx:
433
+ encoded.append(chord_to_idx[chord])
434
+ else:
435
+ print(f"Warning: Chord {chord} not found in chord.json. Skipping.")
436
+
437
+ durations.append(end - start) # Compute duration
438
+
439
+ encoded_chords = np.array(encoded)
440
+ encoded_chords_root = np.array(encoded_root)
441
+ encoded_chords_attr = np.array(encoded_attr)
442
+
443
+ # Maximum sequence length for chords
444
+ max_sequence_length = 100 # Define this globally or as a parameter
445
+
446
+ # Truncate or pad chord sequences
447
+ if len(encoded_chords) > max_sequence_length:
448
+ # Truncate to max length
449
+ encoded_chords = encoded_chords[:max_sequence_length]
450
+ encoded_chords_root = encoded_chords_root[:max_sequence_length]
451
+ encoded_chords_attr = encoded_chords_attr[:max_sequence_length]
452
+
453
+ else:
454
+ # Pad with zeros (padding value for chords)
455
+ padding = [0] * (max_sequence_length - len(encoded_chords))
456
+ encoded_chords = np.concatenate([encoded_chords, padding])
457
+ encoded_chords_root = np.concatenate([encoded_chords_root, padding])
458
+ encoded_chords_attr = np.concatenate([encoded_chords_attr, padding])
459
+
460
+ # Convert to tensor
461
+ chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device)
462
+ chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device)
463
+ chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device)
464
+
465
+ model_input_dic = {
466
+ "x_mert": final_embedding_mert.unsqueeze(0),
467
+ "x_chord": chords_tensor.unsqueeze(0),
468
+ "x_chord_root": chords_root_tensor.unsqueeze(0),
469
+ "x_chord_attr": chords_attr_tensor.unsqueeze(0),
470
+ "x_key": mode_tensor.unsqueeze(0)
471
+ }
472
+
473
+ model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()}
474
+ classification_output, regression_output = self.music2emo_model(model_input_dic)
475
+ probs = torch.sigmoid(classification_output)
476
+
477
+
478
+
479
+ tag_list = np.load ( "./inference/data/tag_list.npy")
480
+ tag_list = tag_list[127:]
481
+ mood_list = [t.replace("mood/theme---", "") for t in tag_list]
482
+ threshold = threshold
483
+ predicted_moods = [mood_list[i] for i, p in enumerate(probs.squeeze().tolist()) if p > threshold]
484
+
485
+ # Print the results
486
+ # print("Predicted Mood Tags:", predicted_moods)
487
+
488
+ valence, arousal = regression_output.squeeze().tolist()
489
+
490
+ # Print results
491
+ # print("\n🎵 **Music Emotion Recognition Results** 🎵")
492
+ # print("-" * 50)
493
+ # print(f"🎭 **Predicted Mood Tags:** {', '.join(predicted_moods) if predicted_moods else 'None'}")
494
+ # print(f"💖 **Valence:** {valence:.2f} (Scale: 1-9)")
495
+ # print(f"⚡ **Arousal:** {arousal:.2f} (Scale: 1-9)")
496
+ # print("-" * 50)
497
+
498
+ # self.model.eval()
499
+ # self.modelReg.eval()
500
+ # with torch.set_grad_enabled(False):
501
+ # f_path_midi = output_dir / "output.mid"
502
+ # f_path_flac = output_dir / "output.flac"
503
+ # f_path_video_out = output_dir / "output.mp4"
504
+
505
+ model_output_dic = {
506
+ "valence": valence,
507
+ "arousal": arousal,
508
+ "predicted_moods": predicted_moods
509
+ }
510
+
511
+ return model_output_dic
requirements.txt CHANGED
@@ -18,4 +18,3 @@ torchaudio==2.3.1
18
  torchmetrics==1.4.1
19
  tqdm==4.66.5
20
  transformers==4.44.0
21
- gradio==5.15.0
 
18
  torchmetrics==1.4.1
19
  tqdm==4.66.5
20
  transformers==4.44.0