Upload 3 files
Browse files
TCUPY.py
ADDED
|
@@ -0,0 +1,1093 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/python3
|
| 2 |
+
|
| 3 |
+
r'''############################################################################
|
| 4 |
+
################################################################################
|
| 5 |
+
#
|
| 6 |
+
#
|
| 7 |
+
# Tegridy Cupy Python Module (TCUPY)
|
| 8 |
+
# Version 1.0
|
| 9 |
+
#
|
| 10 |
+
# Project Los Angeles
|
| 11 |
+
#
|
| 12 |
+
# Tegridy Code 2025
|
| 13 |
+
#
|
| 14 |
+
# https://github.com/asigalov61/tegridy-tools
|
| 15 |
+
#
|
| 16 |
+
#
|
| 17 |
+
################################################################################
|
| 18 |
+
#
|
| 19 |
+
# Copyright 2024 Project Los Angeles / Tegridy Code
|
| 20 |
+
#
|
| 21 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 22 |
+
# you may not use this file except in compliance with the License.
|
| 23 |
+
# You may obtain a copy of the License at
|
| 24 |
+
#
|
| 25 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 26 |
+
#
|
| 27 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 28 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 29 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 30 |
+
# See the License for the specific language governing permissions and
|
| 31 |
+
# limitations under the License.
|
| 32 |
+
#
|
| 33 |
+
################################################################################
|
| 34 |
+
################################################################################
|
| 35 |
+
#
|
| 36 |
+
# Critical dependencies
|
| 37 |
+
#
|
| 38 |
+
# !pip install cupy-cuda12x
|
| 39 |
+
# !pip install numpy==1.24.4
|
| 40 |
+
#
|
| 41 |
+
################################################################################
|
| 42 |
+
'''
|
| 43 |
+
|
| 44 |
+
################################################################################
|
| 45 |
+
|
| 46 |
+
print('=' * 70)
|
| 47 |
+
print('Loading module...')
|
| 48 |
+
print('Please wait...')
|
| 49 |
+
print('=' * 70)
|
| 50 |
+
|
| 51 |
+
################################################################################
|
| 52 |
+
|
| 53 |
+
import sys
|
| 54 |
+
import os
|
| 55 |
+
|
| 56 |
+
################################################################################
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
import cupy as cp
|
| 60 |
+
import cupy as np
|
| 61 |
+
print('=' * 70)
|
| 62 |
+
print('CuPy is found!')
|
| 63 |
+
print('Will use CuPy and GPU for processing!')
|
| 64 |
+
print('=' * 70)
|
| 65 |
+
|
| 66 |
+
except ImportError as e:
|
| 67 |
+
print(f"Error: Could not import CuPy. Details: {e}")
|
| 68 |
+
# Handle the error, such as providing a fallback or exiting the program
|
| 69 |
+
# For example:
|
| 70 |
+
print("Please make sure CuPy is installed.")
|
| 71 |
+
print('=' * 70)
|
| 72 |
+
|
| 73 |
+
raise RuntimeError("CuPy could not be loaded!") from e
|
| 74 |
+
|
| 75 |
+
################################################################################
|
| 76 |
+
|
| 77 |
+
from collections import defaultdict, deque
|
| 78 |
+
from typing import Optional, Tuple, Dict, Any, List
|
| 79 |
+
|
| 80 |
+
################################################################################
|
| 81 |
+
|
| 82 |
+
# Constants
|
| 83 |
+
MEMORY_LEN = 12 # Autoregressive context length
|
| 84 |
+
SEQUENCE_LENGTH = 32 # Each sequence has 24 triplets
|
| 85 |
+
|
| 86 |
+
# Baseline penalty values:
|
| 87 |
+
REPETITION_PENALTY = (1.0, 1.0, 1.0) # base repetition penalty per element
|
| 88 |
+
SPIKE_PENALTY_STRENGTH = (1.0, 1.0, 1.0) # base spike penalty strength per element
|
| 89 |
+
SPIKE_SIGMA = (1.0, 1.0, 1.0) # baseline sigma value per element (minimum allowed)
|
| 90 |
+
|
| 91 |
+
###################################################################################
|
| 92 |
+
|
| 93 |
+
def find_numpy_array(src_array, trg_array):
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
Finds 1D numpy array in 2D numpy array
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
match_mask = np.all(src_array == trg_array, axis=1)
|
| 100 |
+
|
| 101 |
+
return np.where(match_mask)[0]
|
| 102 |
+
|
| 103 |
+
###################################################################################
|
| 104 |
+
|
| 105 |
+
def vertical_list_search(src_list, trg_list):
|
| 106 |
+
|
| 107 |
+
"""
|
| 108 |
+
For each vertical window of consecutive rows of height len(trg_list) in src_list,
|
| 109 |
+
this function checks whether for every offset j (0 <= j < len(trg_list)) the row
|
| 110 |
+
at index (window_start + j) contains trg_list[j].
|
| 111 |
+
|
| 112 |
+
It returns a list of windows (each a list of consecutive row indices) that meet this condition.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
if not src_list or not trg_list:
|
| 116 |
+
return []
|
| 117 |
+
|
| 118 |
+
n = len(src_list)
|
| 119 |
+
k = len(trg_list)
|
| 120 |
+
|
| 121 |
+
num_windows = n - k + 1
|
| 122 |
+
|
| 123 |
+
if num_windows <= 0:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
# Determine the maximum row length.
|
| 127 |
+
max_len = max(len(row) for row in src_list)
|
| 128 |
+
|
| 129 |
+
# Determine a fill value guaranteed to be less than any valid value.
|
| 130 |
+
global_min = min(min(row) for row in src_list if row)
|
| 131 |
+
fill_value = global_min - 1
|
| 132 |
+
|
| 133 |
+
# Build a padded 2D array A (shape n x max_len) from src_list.
|
| 134 |
+
A = np.full((n, max_len), fill_value, dtype=np.int64)
|
| 135 |
+
for i, row in enumerate(src_list):
|
| 136 |
+
L = len(row)
|
| 137 |
+
A[i, :L] = row
|
| 138 |
+
|
| 139 |
+
# For each unique target in trg_list, compute a Boolean vector of length n.
|
| 140 |
+
# present[t][i] will be True if A[i, :] contains t, else False.
|
| 141 |
+
unique_targets = set(trg_list)
|
| 142 |
+
|
| 143 |
+
present_dict = {}
|
| 144 |
+
|
| 145 |
+
for t in unique_targets:
|
| 146 |
+
# Compute along axis=1 so that for each row we see if any element equals t.
|
| 147 |
+
present_dict[t] = np.any(A == t, axis=1)
|
| 148 |
+
|
| 149 |
+
# Build a Boolean array B of shape (k, num_windows) where for each offset j,
|
| 150 |
+
# B[j, s] = present_dict[ trg_list[j] ][s + j] for each window starting index s.
|
| 151 |
+
B = np.empty((k, num_windows), dtype=bool)
|
| 152 |
+
|
| 153 |
+
for j in range(k):
|
| 154 |
+
t = trg_list[j]
|
| 155 |
+
# For a vertical window starting at s, row s+j should contain t.
|
| 156 |
+
B[j, :] = present_dict[t][j: j + num_windows]
|
| 157 |
+
|
| 158 |
+
# A window is valid if all k rows in that window contain the required target.
|
| 159 |
+
valid_windows_mask = np.all(B, axis=0)
|
| 160 |
+
valid_starts = np.nonzero(valid_windows_mask)[0]
|
| 161 |
+
|
| 162 |
+
# Create output windows (each as a list of consecutive row indices).
|
| 163 |
+
result = [list(range(s, s + k)) for s in valid_starts]
|
| 164 |
+
|
| 165 |
+
return result
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
###################################################################################
|
| 169 |
+
|
| 170 |
+
def pack_sequences(train_data, pad_val=-1):
|
| 171 |
+
"""
|
| 172 |
+
Packs a list of variable-length token sequences into a 2D CuPy array.
|
| 173 |
+
|
| 174 |
+
This version computes lengths and builds the padded array and mask entirely on GPU.
|
| 175 |
+
It converts each sequence into a CuPy array, concatenates them, and assigns tokens in one shot.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
batch: a CuPy array of shape (n, max_len)
|
| 179 |
+
lengths: a CuPy array of shape (n,) containing each sequence's length.
|
| 180 |
+
"""
|
| 181 |
+
n = len(train_data)
|
| 182 |
+
# Compute lengths of each sequence and convert to a CuPy array.
|
| 183 |
+
lengths = cp.array([len(seq) for seq in train_data], dtype=cp.int64)
|
| 184 |
+
max_len_val = int(cp.max(lengths).get())
|
| 185 |
+
# Allocate the padded 2D array filled with pad_val.
|
| 186 |
+
batch = cp.full((n, max_len_val), pad_val, dtype=cp.int64)
|
| 187 |
+
# Create a boolean mask: for each row, positions less than the sequence length are valid.
|
| 188 |
+
mask = cp.arange(max_len_val).reshape(1, max_len_val) < lengths.reshape(n, 1)
|
| 189 |
+
# Convert each sequence to a CuPy array and concatenate them.
|
| 190 |
+
sequences = [cp.array(seq, dtype=cp.int64) for seq in train_data]
|
| 191 |
+
flat = cp.concatenate(sequences)
|
| 192 |
+
# Fill in the valid positions.
|
| 193 |
+
batch[mask] = flat
|
| 194 |
+
return batch, lengths
|
| 195 |
+
|
| 196 |
+
###################################################################################
|
| 197 |
+
|
| 198 |
+
def count_best_pair_gpu(batch, lengths, factor, pad_val=-1):
|
| 199 |
+
"""
|
| 200 |
+
Given the entire GPU-resident packed data, compute the most frequent
|
| 201 |
+
adjacent pair (encoded as: pair_val = first * factor + second) on GPU.
|
| 202 |
+
"""
|
| 203 |
+
n, L = batch.shape
|
| 204 |
+
cols = cp.arange(L - 1, dtype=cp.int64)
|
| 205 |
+
cols_expanded = cp.broadcast_to(cols, (n, L - 1))
|
| 206 |
+
valid_mask = cols_expanded < cp.reshape(lengths, (n, 1)) - 1
|
| 207 |
+
|
| 208 |
+
first_tokens = batch[:, :L - 1]
|
| 209 |
+
second_tokens = batch[:, 1:L]
|
| 210 |
+
valid_first = first_tokens[valid_mask]
|
| 211 |
+
valid_second = second_tokens[valid_mask]
|
| 212 |
+
|
| 213 |
+
pairs = valid_first * factor + valid_second
|
| 214 |
+
if pairs.size == 0:
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
+
sorted_pairs = cp.sort(pairs)
|
| 218 |
+
diff = cp.diff(sorted_pairs)
|
| 219 |
+
boundaries = cp.nonzero(diff)[0] + 1
|
| 220 |
+
group_starts = cp.concatenate([cp.array([0], dtype=cp.int64), boundaries])
|
| 221 |
+
group_ends = cp.concatenate([boundaries, cp.array([sorted_pairs.size], dtype=cp.int64)])
|
| 222 |
+
group_counts = group_ends - group_starts
|
| 223 |
+
|
| 224 |
+
max_idx = int(cp.argmax(group_counts))
|
| 225 |
+
best_pair_enc = int(sorted_pairs[group_starts[max_idx]])
|
| 226 |
+
best_freq = int(group_counts[max_idx])
|
| 227 |
+
first = best_pair_enc // factor
|
| 228 |
+
second = best_pair_enc % factor
|
| 229 |
+
return (first, second, best_freq)
|
| 230 |
+
|
| 231 |
+
###################################################################################
|
| 232 |
+
|
| 233 |
+
merge_kernel_code = r'''
|
| 234 |
+
extern "C" __global__
|
| 235 |
+
void merge_pair_kernel(const long* input, long* output,
|
| 236 |
+
const long* input_lengths, long* output_lengths,
|
| 237 |
+
const long num_rows, const long num_cols,
|
| 238 |
+
const long a, const long b, const long new_token,
|
| 239 |
+
const long pad_val) {
|
| 240 |
+
int row = blockIdx.x * blockDim.x + threadIdx.x;
|
| 241 |
+
if (row >= num_rows) return;
|
| 242 |
+
long in_length = input_lengths[row];
|
| 243 |
+
long out_idx = 0;
|
| 244 |
+
bool skip_next = false;
|
| 245 |
+
for (long i = 0; i < in_length; i++) {
|
| 246 |
+
if (skip_next) {
|
| 247 |
+
skip_next = false;
|
| 248 |
+
continue;
|
| 249 |
+
}
|
| 250 |
+
long token = input[row * num_cols + i];
|
| 251 |
+
if (i < in_length - 1 && token == a && input[row * num_cols + i + 1] == b) {
|
| 252 |
+
output[row * num_cols + out_idx] = new_token;
|
| 253 |
+
out_idx++;
|
| 254 |
+
skip_next = true;
|
| 255 |
+
} else {
|
| 256 |
+
output[row * num_cols + out_idx] = token;
|
| 257 |
+
out_idx++;
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
output_lengths[row] = out_idx;
|
| 261 |
+
for (long j = out_idx; j < num_cols; j++) {
|
| 262 |
+
output[row * num_cols + j] = pad_val;
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
'''
|
| 266 |
+
merge_kernel = cp.RawKernel(merge_kernel_code, 'merge_pair_kernel')
|
| 267 |
+
|
| 268 |
+
###################################################################################
|
| 269 |
+
|
| 270 |
+
def learn_bpe_codes_gpu(train_data, vocab_size=4096, max_merges=None, pad_val=-1):
|
| 271 |
+
"""
|
| 272 |
+
Learn BPE merge rules completely on GPU.
|
| 273 |
+
|
| 274 |
+
The training data is packed once (using the vectorized pack_sequences).
|
| 275 |
+
On each merge iteration, the best adjacent pair is computed on GPU and then merged
|
| 276 |
+
into a new token via a custom merge kernel (with double-buffering).
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
codes: a list of merge rules as ((first, second), new_token)
|
| 280 |
+
final_data: the merged training data (list of sequences)
|
| 281 |
+
"""
|
| 282 |
+
# Pack the entire dataset onto GPU.
|
| 283 |
+
batch, lengths = pack_sequences(train_data, pad_val)
|
| 284 |
+
n, L = batch.shape
|
| 285 |
+
|
| 286 |
+
# Initialize vocabulary and the next available token.
|
| 287 |
+
initial_vocab = {token for seq in train_data for token in seq}
|
| 288 |
+
next_token = max(initial_vocab) + 1
|
| 289 |
+
codes = []
|
| 290 |
+
merge_count = 0
|
| 291 |
+
pbar = tqdm.tqdm(total=max_merges if max_merges is not None else None,
|
| 292 |
+
desc="Learning BPE Codes (GPU)", leave=True)
|
| 293 |
+
|
| 294 |
+
# Preallocate buffers for double-buffering.
|
| 295 |
+
work_batch = cp.empty_like(batch)
|
| 296 |
+
work_lengths = cp.empty_like(lengths)
|
| 297 |
+
input_batch = batch
|
| 298 |
+
input_lengths = lengths
|
| 299 |
+
|
| 300 |
+
threads_per_block = 128
|
| 301 |
+
blocks = (n + threads_per_block - 1) // threads_per_block
|
| 302 |
+
|
| 303 |
+
while next_token < vocab_size and (max_merges is None or merge_count < max_merges):
|
| 304 |
+
# Early stop if all sequences have collapsed (checked on GPU).
|
| 305 |
+
if bool(cp.all(input_lengths == 1)):
|
| 306 |
+
pbar.write("All sequences have collapsed; stopping early.")
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
factor = next_token # by construction, every token is < next_token
|
| 310 |
+
best = count_best_pair_gpu(input_batch, input_lengths, factor, pad_val)
|
| 311 |
+
if best is None:
|
| 312 |
+
pbar.write("No mergeable pairs found; stopping early.")
|
| 313 |
+
break
|
| 314 |
+
|
| 315 |
+
best_pair = (best[0], best[1])
|
| 316 |
+
best_freq = best[2]
|
| 317 |
+
if best_freq < 2:
|
| 318 |
+
pbar.write("Best pair frequency is less than 2; stopping early.")
|
| 319 |
+
break
|
| 320 |
+
|
| 321 |
+
codes.append((best_pair, next_token))
|
| 322 |
+
|
| 323 |
+
# Launch the merge kernel.
|
| 324 |
+
merge_kernel((blocks,), (threads_per_block,),
|
| 325 |
+
(input_batch,
|
| 326 |
+
work_batch,
|
| 327 |
+
input_lengths,
|
| 328 |
+
work_lengths,
|
| 329 |
+
cp.int64(n),
|
| 330 |
+
cp.int64(L),
|
| 331 |
+
cp.int64(best_pair[0]),
|
| 332 |
+
cp.int64(best_pair[1]),
|
| 333 |
+
cp.int64(next_token),
|
| 334 |
+
cp.int64(pad_val)))
|
| 335 |
+
# Swap buffers for double-buffering.
|
| 336 |
+
input_batch, work_batch = work_batch, input_batch
|
| 337 |
+
input_lengths, work_lengths = work_lengths, input_lengths
|
| 338 |
+
|
| 339 |
+
next_token += 1
|
| 340 |
+
merge_count += 1
|
| 341 |
+
pbar.update(1)
|
| 342 |
+
pbar.close()
|
| 343 |
+
|
| 344 |
+
final_batch = cp.asnumpy(input_batch)
|
| 345 |
+
final_lengths = cp.asnumpy(input_lengths)
|
| 346 |
+
final_data = [final_batch[i, :final_lengths[i]].tolist() for i in range(n)]
|
| 347 |
+
return codes, final_data
|
| 348 |
+
|
| 349 |
+
###################################################################################
|
| 350 |
+
|
| 351 |
+
fused_merge_kernel_code = r'''
|
| 352 |
+
extern "C" __global__
|
| 353 |
+
void fused_merge_kernel(long* data_in, long* data_out, long* lengths, const long pad_val,
|
| 354 |
+
const long num_rows, const long max_len, const long num_merges, const long* merge_rules) {
|
| 355 |
+
int row = blockIdx.x * blockDim.x + threadIdx.x;
|
| 356 |
+
if (row >= num_rows) return;
|
| 357 |
+
long base = row * max_len;
|
| 358 |
+
long cur_len = lengths[row];
|
| 359 |
+
long* cur = data_in + base;
|
| 360 |
+
long* other = data_out + base;
|
| 361 |
+
// Process each merge rule sequentially.
|
| 362 |
+
for (int m = 0; m < num_merges; m++) {
|
| 363 |
+
long a = merge_rules[3 * m];
|
| 364 |
+
long b = merge_rules[3 * m + 1];
|
| 365 |
+
long new_token = merge_rules[3 * m + 2];
|
| 366 |
+
long out_idx = 0;
|
| 367 |
+
for (int i = 0; i < cur_len; i++) {
|
| 368 |
+
if (i < cur_len - 1 && cur[i] == a && cur[i+1] == b) {
|
| 369 |
+
other[out_idx] = new_token;
|
| 370 |
+
out_idx++;
|
| 371 |
+
i++; // Skip the next token.
|
| 372 |
+
} else {
|
| 373 |
+
other[out_idx] = cur[i];
|
| 374 |
+
out_idx++;
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
cur_len = out_idx;
|
| 378 |
+
// Swap pointers for the next merge.
|
| 379 |
+
long* temp = cur;
|
| 380 |
+
cur = other;
|
| 381 |
+
other = temp;
|
| 382 |
+
}
|
| 383 |
+
lengths[row] = cur_len;
|
| 384 |
+
// Pad the remaining positions with pad_val.
|
| 385 |
+
for (int i = cur_len; i < max_len; i++) {
|
| 386 |
+
cur[i] = pad_val;
|
| 387 |
+
}
|
| 388 |
+
// If the final result is not in data_in, copy back.
|
| 389 |
+
if (cur != data_in + base) {
|
| 390 |
+
for (int i = 0; i < cur_len; i++) {
|
| 391 |
+
data_in[base + i] = cur[i];
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
'''
|
| 396 |
+
fused_kernel = cp.RawKernel(fused_merge_kernel_code, 'fused_merge_kernel')
|
| 397 |
+
|
| 398 |
+
###################################################################################
|
| 399 |
+
|
| 400 |
+
def retokenize_train_data_fused_gpu(train_data, codes, pad_val=-1):
|
| 401 |
+
"""
|
| 402 |
+
Retokenize training data using the fully fused GPU kernel.
|
| 403 |
+
|
| 404 |
+
The entire training dataset is first packed into GPU memory (using pack_sequences).
|
| 405 |
+
All learned merge rules (provided in 'codes') are applied via a single kernel launch.
|
| 406 |
+
Each GPU thread processes one sequence by applying all merge rules sequentially.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
tokenized_data: list of retokenized sequences.
|
| 410 |
+
"""
|
| 411 |
+
# Pack the data.
|
| 412 |
+
batch, lengths = pack_sequences(train_data, pad_val)
|
| 413 |
+
n, max_len = batch.shape
|
| 414 |
+
# Build a flattened merge_rules array using CuPy.
|
| 415 |
+
if len(codes) > 0:
|
| 416 |
+
merge_rules_list = [[rule[0][0], rule[0][1], rule[1]] for rule in codes]
|
| 417 |
+
merge_rules_gpu = cp.array(merge_rules_list, dtype=cp.int64)
|
| 418 |
+
merge_rules_gpu = merge_rules_gpu.reshape(-1)
|
| 419 |
+
else:
|
| 420 |
+
merge_rules_gpu = cp.empty((0,), dtype=cp.int64)
|
| 421 |
+
num_merges = merge_rules_gpu.shape[0] // 3
|
| 422 |
+
# Preallocate a scratch buffer.
|
| 423 |
+
scratch = cp.empty_like(batch)
|
| 424 |
+
threads_per_block = 128
|
| 425 |
+
blocks = (n + threads_per_block - 1) // threads_per_block
|
| 426 |
+
# Launch the fused kernel.
|
| 427 |
+
fused_kernel((blocks,), (threads_per_block,),
|
| 428 |
+
(batch, scratch, lengths, cp.int64(pad_val),
|
| 429 |
+
cp.int64(n), cp.int64(max_len), cp.int64(num_merges), merge_rules_gpu))
|
| 430 |
+
final_batch = cp.asnumpy(batch)
|
| 431 |
+
final_lengths = cp.asnumpy(lengths)
|
| 432 |
+
tokenized_data = [final_batch[i, :final_lengths[i]].tolist() for i in range(n)]
|
| 433 |
+
return tokenized_data
|
| 434 |
+
|
| 435 |
+
###################################################################################
|
| 436 |
+
|
| 437 |
+
def bpe_encode(seq, codes):
|
| 438 |
+
"""
|
| 439 |
+
Iteratively encodes a sequence using BPE merge rules provided in a dictionary.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
seq (list): A list of tokens (e.g. integers) representing the input sequence.
|
| 443 |
+
codes (dict): A dictionary mapping token pairs (a tuple of two tokens)
|
| 444 |
+
to a merged token. For example:
|
| 445 |
+
{ (1, 2): 100, (100, 3): 101 }
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
list: The encoded sequence after applying all possible merges.
|
| 449 |
+
|
| 450 |
+
The function repeatedly scans the entire sequence from left to right;
|
| 451 |
+
whenever it finds a contiguous token pair that exists as a key in the
|
| 452 |
+
codes dict, it replaces that pair with the merged token. This pass is
|
| 453 |
+
repeated until no more merges are possible.
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
if type(codes) == list:
|
| 457 |
+
codes = dict(codes)
|
| 458 |
+
|
| 459 |
+
encoded_seq = seq.copy() # work on a copy so as not to modify the original
|
| 460 |
+
done = False
|
| 461 |
+
while not done:
|
| 462 |
+
new_seq = []
|
| 463 |
+
i = 0
|
| 464 |
+
changed = False
|
| 465 |
+
while i < len(encoded_seq):
|
| 466 |
+
# If a merge is possible, merge the two tokens.
|
| 467 |
+
if i < len(encoded_seq) - 1 and (encoded_seq[i], encoded_seq[i + 1]) in codes:
|
| 468 |
+
new_seq.append(codes[(encoded_seq[i], encoded_seq[i + 1])])
|
| 469 |
+
i += 2 # Skip the next token as it was merged.
|
| 470 |
+
changed = True
|
| 471 |
+
else:
|
| 472 |
+
new_seq.append(encoded_seq[i])
|
| 473 |
+
i += 1
|
| 474 |
+
# If no merges occurred in this pass, exit the loop.
|
| 475 |
+
if not changed:
|
| 476 |
+
done = True
|
| 477 |
+
encoded_seq = new_seq
|
| 478 |
+
return encoded_seq
|
| 479 |
+
|
| 480 |
+
###################################################################################
|
| 481 |
+
|
| 482 |
+
def bpe_decode(seq, codes):
|
| 483 |
+
"""
|
| 484 |
+
Decodes a sequence encoded with BPE merge rules defined in a codes dictionary.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
seq (list): The encoded sequence (a list of tokens).
|
| 488 |
+
codes (dict): A dictionary mapping token pairs to the merged token, used during encoding.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
list: The fully decoded sequence, with all merged tokens recursively expanded.
|
| 492 |
+
|
| 493 |
+
The function constructs a reverse mapping that converts a merged token back into
|
| 494 |
+
its constituent pair. Each token in the sequence is then recursively expanded.
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
if type(codes) == list:
|
| 498 |
+
codes = dict(codes)
|
| 499 |
+
|
| 500 |
+
# Build the reverse mapping: key = merged token, value = tuple (original token pair)
|
| 501 |
+
reverse_mapping = {merged: pair for pair, merged in codes.items()}
|
| 502 |
+
|
| 503 |
+
def recursive_expand(token):
|
| 504 |
+
# If the token is a merged token, expand it recursively.
|
| 505 |
+
if token in reverse_mapping:
|
| 506 |
+
a, b = reverse_mapping[token]
|
| 507 |
+
return recursive_expand(a) + recursive_expand(b)
|
| 508 |
+
else:
|
| 509 |
+
return [token]
|
| 510 |
+
|
| 511 |
+
decoded_seq = []
|
| 512 |
+
for token in seq:
|
| 513 |
+
decoded_seq.extend(recursive_expand(token))
|
| 514 |
+
return decoded_seq
|
| 515 |
+
|
| 516 |
+
###################################################################################
|
| 517 |
+
|
| 518 |
+
def ensure_triplet(val: Any, name: str = "") -> Tuple[float, float, float]:
|
| 519 |
+
"""
|
| 520 |
+
Ensure the given parameter is returned as a triplet.
|
| 521 |
+
If provided as a scalar, promote it to a triplet.
|
| 522 |
+
"""
|
| 523 |
+
if np.isscalar(val):
|
| 524 |
+
return (float(val), float(val), float(val))
|
| 525 |
+
elif isinstance(val, (list, tuple)) and len(val) == 3:
|
| 526 |
+
return tuple(float(x) for x in val)
|
| 527 |
+
else:
|
| 528 |
+
raise ValueError(f"{name} must be a scalar or a sequence of 3 numbers.")
|
| 529 |
+
|
| 530 |
+
###################################################################################
|
| 531 |
+
|
| 532 |
+
REP_PENALTY = ensure_triplet(REPETITION_PENALTY, "REPETITION_PENALTY")
|
| 533 |
+
SPIKE_STRENGTH = ensure_triplet(SPIKE_PENALTY_STRENGTH, "SPIKE_PENALTY_STRENGTH")
|
| 534 |
+
SPIKE_SIG = ensure_triplet(SPIKE_SIGMA, "SPIKE_SIGMA")
|
| 535 |
+
|
| 536 |
+
###################################################################################
|
| 537 |
+
|
| 538 |
+
def sliding_window_view_alternative(a: np.ndarray, window_length: int) -> np.ndarray:
|
| 539 |
+
"""
|
| 540 |
+
Create a sliding-window view (without copying) of an array.
|
| 541 |
+
Expected input shape: (n, L, d) and returns: (n, L - window_length + 1, window_length, d)
|
| 542 |
+
"""
|
| 543 |
+
n, L, d = a.shape
|
| 544 |
+
new_shape = (n, L - window_length + 1, window_length, d)
|
| 545 |
+
new_strides = (a.strides[0], a.strides[1], a.strides[1], a.strides[2])
|
| 546 |
+
return np.lib.stride_tricks.as_strided(a, shape=new_shape, strides=new_strides)
|
| 547 |
+
|
| 548 |
+
###################################################################################
|
| 549 |
+
|
| 550 |
+
def build_ngram_mapping(data: np.ndarray, memory_len: int) -> Dict[Any, Dict[Any, int]]:
|
| 551 |
+
"""
|
| 552 |
+
Build an n-gram mapping from a context (a sequence of triplets) to candidate triplets with frequencies.
|
| 553 |
+
"""
|
| 554 |
+
n, L, d = data.shape
|
| 555 |
+
window_length = memory_len + 1 # context (memory) + candidate
|
| 556 |
+
windows = sliding_window_view_alternative(data, window_length)
|
| 557 |
+
# windows shape: (n, L - window_length + 1, window_length, d)
|
| 558 |
+
|
| 559 |
+
# Split windows into context (first memory_len triplets) and candidates (last triplet)
|
| 560 |
+
contexts = windows[:, :, :memory_len, :] # shape: (n, num_windows, memory_len, d)
|
| 561 |
+
candidates = windows[:, :, memory_len, :] # shape: (n, num_windows, d)
|
| 562 |
+
|
| 563 |
+
# Flatten the batch and window dimensions.
|
| 564 |
+
contexts_flat = contexts.reshape(-1, memory_len, d)
|
| 565 |
+
candidates_flat = candidates.reshape(-1, d)
|
| 566 |
+
|
| 567 |
+
mapping = defaultdict(lambda: defaultdict(int))
|
| 568 |
+
total_windows = contexts_flat.shape[0]
|
| 569 |
+
for context_arr, candidate_arr in tqdm.tqdm(
|
| 570 |
+
zip(contexts_flat, candidates_flat),
|
| 571 |
+
total=total_windows,
|
| 572 |
+
desc="Building n-gram mapping"):
|
| 573 |
+
context_key = tuple(map(tuple, context_arr)) # use a tuple of triplets as the key
|
| 574 |
+
candidate_val = tuple(candidate_arr)
|
| 575 |
+
mapping[context_key][candidate_val] += 1
|
| 576 |
+
|
| 577 |
+
return {context: dict(candidates) for context, candidates in mapping.items()}
|
| 578 |
+
|
| 579 |
+
###################################################################################
|
| 580 |
+
|
| 581 |
+
def precompute_mapping_lookup(mapping: Dict[Any, Dict[Any, int]]) -> Dict[Any, Tuple[Tuple[Any, ...], np.ndarray]]:
|
| 582 |
+
"""
|
| 583 |
+
Converts the mapping into a lookup table: context -> (tuple(candidates), frequencies_array).
|
| 584 |
+
"""
|
| 585 |
+
mapping_lookup = {}
|
| 586 |
+
for context, candidate_dict in tqdm.tqdm(mapping.items(), desc="Precomputing lookup"):
|
| 587 |
+
candidates = tuple(candidate_dict.keys())
|
| 588 |
+
frequencies = np.array(list(candidate_dict.values()), dtype=np.float64)
|
| 589 |
+
mapping_lookup[context] = (candidates, frequencies)
|
| 590 |
+
return mapping_lookup
|
| 591 |
+
|
| 592 |
+
###################################################################################
|
| 593 |
+
|
| 594 |
+
def build_training_sequences_set(data: np.ndarray) -> set:
|
| 595 |
+
"""
|
| 596 |
+
Build a set of training sequences (each as a tuple of triplets) for uniqueness checking.
|
| 597 |
+
"""
|
| 598 |
+
return {tuple(map(tuple, seq)) for seq in data}
|
| 599 |
+
|
| 600 |
+
###################################################################################
|
| 601 |
+
|
| 602 |
+
def generate_sequence_optimized(mapping_lookup: Dict[Any, Tuple[Tuple[Any, ...], np.ndarray]],
|
| 603 |
+
training_set: set,
|
| 604 |
+
memory_len: int,
|
| 605 |
+
sequence_length: int = 24,
|
| 606 |
+
max_attempts: int = 1000) -> Optional[Tuple[Tuple[float, float, float], ...]]:
|
| 607 |
+
"""
|
| 608 |
+
Autoregressively generate a new, unique sequence using the precomputed mapping lookup.
|
| 609 |
+
The invariant maintained is: the second element of one triplet is never greater than the first element
|
| 610 |
+
of the following triplet.
|
| 611 |
+
|
| 612 |
+
Two dynamic adjustments are applied for candidate selection:
|
| 613 |
+
|
| 614 |
+
1. **Dynamic Repetition Penalty:**
|
| 615 |
+
For each candidate, count the occurrences of each element in the generated sequence.
|
| 616 |
+
Rather than a fixed penalty, this repetition penalty scales with the ratio
|
| 617 |
+
(current_length / sequence_length). In log-space, it subtracts:
|
| 618 |
+
(current_length / sequence_length) * sum_k(count[k] * log(REP_PENALTY[k])
|
| 619 |
+
2. **Dynamic Spike (Variance) Penalty:**
|
| 620 |
+
For each candidate, compute the squared difference from the running average for each element.
|
| 621 |
+
Use a dynamic sigma that is the maximum between the running standard deviation and the baseline.
|
| 622 |
+
The penalty term for each element is:
|
| 623 |
+
SPIKE_STRENGTH[k] * ((cand[k] - running_avg[k])^2) / (2 * dynamic_sigma[k]^2)
|
| 624 |
+
The overall spike penalty is the sum of the three terms and is subtracted from the candidate’s log frequency.
|
| 625 |
+
|
| 626 |
+
The resulting candidate log score is computed as:
|
| 627 |
+
log(candidate_frequency) - rep_penalty_component - spike_penalty_component
|
| 628 |
+
A numerical stable softmax is then applied over these scores to determine the probability for drawing a candidate.
|
| 629 |
+
|
| 630 |
+
If no candidate passing the invariant is found, the attempt is aborted.
|
| 631 |
+
|
| 632 |
+
Parameters:
|
| 633 |
+
mapping_lookup: Precomputed lookup mapping (context → (candidates, frequencies)).
|
| 634 |
+
training_set: Set of training sequences to ensure uniqueness.
|
| 635 |
+
memory_len: Number of triplets used as context.
|
| 636 |
+
sequence_length: Desired length of the generated sequence.
|
| 637 |
+
max_attempts: Maximum number of generation attempts.
|
| 638 |
+
|
| 639 |
+
Returns:
|
| 640 |
+
A new unique sequence (tuple of triplets) that respects the invariant, or None if not found.
|
| 641 |
+
"""
|
| 642 |
+
mapping_keys = list(mapping_lookup.keys())
|
| 643 |
+
num_keys = len(mapping_keys)
|
| 644 |
+
|
| 645 |
+
for attempt in range(max_attempts):
|
| 646 |
+
# Select a seed context randomly (from training data so that the invariant holds).
|
| 647 |
+
seed = mapping_keys[np.random.randint(0, num_keys)]
|
| 648 |
+
generated_sequence: List[Tuple[float, float, float]] = list(seed)
|
| 649 |
+
valid_generation = True
|
| 650 |
+
|
| 651 |
+
while len(generated_sequence) < sequence_length:
|
| 652 |
+
last_triplet = generated_sequence[-1]
|
| 653 |
+
current_context = tuple(generated_sequence[-memory_len:]) # context as tuple of triplets
|
| 654 |
+
candidate_found = False
|
| 655 |
+
|
| 656 |
+
if current_context in mapping_lookup:
|
| 657 |
+
candidates, frequencies = mapping_lookup[current_context]
|
| 658 |
+
# Filter candidates by invariant:
|
| 659 |
+
# Candidate's first element must be >= last triplet's second element.
|
| 660 |
+
valid_indices = [i for i, cand in enumerate(candidates) if cand[0] >= last_triplet[1]]
|
| 661 |
+
if valid_indices:
|
| 662 |
+
# Filter candidates and their associated frequencies.
|
| 663 |
+
filtered_freqs = frequencies[valid_indices]
|
| 664 |
+
filtered_candidates = [candidates[i] for i in valid_indices]
|
| 665 |
+
|
| 666 |
+
# Convert candidates into a NumPy array for vectorized operations.
|
| 667 |
+
candidate_array = np.array(filtered_candidates, dtype=np.float64) # shape: (n_candidates, 3)
|
| 668 |
+
|
| 669 |
+
# Prepare generation history as array.
|
| 670 |
+
generated_array = np.array(generated_sequence, dtype=np.float64) # shape: (T, 3)
|
| 671 |
+
current_length = generated_array.shape[0]
|
| 672 |
+
|
| 673 |
+
# Running average and standard deviation for dynamic spike adjustment.
|
| 674 |
+
running_avg = np.mean(generated_array, axis=0) # shape: (3,)
|
| 675 |
+
running_std = np.std(generated_array, axis=0) # shape: (3,)
|
| 676 |
+
# Dynamic sigma: ensure a minimum sigma value.
|
| 677 |
+
dynamic_sigma = np.maximum(running_std, np.array(SPIKE_SIG))
|
| 678 |
+
|
| 679 |
+
# --- Compute Repetition Penalty ---
|
| 680 |
+
# For each candidate, count the number of occurrences for each element along the corresponding column.
|
| 681 |
+
rep_counts = np.array([
|
| 682 |
+
[np.sum(generated_array[:, k] == candidate_array[i, k]) for k in range(3)]
|
| 683 |
+
for i in range(candidate_array.shape[0])
|
| 684 |
+
]) # shape: (n_candidates, 3)
|
| 685 |
+
# The repetition penalty in log-space.
|
| 686 |
+
rep_penalty_term = np.sum(rep_counts * np.log(np.array(REP_PENALTY)) *
|
| 687 |
+
(current_length / sequence_length), axis=1) # shape: (n_candidates,)
|
| 688 |
+
|
| 689 |
+
# --- Compute Spike (Variance) Penalty ---
|
| 690 |
+
# Compute the difference per candidate from the running average.
|
| 691 |
+
diff = candidate_array - running_avg # shape: (n_candidates, 3)
|
| 692 |
+
spike_penalty_term = np.sum(np.array(SPIKE_STRENGTH) * (diff**2) / (2 * (dynamic_sigma**2)),
|
| 693 |
+
axis=1) # shape: (n_candidates,)
|
| 694 |
+
|
| 695 |
+
# --- Compute Candidate Log-Scores ---
|
| 696 |
+
# Use np.log on frequencies (they are positive by construction).
|
| 697 |
+
log_freq = np.log(filtered_freqs)
|
| 698 |
+
log_scores = log_freq - rep_penalty_term - spike_penalty_term
|
| 699 |
+
|
| 700 |
+
# --- Softmax in Log-space (stable computation) ---
|
| 701 |
+
max_log = np.max(log_scores)
|
| 702 |
+
exp_scores = np.exp(log_scores - max_log)
|
| 703 |
+
probabilities = exp_scores / np.sum(exp_scores)
|
| 704 |
+
|
| 705 |
+
# Choose the next candidate using advanced probabilities.
|
| 706 |
+
chosen_idx = np.random.choice(len(filtered_candidates), p=probabilities)
|
| 707 |
+
next_triplet = filtered_candidates[chosen_idx]
|
| 708 |
+
candidate_found = True
|
| 709 |
+
|
| 710 |
+
if not candidate_found:
|
| 711 |
+
# Abort this generation attempt if no valid candidate is available.
|
| 712 |
+
valid_generation = False
|
| 713 |
+
break
|
| 714 |
+
|
| 715 |
+
generated_sequence.append(next_triplet)
|
| 716 |
+
|
| 717 |
+
# Ensure the final sequence meets the invariant and is unique.
|
| 718 |
+
if valid_generation and len(generated_sequence) == sequence_length:
|
| 719 |
+
new_sequence = tuple(generated_sequence)
|
| 720 |
+
invariant_ok = all(a[1] <= b[0] for a, b in zip(new_sequence, new_sequence[1:]))
|
| 721 |
+
if invariant_ok and new_sequence not in training_set:
|
| 722 |
+
return new_sequence
|
| 723 |
+
|
| 724 |
+
return None
|
| 725 |
+
|
| 726 |
+
###################################################################################
|
| 727 |
+
|
| 728 |
+
def analyze_generated_sequence(sequence: tuple, mapping_lookup: dict, memory_len: int) -> tuple:
|
| 729 |
+
"""
|
| 730 |
+
Analyze the generated sequence and return several useful statistics
|
| 731 |
+
as both a dictionary and as a nicely formatted string report.
|
| 732 |
+
|
| 733 |
+
Statistics Computed:
|
| 734 |
+
- unigram_diversity: Ratio of unique triplets to total triplets.
|
| 735 |
+
- repetition_rate: Fraction of repeated triplets.
|
| 736 |
+
- bigram_diversity: Ratio of unique consecutive pairs to total pairs.
|
| 737 |
+
- max_consecutive_repetitions: Maximum number of identical consecutive triplets.
|
| 738 |
+
- avg_candidate_probability (overfit rate): For the transitions (using a sliding window of size
|
| 739 |
+
MEMORY_LEN as context followed by candidate), the average probability of the chosen candidate
|
| 740 |
+
as per the training mapping.
|
| 741 |
+
|
| 742 |
+
Additional Analytics:
|
| 743 |
+
- element_stats: For each element (index 0, 1, 2) in a triplet, includes:
|
| 744 |
+
* mean, standard deviation, minimum, maximum, and average consecutive absolute difference.
|
| 745 |
+
- avg_transition_entropy: The average entropy of the candidate distributions (from mapping_lookup)
|
| 746 |
+
for each transition context.
|
| 747 |
+
- context_coverage: The fraction of transitions (based on context of length MEMORY_LEN) that are found
|
| 748 |
+
in the mapping_lookup.
|
| 749 |
+
|
| 750 |
+
Parameters:
|
| 751 |
+
sequence: Generated sequence (tuple of triplets).
|
| 752 |
+
mapping_lookup: Precomputed mapping lookup.
|
| 753 |
+
memory_len: The context length used.
|
| 754 |
+
|
| 755 |
+
Returns:
|
| 756 |
+
A tuple containing:
|
| 757 |
+
(stats_dict, stats_report_string)
|
| 758 |
+
"""
|
| 759 |
+
stats = {}
|
| 760 |
+
seq_len = len(sequence)
|
| 761 |
+
|
| 762 |
+
# --- Basic Statistics ---
|
| 763 |
+
|
| 764 |
+
# Unigram.
|
| 765 |
+
unique_triplets = len(set(sequence))
|
| 766 |
+
stats["unigram_diversity"] = unique_triplets / seq_len
|
| 767 |
+
stats["repetition_rate"] = 1 - (unique_triplets / seq_len)
|
| 768 |
+
|
| 769 |
+
# Bigram.
|
| 770 |
+
bigrams = [(sequence[i], sequence[i+1]) for i in range(seq_len - 1)]
|
| 771 |
+
unique_bigrams = len(set(bigrams))
|
| 772 |
+
stats["bigram_diversity"] = unique_bigrams / (seq_len - 1)
|
| 773 |
+
|
| 774 |
+
# Maximum consecutive repetitions.
|
| 775 |
+
max_consecutive = 1
|
| 776 |
+
current_consecutive = 1
|
| 777 |
+
for i in range(1, seq_len):
|
| 778 |
+
if sequence[i] == sequence[i-1]:
|
| 779 |
+
current_consecutive += 1
|
| 780 |
+
if current_consecutive > max_consecutive:
|
| 781 |
+
max_consecutive = current_consecutive
|
| 782 |
+
else:
|
| 783 |
+
current_consecutive = 1
|
| 784 |
+
stats["max_consecutive_repetitions"] = max_consecutive
|
| 785 |
+
|
| 786 |
+
# Avg Candidate Probability (Overfit Rate)
|
| 787 |
+
overfit_probs = []
|
| 788 |
+
for i in range(memory_len, seq_len):
|
| 789 |
+
context = tuple(sequence[i - memory_len: i])
|
| 790 |
+
candidate = sequence[i]
|
| 791 |
+
if context in mapping_lookup:
|
| 792 |
+
candidates, frequencies = mapping_lookup[context]
|
| 793 |
+
total_freq = np.sum(frequencies)
|
| 794 |
+
try:
|
| 795 |
+
idx = candidates.index(candidate)
|
| 796 |
+
cand_prob = frequencies[idx] / total_freq
|
| 797 |
+
overfit_probs.append(cand_prob)
|
| 798 |
+
except ValueError:
|
| 799 |
+
pass
|
| 800 |
+
stats["avg_candidate_probability"] = np.mean(overfit_probs) if overfit_probs else None
|
| 801 |
+
|
| 802 |
+
# --- Additional Analytics ---
|
| 803 |
+
|
| 804 |
+
# 1. Element-Level Statistics.
|
| 805 |
+
seq_arr = np.array(sequence) # shape: (seq_len, 3)
|
| 806 |
+
element_stats = {}
|
| 807 |
+
for dim in range(seq_arr.shape[1]):
|
| 808 |
+
values = seq_arr[:, dim]
|
| 809 |
+
mean_val = np.mean(values)
|
| 810 |
+
std_val = np.std(values)
|
| 811 |
+
min_val = np.min(values)
|
| 812 |
+
max_val = np.max(values)
|
| 813 |
+
# Calculate average absolute difference between consecutive values:
|
| 814 |
+
diffs = np.abs(np.diff(values))
|
| 815 |
+
avg_diff = np.mean(diffs) if diffs.size > 0 else 0
|
| 816 |
+
element_stats[f"element_{dim}"] = {
|
| 817 |
+
"mean": mean_val,
|
| 818 |
+
"std": std_val,
|
| 819 |
+
"min": min_val,
|
| 820 |
+
"max": max_val,
|
| 821 |
+
"avg_consecutive_diff": avg_diff,
|
| 822 |
+
}
|
| 823 |
+
stats["element_stats"] = element_stats
|
| 824 |
+
|
| 825 |
+
# 2. Transition Entropy:
|
| 826 |
+
entropies = []
|
| 827 |
+
valid_transitions = 0
|
| 828 |
+
for i in range(memory_len, seq_len):
|
| 829 |
+
context = tuple(sequence[i - memory_len: i])
|
| 830 |
+
if context in mapping_lookup:
|
| 831 |
+
candidates, freqs = mapping_lookup[context]
|
| 832 |
+
total_freq = np.sum(freqs)
|
| 833 |
+
if total_freq > 0:
|
| 834 |
+
probs = freqs / total_freq
|
| 835 |
+
# Add a very small constant to avoid log(0)
|
| 836 |
+
epsilon = 1e-10
|
| 837 |
+
entropy = -np.sum(probs * np.log(probs + epsilon))
|
| 838 |
+
entropies.append(entropy)
|
| 839 |
+
valid_transitions += 1
|
| 840 |
+
stats["avg_transition_entropy"] = np.mean(entropies) if entropies else None
|
| 841 |
+
|
| 842 |
+
# 3. Context Coverage:
|
| 843 |
+
total_transitions = seq_len - memory_len
|
| 844 |
+
stats["context_coverage"] = (valid_transitions / total_transitions) if total_transitions > 0 else None
|
| 845 |
+
|
| 846 |
+
# --- Build a Pretty Report String ---
|
| 847 |
+
sep_line = "-" * 60
|
| 848 |
+
lines = []
|
| 849 |
+
lines.append(sep_line)
|
| 850 |
+
lines.append("Sequence Analytics Report:")
|
| 851 |
+
lines.append(sep_line)
|
| 852 |
+
lines.append("Overall Statistics:")
|
| 853 |
+
lines.append(f" Unigram Diversity : {stats['unigram_diversity']:.3f}")
|
| 854 |
+
lines.append(f" Repetition Rate : {stats['repetition_rate']:.3f}")
|
| 855 |
+
lines.append(f" Bigram Diversity : {stats['bigram_diversity']:.3f}")
|
| 856 |
+
lines.append(f" Max Consecutive Repetitions: {stats['max_consecutive_repetitions']}")
|
| 857 |
+
cand_prob = stats["avg_candidate_probability"]
|
| 858 |
+
cand_prob_str = f"{cand_prob:.3f}" if cand_prob is not None else "N/A"
|
| 859 |
+
lines.append(f" Avg Candidate Probability : {cand_prob_str}")
|
| 860 |
+
lines.append("")
|
| 861 |
+
|
| 862 |
+
lines.append("Element-Level Statistics:")
|
| 863 |
+
for dim in sorted(element_stats.keys()):
|
| 864 |
+
ed = element_stats[dim]
|
| 865 |
+
lines.append(f" {dim.capitalize()}:")
|
| 866 |
+
lines.append(f" Mean : {ed['mean']:.3f}")
|
| 867 |
+
lines.append(f" Std Dev : {ed['std']:.3f}")
|
| 868 |
+
lines.append(f" Min : {ed['min']:.3f}")
|
| 869 |
+
lines.append(f" Max : {ed['max']:.3f}")
|
| 870 |
+
lines.append(f" Avg Consecutive Diff : {ed['avg_consecutive_diff']:.3f}")
|
| 871 |
+
lines.append("")
|
| 872 |
+
|
| 873 |
+
lines.append("Transition Statistics:")
|
| 874 |
+
avg_entropy = stats["avg_transition_entropy"]
|
| 875 |
+
entropy_str = f"{avg_entropy:.3f}" if avg_entropy is not None else "N/A"
|
| 876 |
+
lines.append(f" Average Transition Entropy: {entropy_str}")
|
| 877 |
+
cc = stats["context_coverage"]
|
| 878 |
+
cc_str = f"{cc:.3f}" if cc is not None else "N/A"
|
| 879 |
+
lines.append(f" Context Coverage : {cc_str}")
|
| 880 |
+
lines.append(sep_line)
|
| 881 |
+
|
| 882 |
+
stats_report = "\n".join(lines)
|
| 883 |
+
|
| 884 |
+
# Return both the dictionary and the formatted report string.
|
| 885 |
+
return stats, stats_report
|
| 886 |
+
|
| 887 |
+
###################################################################################
|
| 888 |
+
|
| 889 |
+
def autoregressive_generate(start_seq, mel_tones, trg_array, trg_matches_array, num_new_tokens, chunk_len=5):
|
| 890 |
+
|
| 891 |
+
# Convert sequences to NumPy arrays.
|
| 892 |
+
current_seq = np.array(start_seq, dtype=int) # Shape: (num_tokens, token_dim)
|
| 893 |
+
trg_array = np.array(trg_array, dtype=int) # Shape: (num_candidates, 2, token_dim)
|
| 894 |
+
start_len = len(start_seq)
|
| 895 |
+
|
| 896 |
+
midx = start_len-1
|
| 897 |
+
|
| 898 |
+
# Deque for sliding memory of candidate pairs (immutable tuples).
|
| 899 |
+
recent_candidates = deque(maxlen=5)
|
| 900 |
+
|
| 901 |
+
while (len(current_seq) - start_len) < num_new_tokens:
|
| 902 |
+
|
| 903 |
+
midx += 1
|
| 904 |
+
|
| 905 |
+
# Get the last two tokens as context.
|
| 906 |
+
context = current_seq[-(chunk_len-1):] # Shape: (2, token_dim)
|
| 907 |
+
|
| 908 |
+
sli = 0
|
| 909 |
+
msize = 0
|
| 910 |
+
|
| 911 |
+
ctx = context[:, :-1].reshape(1, -1)
|
| 912 |
+
trg_mat_arr = trg_matches_array
|
| 913 |
+
|
| 914 |
+
while msize < 8:
|
| 915 |
+
|
| 916 |
+
print('=== Slice', sli)
|
| 917 |
+
|
| 918 |
+
# Compare context with candidates in trg_array.
|
| 919 |
+
match_mask = np.all(ctx == trg_mat_arr, axis=1)
|
| 920 |
+
match_indices = np.where(match_mask)[0]
|
| 921 |
+
|
| 922 |
+
msize = match_indices.size
|
| 923 |
+
|
| 924 |
+
if msize < 8:
|
| 925 |
+
sli += 1
|
| 926 |
+
ctx = context[:, :-1].reshape(1, -1)[:, sli:]
|
| 927 |
+
trg_mat_arr = trg_matches_array[:, :-sli]
|
| 928 |
+
|
| 929 |
+
if match_indices.size == 0:
|
| 930 |
+
if len(current_seq) > start_len:
|
| 931 |
+
|
| 932 |
+
#tones_chord = sorted([mel_tones[midx], (mel_tones[midx]+7) % 12])
|
| 933 |
+
tones_chord = sorted([mel_tones[midx]])
|
| 934 |
+
new_tuple = [[mel_tones[midx], TMIDIX.ALL_CHORDS_SORTED.index(tones_chord)]]
|
| 935 |
+
current_seq = np.concatenate((current_seq, new_tuple), axis=0)
|
| 936 |
+
print('Subbed', midx)
|
| 937 |
+
continue
|
| 938 |
+
|
| 939 |
+
# From the matching candidates, filter out those whose candidate pair is in recent memory.
|
| 940 |
+
available_candidates = []
|
| 941 |
+
cseen = []
|
| 942 |
+
for idx in match_indices:
|
| 943 |
+
|
| 944 |
+
if idx not in recent_candidates:
|
| 945 |
+
# Convert candidate pair to an immutable tuple
|
| 946 |
+
candidate_pair = tuple(trg_array[idx].tolist())
|
| 947 |
+
if candidate_pair[-1][0] == mel_tones[midx] and candidate_pair[-1][1] not in cseen:
|
| 948 |
+
available_candidates.append((idx, candidate_pair))
|
| 949 |
+
cseen.append(candidate_pair[-1][1])
|
| 950 |
+
|
| 951 |
+
# If all candidates have recently been used, backtrack.
|
| 952 |
+
if len(available_candidates) < 3:
|
| 953 |
+
if len(current_seq) >= start_len:
|
| 954 |
+
#tones_chord = sorted([mel_tones[midx], (mel_tones[midx]+7) % 12])
|
| 955 |
+
tones_chord = sorted([mel_tones[midx]])
|
| 956 |
+
new_tuple = [[mel_tones[midx], TMIDIX.ALL_CHORDS_SORTED.index(tones_chord)]]
|
| 957 |
+
current_seq = np.concatenate((current_seq, new_tuple), axis=0)
|
| 958 |
+
#rev_val = random.choice([-1, -2])
|
| 959 |
+
#current_seq = current_seq[:rev_val]
|
| 960 |
+
#print(midx)
|
| 961 |
+
#midx = len(current_seq)
|
| 962 |
+
#print('Reverted', midx, len(current_seq))
|
| 963 |
+
continue
|
| 964 |
+
|
| 965 |
+
else:
|
| 966 |
+
print(len(available_candidates))
|
| 967 |
+
# Choose one available candidate at random.
|
| 968 |
+
chosen_idx, chosen_pair = available_candidates[np.random.choice(len(available_candidates))]
|
| 969 |
+
new_token = trg_array[chosen_idx][-1] # The second token of the candidate pair.
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
# Append the new token to the sequence.
|
| 973 |
+
current_seq = np.concatenate((current_seq, new_token[None, :]), axis=0)
|
| 974 |
+
|
| 975 |
+
recent_candidates.append(chosen_idx)
|
| 976 |
+
|
| 977 |
+
print('Gen seq len', len(current_seq))
|
| 978 |
+
|
| 979 |
+
return current_seq
|
| 980 |
+
|
| 981 |
+
###################################################################################
|
| 982 |
+
|
| 983 |
+
def minkowski_distance_vector_to_matrix(x: cp.ndarray, X: cp.ndarray, p: float = 3) -> cp.ndarray:
|
| 984 |
+
|
| 985 |
+
"""
|
| 986 |
+
Computes the Minkowski distance between a 1D CuPy array 'x' and each row of a 2D CuPy array 'X'.
|
| 987 |
+
|
| 988 |
+
Parameters:
|
| 989 |
+
x (cp.ndarray): A 1D array with shape (n_features,) representing a single vector.
|
| 990 |
+
X (cp.ndarray): A 2D array with shape (n_samples, n_features) where each row is a vector.
|
| 991 |
+
p (float): The order of the Minkowski distance.
|
| 992 |
+
For instance:
|
| 993 |
+
- p=1 yields the Manhattan distance,
|
| 994 |
+
- p=2 yields the Euclidean distance,
|
| 995 |
+
- p=3 yields the Minkowski distance and will use the cube-root implementation,
|
| 996 |
+
- p=∞ (or cp.inf) gives the Chebyshev distance.
|
| 997 |
+
|
| 998 |
+
Returns:
|
| 999 |
+
cp.ndarray: A 1D array of length n_samples containing the Minkowski distance between 'x'
|
| 1000 |
+
and the corresponding row in 'X'.
|
| 1001 |
+
"""
|
| 1002 |
+
|
| 1003 |
+
# Compute the element-wise absolute differences between x and every row in X.
|
| 1004 |
+
# Broadcasting x over the rows of X results in an array of shape (n_samples, n_features).
|
| 1005 |
+
diff = cp.abs(X - x)
|
| 1006 |
+
|
| 1007 |
+
if p == float('inf') or p == cp.inf:
|
| 1008 |
+
# For the Chebyshev distance, use the maximum absolute difference along the feature axis.
|
| 1009 |
+
distances = cp.max(diff, axis=1)
|
| 1010 |
+
elif p == 3:
|
| 1011 |
+
# Instead of using the generic power operation (sum(diff**3) ** (1/3)),
|
| 1012 |
+
# we use cp.cbrt for cube-root calculation when p is exactly 3.
|
| 1013 |
+
distances = cp.cbrt(cp.sum(diff ** 3, axis=1))
|
| 1014 |
+
else:
|
| 1015 |
+
# For general Minkowski distance with finite p,
|
| 1016 |
+
# compute the p-th power of differences, sum them, then take the p-th root.
|
| 1017 |
+
distances = cp.sum(diff ** p, axis=1) ** (1.0 / p)
|
| 1018 |
+
|
| 1019 |
+
return distances
|
| 1020 |
+
|
| 1021 |
+
###################################################################################
|
| 1022 |
+
|
| 1023 |
+
def pairwise_minkowski_distance(X: cp.ndarray, p: float = 2) -> cp.ndarray:
|
| 1024 |
+
|
| 1025 |
+
"""
|
| 1026 |
+
Computes pairwise Minkowski distances for a 2D CuPy array.
|
| 1027 |
+
|
| 1028 |
+
Parameters:
|
| 1029 |
+
X (cp.ndarray): A 2D array of shape (n_samples, n_features), where each row represents a vector.
|
| 1030 |
+
p (float): The order of the Minkowski distance.
|
| 1031 |
+
For example:
|
| 1032 |
+
- p=1 is the Manhattan distance,
|
| 1033 |
+
- p=2 is the Euclidean distance,
|
| 1034 |
+
- p=∞ (e.g., float('inf') or cp.inf) is the Chebyshev distance.
|
| 1035 |
+
|
| 1036 |
+
Returns:
|
| 1037 |
+
cp.ndarray: A 2D array of shape (n_samples, n_samples) containing the pairwise Minkowski distances.
|
| 1038 |
+
"""
|
| 1039 |
+
|
| 1040 |
+
# Use broadcasting to compute the absolute difference between every pair of vectors.
|
| 1041 |
+
# The result of X[:, None, :] - X[None, :, :] will have shape (n_samples, n_samples, n_features).
|
| 1042 |
+
if p == float('inf') or p == cp.inf:
|
| 1043 |
+
# For the Chebyshev distance, take the maximum absolute difference along the feature axis.
|
| 1044 |
+
return cp.max(cp.abs(X[:, None, :] - X[None, :, :]), axis=-1)
|
| 1045 |
+
else:
|
| 1046 |
+
# Raise the absolute differences to the power p.
|
| 1047 |
+
diff_powered = cp.abs(X[:, None, :] - X[None, :, :]) ** p
|
| 1048 |
+
# Sum over the features for each pair (i, j) and then take the p-th root.
|
| 1049 |
+
distances = cp.sum(diff_powered, axis=-1) ** (1.0 / p)
|
| 1050 |
+
|
| 1051 |
+
return distances
|
| 1052 |
+
|
| 1053 |
+
###################################################################################
|
| 1054 |
+
|
| 1055 |
+
def pairwise_cosine_similarity(X: cp.ndarray, eps: float = 1e-10) -> cp.ndarray:
|
| 1056 |
+
|
| 1057 |
+
"""
|
| 1058 |
+
Computes the pairwise cosine similarity for a 2D CuPy array.
|
| 1059 |
+
|
| 1060 |
+
Parameters:
|
| 1061 |
+
X (cp.ndarray): A 2D array of shape (n_samples, n_features) where each row represents a vector.
|
| 1062 |
+
eps (float): A small constant added to the denominator to prevent division by zero.
|
| 1063 |
+
|
| 1064 |
+
Returns:
|
| 1065 |
+
cp.ndarray: A 2D array of shape (n_samples, n_samples) containing the pairwise cosine similarities.
|
| 1066 |
+
"""
|
| 1067 |
+
|
| 1068 |
+
# Compute the dot product between every pair of rows.
|
| 1069 |
+
# This results in a matrix where element (i, j) is the dot product of X[i] and X[j].
|
| 1070 |
+
dot_product = cp.dot(X, X.T)
|
| 1071 |
+
|
| 1072 |
+
# Compute the L2 norm (Euclidean norm) for each row vector.
|
| 1073 |
+
norms = cp.linalg.norm(X, axis=1)
|
| 1074 |
+
|
| 1075 |
+
# Compute the outer product of the norms to form the denominator.
|
| 1076 |
+
# The element (i, j) in this matrix is norms[i] * norms[j].
|
| 1077 |
+
norm_matrix = cp.outer(norms, norms)
|
| 1078 |
+
|
| 1079 |
+
# Compute the cosine similarity matrix.
|
| 1080 |
+
# Adding a small epsilon (eps) to the denominator prevents division by zero.
|
| 1081 |
+
cosine_similarity = dot_product / (norm_matrix + eps)
|
| 1082 |
+
|
| 1083 |
+
return cosine_similarity
|
| 1084 |
+
|
| 1085 |
+
###################################################################################
|
| 1086 |
+
|
| 1087 |
+
print('Module is loaded!')
|
| 1088 |
+
print('Enjoy! :)')
|
| 1089 |
+
print('=' * 70)
|
| 1090 |
+
|
| 1091 |
+
###################################################################################
|
| 1092 |
+
# This is the end of the TCUPY Python module
|
| 1093 |
+
###################################################################################
|
TMIDIX.py
CHANGED
|
@@ -7,7 +7,7 @@ r'''############################################################################
|
|
| 7 |
# Tegridy MIDI X Module (TMIDI X / tee-midi eks)
|
| 8 |
# Version 1.0
|
| 9 |
#
|
| 10 |
-
# NOTE: TMIDI X Module starts after the partial MIDI.py module @ line
|
| 11 |
#
|
| 12 |
# Based upon MIDI.py module v.6.7. by Peter Billam / pjb.com.au
|
| 13 |
#
|
|
@@ -50,6 +50,7 @@ r'''############################################################################
|
|
| 50 |
###################################################################################'''
|
| 51 |
|
| 52 |
import sys, struct, copy
|
|
|
|
| 53 |
Version = '6.7'
|
| 54 |
VersionDate = '20201120'
|
| 55 |
|
|
@@ -1492,6 +1493,10 @@ import psutil
|
|
| 1492 |
|
| 1493 |
import json
|
| 1494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1495 |
###################################################################################
|
| 1496 |
#
|
| 1497 |
# Original TMIDI Tegridy helper functions
|
|
@@ -4144,15 +4149,16 @@ def tones_chord_to_pitches(tones_chord, base_pitch=60):
|
|
| 4144 |
###################################################################################
|
| 4145 |
|
| 4146 |
def advanced_score_processor(raw_score,
|
| 4147 |
-
|
| 4148 |
-
|
| 4149 |
-
|
| 4150 |
-
|
| 4151 |
-
|
| 4152 |
-
|
| 4153 |
-
|
| 4154 |
-
|
| 4155 |
-
|
|
|
|
| 4156 |
):
|
| 4157 |
|
| 4158 |
'''TMIDIX Advanced Score Processor'''
|
|
@@ -4192,6 +4198,9 @@ def advanced_score_processor(raw_score,
|
|
| 4192 |
e[2] = e[2] % 16
|
| 4193 |
e[3] = e[3] % 128
|
| 4194 |
|
|
|
|
|
|
|
|
|
|
| 4195 |
basic_single_track_score.sort(key=lambda x: x[4] if x[0] == 'note' else 128, reverse=True)
|
| 4196 |
basic_single_track_score.sort(key=lambda x: x[1])
|
| 4197 |
|
|
@@ -12226,6 +12235,130 @@ def escore_notes_pitches_chords_signature(escore_notes,
|
|
| 12226 |
else:
|
| 12227 |
return []
|
| 12228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12229 |
###################################################################################
|
| 12230 |
# This is the end of the TMIDI X Python module
|
| 12231 |
###################################################################################
|
|
|
|
| 7 |
# Tegridy MIDI X Module (TMIDI X / tee-midi eks)
|
| 8 |
# Version 1.0
|
| 9 |
#
|
| 10 |
+
# NOTE: TMIDI X Module starts after the partial MIDI.py module @ line 1438
|
| 11 |
#
|
| 12 |
# Based upon MIDI.py module v.6.7. by Peter Billam / pjb.com.au
|
| 13 |
#
|
|
|
|
| 50 |
###################################################################################'''
|
| 51 |
|
| 52 |
import sys, struct, copy
|
| 53 |
+
|
| 54 |
Version = '6.7'
|
| 55 |
VersionDate = '20201120'
|
| 56 |
|
|
|
|
| 1493 |
|
| 1494 |
import json
|
| 1495 |
|
| 1496 |
+
from pathlib import Path
|
| 1497 |
+
|
| 1498 |
+
import shutil
|
| 1499 |
+
|
| 1500 |
###################################################################################
|
| 1501 |
#
|
| 1502 |
# Original TMIDI Tegridy helper functions
|
|
|
|
| 4149 |
###################################################################################
|
| 4150 |
|
| 4151 |
def advanced_score_processor(raw_score,
|
| 4152 |
+
patches_to_analyze=list(range(129)),
|
| 4153 |
+
return_score_analysis=False,
|
| 4154 |
+
return_enhanced_score=False,
|
| 4155 |
+
return_enhanced_score_notes=False,
|
| 4156 |
+
return_enhanced_monophonic_melody=False,
|
| 4157 |
+
return_chordified_enhanced_score=False,
|
| 4158 |
+
return_chordified_enhanced_score_with_lyrics=False,
|
| 4159 |
+
return_score_tones_chords=False,
|
| 4160 |
+
return_text_and_lyric_events=False,
|
| 4161 |
+
apply_sustain=False
|
| 4162 |
):
|
| 4163 |
|
| 4164 |
'''TMIDIX Advanced Score Processor'''
|
|
|
|
| 4198 |
e[2] = e[2] % 16
|
| 4199 |
e[3] = e[3] % 128
|
| 4200 |
|
| 4201 |
+
if apply_sustain:
|
| 4202 |
+
apply_sustain_to_ms_score([1000, basic_single_track_score])
|
| 4203 |
+
|
| 4204 |
basic_single_track_score.sort(key=lambda x: x[4] if x[0] == 'note' else 128, reverse=True)
|
| 4205 |
basic_single_track_score.sort(key=lambda x: x[1])
|
| 4206 |
|
|
|
|
| 12235 |
else:
|
| 12236 |
return []
|
| 12237 |
|
| 12238 |
+
###################################################################################
|
| 12239 |
+
|
| 12240 |
+
def compute_sustain_intervals(events):
|
| 12241 |
+
|
| 12242 |
+
intervals = []
|
| 12243 |
+
pedal_on = False
|
| 12244 |
+
current_start = None
|
| 12245 |
+
|
| 12246 |
+
for t, cc in events:
|
| 12247 |
+
if not pedal_on and cc >= 64:
|
| 12248 |
+
|
| 12249 |
+
pedal_on = True
|
| 12250 |
+
current_start = t
|
| 12251 |
+
elif pedal_on and cc < 64:
|
| 12252 |
+
|
| 12253 |
+
pedal_on = False
|
| 12254 |
+
intervals.append((current_start, t))
|
| 12255 |
+
current_start = None
|
| 12256 |
+
|
| 12257 |
+
if pedal_on:
|
| 12258 |
+
intervals.append((current_start, float('inf')))
|
| 12259 |
+
|
| 12260 |
+
merged = []
|
| 12261 |
+
|
| 12262 |
+
for interval in intervals:
|
| 12263 |
+
if merged and interval[0] <= merged[-1][1]:
|
| 12264 |
+
merged[-1] = (merged[-1][0], max(merged[-1][1], interval[1]))
|
| 12265 |
+
else:
|
| 12266 |
+
merged.append(interval)
|
| 12267 |
+
return merged
|
| 12268 |
+
|
| 12269 |
+
###################################################################################
|
| 12270 |
+
|
| 12271 |
+
def apply_sustain_to_ms_score(score):
|
| 12272 |
+
|
| 12273 |
+
sustain_by_channel = {}
|
| 12274 |
+
|
| 12275 |
+
for track in score[1:]:
|
| 12276 |
+
for event in track:
|
| 12277 |
+
if event[0] == 'control_change' and event[3] == 64:
|
| 12278 |
+
channel = event[2]
|
| 12279 |
+
sustain_by_channel.setdefault(channel, []).append((event[1], event[4]))
|
| 12280 |
+
|
| 12281 |
+
sustain_intervals_by_channel = {}
|
| 12282 |
+
|
| 12283 |
+
for channel, events in sustain_by_channel.items():
|
| 12284 |
+
events.sort(key=lambda x: x[0])
|
| 12285 |
+
sustain_intervals_by_channel[channel] = compute_sustain_intervals(events)
|
| 12286 |
+
|
| 12287 |
+
global_max_off = 0
|
| 12288 |
+
|
| 12289 |
+
for track in score[1:]:
|
| 12290 |
+
for event in track:
|
| 12291 |
+
if event[0] == 'note':
|
| 12292 |
+
global_max_off = max(global_max_off, event[1] + event[2])
|
| 12293 |
+
|
| 12294 |
+
for channel, intervals in sustain_intervals_by_channel.items():
|
| 12295 |
+
updated_intervals = []
|
| 12296 |
+
for start, end in intervals:
|
| 12297 |
+
if end == float('inf'):
|
| 12298 |
+
end = global_max_off
|
| 12299 |
+
updated_intervals.append((start, end))
|
| 12300 |
+
sustain_intervals_by_channel[channel] = updated_intervals
|
| 12301 |
+
|
| 12302 |
+
if sustain_intervals_by_channel:
|
| 12303 |
+
|
| 12304 |
+
for track in score[1:]:
|
| 12305 |
+
for event in track:
|
| 12306 |
+
if event[0] == 'note':
|
| 12307 |
+
start = event[1]
|
| 12308 |
+
nominal_dur = event[2]
|
| 12309 |
+
nominal_off = start + nominal_dur
|
| 12310 |
+
channel = event[3]
|
| 12311 |
+
|
| 12312 |
+
intervals = sustain_intervals_by_channel.get(channel, [])
|
| 12313 |
+
effective_off = nominal_off
|
| 12314 |
+
|
| 12315 |
+
for intv_start, intv_end in intervals:
|
| 12316 |
+
if intv_start < nominal_off < intv_end:
|
| 12317 |
+
effective_off = intv_end
|
| 12318 |
+
break
|
| 12319 |
+
|
| 12320 |
+
effective_dur = effective_off - start
|
| 12321 |
+
|
| 12322 |
+
event[2] = effective_dur
|
| 12323 |
+
|
| 12324 |
+
return score
|
| 12325 |
+
|
| 12326 |
+
###################################################################################
|
| 12327 |
+
|
| 12328 |
+
def copy_file(src_file: str, trg_dir: str, add_subdir: bool = False, verbose: bool = False):
|
| 12329 |
+
|
| 12330 |
+
src_path = Path(src_file)
|
| 12331 |
+
target_directory = Path(trg_dir)
|
| 12332 |
+
|
| 12333 |
+
if not src_path.is_file():
|
| 12334 |
+
if verbose:
|
| 12335 |
+
print("Source file does not exist or is not a file.")
|
| 12336 |
+
|
| 12337 |
+
return None
|
| 12338 |
+
|
| 12339 |
+
target_directory.mkdir(parents=True, exist_ok=True)
|
| 12340 |
+
|
| 12341 |
+
if add_subdir:
|
| 12342 |
+
first_letter = src_path.name[0]
|
| 12343 |
+
target_directory = target_directory / first_letter
|
| 12344 |
+
target_directory.mkdir(parents=True, exist_ok=True)
|
| 12345 |
+
|
| 12346 |
+
destination = target_directory / src_path.name
|
| 12347 |
+
|
| 12348 |
+
try:
|
| 12349 |
+
shutil.copy2(src_path, destination)
|
| 12350 |
+
|
| 12351 |
+
except:
|
| 12352 |
+
if verbose:
|
| 12353 |
+
print('File could not be copied!')
|
| 12354 |
+
|
| 12355 |
+
return None
|
| 12356 |
+
|
| 12357 |
+
if verbose:
|
| 12358 |
+
print('File copied!')
|
| 12359 |
+
|
| 12360 |
+
return None
|
| 12361 |
+
|
| 12362 |
###################################################################################
|
| 12363 |
# This is the end of the TMIDI X Python module
|
| 12364 |
###################################################################################
|