|
import os |
|
import shutil |
|
import json |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
import logging |
|
import warnings |
|
import subprocess |
|
import math |
|
import random |
|
import time |
|
from pathlib import Path |
|
from tqdm import tqdm |
|
from PIL import Image |
|
from huggingface_hub import snapshot_download |
|
from omegaconf import DictConfig |
|
import hydra |
|
from hydra.utils import to_absolute_path |
|
from transformers import Wav2Vec2FeatureExtractor, AutoModel |
|
import mir_eval |
|
import pretty_midi as pm |
|
import gradio as gr |
|
from gradio import Markdown |
|
from music21 import converter |
|
import torchaudio.transforms as T |
|
|
|
|
|
from utils import logger |
|
from utils.btc_model import BTC_model |
|
from utils.transformer_modules import * |
|
from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask |
|
from utils.hparams import HParams |
|
from utils.mir_eval_modules import ( |
|
audio_file_to_features, idx2chord, idx2voca_chord, |
|
get_audio_paths, get_lab_paths |
|
) |
|
from utils.mert import FeatureExtractorMERT |
|
from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
|
|
|
|
|
|
|
PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] |
|
|
|
tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"] |
|
mode_signatures = ["major", "minor"] |
|
|
|
|
|
pitch_num_dic = { |
|
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, |
|
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11 |
|
} |
|
|
|
minor_major_dic = { |
|
'D-':'C#', 'E-':'D#', 'G-':'F#', 'A-':'G#', 'B-':'A#' |
|
} |
|
minor_major_dic2 = { |
|
'Db':'C#', 'Eb':'D#', 'Gb':'F#', 'Ab':'G#', 'Bb':'A#' |
|
} |
|
|
|
shift_major_dic = { |
|
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, |
|
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11 |
|
} |
|
|
|
shift_minor_dic = { |
|
'A': 0, 'A#': 1, 'B': 2, 'C': 3, 'C#': 4, 'D': 5, |
|
'D#': 6, 'E': 7, 'F': 8, 'F#': 9, 'G': 10, 'G#': 11, |
|
} |
|
|
|
flat_to_sharp_mapping = { |
|
"Cb": "B", |
|
"Db": "C#", |
|
"Eb": "D#", |
|
"Fb": "E", |
|
"Gb": "F#", |
|
"Ab": "G#", |
|
"Bb": "A#" |
|
} |
|
|
|
segment_duration = 30 |
|
resample_rate = 24000 |
|
is_split = True |
|
|
|
def normalize_chord(file_path, key, key_type='major'): |
|
with open(file_path, 'r') as f: |
|
lines = f.readlines() |
|
|
|
if key == "None": |
|
new_key = "C major" |
|
shift = 0 |
|
else: |
|
|
|
if len(key) == 1: |
|
key = key[0].upper() |
|
else: |
|
key = key[0].upper() + key[1:] |
|
|
|
if key in minor_major_dic2: |
|
key = minor_major_dic2[key] |
|
|
|
shift = 0 |
|
|
|
if key_type == "major": |
|
new_key = "C major" |
|
|
|
shift = shift_major_dic[key] |
|
else: |
|
new_key = "A minor" |
|
shift = shift_minor_dic[key] |
|
|
|
converted_lines = [] |
|
for line in lines: |
|
if line.strip(): |
|
parts = line.split() |
|
start_time = parts[0] |
|
end_time = parts[1] |
|
chord = parts[2] |
|
if chord == "N": |
|
newchordnorm = "N" |
|
elif chord == "X": |
|
newchordnorm = "X" |
|
elif ":" in chord: |
|
pitch = chord.split(":")[0] |
|
attr = chord.split(":")[1] |
|
pnum = pitch_num_dic [pitch] |
|
new_idx = (pnum - shift)%12 |
|
newchord = PITCH_CLASS[new_idx] |
|
newchordnorm = newchord + ":" + attr |
|
else: |
|
pitch = chord |
|
pnum = pitch_num_dic [pitch] |
|
new_idx = (pnum - shift)%12 |
|
newchord = PITCH_CLASS[new_idx] |
|
newchordnorm = newchord |
|
|
|
converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n") |
|
|
|
return converted_lines |
|
|
|
def sanitize_key_signature(key): |
|
return key.replace('-', 'b') |
|
|
|
def resample_waveform(waveform, original_sample_rate, target_sample_rate): |
|
if original_sample_rate != target_sample_rate: |
|
resampler = T.Resample(original_sample_rate, target_sample_rate) |
|
return resampler(waveform), target_sample_rate |
|
return waveform, original_sample_rate |
|
|
|
def split_audio(waveform, sample_rate): |
|
segment_samples = segment_duration * sample_rate |
|
total_samples = waveform.size(0) |
|
|
|
segments = [] |
|
for start in range(0, total_samples, segment_samples): |
|
end = start + segment_samples |
|
if end <= total_samples: |
|
segment = waveform[start:end] |
|
segments.append(segment) |
|
|
|
|
|
if len(segments) == 0: |
|
segment = waveform |
|
segments.append(segment) |
|
|
|
return segments |
|
|
|
|
|
def safe_remove_dir(directory): |
|
""" |
|
Safely removes a directory only if it exists and is empty. |
|
""" |
|
directory = Path(directory) |
|
if directory.exists(): |
|
try: |
|
shutil.rmtree(directory) |
|
except FileNotFoundError: |
|
print(f"Warning: Some files in {directory} were already deleted.") |
|
except PermissionError: |
|
print(f"Warning: Permission issue encountered while deleting {directory}.") |
|
except Exception as e: |
|
print(f"Unexpected error while deleting {directory}: {e}") |
|
|
|
|
|
class Music2emo: |
|
def __init__( |
|
self, |
|
name="amaai-lab/music2emo", |
|
device="cuda:0", |
|
cache_dir=None, |
|
local_files_only=False, |
|
): |
|
|
|
|
|
|
|
model_weights = "saved_models/J_all.ckpt" |
|
self.device = device |
|
|
|
self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate) |
|
self.model_weights = model_weights |
|
|
|
self.music2emo_model = FeedforwardModelMTAttnCK( |
|
input_size= 768 * 2, |
|
output_size_classification=56, |
|
output_size_regression=2 |
|
) |
|
|
|
checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False) |
|
state_dict = checkpoint["state_dict"] |
|
|
|
|
|
state_dict = {key.replace("model.", ""): value for key, value in state_dict.items()} |
|
|
|
|
|
model_keys = set(self.music2emo_model.state_dict().keys()) |
|
filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys} |
|
|
|
|
|
self.music2emo_model.load_state_dict(filtered_state_dict) |
|
|
|
self.music2emo_model.to(self.device) |
|
self.music2emo_model.eval() |
|
|
|
self.config = HParams.load("./inference/data/run_config.yaml") |
|
self.config.feature['large_voca'] = True |
|
self.config.model['num_chords'] = 170 |
|
model_file = './inference/data/btc_model_large_voca.pt' |
|
self.idx_to_voca = idx2voca_chord() |
|
self.btc_model = BTC_model(config=self.config.model).to(self.device) |
|
|
|
if os.path.isfile(model_file): |
|
checkpoint = torch.load(model_file, map_location=self.device) |
|
self.mean = checkpoint['mean'] |
|
self.std = checkpoint['std'] |
|
self.btc_model.load_state_dict(checkpoint['model']) |
|
|
|
|
|
self.tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)} |
|
self.mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)} |
|
self.idx_to_tonic = {idx: tonic for tonic, idx in self.tonic_to_idx.items()} |
|
self.idx_to_mode = {idx: mode for mode, idx in self.mode_to_idx.items()} |
|
|
|
with open('inference/data/chord.json', 'r') as f: |
|
self.chord_to_idx = json.load(f) |
|
with open('inference/data/chord_inv.json', 'r') as f: |
|
self.idx_to_chord = json.load(f) |
|
self.idx_to_chord = {int(k): v for k, v in self.idx_to_chord.items()} |
|
with open('inference/data/chord_root.json') as json_file: |
|
self.chordRootDic = json.load(json_file) |
|
with open('inference/data/chord_attr.json') as json_file: |
|
self.chordAttrDic = json.load(json_file) |
|
|
|
|
|
|
|
def predict(self, audio, threshold = 0.5): |
|
|
|
feature_dir = Path("./inference/temp_out") |
|
output_dir = Path("./inference/output") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
safe_remove_dir(feature_dir) |
|
safe_remove_dir(output_dir) |
|
|
|
feature_dir.mkdir(parents=True, exist_ok=True) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
warnings.filterwarnings('ignore') |
|
logger.logging_verbosity(1) |
|
|
|
mert_dir = feature_dir / "mert" |
|
mert_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
waveform, sample_rate = torchaudio.load(audio) |
|
if waveform.shape[0] > 1: |
|
waveform = waveform.mean(dim=0).unsqueeze(0) |
|
waveform = waveform.squeeze() |
|
waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate) |
|
|
|
if is_split: |
|
segments = split_audio(waveform, sample_rate) |
|
for i, segment in enumerate(segments): |
|
segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy") |
|
self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path) |
|
else: |
|
segment_save_path = os.path.join(mert_dir, f"segment_0.npy") |
|
self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path) |
|
|
|
embeddings = [] |
|
layers_to_extract = [5,6] |
|
segment_embeddings = [] |
|
for filename in sorted(os.listdir(mert_dir)): |
|
file_path = os.path.join(mert_dir, filename) |
|
if os.path.isfile(file_path) and filename.endswith('.npy'): |
|
segment = np.load(file_path) |
|
concatenated_features = np.concatenate( |
|
[segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1 |
|
) |
|
concatenated_features = np.squeeze(concatenated_features) |
|
segment_embeddings.append(concatenated_features) |
|
|
|
segment_embeddings = np.array(segment_embeddings) |
|
if len(segment_embeddings) > 0: |
|
final_embedding_mert = np.mean(segment_embeddings, axis=0) |
|
else: |
|
final_embedding_mert = np.zeros((1536,)) |
|
|
|
final_embedding_mert = torch.from_numpy(final_embedding_mert) |
|
final_embedding_mert.to(self.device) |
|
|
|
|
|
|
|
audio_path = audio |
|
audio_id = audio_path.split("/")[-1][:-4] |
|
try: |
|
feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, self.config) |
|
except: |
|
logger.info("audio file failed to load : %s" % audio_path) |
|
assert(False) |
|
|
|
logger.info("audio file loaded and feature computation success : %s" % audio_path) |
|
|
|
feature = feature.T |
|
feature = (feature - self.mean) / self.std |
|
time_unit = feature_per_second |
|
n_timestep = self.config.model['timestep'] |
|
|
|
num_pad = n_timestep - (feature.shape[0] % n_timestep) |
|
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) |
|
num_instance = feature.shape[0] // n_timestep |
|
|
|
start_time = 0.0 |
|
lines = [] |
|
with torch.no_grad(): |
|
self.btc_model.eval() |
|
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device) |
|
for t in range(num_instance): |
|
self_attn_output, _ = self.btc_model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :]) |
|
prediction, _ = self.btc_model.output_layer(self_attn_output) |
|
prediction = prediction.squeeze() |
|
for i in range(n_timestep): |
|
if t == 0 and i == 0: |
|
prev_chord = prediction[i].item() |
|
continue |
|
if prediction[i].item() != prev_chord: |
|
lines.append( |
|
'%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord])) |
|
start_time = time_unit * (n_timestep * t + i) |
|
prev_chord = prediction[i].item() |
|
if t == num_instance - 1 and i + num_pad == n_timestep: |
|
if start_time != time_unit * (n_timestep * t + i): |
|
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord])) |
|
break |
|
|
|
save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab') |
|
with open(save_path, 'w') as f: |
|
for line in lines: |
|
f.write(line) |
|
|
|
|
|
|
|
|
|
starts, ends, pitchs = list(), list(), list() |
|
|
|
intervals, chords = mir_eval.io.load_labeled_intervals(save_path) |
|
for p in range(12): |
|
for i, (interval, chord) in enumerate(zip(intervals, chords)): |
|
root_num, relative_bitmap, _ = mir_eval.chord.encode(chord) |
|
tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p] |
|
if i == 0: |
|
start_time = interval[0] |
|
label = tmp_label |
|
continue |
|
if tmp_label != label: |
|
if label == 1.0: |
|
starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48) |
|
start_time = interval[0] |
|
label = tmp_label |
|
if i == (len(intervals) - 1): |
|
if label == 1.0: |
|
starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48) |
|
|
|
midi = pm.PrettyMIDI() |
|
instrument = pm.Instrument(program=0) |
|
|
|
for start, end, pitch in zip(starts, ends, pitchs): |
|
pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end) |
|
instrument.notes.append(pm_note) |
|
|
|
midi.instruments.append(instrument) |
|
midi.write(save_path.replace('.lab', '.midi')) |
|
|
|
|
|
|
|
|
|
try: |
|
midi_file = converter.parse(save_path.replace('.lab', '.midi')) |
|
key_signature = str(midi_file.analyze('key')) |
|
except Exception as e: |
|
key_signature = "None" |
|
|
|
key_parts = key_signature.split() |
|
key_signature = sanitize_key_signature(key_parts[0]) |
|
key_type = key_parts[1] if len(key_parts) > 1 else 'major' |
|
|
|
|
|
if key_signature == "None": |
|
mode = "major" |
|
else: |
|
mode = key_signature.split()[-1] |
|
|
|
encoded_mode = self.mode_to_idx.get(mode, 0) |
|
mode_tensor = torch.tensor([encoded_mode], dtype=torch.long).to(self.device) |
|
|
|
converted_lines = normalize_chord(save_path, key_signature, key_type) |
|
|
|
lab_norm_path = save_path[:-4] + "_norm.lab" |
|
|
|
|
|
with open(lab_norm_path, 'w') as f: |
|
f.writelines(converted_lines) |
|
|
|
chords = [] |
|
|
|
if not os.path.exists(lab_norm_path): |
|
chords.append((float(0), float(0), "N")) |
|
else: |
|
with open(lab_norm_path, 'r') as file: |
|
for line in file: |
|
start, end, chord = line.strip().split() |
|
chords.append((float(start), float(end), chord)) |
|
|
|
encoded = [] |
|
encoded_root= [] |
|
encoded_attr=[] |
|
durations = [] |
|
|
|
for start, end, chord in chords: |
|
chord_arr = chord.split(":") |
|
if len(chord_arr) == 1: |
|
chordRootID = self.chordRootDic[chord_arr[0]] |
|
if chord_arr[0] == "N" or chord_arr[0] == "X": |
|
chordAttrID = 0 |
|
else: |
|
chordAttrID = 1 |
|
elif len(chord_arr) == 2: |
|
chordRootID = self.chordRootDic[chord_arr[0]] |
|
chordAttrID = self.chordAttrDic[chord_arr[1]] |
|
encoded_root.append(chordRootID) |
|
encoded_attr.append(chordAttrID) |
|
|
|
if chord in self.chord_to_idx: |
|
encoded.append(self.chord_to_idx[chord]) |
|
else: |
|
print(f"Warning: Chord {chord} not found in chord.json. Skipping.") |
|
|
|
durations.append(end - start) |
|
|
|
encoded_chords = np.array(encoded) |
|
encoded_chords_root = np.array(encoded_root) |
|
encoded_chords_attr = np.array(encoded_attr) |
|
|
|
|
|
max_sequence_length = 100 |
|
|
|
|
|
if len(encoded_chords) > max_sequence_length: |
|
|
|
encoded_chords = encoded_chords[:max_sequence_length] |
|
encoded_chords_root = encoded_chords_root[:max_sequence_length] |
|
encoded_chords_attr = encoded_chords_attr[:max_sequence_length] |
|
|
|
else: |
|
|
|
padding = [0] * (max_sequence_length - len(encoded_chords)) |
|
encoded_chords = np.concatenate([encoded_chords, padding]) |
|
encoded_chords_root = np.concatenate([encoded_chords_root, padding]) |
|
encoded_chords_attr = np.concatenate([encoded_chords_attr, padding]) |
|
|
|
|
|
chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device) |
|
chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device) |
|
chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device) |
|
|
|
model_input_dic = { |
|
"x_mert": final_embedding_mert.unsqueeze(0), |
|
"x_chord": chords_tensor.unsqueeze(0), |
|
"x_chord_root": chords_root_tensor.unsqueeze(0), |
|
"x_chord_attr": chords_attr_tensor.unsqueeze(0), |
|
"x_key": mode_tensor.unsqueeze(0) |
|
} |
|
|
|
model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()} |
|
classification_output, regression_output = self.music2emo_model(model_input_dic) |
|
|
|
|
|
tag_list = np.load ( "./inference/data/tag_list.npy") |
|
tag_list = tag_list[127:] |
|
mood_list = [t.replace("mood/theme---", "") for t in tag_list] |
|
threshold = threshold |
|
|
|
|
|
probs = torch.sigmoid(classification_output).squeeze().tolist() |
|
|
|
|
|
predicted_moods_with_scores = [ |
|
{"mood": mood_list[i], "score": round(p, 4)} |
|
for i, p in enumerate(probs) if p > threshold |
|
] |
|
|
|
|
|
predicted_moods_with_scores_all = [ |
|
{"mood": mood_list[i], "score": round(p, 4)} |
|
for i, p in enumerate(probs) |
|
] |
|
|
|
|
|
|
|
predicted_moods_with_scores.sort(key=lambda x: x["score"], reverse=True) |
|
|
|
valence, arousal = regression_output.squeeze().tolist() |
|
|
|
model_output_dic = { |
|
"valence": valence, |
|
"arousal": arousal, |
|
"predicted_moods": predicted_moods_with_scores, |
|
"predicted_moods_all": predicted_moods_with_scores_all |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model_output_dic |
|
|
|
|
|
if torch.cuda.is_available(): |
|
music2emo = Music2emo() |
|
else: |
|
music2emo = Music2emo(device="cpu") |
|
|
|
|
|
def plot_mood_probabilities(predicted_moods_with_scores): |
|
"""Plot mood probabilities as a horizontal bar chart.""" |
|
if not predicted_moods_with_scores: |
|
return None |
|
|
|
|
|
moods = [m["mood"] for m in predicted_moods_with_scores] |
|
probs = [m["score"] for m in predicted_moods_with_scores] |
|
|
|
|
|
sorted_indices = np.argsort(probs)[::-1] |
|
sorted_probs = [probs[i] for i in sorted_indices] |
|
sorted_moods = [moods[i] for i in sorted_indices] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 4)) |
|
ax.barh(sorted_moods[:10], sorted_probs[:10], color="#4CAF50") |
|
ax.set_xlabel("Probability") |
|
ax.set_title("Top 10 Predicted Mood Tags") |
|
ax.invert_yaxis() |
|
|
|
return fig |
|
|
|
def plot_valence_arousal(valence, arousal): |
|
"""Plot valence-arousal on a 2D circumplex model.""" |
|
fig, ax = plt.subplots(figsize=(4, 4)) |
|
ax.scatter(valence, arousal, color="red", s=100) |
|
ax.set_xlim(1, 9) |
|
ax.set_ylim(1, 9) |
|
|
|
|
|
ax.axhline(y=5, color='gray', linestyle='--', linewidth=1) |
|
ax.axvline(x=5, color='gray', linestyle='--', linewidth=1) |
|
|
|
|
|
ax.set_xlabel("Valence (Positivity)") |
|
ax.set_ylabel("Arousal (Intensity)") |
|
ax.set_title("Valence-Arousal Plot") |
|
ax.legend() |
|
ax.grid(True, linestyle="--", alpha=0.6) |
|
|
|
return fig |
|
|
|
|
|
|
|
def format_prediction(model_output_dic): |
|
"""Format the model output in a structured format""" |
|
valence = model_output_dic["valence"] |
|
arousal = model_output_dic["arousal"] |
|
predicted_moods_with_scores = model_output_dic["predicted_moods"] |
|
predicted_moods_with_scores_all = model_output_dic["predicted_moods_all"] |
|
|
|
|
|
va_chart = plot_valence_arousal(valence, arousal) |
|
mood_chart = plot_mood_probabilities(predicted_moods_with_scores_all) |
|
|
|
|
|
if predicted_moods_with_scores: |
|
moods_text = ", ".join( |
|
[f"{m['mood']} ({m['score']:.2f})" for m in predicted_moods_with_scores] |
|
) |
|
else: |
|
moods_text = "No significant moods detected." |
|
|
|
|
|
output_text = f"""π Predicted Mood Tags: {moods_text} |
|
|
|
π Valence: {valence:.2f} (Scale: 1-9) |
|
β‘ Arousal: {arousal:.2f} (Scale: 1-9)""" |
|
|
|
return output_text, va_chart, mood_chart |
|
|
|
|
|
title="π΅ Music2Emo: Toward Unified Music Emotion Recognition" |
|
description_text = """ |
|
<p> Upload an audio file to analyze its emotional characteristics using Music2Emo. The model will predict: 1) Mood tags describing the emotional content, 2) Valence score (1-9 scale, representing emotional positivity), and 3) Arousal score (1-9 scale, representing emotional intensity) |
|
<br/><br/> This is the demo for Music2Emo for music emotion recognition: <a href="https://arxiv.org/abs/2502.03979">Read our paper.</a> |
|
</p> |
|
""" |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
font-family: 'Inter', -apple-system, system-ui, sans-serif; |
|
} |
|
.gr-button { |
|
color: white; |
|
background: #4CAF50; |
|
border-radius: 8px; |
|
padding: 10px; |
|
} |
|
/* Add padding to the top of the two plot boxes */ |
|
.gr-box { |
|
padding-top: 25px !important; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML(f"<h1 style='text-align: center;'>{title}</h1>") |
|
gr.Markdown(description_text) |
|
|
|
|
|
gr.Markdown(""" |
|
### π Notes: |
|
- **Supported audio formats:** MP3, WAV |
|
- **Recommended:** High-quality audio files |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
input_audio = gr.Audio( |
|
label="Upload Audio File", |
|
type="filepath" |
|
) |
|
threshold = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.5, |
|
step=0.01, |
|
label="Mood Detection Threshold", |
|
info="Adjust threshold for mood detection" |
|
) |
|
predict_btn = gr.Button("π Analyze Emotions", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
output_text = gr.Textbox( |
|
label="Analysis Results", |
|
lines=4, |
|
interactive=False |
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
mood_chart = gr.Plot(label="Mood Probabilities", scale=2, elem_classes=["gr-box"]) |
|
va_chart = gr.Plot(label="Valence-Arousal Space", scale=1, elem_classes=["gr-box"]) |
|
|
|
predict_btn.click( |
|
fn=lambda audio, thresh: format_prediction(music2emo.predict(audio, thresh)), |
|
inputs=[input_audio, threshold], |
|
outputs=[output_text, va_chart, mood_chart] |
|
) |
|
|
|
|
|
demo.queue().launch() |
|
|
|
|
|
|