from typing import List, Tuple import scipy import numpy as np # SPOTIFY def get_inferred_onsets(onset_roll: np.array, note_roll: np.array, n_diff: int = 2) -> np.array: """ Infer onsets from large changes in note roll matrix amplitudes. Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py :param onset_roll: Onset activation matrix (n_times, n_freqs). :param note_roll: Frame-level note activation matrix (n_times, n_freqs). :param n_diff: Differences used to detect onsets. :return: The maximum between the predicted onsets and its differences. """ diffs = [] for n in range(1, n_diff + 1): frames_appended = np.concatenate([np.zeros((n, note_roll.shape[1])), note_roll]) diffs.append(frames_appended[n:, :] - frames_appended[:-n, :]) frame_diff = np.min(diffs, axis=0) frame_diff[frame_diff < 0] = 0 frame_diff[:n_diff, :] = 0 frame_diff = np.max(onset_roll) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets max_onsets_diff = np.max([onset_roll, frame_diff], axis=0) # use the max of the predicted onsets and the differences return max_onsets_diff def spotify_create_notes( note_roll: np.array, onset_roll: np.array, onset_thresh: float, frame_thresh: float, min_note_len: int, infer_onsets: bool, note_low : int, #self.labeling.midi_centers[0] note_high : int, #self.labeling.midi_centers[-1], melodia_trick: bool = True, energy_tol: int = 11, ) -> List[Tuple[int, int, int, float]]: """Decode raw model output to polyphonic note events Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py Args: note_roll: Frame activation matrix (n_times, n_freqs). onset_roll: Onset activation matrix (n_times, n_freqs). onset_thresh: Minimum amplitude of an onset activation to be considered an onset. frame_thresh: Minimum amplitude of a frame activation for a note to remain "on". min_note_len: Minimum allowed note length in frames. infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes. melodia_trick : Whether to use the melodia trick to better detect notes. energy_tol: Drop notes below this energy. Returns: list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)] representing the note events, where amplitude is a number between 0 and 1 """ n_frames = note_roll.shape[0] # use onsets inferred from frames in addition to the predicted onsets if infer_onsets: onset_roll = get_inferred_onsets(onset_roll, note_roll) peak_thresh_mat = np.zeros(onset_roll.shape) peaks = scipy.signal.argrelmax(onset_roll, axis=0) peak_thresh_mat[peaks] = onset_roll[peaks] onset_idx = np.where(peak_thresh_mat >= onset_thresh) onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time remaining_energy = np.zeros(note_roll.shape) remaining_energy[:, :] = note_roll[:, :] # loop over onsets note_events = [] for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx): # if we're too close to the end of the audio, continue if note_start_idx >= n_frames - 1: continue # find time index at this frequency band where the frames drop below an energy threshold i = note_start_idx + 1 k = 0 # number of frames since energy dropped below threshold while i < n_frames - 1 and k < energy_tol: if remaining_energy[i, freq_idx] < frame_thresh: k += 1 else: k = 0 i += 1 i -= k # go back to frame above threshold # if the note is too short, skip it if i - note_start_idx <= min_note_len: continue remaining_energy[note_start_idx:i, freq_idx] = 0 if freq_idx < note_high: remaining_energy[note_start_idx:i, freq_idx + 1] = 0 if freq_idx > note_low: remaining_energy[note_start_idx:i, freq_idx - 1] = 0 # add the note amplitude = np.mean(note_roll[note_start_idx:i, freq_idx]) note_events.append( ( note_start_idx, i, freq_idx + note_low, amplitude, ) ) if melodia_trick: energy_shape = remaining_energy.shape while np.max(remaining_energy) > frame_thresh: i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape) remaining_energy[i_mid, freq_idx] = 0 # forward pass i = i_mid + 1 k = 0 while i < n_frames - 1 and k < energy_tol: if remaining_energy[i, freq_idx] < frame_thresh: k += 1 else: k = 0 remaining_energy[i, freq_idx] = 0 if freq_idx < note_high: remaining_energy[i, freq_idx + 1] = 0 if freq_idx > note_low: remaining_energy[i, freq_idx - 1] = 0 i += 1 i_end = i - 1 - k # go back to frame above threshold # backward pass i = i_mid - 1 k = 0 while i > 0 and k < energy_tol: if remaining_energy[i, freq_idx] < frame_thresh: k += 1 else: k = 0 remaining_energy[i, freq_idx] = 0 if freq_idx < note_high: remaining_energy[i, freq_idx + 1] = 0 if freq_idx > note_low: remaining_energy[i, freq_idx - 1] = 0 i -= 1 i_start = i + 1 + k # go back to frame above threshold assert i_start >= 0, "{}".format(i_start) assert i_end < n_frames if i_end - i_start <= min_note_len: # note is too short, skip it continue # add the note amplitude = np.mean(note_roll[i_start:i_end, freq_idx]) note_events.append( ( i_start, i_end, freq_idx + note_low, amplitude, ) ) return note_events # TIKTOK def note_detection_with_onset_offset_regress(frame_output, onset_output, onset_shift_output, offset_output, offset_shift_output, velocity_output, frame_threshold): """Process prediction matrices to note events information. First, detect onsets with onset outputs. Then, detect offsets with frame and offset outputs. Args: frame_output: (frames_num,) onset_output: (frames_num,) onset_shift_output: (frames_num,) offset_output: (frames_num,) offset_shift_output: (frames_num,) velocity_output: (frames_num,) frame_threshold: float Returns: output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity], e.g., [ [1821, 1909, 0.47498, 0.3048533, 0.72119445], [1909, 1947, 0.30730522, -0.45764327, 0.64200014], ...] """ output_tuples = [] bgn = None frame_disappear = None offset_occur = None for i in range(onset_output.shape[0]): if onset_output[i] == 1: """Onset detected""" if bgn: """Consecutive onsets. E.g., pedal is not released, but two consecutive notes being played.""" fin = max(i - 1, 0) output_tuples.append([bgn, fin, onset_shift_output[bgn], 0, velocity_output[bgn]]) frame_disappear, offset_occur = None, None bgn = i if bgn and i > bgn: """If onset found, then search offset""" if frame_output[i] <= frame_threshold and not frame_disappear: """Frame disappear detected""" frame_disappear = i if offset_output[i] == 1 and not offset_occur: """Offset detected""" offset_occur = i if frame_disappear: if offset_occur and offset_occur - bgn > frame_disappear - offset_occur: """bgn --------- offset_occur --- frame_disappear""" fin = offset_occur else: """bgn --- offset_occur --------- frame_disappear""" fin = frame_disappear output_tuples.append([bgn, fin, onset_shift_output[bgn], offset_shift_output[fin], velocity_output[bgn]]) bgn, frame_disappear, offset_occur = None, None, None if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1): """Offset not detected""" fin = i output_tuples.append([bgn, fin, onset_shift_output[bgn], offset_shift_output[fin], velocity_output[bgn]]) bgn, frame_disappear, offset_occur = None, None, None # Sort pairs by onsets output_tuples.sort(key=lambda pair: pair[0]) return output_tuples class RegressionPostProcessor(object): def __init__(self, frames_per_second, classes_num, onset_threshold, offset_threshold, frame_threshold, pedal_offset_threshold, begin_note): """Postprocess the output probabilities of a transription model to MIDI events. Args: frames_per_second: float classes_num: int onset_threshold: float offset_threshold: float frame_threshold: float pedal_offset_threshold: float """ self.frames_per_second = frames_per_second self.classes_num = classes_num self.onset_threshold = onset_threshold self.offset_threshold = offset_threshold self.frame_threshold = frame_threshold self.pedal_offset_threshold = pedal_offset_threshold self.begin_note = begin_note self.velocity_scale = 128 def output_dict_to_midi_events(self, output_dict): """Main function. Post process model outputs to MIDI events. Args: output_dict: { 'reg_onset_output': (segment_frames, classes_num), 'reg_offset_output': (segment_frames, classes_num), 'frame_output': (segment_frames, classes_num), 'velocity_output': (segment_frames, classes_num), 'reg_pedal_onset_output': (segment_frames, 1), 'reg_pedal_offset_output': (segment_frames, 1), 'pedal_frame_output': (segment_frames, 1)} Outputs: est_note_events: list of dict, e.g. [ {'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83}, {'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}] est_pedal_events: list of dict, e.g. [ {'onset_time': 0.17, 'offset_time': 0.96}, {'osnet_time': 1.17, 'offset_time': 2.65}] """ output_dict['frame_output'] = output_dict['note'] output_dict['velocity_output'] = output_dict['note'] output_dict['reg_onset_output'] = output_dict['onset'] output_dict['reg_offset_output'] = output_dict['offset'] # Post process piano note outputs to piano note and pedal events information (est_on_off_note_vels, est_pedal_on_offs) = \ self.output_dict_to_note_pedal_arrays(output_dict) """est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity], est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]""" # Reformat notes to MIDI events est_note_events = self.detected_notes_to_events(est_on_off_note_vels) if est_pedal_on_offs is None: est_pedal_events = None else: est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs) return est_note_events, est_pedal_events def output_dict_to_note_pedal_arrays(self, output_dict): """Postprocess the output probabilities of a transription model to MIDI events. Args: output_dict: dict, { 'reg_onset_output': (frames_num, classes_num), 'reg_offset_output': (frames_num, classes_num), 'frame_output': (frames_num, classes_num), 'velocity_output': (frames_num, classes_num), ...} Returns: est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time, offset_time, piano_note and velocity. E.g. [ [39.74, 39.87, 27, 0.65], [11.98, 12.11, 33, 0.69], ...] est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time and offset_time. E.g. [ [0.17, 0.96], [1.17, 2.65], ...] """ # ------ 1. Process regression outputs to binarized outputs ------ # For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.] # will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.] # Calculate binarized onset output from regression output (onset_output, onset_shift_output) = \ self.get_binarized_output_from_regression( reg_output=output_dict['reg_onset_output'], threshold=self.onset_threshold, neighbour=2) output_dict['onset_output'] = onset_output # Values are 0 or 1 output_dict['onset_shift_output'] = onset_shift_output # Calculate binarized offset output from regression output (offset_output, offset_shift_output) = \ self.get_binarized_output_from_regression( reg_output=output_dict['reg_offset_output'], threshold=self.offset_threshold, neighbour=4) output_dict['offset_output'] = offset_output # Values are 0 or 1 output_dict['offset_shift_output'] = offset_shift_output if 'reg_pedal_onset_output' in output_dict.keys(): """Pedal onsets are not used in inference. Instead, frame-wise pedal predictions are used to detect onsets. We empirically found this is more accurate to detect pedal onsets.""" pass if 'reg_pedal_offset_output' in output_dict.keys(): # Calculate binarized pedal offset output from regression output (pedal_offset_output, pedal_offset_shift_output) = \ self.get_binarized_output_from_regression( reg_output=output_dict['reg_pedal_offset_output'], threshold=self.pedal_offset_threshold, neighbour=4) output_dict['pedal_offset_output'] = pedal_offset_output # Values are 0 or 1 output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output # ------ 2. Process matrices results to event results ------ # Detect piano notes from output_dict est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict) est_pedal_on_offs = None return est_on_off_note_vels, est_pedal_on_offs def get_binarized_output_from_regression(self, reg_output, threshold, neighbour): """Calculate binarized output and shifts of onsets or offsets from the regression results. Args: reg_output: (frames_num, classes_num) threshold: float neighbour: int Returns: binary_output: (frames_num, classes_num) shift_output: (frames_num, classes_num) """ binary_output = np.zeros_like(reg_output) shift_output = np.zeros_like(reg_output) (frames_num, classes_num) = reg_output.shape for k in range(classes_num): x = reg_output[:, k] for n in range(neighbour, frames_num - neighbour): if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour): binary_output[n, k] = 1 """See Section III-D in [1] for deduction. [1] Q. Kong, et al., High-resolution Piano Transcription with Pedals by Regressing Onsets and Offsets Times, 2020.""" if x[n - 1] > x[n + 1]: shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2 else: shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2 shift_output[n, k] = shift return binary_output, shift_output def is_monotonic_neighbour(self, x, n, neighbour): """Detect if values are monotonic in both side of x[n]. Args: x: (frames_num,) n: int neighbour: int Returns: monotonic: bool """ monotonic = True for i in range(neighbour): if x[n - i] < x[n - i - 1]: monotonic = False if x[n + i] < x[n + i + 1]: monotonic = False return monotonic def output_dict_to_detected_notes(self, output_dict): """Postprocess output_dict to piano notes. Args: output_dict: dict, e.g. { 'onset_output': (frames_num, classes_num), 'onset_shift_output': (frames_num, classes_num), 'offset_output': (frames_num, classes_num), 'offset_shift_output': (frames_num, classes_num), 'frame_output': (frames_num, classes_num), 'onset_output': (frames_num, classes_num), ...} Returns: est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets, MIDI notes and velocities. E.g., [[39.7375, 39.7500, 27., 0.6638], [11.9824, 12.5000, 33., 0.6892], ...] """ est_tuples = [] est_midi_notes = [] classes_num = output_dict['frame_output'].shape[-1] for piano_note in range(classes_num): """Detect piano notes""" est_tuples_per_note = note_detection_with_onset_offset_regress( frame_output=output_dict['frame_output'][:, piano_note], onset_output=output_dict['onset_output'][:, piano_note], onset_shift_output=output_dict['onset_shift_output'][:, piano_note], offset_output=output_dict['offset_output'][:, piano_note], offset_shift_output=output_dict['offset_shift_output'][:, piano_note], velocity_output=output_dict['velocity_output'][:, piano_note], frame_threshold=self.frame_threshold) est_tuples += est_tuples_per_note est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note) est_tuples = np.array(est_tuples) # (notes, 5) """(notes, 5), the five columns are onset, offset, onset_shift, offset_shift and normalized_velocity""" est_midi_notes = np.array(est_midi_notes) # (notes,) onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second velocities = est_tuples[:, 4] est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1) """(notes, 3), the three columns are onset_times, offset_times and velocity.""" est_on_off_note_vels = est_on_off_note_vels.astype(np.float32) return est_on_off_note_vels def detected_notes_to_events(self, est_on_off_note_vels): """Reformat detected notes to midi events. Args: est_on_off_vels: (notes, 3), the three columns are onset_times, offset_times and velocity. E.g. [[32.8376, 35.7700, 0.7932], [37.3712, 39.9300, 0.8058], ...] Returns: midi_events, list, e.g., [{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84}, {'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88}, ...] """ midi_events = [] for i in range(est_on_off_note_vels.shape[0]): midi_events.append({ 'onset_time': est_on_off_note_vels[i][0], 'offset_time': est_on_off_note_vels[i][1], 'midi_note': int(est_on_off_note_vels[i][2]), 'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)}) return midi_events