Spaces:
Running
Running
import os | |
import json | |
import shutil | |
import argparse | |
import warnings | |
import gradio as gr | |
from generate import generate_music, get_args | |
from utils import WEIGHTS_DIR, TEMP_DIR, LANG | |
EN2ZH = { | |
"Cite": "引用", | |
"Submit": "提交", | |
"Feedback: the emotion you believe the generated result should belong to": "反馈:你所认为的生成结果该所属的情感", | |
"Status": "状态栏", | |
"Staff": "五线谱", | |
"ABC notation": "ABC 记谱", | |
"Download MXL": "下载 MXL", | |
"Download MusicXML": "下载 MusicXML", | |
"Download PDF score": "下载 PDF 乐谱", | |
"Download MIDI": "下载 MIDI", | |
"Audio": "音频", | |
"Download template": "下载模板", | |
"Save template": "保存模板", | |
"The emotion to which the current template belongs": "当前模板所属情感", | |
"Generate": "生成", | |
"Generate chords coming soon": "生成和声控制暂不可用", | |
"Volume in dB": "dB 音量调节", | |
"±12 octave": "±12 八度上下移", | |
"BPM tempo": "BPM 速度", | |
"Minor": "小调", | |
"Major": "大调", | |
"Mode": "大小调", | |
"Pitch SD": "音高标准差", | |
"Low": "低", | |
"High": "高", | |
"By feature control": "通过特征控制生成", | |
"By template": "通过模板生成", | |
"Arousal: reflects the calmness-intensity of the emotion": "唤醒度 反映情绪的 平静-激烈 程度", | |
"Valence: reflects negative-positive levels of emotion": "愉悦度 反映情绪的 消极-积极 程度", | |
"Video demo": "视频教程", | |
"Dataset": "数据集", | |
"Status": "状态栏", | |
} | |
def _L(en_txt: str): | |
return en_txt if LANG else f"{en_txt} ({EN2ZH[en_txt]})" | |
def infer_by_template(dataset: str, v: str, a: str, add_chord: bool): | |
status = "Success" | |
audio = midi = pdf = xml = mxl = tunes = jpg = None | |
emotion = "Q1" | |
if v == _L("Low") and a == _L("High"): | |
emotion = "Q2" | |
elif v == _L("Low") and a == _L("Low"): | |
emotion = "Q3" | |
elif v == _L("High") and a == _L("Low"): | |
emotion = "Q4" | |
try: | |
parser = argparse.ArgumentParser() | |
args = get_args(parser) | |
args.template = True | |
audio, midi, pdf, xml, mxl, tunes, jpg = generate_music( | |
args, | |
emo=emotion, | |
weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", | |
) | |
except Exception as e: | |
status = f"{e}" | |
return status, audio, midi, pdf, xml, mxl, tunes, jpg | |
def infer_by_features( | |
dataset: str, | |
pitch_std: str, | |
mode: str, | |
tempo: int, | |
octave: int, | |
rms: int, | |
add_chord: bool, | |
): | |
status = "Success" | |
audio = midi = pdf = xml = mxl = tunes = jpg = None | |
emotion = "Q1" | |
if mode == _L("Minor") and pitch_std == _L("High"): | |
emotion = "Q2" | |
elif mode == _L("Minor") and pitch_std == _L("Low"): | |
emotion = "Q3" | |
elif mode == _L("Major") and pitch_std == _L("Low"): | |
emotion = "Q4" | |
try: | |
parser = argparse.ArgumentParser() | |
args = get_args(parser) | |
args.template = False | |
audio, midi, pdf, xml, mxl, tunes, jpg = generate_music( | |
args, | |
emo=emotion, | |
weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", | |
fix_tempo=tempo, | |
fix_pitch=octave, | |
fix_volume=rms, | |
) | |
except Exception as e: | |
status = f"{e}" | |
return status, audio, midi, pdf, xml, mxl, tunes, jpg | |
def feedback( | |
fixed_emo: str, | |
source_dir=f"./{TEMP_DIR}/output", | |
target_dir=f"./{TEMP_DIR}/feedback", | |
): | |
try: | |
if not fixed_emo: | |
raise ValueError("Please select feedback before submitting! ") | |
os.makedirs(target_dir, exist_ok=True) | |
for root, _, files in os.walk(source_dir): | |
for file in files: | |
if file.endswith(".mxl"): | |
prompt_emo = file.split("]")[0][1:] | |
if prompt_emo != fixed_emo: | |
file_path = os.path.join(root, file) | |
target_path = os.path.join( | |
target_dir, file.replace(".mxl", f"_{fixed_emo}.mxl") | |
) | |
shutil.copy(file_path, target_path) | |
return f"Copied {file_path} to {target_path}" | |
else: | |
return "Thanks for your feedback!" | |
return "No .mxl files found in the source directory." | |
except Exception as e: | |
return f"{e}" | |
def save_template(label: str, pitch_std: str, mode: str, tempo: int, octave: int, rms): | |
status = "Success" | |
template = None | |
try: | |
if ( | |
label | |
and pitch_std | |
and mode | |
and tempo != None | |
and octave != None | |
and rms != None | |
): | |
json_str = json.dumps( | |
{ | |
"label": label, | |
"pitch_std": pitch_std == _L("High"), | |
"mode": mode == _L("Major"), | |
"tempo": tempo, | |
"octave": octave, | |
"volume": rms, | |
} | |
) | |
with open( | |
f"./{TEMP_DIR}/feedback/templates.jsonl", | |
"a", | |
encoding="utf-8", | |
) as file: | |
file.write(json_str + "\n") | |
template = f"./{TEMP_DIR}/feedback/templates.jsonl" | |
else: | |
raise ValueError("Please check features") | |
except Exception as e: | |
status = f"{e}" | |
return status, template | |
if __name__ == "__main__": | |
warnings.filterwarnings("ignore") | |
with gr.Blocks() as demo: | |
if LANG: | |
gr.Markdown( | |
"## The current CPU-based version on HuggingFace has slow inference, you can access the GPU-based mirror on [ModelScope](https://www.modelscope.cn/studios/monetjoe/EMelodyGen)" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Video( | |
"./demo.mp4" if LANG else "./src/tutorial.mp4", | |
label=_L("Video demo"), | |
show_download_button=False, | |
show_share_button=False, | |
) | |
dataset_option = gr.Dropdown( | |
["VGMIDI", "EMOPIA", "Rough4Q"], | |
label=_L("Dataset"), | |
value="Rough4Q", | |
) | |
with gr.Tab(_L("By template")): | |
gr.Image( | |
"https://www.modelscope.cn/studio/monetjoe/EMelodyGen/resolve/master/src/4q.jpg", | |
show_label=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
show_share_button=False, | |
) | |
valence_radio = gr.Radio( | |
[_L("Low"), _L("High")], | |
label=_L( | |
"Valence: reflects negative-positive levels of emotion" | |
), | |
value=_L("High"), | |
) | |
arousal_radio = gr.Radio( | |
[_L("Low"), _L("High")], | |
label=_L( | |
"Arousal: reflects the calmness-intensity of the emotion" | |
), | |
value=_L("High"), | |
) | |
chord_check = gr.Checkbox( | |
label=_L("Generate chords coming soon"), | |
value=False, | |
) | |
gen_btn_1 = gr.Button(_L("Generate")) | |
with gr.Tab(_L("By feature control")): | |
std_option = gr.Radio( | |
[_L("Low"), _L("High")], label=_L("Pitch SD"), value=_L("High") | |
) | |
mode_option = gr.Radio( | |
[_L("Minor"), _L("Major")], label=_L("Mode"), value=_L("Major") | |
) | |
tempo_option = gr.Slider( | |
minimum=40, | |
maximum=228, | |
step=1, | |
value=120, | |
label=_L("BPM tempo"), | |
) | |
octave_option = gr.Slider( | |
minimum=-24, | |
maximum=24, | |
step=12, | |
value=0, | |
label=_L("±12 octave"), | |
) | |
volume_option = gr.Slider( | |
minimum=-5, | |
maximum=10, | |
step=5, | |
value=0, | |
label=_L("Volume in dB"), | |
) | |
chord_check_2 = gr.Checkbox( | |
label=_L("Generate chords coming soon"), | |
value=False, | |
) | |
gen_btn_2 = gr.Button(_L("Generate")) | |
template_radio = gr.Radio( | |
["Q1", "Q2", "Q3", "Q4"], | |
label=_L("The emotion to which the current template belongs"), | |
) | |
save_btn = gr.Button(_L("Save template")) | |
dld_template = gr.File(label=_L("Download template")) | |
with gr.Column(): | |
wav_audio = gr.Audio(label=_L("Audio"), type="filepath") | |
midi_file = gr.File(label=_L("Download MIDI")) | |
pdf_file = gr.File(label=_L("Download PDF score")) | |
xml_file = gr.File(label=_L("Download MusicXML")) | |
mxl_file = gr.File(label=_L("Download MXL")) | |
abc_textbox = gr.Textbox( | |
label=_L("ABC notation"), show_copy_button=True | |
) | |
staff_img = gr.Image(label=_L("Staff"), type="filepath") | |
with gr.Column(): | |
status_bar = gr.Textbox(label=_L("Status"), show_copy_button=True) | |
fdb_radio = gr.Radio( | |
["Q1", "Q2", "Q3", "Q4"], | |
label=_L( | |
"Feedback: the emotion you believe the generated result should belong to" | |
), | |
) | |
fdb_btn = gr.Button(_L("Submit")) | |
gr.Markdown( | |
f"""## {_L("Cite")} | |
```bibtex | |
@inproceedings{{Zhou2025EMelodyGen, | |
title = {{EMelodyGen: Emotion-Conditioned Melody Generation in ABC Notation with the Musical Feature Template}}, | |
author = {{Monan Zhou and Xiaobing Li and Feng Yu and Wei Li}}, | |
month = {{Mar}}, | |
year = {{2025}}, | |
publisher = {{GitHub}}, | |
version = {{0.1}}, | |
url = {{https://github.com/monetjoe/EMelodyGen}} | |
}} | |
```""" | |
) | |
# actions | |
gen_btn_1.click( | |
fn=infer_by_template, | |
inputs=[dataset_option, valence_radio, arousal_radio, chord_check], | |
outputs=[ | |
status_bar, | |
wav_audio, | |
midi_file, | |
pdf_file, | |
xml_file, | |
mxl_file, | |
abc_textbox, | |
staff_img, | |
], | |
) | |
gen_btn_2.click( | |
fn=infer_by_features, | |
inputs=[ | |
dataset_option, | |
std_option, | |
mode_option, | |
tempo_option, | |
octave_option, | |
volume_option, | |
chord_check, | |
], | |
outputs=[ | |
status_bar, | |
wav_audio, | |
midi_file, | |
pdf_file, | |
xml_file, | |
mxl_file, | |
abc_textbox, | |
staff_img, | |
], | |
) | |
save_btn.click( | |
fn=save_template, | |
inputs=[ | |
template_radio, | |
std_option, | |
mode_option, | |
tempo_option, | |
octave_option, | |
volume_option, | |
], | |
outputs=[status_bar, dld_template], | |
) | |
fdb_btn.click(fn=feedback, inputs=fdb_radio, outputs=status_bar) | |
demo.launch() | |