Spaces:
Running
Running
import re | |
import os | |
import shutil | |
import time | |
import torch | |
import random | |
import argparse | |
import soundfile as sf | |
from transformers import GPT2Config | |
from model import Patchilizer, TunesFormer | |
from convert import abc2xml, xml2img, xml2, transpose_octaves_abc | |
from utils import ( | |
PATCH_NUM_LAYERS, | |
PATCH_LENGTH, | |
CHAR_NUM_LAYERS, | |
PATCH_SIZE, | |
SHARE_WEIGHTS, | |
TEMP_DIR, | |
DEVICE, | |
) | |
def get_args(parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"-num_tunes", | |
type=int, | |
default=1, | |
help="the number of independently computed returned tunes", | |
) | |
parser.add_argument( | |
"-max_patch", | |
type=int, | |
default=128, | |
help="integer to define the maximum length in tokens of each tune", | |
) | |
parser.add_argument( | |
"-top_p", | |
type=float, | |
default=0.8, | |
help="float to define the tokens that are within the sample operation of text generation", | |
) | |
parser.add_argument( | |
"-top_k", | |
type=int, | |
default=8, | |
help="integer to define the tokens that are within the sample operation of text generation", | |
) | |
parser.add_argument( | |
"-temperature", | |
type=float, | |
default=1.2, | |
help="the temperature of the sampling operation", | |
) | |
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate") | |
parser.add_argument( | |
"-show_control_code", | |
type=bool, | |
default=False, | |
help="whether to show control code", | |
) | |
parser.add_argument( | |
"-template", | |
type=bool, | |
default=True, | |
help="whether to generate by template", | |
) | |
return parser.parse_args() | |
def get_abc_key_val(text: str, key="K"): | |
pattern = re.escape(key) + r":(.*?)\n" | |
match = re.search(pattern, text) | |
if match: | |
return match.group(1).strip() | |
else: | |
return None | |
def adjust_volume(in_audio: str, dB_change: int): | |
y, sr = sf.read(in_audio) | |
sf.write(in_audio, y * 10 ** (dB_change / 20), sr) | |
def clean_dir(dir_path: str): | |
if os.path.exists(dir_path): | |
shutil.rmtree(dir_path) | |
os.makedirs(dir_path) | |
def generate_music( | |
args, | |
emo: str, | |
weights: str, | |
outdir=f"{TEMP_DIR}/output", | |
fix_tempo=None, | |
fix_pitch=None, | |
fix_volume=None, | |
): | |
clean_dir(outdir) | |
patchilizer = Patchilizer() | |
patch_config = GPT2Config( | |
num_hidden_layers=PATCH_NUM_LAYERS, | |
max_length=PATCH_LENGTH, | |
max_position_embeddings=PATCH_LENGTH, | |
vocab_size=1, | |
) | |
char_config = GPT2Config( | |
num_hidden_layers=CHAR_NUM_LAYERS, | |
max_length=PATCH_SIZE, | |
max_position_embeddings=PATCH_SIZE, | |
vocab_size=128, | |
) | |
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) | |
checkpoint = torch.load(weights, map_location=DEVICE) | |
model.load_state_dict(checkpoint["model"]) | |
model = model.to(DEVICE) | |
model.eval() | |
prompt = f"A:{emo}\n" | |
tunes = "" | |
num_tunes = args.num_tunes | |
max_patch = args.max_patch | |
top_p = args.top_p | |
top_k = args.top_k | |
temperature = args.temperature | |
seed = args.seed | |
show_control_code = args.show_control_code | |
fname_prefix = emo if args.template else "Melody" | |
print(" Hyper parms ".center(60, "#"), "\n") | |
args_dict: dict = vars(args) | |
for arg in args_dict.keys(): | |
print(f"{arg}: {str(args_dict[arg])}") | |
print("\n", " Output tunes ".center(60, "#")) | |
start_time = time.time() | |
for i in range(num_tunes): | |
title = f"T:{fname_prefix} Fragment\n" | |
artist = f"C:Generated by AI\n" | |
tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}" | |
lines = re.split(r"(\n)", tune) | |
tune = "" | |
skip = False | |
for line in lines: | |
if show_control_code or line[:2] not in ["S:", "B:", "E:", "D:"]: | |
if not skip: | |
print(line, end="") | |
tune += line | |
skip = False | |
else: | |
skip = True | |
input_patches = torch.tensor( | |
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], | |
device=DEVICE, | |
) | |
if tune == "": | |
tokens = None | |
else: | |
prefix = patchilizer.decode(input_patches[0]) | |
remaining_tokens = prompt[len(prefix) :] | |
tokens = torch.tensor( | |
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens], | |
device=DEVICE, | |
) | |
while input_patches.shape[1] < max_patch: | |
predicted_patch, seed = model.generate( | |
input_patches, | |
tokens, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
seed=seed, | |
) | |
tokens = None | |
if predicted_patch[0] != patchilizer.eos_token_id: | |
next_bar = patchilizer.decode([predicted_patch]) | |
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:", "D:"]: | |
print(next_bar, end="") | |
tune += next_bar | |
if next_bar == "": | |
break | |
next_bar = remaining_tokens + next_bar | |
remaining_tokens = "" | |
predicted_patch = torch.tensor( | |
patchilizer.bar2patch(next_bar), | |
device=DEVICE, | |
).unsqueeze(0) | |
input_patches = torch.cat( | |
[input_patches, predicted_patch.unsqueeze(0)], | |
dim=1, | |
) | |
else: | |
break | |
tunes += f"{tune}\n\n" | |
print("\n") | |
# fix tempo | |
if fix_tempo != None: | |
tempo = f"Q:{fix_tempo}\n" | |
else: | |
tempo = f"Q:{random.randint(88, 132)}\n" | |
if emo == "Q1": | |
tempo = f"Q:{random.randint(160, 184)}\n" | |
elif emo == "Q2": | |
tempo = f"Q:{random.randint(184, 228)}\n" | |
elif emo == "Q3": | |
tempo = f"Q:{random.randint(40, 69)}\n" | |
elif emo == "Q4": | |
tempo = f"Q:{random.randint(40, 69)}\n" | |
Q_val = get_abc_key_val(tunes, "Q") | |
if Q_val: | |
tunes = tunes.replace(f"Q:{Q_val}\n", "") | |
K_val = get_abc_key_val(tunes) | |
if K_val == "none": | |
K_val = "C" | |
tunes = tunes.replace("K:none\n", f"K:{K_val}\n") | |
tunes = tunes.replace(f"A:{emo}\n", tempo) | |
# fix mode:major/minor | |
mode = "major" if emo == "Q1" or emo == "Q4" else "minor" | |
if (mode == "major") and ("m" in K_val): | |
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n") | |
elif (mode == "minor") and (not "m" in K_val): | |
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n") | |
print("Generation time: {:.2f} seconds".format(time.time() - start_time)) | |
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) | |
try: | |
# fix avg_pitch (octave) | |
if fix_pitch != None: | |
if fix_pitch: | |
tunes, xml = transpose_octaves_abc( | |
tunes, | |
f"{outdir}/{timestamp}.musicxml", | |
fix_pitch, | |
) | |
tunes = tunes.replace(title + title, title) | |
os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") | |
xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml" | |
else: | |
if mode == "minor": | |
offset = -12 | |
if emo == "Q2": | |
offset -= 12 | |
tunes, xml = transpose_octaves_abc( | |
tunes, | |
f"{outdir}/{timestamp}.musicxml", | |
offset, | |
) | |
tunes = tunes.replace(title + title, title) | |
os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") | |
xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml" | |
else: | |
xml = abc2xml(tunes, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") | |
audio = xml2(xml, "wav") | |
if fix_volume != None: | |
if fix_volume: | |
adjust_volume(audio, fix_volume) | |
elif os.path.exists(audio): | |
if emo == "Q1": | |
adjust_volume(audio, 5) | |
elif emo == "Q2": | |
adjust_volume(audio, 10) | |
mxl = xml2(xml, "mxl") | |
midi = xml2(xml, "mid") | |
pdf, jpg = xml2img(xml) | |
return audio, midi, pdf, xml, mxl, tunes, jpg | |
except Exception as e: | |
print(f"{e}") | |
return generate_music(args, emo, weights) | |