EMelodyGen / generate.py
admin
upd ui
1aa8b04
raw
history blame
8.69 kB
import re
import os
import shutil
import time
import torch
import random
import argparse
import soundfile as sf
from transformers import GPT2Config
from model import Patchilizer, TunesFormer
from convert import abc2xml, xml2img, xml2, transpose_octaves_abc
from utils import (
PATCH_NUM_LAYERS,
PATCH_LENGTH,
CHAR_NUM_LAYERS,
PATCH_SIZE,
SHARE_WEIGHTS,
TEMP_DIR,
DEVICE,
)
def get_args(parser: argparse.ArgumentParser):
parser.add_argument(
"-num_tunes",
type=int,
default=1,
help="the number of independently computed returned tunes",
)
parser.add_argument(
"-max_patch",
type=int,
default=128,
help="integer to define the maximum length in tokens of each tune",
)
parser.add_argument(
"-top_p",
type=float,
default=0.8,
help="float to define the tokens that are within the sample operation of text generation",
)
parser.add_argument(
"-top_k",
type=int,
default=8,
help="integer to define the tokens that are within the sample operation of text generation",
)
parser.add_argument(
"-temperature",
type=float,
default=1.2,
help="the temperature of the sampling operation",
)
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
parser.add_argument(
"-show_control_code",
type=bool,
default=False,
help="whether to show control code",
)
parser.add_argument(
"-template",
type=bool,
default=True,
help="whether to generate by template",
)
return parser.parse_args()
def get_abc_key_val(text: str, key="K"):
pattern = re.escape(key) + r":(.*?)\n"
match = re.search(pattern, text)
if match:
return match.group(1).strip()
else:
return None
def adjust_volume(in_audio: str, dB_change: int):
y, sr = sf.read(in_audio)
sf.write(in_audio, y * 10 ** (dB_change / 20), sr)
def clean_dir(dir_path: str):
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
os.makedirs(dir_path)
def generate_music(
args,
emo: str,
weights: str,
outdir=f"{TEMP_DIR}/output",
fix_tempo=None,
fix_pitch=None,
fix_volume=None,
):
clean_dir(outdir)
patchilizer = Patchilizer()
patch_config = GPT2Config(
num_hidden_layers=PATCH_NUM_LAYERS,
max_length=PATCH_LENGTH,
max_position_embeddings=PATCH_LENGTH,
vocab_size=1,
)
char_config = GPT2Config(
num_hidden_layers=CHAR_NUM_LAYERS,
max_length=PATCH_SIZE,
max_position_embeddings=PATCH_SIZE,
vocab_size=128,
)
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
checkpoint = torch.load(weights, map_location=DEVICE)
model.load_state_dict(checkpoint["model"])
model = model.to(DEVICE)
model.eval()
prompt = f"A:{emo}\n"
tunes = ""
num_tunes = args.num_tunes
max_patch = args.max_patch
top_p = args.top_p
top_k = args.top_k
temperature = args.temperature
seed = args.seed
show_control_code = args.show_control_code
fname_prefix = emo if args.template else "Melody"
print(" Hyper parms ".center(60, "#"), "\n")
args_dict: dict = vars(args)
for arg in args_dict.keys():
print(f"{arg}: {str(args_dict[arg])}")
print("\n", " Output tunes ".center(60, "#"))
start_time = time.time()
for i in range(num_tunes):
title = f"T:{fname_prefix} Fragment\n"
artist = f"C:Generated by AI\n"
tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}"
lines = re.split(r"(\n)", tune)
tune = ""
skip = False
for line in lines:
if show_control_code or line[:2] not in ["S:", "B:", "E:", "D:"]:
if not skip:
print(line, end="")
tune += line
skip = False
else:
skip = True
input_patches = torch.tensor(
[patchilizer.encode(prompt, add_special_patches=True)[:-1]],
device=DEVICE,
)
if tune == "":
tokens = None
else:
prefix = patchilizer.decode(input_patches[0])
remaining_tokens = prompt[len(prefix) :]
tokens = torch.tensor(
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
device=DEVICE,
)
while input_patches.shape[1] < max_patch:
predicted_patch, seed = model.generate(
input_patches,
tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
seed=seed,
)
tokens = None
if predicted_patch[0] != patchilizer.eos_token_id:
next_bar = patchilizer.decode([predicted_patch])
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:", "D:"]:
print(next_bar, end="")
tune += next_bar
if next_bar == "":
break
next_bar = remaining_tokens + next_bar
remaining_tokens = ""
predicted_patch = torch.tensor(
patchilizer.bar2patch(next_bar),
device=DEVICE,
).unsqueeze(0)
input_patches = torch.cat(
[input_patches, predicted_patch.unsqueeze(0)],
dim=1,
)
else:
break
tunes += f"{tune}\n\n"
print("\n")
# fix tempo
if fix_tempo != None:
tempo = f"Q:{fix_tempo}\n"
else:
tempo = f"Q:{random.randint(88, 132)}\n"
if emo == "Q1":
tempo = f"Q:{random.randint(160, 184)}\n"
elif emo == "Q2":
tempo = f"Q:{random.randint(184, 228)}\n"
elif emo == "Q3":
tempo = f"Q:{random.randint(40, 69)}\n"
elif emo == "Q4":
tempo = f"Q:{random.randint(40, 69)}\n"
Q_val = get_abc_key_val(tunes, "Q")
if Q_val:
tunes = tunes.replace(f"Q:{Q_val}\n", "")
K_val = get_abc_key_val(tunes)
if K_val == "none":
K_val = "C"
tunes = tunes.replace("K:none\n", f"K:{K_val}\n")
tunes = tunes.replace(f"A:{emo}\n", tempo)
# fix mode:major/minor
mode = "major" if emo == "Q1" or emo == "Q4" else "minor"
if (mode == "major") and ("m" in K_val):
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n")
elif (mode == "minor") and (not "m" in K_val):
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n")
print("Generation time: {:.2f} seconds".format(time.time() - start_time))
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
try:
# fix avg_pitch (octave)
if fix_pitch != None:
if fix_pitch:
tunes, xml = transpose_octaves_abc(
tunes,
f"{outdir}/{timestamp}.musicxml",
fix_pitch,
)
tunes = tunes.replace(title + title, title)
os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml"
else:
if mode == "minor":
offset = -12
if emo == "Q2":
offset -= 12
tunes, xml = transpose_octaves_abc(
tunes,
f"{outdir}/{timestamp}.musicxml",
offset,
)
tunes = tunes.replace(title + title, title)
os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml"
else:
xml = abc2xml(tunes, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
audio = xml2(xml, "wav")
if fix_volume != None:
if fix_volume:
adjust_volume(audio, fix_volume)
elif os.path.exists(audio):
if emo == "Q1":
adjust_volume(audio, 5)
elif emo == "Q2":
adjust_volume(audio, 10)
mxl = xml2(xml, "mxl")
midi = xml2(xml, "mid")
pdf, jpg = xml2img(xml)
return audio, midi, pdf, xml, mxl, tunes, jpg
except Exception as e:
print(f"{e}")
return generate_music(args, emo, weights)