import re import os import shutil import time import torch import random import argparse import soundfile as sf from transformers import GPT2Config from model import Patchilizer, TunesFormer from convert import abc2xml, xml2img, xml2, transpose_octaves_abc from utils import ( PATCH_NUM_LAYERS, PATCH_LENGTH, CHAR_NUM_LAYERS, PATCH_SIZE, SHARE_WEIGHTS, TEMP_DIR, DEVICE, ) def get_args(parser: argparse.ArgumentParser): parser.add_argument( "-num_tunes", type=int, default=1, help="the number of independently computed returned tunes", ) parser.add_argument( "-max_patch", type=int, default=128, help="integer to define the maximum length in tokens of each tune", ) parser.add_argument( "-top_p", type=float, default=0.8, help="float to define the tokens that are within the sample operation of text generation", ) parser.add_argument( "-top_k", type=int, default=8, help="integer to define the tokens that are within the sample operation of text generation", ) parser.add_argument( "-temperature", type=float, default=1.2, help="the temperature of the sampling operation", ) parser.add_argument("-seed", type=int, default=None, help="seed for randomstate") parser.add_argument( "-show_control_code", type=bool, default=False, help="whether to show control code", ) parser.add_argument( "-template", type=bool, default=True, help="whether to generate by template", ) return parser.parse_args() def get_abc_key_val(text: str, key="K"): pattern = re.escape(key) + r":(.*?)\n" match = re.search(pattern, text) if match: return match.group(1).strip() else: return None def adjust_volume(in_audio: str, dB_change: int): y, sr = sf.read(in_audio) sf.write(in_audio, y * 10 ** (dB_change / 20), sr) def clean_dir(dir_path: str): if os.path.exists(dir_path): shutil.rmtree(dir_path) os.makedirs(dir_path) def generate_music( args, emo: str, weights: str, outdir=f"{TEMP_DIR}/output", fix_tempo=None, fix_pitch=None, fix_volume=None, ): clean_dir(outdir) patchilizer = Patchilizer() patch_config = GPT2Config( num_hidden_layers=PATCH_NUM_LAYERS, max_length=PATCH_LENGTH, max_position_embeddings=PATCH_LENGTH, vocab_size=1, ) char_config = GPT2Config( num_hidden_layers=CHAR_NUM_LAYERS, max_length=PATCH_SIZE, max_position_embeddings=PATCH_SIZE, vocab_size=128, ) model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) checkpoint = torch.load(weights, map_location=DEVICE) model.load_state_dict(checkpoint["model"]) model = model.to(DEVICE) model.eval() prompt = f"A:{emo}\n" tunes = "" num_tunes = args.num_tunes max_patch = args.max_patch top_p = args.top_p top_k = args.top_k temperature = args.temperature seed = args.seed show_control_code = args.show_control_code fname_prefix = emo if args.template else "Melody" print(" Hyper parms ".center(60, "#"), "\n") args_dict: dict = vars(args) for arg in args_dict.keys(): print(f"{arg}: {str(args_dict[arg])}") print("\n", " Output tunes ".center(60, "#")) start_time = time.time() for i in range(num_tunes): title = f"T:{fname_prefix} Fragment\n" artist = f"C:Generated by AI\n" tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}" lines = re.split(r"(\n)", tune) tune = "" skip = False for line in lines: if show_control_code or line[:2] not in ["S:", "B:", "E:", "D:"]: if not skip: print(line, end="") tune += line skip = False else: skip = True input_patches = torch.tensor( [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE, ) if tune == "": tokens = None else: prefix = patchilizer.decode(input_patches[0]) remaining_tokens = prompt[len(prefix) :] tokens = torch.tensor( [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens], device=DEVICE, ) while input_patches.shape[1] < max_patch: predicted_patch, seed = model.generate( input_patches, tokens, top_p=top_p, top_k=top_k, temperature=temperature, seed=seed, ) tokens = None if predicted_patch[0] != patchilizer.eos_token_id: next_bar = patchilizer.decode([predicted_patch]) if show_control_code or next_bar[:2] not in ["S:", "B:", "E:", "D:"]: print(next_bar, end="") tune += next_bar if next_bar == "": break next_bar = remaining_tokens + next_bar remaining_tokens = "" predicted_patch = torch.tensor( patchilizer.bar2patch(next_bar), device=DEVICE, ).unsqueeze(0) input_patches = torch.cat( [input_patches, predicted_patch.unsqueeze(0)], dim=1, ) else: break tunes += f"{tune}\n\n" print("\n") # fix tempo if fix_tempo != None: tempo = f"Q:{fix_tempo}\n" else: tempo = f"Q:{random.randint(88, 132)}\n" if emo == "Q1": tempo = f"Q:{random.randint(160, 184)}\n" elif emo == "Q2": tempo = f"Q:{random.randint(184, 228)}\n" elif emo == "Q3": tempo = f"Q:{random.randint(40, 69)}\n" elif emo == "Q4": tempo = f"Q:{random.randint(40, 69)}\n" Q_val = get_abc_key_val(tunes, "Q") if Q_val: tunes = tunes.replace(f"Q:{Q_val}\n", "") K_val = get_abc_key_val(tunes) if K_val == "none": K_val = "C" tunes = tunes.replace("K:none\n", f"K:{K_val}\n") tunes = tunes.replace(f"A:{emo}\n", tempo) # fix mode:major/minor mode = "major" if emo == "Q1" or emo == "Q4" else "minor" if (mode == "major") and ("m" in K_val): tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n") elif (mode == "minor") and (not "m" in K_val): tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n") print("Generation time: {:.2f} seconds".format(time.time() - start_time)) timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) try: # fix avg_pitch (octave) if fix_pitch != None: if fix_pitch: tunes, xml = transpose_octaves_abc( tunes, f"{outdir}/{timestamp}.musicxml", fix_pitch, ) tunes = tunes.replace(title + title, title) os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml" else: if mode == "minor": offset = -12 if emo == "Q2": offset -= 12 tunes, xml = transpose_octaves_abc( tunes, f"{outdir}/{timestamp}.musicxml", offset, ) tunes = tunes.replace(title + title, title) os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml" else: xml = abc2xml(tunes, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml") audio = xml2(xml, "wav") if fix_volume != None: if fix_volume: adjust_volume(audio, fix_volume) elif os.path.exists(audio): if emo == "Q1": adjust_volume(audio, 5) elif emo == "Q2": adjust_volume(audio, 10) mxl = xml2(xml, "mxl") midi = xml2(xml, "mid") pdf, jpg = xml2img(xml) return audio, midi, pdf, xml, mxl, tunes, jpg except Exception as e: print(f"{e}") return generate_music(args, emo, weights)