clamp3 / utils.py
sander-wood's picture
Upload 8 files
ad822ab verified
import re
import os
import math
import torch
import random
from config import *
from unidecode import unidecode
from torch.nn import functional as F
from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
try:
import horovod.torch as hvd
except ImportError:
hvd = None
class ClipLoss(torch.nn.Module):
def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.prev_num_logits = 0
self.labels = {}
def gather_features(
self,
image_features,
text_features,
local_loss=False,
gather_with_grad=False,
rank=0,
world_size=1,
use_horovod=False
):
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = self.gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
return logits_per_image, logits_per_text
def forward(self, image_features, text_features, logit_scale, output_dict=False):
device = image_features.device
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
return {"contrastive_loss": total_loss} if output_dict else total_loss
class M3Patchilizer:
def __init__(self):
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
self.mask_token_id = 3
def split_bars(self, body):
bars = re.split(self.regexPattern, ''.join(body))
bars = list(filter(None, bars)) # remove empty strings
if bars[0] in self.delimiters:
bars[1] = bars[0] + bars[1]
bars = bars[1:]
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
return bars
def bar2patch(self, bar, patch_size=PATCH_SIZE):
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
patch = patch[:patch_size]
patch += [self.pad_token_id] * (patch_size - len(patch))
return patch
def patch2bar(self, patch):
return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch)
def encode(self,
item,
patch_size=PATCH_SIZE,
add_special_patches=False,
truncate=False,
random_truncate=False):
item = item.replace("L:1/8\n", "")
item = unidecode(item)
lines = re.findall(r'.*?\n|.*$', item)
lines = list(filter(None, lines)) # remove empty lines
patches = []
if lines[0].split(" ")[0] == "ticks_per_beat":
patch = ""
for line in lines:
if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2):
patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:])
else:
if patch:
patches.append(patch)
patch = line
if patch!="":
patches.append(patch)
else:
for line in lines:
if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')):
patches.append(line)
else:
bars = self.split_bars(line)
if bars:
bars[-1] += '\n'
patches.extend(bars)
if add_special_patches:
bos_patch = chr(self.bos_token_id) * patch_size
eos_patch = chr(self.eos_token_id) * patch_size
patches = [bos_patch] + patches + [eos_patch]
if len(patches) > PATCH_LENGTH and truncate:
choices = ["head", "tail", "middle"]
choice = random.choice(choices)
if choice=="head" or random_truncate==False:
patches = patches[:PATCH_LENGTH]
elif choice=="tail":
patches = patches[-PATCH_LENGTH:]
else:
start = random.randint(1, len(patches)-PATCH_LENGTH)
patches = patches[start:start+PATCH_LENGTH]
patches = [self.bar2patch(patch) for patch in patches]
return patches
def decode(self, patches):
return ''.join(self.patch2bar(patch) for patch in patches)
class M3PatchEncoder(PreTrainedModel):
def __init__(self, config):
super(M3PatchEncoder, self).__init__(config)
self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE)
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
self.base = BertModel(config=config)
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
self.mask_token_id = 3
def forward(self,
input_patches, # [batch_size, seq_length, hidden_size]
input_masks): # [batch_size, seq_length]
# Transform input_patches into embeddings
input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128)
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor)
input_patches = self.patch_embedding(input_patches.to(self.device))
# Apply BERT model to input_patches and input_masks
return self.base(inputs_embeds=input_patches, attention_mask=input_masks)
class M3TokenDecoder(PreTrainedModel):
def __init__(self, config):
super(M3TokenDecoder, self).__init__(config)
self.base = GPT2LMHeadModel(config=config)
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
self.mask_token_id = 3
def forward(self,
patch_features, # [batch_size, hidden_size]
target_patches): # [batch_size, seq_length]
# get input embeddings
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
# concatenate the encoded patches with the input embeddings
inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
# preparing the labels for model training
target_masks = target_patches == self.pad_token_id
target_patches = target_patches.clone().masked_fill_(target_masks, -100)
# get the attention mask
target_masks = ~target_masks
target_masks = target_masks.type(torch.int)
return self.base(inputs_embeds=inputs_embeds,
attention_mask=target_masks,
labels=target_patches)
def generate(self,
patch_feature,
tokens):
# reshape the patch_feature and tokens
patch_feature = patch_feature.reshape(1, 1, -1)
tokens = tokens.reshape(1, -1)
# get input embeddings
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
# concatenate the encoded patches with the input embeddings
tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1)
# get the outputs from the model
outputs = self.base(inputs_embeds=tokens)
# get the probabilities of the next token
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
return probs.detach().cpu().numpy()
class M3Model(PreTrainedModel):
def __init__(self, encoder_config, decoder_config):
super(M3Model, self).__init__(encoder_config)
self.encoder = M3PatchEncoder(encoder_config)
self.decoder = M3TokenDecoder(decoder_config)
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
self.mask_token_id = 3
def forward(self,
input_patches, # [batch_size, seq_length, hidden_size]
input_masks, # [batch_size, seq_length]
selected_indices, # [batch_size, seq_length]
target_patches): # [batch_size, seq_length, hidden_size]
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device)
input_masks = input_masks.to(self.device)
selected_indices = selected_indices.to(self.device)
target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device)
# Pass the input_patches and input_masks through the encoder
outputs = self.encoder(input_patches, input_masks)["last_hidden_state"]
# Use selected_indices to form target_patches
target_patches = target_patches[selected_indices.bool()]
patch_features = outputs[selected_indices.bool()]
# Pass patch_features and target_patches through the decoder
return self.decoder(patch_features, target_patches)
class CLaMP3Model(PreTrainedModel):
def __init__(self,
audio_config,
symbolic_config,
global_rank=None,
world_size=None,
text_model_name=TEXT_MODEL_NAME,
hidden_size=CLAMP3_HIDDEN_SIZE,
load_m3=CLAMP3_LOAD_M3):
super(CLaMP3Model, self).__init__(symbolic_config)
self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model
self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections
torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution
self.symbolic_model = M3PatchEncoder(symbolic_config) # Initialize the symbolic model
self.symbolic_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for symbolic projections
torch.nn.init.normal_(self.symbolic_proj.weight, std=0.02) # Initialize weights with normal distribution
self.audio_model = BertModel(audio_config) # Initialize the audio model
self.audio_proj = torch.nn.Linear(audio_config.hidden_size, hidden_size) # Linear layer for audio projections
torch.nn.init.normal_(self.audio_proj.weight, std=0.02) # Initialize weights with normal distribution
if global_rank==None or world_size==None:
global_rank = 0
world_size = 1
self.loss_fn = ClipLoss(local_loss=False,
gather_with_grad=True,
cache_labels=False,
rank=global_rank,
world_size=world_size,
use_horovod=False)
if load_m3 and os.path.exists(M3_WEIGHTS_PATH):
checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
decoder_config = GPT2Config(vocab_size=128,
n_positions=PATCH_SIZE,
n_embd=M3_HIDDEN_SIZE,
n_layer=TOKEN_NUM_LAYERS,
n_head=M3_HIDDEN_SIZE//64,
n_inner=M3_HIDDEN_SIZE*4)
model = M3Model(symbolic_config, decoder_config)
model.load_state_dict(checkpoint['model'])
self.symbolic_model = model.encoder
model = None
print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
def set_trainable(self, freeze_list):
if "text_model" in freeze_list:
self.text_model.eval()
for param in self.text_model.parameters():
param.requires_grad = False
print("Text Model Frozen")
else:
self.text_model.train()
for param in self.text_model.parameters():
param.requires_grad = True
print("Text Model Training")
if "text_proj" in freeze_list:
self.text_proj.eval()
for param in self.text_proj.parameters():
param.requires_grad = False
print("Text Projection Layer Frozen")
else:
self.text_proj.train()
for param in self.text_proj.parameters():
param.requires_grad = True
print("Text Projection Layer Training")
if "symbolic_model" in freeze_list:
self.symbolic_model.eval()
for param in self.symbolic_model.parameters():
param.requires_grad = False
print("Symbolic Model Frozen")
else:
self.symbolic_model.train()
for param in self.symbolic_model.parameters():
param.requires_grad = True
print("Symbolic Model Training")
if "symbolic_proj" in freeze_list:
self.symbolic_proj.eval()
for param in self.symbolic_proj.parameters():
param.requires_grad = False
print("Symbolic Projection Layer Frozen")
else:
self.symbolic_proj.train()
for param in self.symbolic_proj.parameters():
param.requires_grad = True
print("Symbolic Projection Layer Training")
if "audio_model" in freeze_list:
self.audio_model.eval()
for param in self.audio_model.parameters():
param.requires_grad = False
print("Audio Model Frozen")
else:
self.audio_model.train()
for param in self.audio_model.parameters():
param.requires_grad = True
print("Audio Model Training")
if "audio_proj" in freeze_list:
self.audio_proj.eval()
for param in self.audio_proj.parameters():
param.requires_grad = False
print("Audio Projection Layer Frozen")
else:
self.audio_proj.train()
for param in self.audio_proj.parameters():
param.requires_grad = True
print("Audio Projection Layer Training")
def avg_pooling(self, input_features, input_masks):
input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension
input_features = input_features * input_masks # apply mask to input_features
avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling
return avg_pool
def get_text_features(self,
text_inputs,
text_masks,
get_global=False):
text_features = self.text_model(text_inputs.to(self.device),
attention_mask=text_masks.to(self.device))['last_hidden_state']
if get_global:
text_features = self.avg_pooling(text_features, text_masks)
text_features = self.text_proj(text_features)
return text_features
def get_symbolic_features(self,
symbolic_inputs,
symbolic_masks,
get_global=False):
symbolic_features = self.symbolic_model(symbolic_inputs.to(self.device),
symbolic_masks.to(self.device))['last_hidden_state']
if get_global:
symbolic_features = self.avg_pooling(symbolic_features, symbolic_masks)
symbolic_features = self.symbolic_proj(symbolic_features)
return symbolic_features
def get_audio_features(self,
audio_inputs,
audio_masks,
get_global=False):
audio_features = self.audio_model(inputs_embeds=audio_inputs.to(self.device),
attention_mask=audio_masks.to(self.device))['last_hidden_state']
if get_global:
audio_features = self.avg_pooling(audio_features, audio_masks)
audio_features = self.audio_proj(audio_features)
return audio_features
def forward(self,
text_inputs, # [batch_size, seq_length]
text_masks, # [batch_size, seq_length]
music_inputs, # [batch_size, seq_length, hidden_size]
music_masks, # [batch_size, seq_length]
music_modality): # "symbolic" or "audio"
# Compute the text features
text_features = self.get_text_features(text_inputs, text_masks, get_global=True)
# Compute the music features
if music_modality=="symbolic":
music_features = self.get_symbolic_features(music_inputs, music_masks, get_global=True)
elif music_modality=="audio":
music_features = self.get_audio_features(music_inputs, music_masks, get_global=True)
else:
raise ValueError("music_modality must be either 'symbolic' or 'audio'")
return self.loss_fn(text_features,
music_features,
LOGIT_SCALE,
output_dict=False)
def split_data(data, eval_ratio=EVAL_SPLIT):
random.shuffle(data)
split_idx = int(len(data)*eval_ratio)
eval_set = data[:split_idx]
train_set = data[split_idx:]
return train_set, eval_set
def mask_patches(target_patches, patchilizer, mode):
indices = list(range(len(target_patches)))
random.shuffle(indices)
selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))]
sorted_indices = sorted(selected_indices)
input_patches = torch.tensor(target_patches)
if mode=="eval":
choice = "original"
else:
choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0]
if choice=="mask":
input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE)
elif choice=="shuffle":
for idx in sorted_indices:
patch = input_patches[idx]
try:
index_eos = (patch == patchilizer.eos_token_id).nonzero().item()
except:
index_eos = len(patch)
indices = list(range(1, index_eos))
random.shuffle(indices)
indices = [0] + indices + list(range(index_eos, len(patch)))
input_patches[idx] = patch[indices]
selected_indices = torch.zeros(len(target_patches))
selected_indices[sorted_indices] = 1.
return input_patches, selected_indices
def remove_instrument_info(item):
# remove instrument information from symbolic music
lines = re.findall(r'.*?\n|.*$', item)
lines = list(filter(None, lines))
if lines[0].split(" ")[0] == "ticks_per_beat":
type = "mtf"
else:
type = "abc"
cleaned_lines = []
for line in lines:
if type=="abc" and line.startswith("V:"):
# find the position of " nm=" or " snm="
nm_pos = line.find(" nm=")
snm_pos = line.find(" snm=")
# keep the part before " nm=" or " snm="
if nm_pos != -1:
line = line[:nm_pos]
elif snm_pos != -1:
line = line[:snm_pos]
if nm_pos != -1 or snm_pos != -1:
line += "\n"
elif type=="mtf" and line.startswith("program_change"):
line = " ".join(line.split(" ")[:-1]) + " 0\n"
cleaned_lines.append(line)
return ''.join(cleaned_lines)