Spaces:
Running
Running
import re | |
import os | |
import json | |
import time | |
import torch | |
import random | |
import shutil | |
import argparse | |
import warnings | |
import gradio as gr | |
import soundfile as sf | |
from transformers import GPT2Config | |
from model import Patchilizer, TunesFormer | |
from convert import abc2xml, xml2img, xml2, transpose_octaves_abc | |
from utils import ( | |
PATCH_NUM_LAYERS, | |
PATCH_LENGTH, | |
CHAR_NUM_LAYERS, | |
PATCH_SIZE, | |
SHARE_WEIGHTS, | |
TEMP_DIR, | |
WEIGHTS_DIR, | |
DEVICE, | |
) | |
def get_args(parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"-num_tunes", | |
type=int, | |
default=1, | |
help="the number of independently computed returned tunes", | |
) | |
parser.add_argument( | |
"-max_patch", | |
type=int, | |
default=128, | |
help="integer to define the maximum length in tokens of each tune", | |
) | |
parser.add_argument( | |
"-top_p", | |
type=float, | |
default=0.8, | |
help="float to define the tokens that are within the sample operation of text generation", | |
) | |
parser.add_argument( | |
"-top_k", | |
type=int, | |
default=8, | |
help="integer to define the tokens that are within the sample operation of text generation", | |
) | |
parser.add_argument( | |
"-temperature", | |
type=float, | |
default=1.2, | |
help="the temperature of the sampling operation", | |
) | |
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate") | |
parser.add_argument( | |
"-show_control_code", | |
type=bool, | |
default=False, | |
help="whether to show control code", | |
) | |
return parser.parse_args() | |
def get_abc_key_val(text: str, key="K"): | |
pattern = re.escape(key) + r":(.*?)\n" | |
match = re.search(pattern, text) | |
if match: | |
return match.group(1).strip() | |
else: | |
return None | |
def adjust_volume(in_audio: str, dB_change: int): | |
y, sr = sf.read(in_audio) | |
sf.write(in_audio, y * 10 ** (dB_change / 20), sr) | |
def generate_music( | |
args, | |
emo: str, | |
weights: str, | |
outdir=TEMP_DIR, | |
fix_tempo=None, | |
fix_pitch=None, | |
fix_volume=None, | |
): | |
patchilizer = Patchilizer() | |
patch_config = GPT2Config( | |
num_hidden_layers=PATCH_NUM_LAYERS, | |
max_length=PATCH_LENGTH, | |
max_position_embeddings=PATCH_LENGTH, | |
vocab_size=1, | |
) | |
char_config = GPT2Config( | |
num_hidden_layers=CHAR_NUM_LAYERS, | |
max_length=PATCH_SIZE, | |
max_position_embeddings=PATCH_SIZE, | |
vocab_size=128, | |
) | |
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) | |
checkpoint = torch.load(weights, map_location=DEVICE) | |
model.load_state_dict(checkpoint["model"]) | |
model = model.to(DEVICE) | |
model.eval() | |
prompt = f"A:{emo}\n" | |
tunes = "" | |
num_tunes = args.num_tunes | |
max_patch = args.max_patch | |
top_p = args.top_p | |
top_k = args.top_k | |
temperature = args.temperature | |
seed = args.seed | |
show_control_code = args.show_control_code | |
print(" Hyper parms ".center(60, "#"), "\n") | |
args_dict: dict = vars(args) | |
for arg in args_dict.keys(): | |
print(f"{arg}: {str(args_dict[arg])}") | |
print("\n", " Output tunes ".center(60, "#")) | |
start_time = time.time() | |
for i in range(num_tunes): | |
title = f"T:{emo} Fragment\n" | |
artist = f"C:Generated by AI\n" | |
tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}" | |
lines = re.split(r"(\n)", tune) | |
tune = "" | |
skip = False | |
for line in lines: | |
if show_control_code or line[:2] not in ["S:", "B:", "E:"]: | |
if not skip: | |
print(line, end="") | |
tune += line | |
skip = False | |
else: | |
skip = True | |
input_patches = torch.tensor( | |
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], | |
device=DEVICE, | |
) | |
if tune == "": | |
tokens = None | |
else: | |
prefix = patchilizer.decode(input_patches[0]) | |
remaining_tokens = prompt[len(prefix) :] | |
tokens = torch.tensor( | |
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens], | |
device=DEVICE, | |
) | |
while input_patches.shape[1] < max_patch: | |
predicted_patch, seed = model.generate( | |
input_patches, | |
tokens, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
seed=seed, | |
) | |
tokens = None | |
if predicted_patch[0] != patchilizer.eos_token_id: | |
next_bar = patchilizer.decode([predicted_patch]) | |
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]: | |
print(next_bar, end="") | |
tune += next_bar | |
if next_bar == "": | |
break | |
next_bar = remaining_tokens + next_bar | |
remaining_tokens = "" | |
predicted_patch = torch.tensor( | |
patchilizer.bar2patch(next_bar), | |
device=DEVICE, | |
).unsqueeze(0) | |
input_patches = torch.cat( | |
[input_patches, predicted_patch.unsqueeze(0)], | |
dim=1, | |
) | |
else: | |
break | |
tunes += f"{tune}\n\n" | |
print("\n") | |
# fix tempo | |
if fix_tempo != None: | |
tempo = f"Q:{fix_tempo}\n" | |
else: | |
tempo = f"Q:{random.randint(88, 132)}\n" | |
if emo == "Q1": | |
tempo = f"Q:{random.randint(160, 184)}\n" | |
elif emo == "Q2": | |
tempo = f"Q:{random.randint(184, 228)}\n" | |
elif emo == "Q3": | |
tempo = f"Q:{random.randint(40, 69)}\n" | |
elif emo == "Q4": | |
tempo = f"Q:{random.randint(40, 69)}\n" | |
Q_val = get_abc_key_val(tunes, "Q") | |
if Q_val: | |
tunes = tunes.replace(f"Q:{Q_val}\n", "") | |
K_val = get_abc_key_val(tunes) | |
if K_val == "none": | |
K_val = "C" | |
tunes = tunes.replace("K:none\n", f"K:{K_val}\n") | |
tunes = tunes.replace(f"A:{emo}\n", tempo) | |
# fix mode:major/minor | |
mode = "major" if emo == "Q1" or emo == "Q4" else "minor" | |
if (mode == "major") and ("m" in K_val): | |
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n") | |
elif (mode == "minor") and (not "m" in K_val): | |
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n") | |
print("Generation time: {:.2f} seconds".format(time.time() - start_time)) | |
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) | |
try: | |
# fix avg_pitch (octave) | |
if fix_pitch != None: | |
if fix_pitch: | |
tunes, xml = transpose_octaves_abc( | |
tunes, | |
f"{outdir}/{timestamp}.musicxml", | |
fix_pitch, | |
) | |
tunes = tunes.replace(title + title, title) | |
os.rename(xml, f"{outdir}/[{emo}]{timestamp}.musicxml") | |
xml = f"{outdir}/[{emo}]{timestamp}.musicxml" | |
else: | |
if mode == "minor": | |
offset = -12 | |
if emo == "Q2": | |
offset -= 12 | |
tunes, xml = transpose_octaves_abc( | |
tunes, | |
f"{outdir}/{timestamp}.musicxml", | |
offset, | |
) | |
tunes = tunes.replace(title + title, title) | |
os.rename(xml, f"{outdir}/[{emo}]{timestamp}.musicxml") | |
xml = f"{outdir}/[{emo}]{timestamp}.musicxml" | |
else: | |
xml = abc2xml(tunes, f"{outdir}/[{emo}]{timestamp}.musicxml") | |
audio = xml2(xml, "wav") | |
if fix_volume != None: | |
if fix_volume: | |
adjust_volume(audio, fix_volume) | |
elif os.path.exists(audio): | |
if emo == "Q1": | |
adjust_volume(audio, 5) | |
elif emo == "Q2": | |
adjust_volume(audio, 10) | |
mxl = xml2(xml, "mxl") | |
midi = xml2(xml, "mid") | |
pdf, jpg = xml2img(xml) | |
return audio, midi, pdf, xml, mxl, tunes, jpg | |
except Exception as e: | |
print(f"{e}") | |
return generate_music(args, emo, weights) | |
def inference(dataset: str, v: str, a: str, add_chord: bool): | |
if os.path.exists(TEMP_DIR): | |
shutil.rmtree(TEMP_DIR) | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
emotion = "Q1" | |
if v == "Low" and a == "High": | |
emotion = "Q2" | |
elif v == "Low" and a == "Low": | |
emotion = "Q3" | |
elif v == "High" and a == "Low": | |
emotion = "Q4" | |
parser = argparse.ArgumentParser() | |
args = get_args(parser) | |
return generate_music( | |
args, | |
emo=emotion, | |
weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", | |
) | |
def infer( | |
dataset: str, | |
pitch_std: str, | |
mode: str, | |
tempo: int, | |
octave: int, | |
rms: int, | |
add_chord: bool, | |
): | |
if os.path.exists(TEMP_DIR): | |
shutil.rmtree(TEMP_DIR) | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
emotion = "Q1" | |
if mode == "Minor" and pitch_std == "High": | |
emotion = "Q2" | |
elif mode == "Minor" and pitch_std == "Low": | |
emotion = "Q3" | |
elif mode == "Major" and pitch_std == "Low": | |
emotion = "Q4" | |
parser = argparse.ArgumentParser() | |
args = get_args(parser) | |
return generate_music( | |
args, | |
emo=emotion, | |
weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", | |
fix_tempo=tempo, | |
fix_pitch=octave, | |
fix_volume=rms, | |
) | |
def feedback(fixed_emo: str, source_dir="./flagged", target_dir="./feedbacks"): | |
if not fixed_emo: | |
return "Please select feedback before submitting! " | |
os.makedirs(target_dir, exist_ok=True) | |
for root, _, files in os.walk(source_dir): | |
for file in files: | |
if file.endswith(".mxl"): | |
prompt_emo = file.split("]")[0][1:] | |
if prompt_emo != fixed_emo: | |
file_path = os.path.join(root, file) | |
target_path = os.path.join( | |
target_dir, file.replace(".mxl", f"_{fixed_emo}.mxl") | |
) | |
shutil.copy(file_path, target_path) | |
return f"Copied {file_path} to {target_path}" | |
else: | |
return "Thanks for your feedback!" | |
return "No .mxl files found in the source directory." | |
def save_template( | |
label: str, | |
pitch_std: str, | |
mode: str, | |
tempo: int, | |
octave: int, | |
rms: int, | |
): | |
if ( | |
label | |
and pitch_std | |
and mode | |
and tempo != None | |
and octave != None | |
and rms != None | |
): | |
json_str = json.dumps( | |
{ | |
"label": label, | |
"pitch_std": pitch_std == "High", | |
"mode": mode == "Major", | |
"tempo": tempo, | |
"octave": octave, | |
"volume": rms, | |
} | |
) | |
with open("./feedbacks/templates.jsonl", "a", encoding="utf-8") as file: | |
file.write(json_str + "\n") | |
if __name__ == "__main__": | |
warnings.filterwarnings("ignore") | |
if os.path.exists("./flagged"): | |
shutil.rmtree("./flagged") | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"## The current CPU-based version on HuggingFace has slow inference, you can access the GPU-based mirror on [ModelScope](https://www.modelscope.cn/studios/monetjoe/EMelodyGen)" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Video( | |
"./tutorial.mp4", | |
label="Tutorial", | |
show_download_button=False, | |
show_share_button=False, | |
) | |
dataset_option = gr.Dropdown( | |
["VGMIDI", "EMOPIA", "Rough4Q"], | |
label="Dataset", | |
value="Rough4Q", | |
) | |
gr.Markdown("# Generate by emotion condition") | |
gr.Image( | |
"https://www.modelscope.cn/studio/monetjoe/EMelodyGen/resolve/master/src/4q.jpg", | |
show_label=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
show_share_button=False, | |
) | |
valence_radio = gr.Radio( | |
["Low", "High"], | |
label="Valence (reflects negative-positive levels of emotion)", | |
value="High", | |
) | |
arousal_radio = gr.Radio( | |
["Low", "High"], | |
label="Arousal (reflects the calmness-intensity of the emotion)", | |
value="High", | |
) | |
chord_check = gr.Checkbox( | |
label="Generate chords (Coming soon)", | |
value=False, | |
) | |
gen_btn = gr.Button("Generate") | |
gr.Markdown("# Generate by feature control") | |
std_option = gr.Radio(["Low", "High"], label="Pitch SD", value="High") | |
mode_option = gr.Radio(["Minor", "Major"], label="Mode", value="Major") | |
tempo_option = gr.Slider( | |
minimum=40, | |
maximum=228, | |
step=1, | |
value=120, | |
label="Tempo (BPM)", | |
) | |
octave_option = gr.Slider( | |
minimum=-24, | |
maximum=24, | |
step=12, | |
value=0, | |
label="Octave (Β±12)", | |
) | |
volume_option = gr.Slider( | |
minimum=-5, | |
maximum=10, | |
step=5, | |
value=0, | |
label="Volume (dB)", | |
) | |
chord_check_2 = gr.Checkbox( | |
label="Generate chords (Coming soon)", | |
value=False, | |
) | |
gen_btn_2 = gr.Button("Generate") | |
template_radio = gr.Radio( | |
["Q1", "Q2", "Q3", "Q4"], | |
label="The emotion to which the current template belongs", | |
) | |
save_btn = gr.Button("Save template") | |
# gr.Markdown( | |
# """ | |
# ## Cite | |
# ```bibtex | |
# @article{Zhou2024EMelodyGen, | |
# title = {EMelodyGen: Emotion-Conditioned Melody Generation in ABC Notation}, | |
# author = {Monan Zhou, Xiaobing Li, Feng Yu and Wei Li}, | |
# month = {Sep}, | |
# year = {2024}, | |
# publisher = {GitHub}, | |
# version = {0.1}, | |
# url = {https://github.com/monetjoe/EMelodyGen} | |
# } | |
# ``` | |
# """ | |
# ) | |
with gr.Column(): | |
wav_audio = gr.Audio(label="Audio", type="filepath") | |
midi_file = gr.File(label="Download MIDI") | |
pdf_file = gr.File(label="Download PDF score") | |
xml_file = gr.File(label="Download MusicXML") | |
mxl_file = gr.File(label="Download MXL") | |
abc_textbox = gr.Textbox(label="ABC notation", show_copy_button=True) | |
staff_img = gr.Image(label="Staff", type="filepath") | |
gr.Interface( | |
fn=feedback, | |
inputs=gr.Radio( | |
["Q1", "Q2", "Q3", "Q4"], | |
label="Feedback: the emotion you believe the generated result should belong to", | |
), | |
outputs=gr.Textbox(show_copy_button=False, show_label=False), | |
allow_flagging="never", | |
) | |
gen_btn.click( | |
fn=inference, | |
inputs=[dataset_option, valence_radio, arousal_radio, chord_check], | |
outputs=[ | |
wav_audio, | |
midi_file, | |
pdf_file, | |
xml_file, | |
mxl_file, | |
abc_textbox, | |
staff_img, | |
], | |
) | |
gen_btn_2.click( | |
fn=infer, | |
inputs=[ | |
dataset_option, | |
std_option, | |
mode_option, | |
tempo_option, | |
octave_option, | |
volume_option, | |
chord_check, | |
], | |
outputs=[ | |
wav_audio, | |
midi_file, | |
pdf_file, | |
xml_file, | |
mxl_file, | |
abc_textbox, | |
staff_img, | |
], | |
) | |
save_btn.click( | |
fn=save_template, | |
inputs=[ | |
template_radio, | |
std_option, | |
mode_option, | |
tempo_option, | |
octave_option, | |
volume_option, | |
], | |
) | |
demo.launch() | |