Spaces:
Running
Running
import re | |
import os | |
import math | |
import torch | |
import random | |
from config import * | |
from unidecode import unidecode | |
from torch.nn import functional as F | |
from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config | |
try: | |
import torch.distributed.nn | |
from torch import distributed as dist | |
has_distributed = True | |
except ImportError: | |
has_distributed = False | |
try: | |
import horovod.torch as hvd | |
except ImportError: | |
hvd = None | |
class ClipLoss(torch.nn.Module): | |
def __init__( | |
self, | |
local_loss=False, | |
gather_with_grad=False, | |
cache_labels=False, | |
rank=0, | |
world_size=1, | |
use_horovod=False, | |
): | |
super().__init__() | |
self.local_loss = local_loss | |
self.gather_with_grad = gather_with_grad | |
self.cache_labels = cache_labels | |
self.rank = rank | |
self.world_size = world_size | |
self.use_horovod = use_horovod | |
# cache state | |
self.prev_num_logits = 0 | |
self.labels = {} | |
def gather_features( | |
self, | |
image_features, | |
text_features, | |
local_loss=False, | |
gather_with_grad=False, | |
rank=0, | |
world_size=1, | |
use_horovod=False | |
): | |
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' | |
if use_horovod: | |
assert hvd is not None, 'Please install horovod' | |
if gather_with_grad: | |
all_image_features = hvd.allgather(image_features) | |
all_text_features = hvd.allgather(text_features) | |
else: | |
with torch.no_grad(): | |
all_image_features = hvd.allgather(image_features) | |
all_text_features = hvd.allgather(text_features) | |
if not local_loss: | |
# ensure grads for local rank when all_* features don't have a gradient | |
gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) | |
gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) | |
gathered_image_features[rank] = image_features | |
gathered_text_features[rank] = text_features | |
all_image_features = torch.cat(gathered_image_features, dim=0) | |
all_text_features = torch.cat(gathered_text_features, dim=0) | |
else: | |
# We gather tensors from all gpus | |
if gather_with_grad: | |
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) | |
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) | |
else: | |
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] | |
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] | |
dist.all_gather(gathered_image_features, image_features) | |
dist.all_gather(gathered_text_features, text_features) | |
if not local_loss: | |
# ensure grads for local rank when all_* features don't have a gradient | |
gathered_image_features[rank] = image_features | |
gathered_text_features[rank] = text_features | |
all_image_features = torch.cat(gathered_image_features, dim=0) | |
all_text_features = torch.cat(gathered_text_features, dim=0) | |
return all_image_features, all_text_features | |
def get_ground_truth(self, device, num_logits) -> torch.Tensor: | |
# calculated ground-truth and cache if enabled | |
if self.prev_num_logits != num_logits or device not in self.labels: | |
labels = torch.arange(num_logits, device=device, dtype=torch.long) | |
if self.world_size > 1 and self.local_loss: | |
labels = labels + num_logits * self.rank | |
if self.cache_labels: | |
self.labels[device] = labels | |
self.prev_num_logits = num_logits | |
else: | |
labels = self.labels[device] | |
return labels | |
def get_logits(self, image_features, text_features, logit_scale): | |
if self.world_size > 1: | |
all_image_features, all_text_features = self.gather_features( | |
image_features, text_features, | |
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) | |
if self.local_loss: | |
logits_per_image = logit_scale * image_features @ all_text_features.T | |
logits_per_text = logit_scale * text_features @ all_image_features.T | |
else: | |
logits_per_image = logit_scale * all_image_features @ all_text_features.T | |
logits_per_text = logits_per_image.T | |
else: | |
logits_per_image = logit_scale * image_features @ text_features.T | |
logits_per_text = logit_scale * text_features @ image_features.T | |
return logits_per_image, logits_per_text | |
def forward(self, image_features, text_features, logit_scale, output_dict=False): | |
device = image_features.device | |
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) | |
labels = self.get_ground_truth(device, logits_per_image.shape[0]) | |
total_loss = ( | |
F.cross_entropy(logits_per_image, labels) + | |
F.cross_entropy(logits_per_text, labels) | |
) / 2 | |
return {"contrastive_loss": total_loss} if output_dict else total_loss | |
class M3Patchilizer: | |
def __init__(self): | |
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] | |
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')' | |
self.pad_token_id = 0 | |
self.bos_token_id = 1 | |
self.eos_token_id = 2 | |
self.mask_token_id = 3 | |
def split_bars(self, body): | |
bars = re.split(self.regexPattern, ''.join(body)) | |
bars = list(filter(None, bars)) # remove empty strings | |
if bars[0] in self.delimiters: | |
bars[1] = bars[0] + bars[1] | |
bars = bars[1:] | |
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] | |
return bars | |
def bar2patch(self, bar, patch_size=PATCH_SIZE): | |
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] | |
patch = patch[:patch_size] | |
patch += [self.pad_token_id] * (patch_size - len(patch)) | |
return patch | |
def patch2bar(self, patch): | |
return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch) | |
def encode(self, | |
item, | |
patch_size=PATCH_SIZE, | |
add_special_patches=False, | |
truncate=False, | |
random_truncate=False): | |
item = item.replace("L:1/8\n", "") | |
item = unidecode(item) | |
lines = re.findall(r'.*?\n|.*$', item) | |
lines = list(filter(None, lines)) # remove empty lines | |
patches = [] | |
if lines[0].split(" ")[0] == "ticks_per_beat": | |
patch = "" | |
for line in lines: | |
if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2): | |
patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:]) | |
else: | |
if patch: | |
patches.append(patch) | |
patch = line | |
if patch!="": | |
patches.append(patch) | |
else: | |
for line in lines: | |
if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')): | |
patches.append(line) | |
else: | |
bars = self.split_bars(line) | |
if bars: | |
bars[-1] += '\n' | |
patches.extend(bars) | |
if add_special_patches: | |
bos_patch = chr(self.bos_token_id) * patch_size | |
eos_patch = chr(self.eos_token_id) * patch_size | |
patches = [bos_patch] + patches + [eos_patch] | |
if len(patches) > PATCH_LENGTH and truncate: | |
choices = ["head", "tail", "middle"] | |
choice = random.choice(choices) | |
if choice=="head" or random_truncate==False: | |
patches = patches[:PATCH_LENGTH] | |
elif choice=="tail": | |
patches = patches[-PATCH_LENGTH:] | |
else: | |
start = random.randint(1, len(patches)-PATCH_LENGTH) | |
patches = patches[start:start+PATCH_LENGTH] | |
patches = [self.bar2patch(patch) for patch in patches] | |
return patches | |
def decode(self, patches): | |
return ''.join(self.patch2bar(patch) for patch in patches) | |
class M3PatchEncoder(PreTrainedModel): | |
def __init__(self, config): | |
super(M3PatchEncoder, self).__init__(config) | |
self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE) | |
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) | |
self.base = BertModel(config=config) | |
self.pad_token_id = 0 | |
self.bos_token_id = 1 | |
self.eos_token_id = 2 | |
self.mask_token_id = 3 | |
def forward(self, | |
input_patches, # [batch_size, seq_length, hidden_size] | |
input_masks): # [batch_size, seq_length] | |
# Transform input_patches into embeddings | |
input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128) | |
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor) | |
input_patches = self.patch_embedding(input_patches.to(self.device)) | |
# Apply BERT model to input_patches and input_masks | |
return self.base(inputs_embeds=input_patches, attention_mask=input_masks) | |
class M3TokenDecoder(PreTrainedModel): | |
def __init__(self, config): | |
super(M3TokenDecoder, self).__init__(config) | |
self.base = GPT2LMHeadModel(config=config) | |
self.pad_token_id = 0 | |
self.bos_token_id = 1 | |
self.eos_token_id = 2 | |
self.mask_token_id = 3 | |
def forward(self, | |
patch_features, # [batch_size, hidden_size] | |
target_patches): # [batch_size, seq_length] | |
# get input embeddings | |
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight) | |
# concatenate the encoded patches with the input embeddings | |
inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1) | |
# preparing the labels for model training | |
target_masks = target_patches == self.pad_token_id | |
target_patches = target_patches.clone().masked_fill_(target_masks, -100) | |
# get the attention mask | |
target_masks = ~target_masks | |
target_masks = target_masks.type(torch.int) | |
return self.base(inputs_embeds=inputs_embeds, | |
attention_mask=target_masks, | |
labels=target_patches) | |
def generate(self, | |
patch_feature, | |
tokens): | |
# reshape the patch_feature and tokens | |
patch_feature = patch_feature.reshape(1, 1, -1) | |
tokens = tokens.reshape(1, -1) | |
# get input embeddings | |
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight) | |
# concatenate the encoded patches with the input embeddings | |
tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1) | |
# get the outputs from the model | |
outputs = self.base(inputs_embeds=tokens) | |
# get the probabilities of the next token | |
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1) | |
return probs.detach().cpu().numpy() | |
class M3Model(PreTrainedModel): | |
def __init__(self, encoder_config, decoder_config): | |
super(M3Model, self).__init__(encoder_config) | |
self.encoder = M3PatchEncoder(encoder_config) | |
self.decoder = M3TokenDecoder(decoder_config) | |
self.pad_token_id = 0 | |
self.bos_token_id = 1 | |
self.eos_token_id = 2 | |
self.mask_token_id = 3 | |
def forward(self, | |
input_patches, # [batch_size, seq_length, hidden_size] | |
input_masks, # [batch_size, seq_length] | |
selected_indices, # [batch_size, seq_length] | |
target_patches): # [batch_size, seq_length, hidden_size] | |
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device) | |
input_masks = input_masks.to(self.device) | |
selected_indices = selected_indices.to(self.device) | |
target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device) | |
# Pass the input_patches and input_masks through the encoder | |
outputs = self.encoder(input_patches, input_masks)["last_hidden_state"] | |
# Use selected_indices to form target_patches | |
target_patches = target_patches[selected_indices.bool()] | |
patch_features = outputs[selected_indices.bool()] | |
# Pass patch_features and target_patches through the decoder | |
return self.decoder(patch_features, target_patches) | |
class CLaMP3Model(PreTrainedModel): | |
def __init__(self, | |
audio_config, | |
symbolic_config, | |
global_rank=None, | |
world_size=None, | |
text_model_name=TEXT_MODEL_NAME, | |
hidden_size=CLAMP3_HIDDEN_SIZE, | |
load_m3=CLAMP3_LOAD_M3): | |
super(CLaMP3Model, self).__init__(symbolic_config) | |
self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model | |
self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections | |
torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution | |
self.symbolic_model = M3PatchEncoder(symbolic_config) # Initialize the symbolic model | |
self.symbolic_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for symbolic projections | |
torch.nn.init.normal_(self.symbolic_proj.weight, std=0.02) # Initialize weights with normal distribution | |
self.audio_model = BertModel(audio_config) # Initialize the audio model | |
self.audio_proj = torch.nn.Linear(audio_config.hidden_size, hidden_size) # Linear layer for audio projections | |
torch.nn.init.normal_(self.audio_proj.weight, std=0.02) # Initialize weights with normal distribution | |
if global_rank==None or world_size==None: | |
global_rank = 0 | |
world_size = 1 | |
self.loss_fn = ClipLoss(local_loss=False, | |
gather_with_grad=True, | |
cache_labels=False, | |
rank=global_rank, | |
world_size=world_size, | |
use_horovod=False) | |
if load_m3 and os.path.exists(M3_WEIGHTS_PATH): | |
checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True) | |
decoder_config = GPT2Config(vocab_size=128, | |
n_positions=PATCH_SIZE, | |
n_embd=M3_HIDDEN_SIZE, | |
n_layer=TOKEN_NUM_LAYERS, | |
n_head=M3_HIDDEN_SIZE//64, | |
n_inner=M3_HIDDEN_SIZE*4) | |
model = M3Model(symbolic_config, decoder_config) | |
model.load_state_dict(checkpoint['model']) | |
self.symbolic_model = model.encoder | |
model = None | |
print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") | |
def set_trainable(self, freeze_list): | |
if "text_model" in freeze_list: | |
self.text_model.eval() | |
for param in self.text_model.parameters(): | |
param.requires_grad = False | |
print("Text Model Frozen") | |
else: | |
self.text_model.train() | |
for param in self.text_model.parameters(): | |
param.requires_grad = True | |
print("Text Model Training") | |
if "text_proj" in freeze_list: | |
self.text_proj.eval() | |
for param in self.text_proj.parameters(): | |
param.requires_grad = False | |
print("Text Projection Layer Frozen") | |
else: | |
self.text_proj.train() | |
for param in self.text_proj.parameters(): | |
param.requires_grad = True | |
print("Text Projection Layer Training") | |
if "symbolic_model" in freeze_list: | |
self.symbolic_model.eval() | |
for param in self.symbolic_model.parameters(): | |
param.requires_grad = False | |
print("Symbolic Model Frozen") | |
else: | |
self.symbolic_model.train() | |
for param in self.symbolic_model.parameters(): | |
param.requires_grad = True | |
print("Symbolic Model Training") | |
if "symbolic_proj" in freeze_list: | |
self.symbolic_proj.eval() | |
for param in self.symbolic_proj.parameters(): | |
param.requires_grad = False | |
print("Symbolic Projection Layer Frozen") | |
else: | |
self.symbolic_proj.train() | |
for param in self.symbolic_proj.parameters(): | |
param.requires_grad = True | |
print("Symbolic Projection Layer Training") | |
if "audio_model" in freeze_list: | |
self.audio_model.eval() | |
for param in self.audio_model.parameters(): | |
param.requires_grad = False | |
print("Audio Model Frozen") | |
else: | |
self.audio_model.train() | |
for param in self.audio_model.parameters(): | |
param.requires_grad = True | |
print("Audio Model Training") | |
if "audio_proj" in freeze_list: | |
self.audio_proj.eval() | |
for param in self.audio_proj.parameters(): | |
param.requires_grad = False | |
print("Audio Projection Layer Frozen") | |
else: | |
self.audio_proj.train() | |
for param in self.audio_proj.parameters(): | |
param.requires_grad = True | |
print("Audio Projection Layer Training") | |
def avg_pooling(self, input_features, input_masks): | |
input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension | |
input_features = input_features * input_masks # apply mask to input_features | |
avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling | |
return avg_pool | |
def get_text_features(self, | |
text_inputs, | |
text_masks, | |
get_global=False): | |
text_features = self.text_model(text_inputs.to(self.device), | |
attention_mask=text_masks.to(self.device))['last_hidden_state'] | |
if get_global: | |
text_features = self.avg_pooling(text_features, text_masks) | |
text_features = self.text_proj(text_features) | |
return text_features | |
def get_symbolic_features(self, | |
symbolic_inputs, | |
symbolic_masks, | |
get_global=False): | |
symbolic_features = self.symbolic_model(symbolic_inputs.to(self.device), | |
symbolic_masks.to(self.device))['last_hidden_state'] | |
if get_global: | |
symbolic_features = self.avg_pooling(symbolic_features, symbolic_masks) | |
symbolic_features = self.symbolic_proj(symbolic_features) | |
return symbolic_features | |
def get_audio_features(self, | |
audio_inputs, | |
audio_masks, | |
get_global=False): | |
audio_features = self.audio_model(inputs_embeds=audio_inputs.to(self.device), | |
attention_mask=audio_masks.to(self.device))['last_hidden_state'] | |
if get_global: | |
audio_features = self.avg_pooling(audio_features, audio_masks) | |
audio_features = self.audio_proj(audio_features) | |
return audio_features | |
def forward(self, | |
text_inputs, # [batch_size, seq_length] | |
text_masks, # [batch_size, seq_length] | |
music_inputs, # [batch_size, seq_length, hidden_size] | |
music_masks, # [batch_size, seq_length] | |
music_modality): # "symbolic" or "audio" | |
# Compute the text features | |
text_features = self.get_text_features(text_inputs, text_masks, get_global=True) | |
# Compute the music features | |
if music_modality=="symbolic": | |
music_features = self.get_symbolic_features(music_inputs, music_masks, get_global=True) | |
elif music_modality=="audio": | |
music_features = self.get_audio_features(music_inputs, music_masks, get_global=True) | |
else: | |
raise ValueError("music_modality must be either 'symbolic' or 'audio'") | |
return self.loss_fn(text_features, | |
music_features, | |
LOGIT_SCALE, | |
output_dict=False) | |
def split_data(data, eval_ratio=EVAL_SPLIT): | |
random.shuffle(data) | |
split_idx = int(len(data)*eval_ratio) | |
eval_set = data[:split_idx] | |
train_set = data[split_idx:] | |
return train_set, eval_set | |
def mask_patches(target_patches, patchilizer, mode): | |
indices = list(range(len(target_patches))) | |
random.shuffle(indices) | |
selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))] | |
sorted_indices = sorted(selected_indices) | |
input_patches = torch.tensor(target_patches) | |
if mode=="eval": | |
choice = "original" | |
else: | |
choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0] | |
if choice=="mask": | |
input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE) | |
elif choice=="shuffle": | |
for idx in sorted_indices: | |
patch = input_patches[idx] | |
try: | |
index_eos = (patch == patchilizer.eos_token_id).nonzero().item() | |
except: | |
index_eos = len(patch) | |
indices = list(range(1, index_eos)) | |
random.shuffle(indices) | |
indices = [0] + indices + list(range(index_eos, len(patch))) | |
input_patches[idx] = patch[indices] | |
selected_indices = torch.zeros(len(target_patches)) | |
selected_indices[sorted_indices] = 1. | |
return input_patches, selected_indices | |
def remove_instrument_info(item): | |
# remove instrument information from symbolic music | |
lines = re.findall(r'.*?\n|.*$', item) | |
lines = list(filter(None, lines)) | |
if lines[0].split(" ")[0] == "ticks_per_beat": | |
type = "mtf" | |
else: | |
type = "abc" | |
cleaned_lines = [] | |
for line in lines: | |
if type=="abc" and line.startswith("V:"): | |
# find the position of " nm=" or " snm=" | |
nm_pos = line.find(" nm=") | |
snm_pos = line.find(" snm=") | |
# keep the part before " nm=" or " snm=" | |
if nm_pos != -1: | |
line = line[:nm_pos] | |
elif snm_pos != -1: | |
line = line[:snm_pos] | |
if nm_pos != -1 or snm_pos != -1: | |
line += "\n" | |
elif type=="mtf" and line.startswith("program_change"): | |
line = " ".join(line.split(" ")[:-1]) + " 0\n" | |
cleaned_lines.append(line) | |
return ''.join(cleaned_lines) |