Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,069 Bytes
d1ed09d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import numpy as np
import numba
@numba.njit(
[
"f4(f4[::1],f4[::1])",
numba.types.float32(
numba.types.Array(numba.types.float32, 1, "C", readonly=True),
numba.types.Array(numba.types.float32, 1, "C", readonly=True),
),
],
fastmath=True,
locals={
"result": numba.types.float32,
"diff": numba.types.float32,
"dim": numba.types.intp,
"i": numba.types.uint16,
},
)
def euclidean(x, y):
r"""Squared euclidean distance.
.. math::
D(x, y) = \sum_i (x_i - y_i)^2
"""
result = 0.0
dim = x.shape[0]
for i in range(dim):
diff = x[i] - y[i]
result += diff * diff
return np.sqrt(result)
@numba.njit(parallel=True, nogil=True)
def chunked_parallel_pairwise_distances(X, Y=None, metric=euclidean, chunk_size=16):
if Y is None:
XX, symmetrical = X, True
row_size = col_size = X.shape[0]
else:
XX, symmetrical = Y, False
row_size, col_size = X.shape[0], Y.shape[0]
result = np.zeros((row_size, col_size), dtype=np.float32)
n_row_chunks = (row_size // chunk_size) + 1
for chunk_idx in numba.prange(n_row_chunks):
n = chunk_idx * chunk_size
chunk_end_n = min(n + chunk_size, row_size)
m_start = n if symmetrical else 0
for m in range(m_start, col_size, chunk_size):
chunk_end_m = min(m + chunk_size, col_size)
for i in range(n, chunk_end_n):
for j in range(m, chunk_end_m):
result[i, j] = metric(X[i], XX[j])
return result
@numba.njit()
def pull_arms(data, arms, num_pulls_per_arm, estimates, pull_counts):
other_candidates = np.random.choice(
data.shape[0], size=num_pulls_per_arm, replace=False
).astype(np.int32)
data_arm = data[arms]
data_other = data[other_candidates]
distance_sums = np.sum(
chunked_parallel_pairwise_distances(data_arm, data_other), axis=1
)
estimates *= pull_counts
estimates += distance_sums
pull_counts += num_pulls_per_arm
estimates /= pull_counts
@numba.njit()
def medoid(data, arm_budget=20):
pull_counts = np.zeros(data.shape[0], dtype=np.int32)
pull_budget = arm_budget * data.shape[0]
estimates = np.zeros(data.shape[0], dtype=np.float32)
current_active_arms = np.arange(data.shape[0])
n_rounds = int(np.ceil(np.log2(data.shape[0])))
while current_active_arms.shape[0] > 1:
num_pulls_per_arm = max(
1,
int(
min(
data.shape[0],
np.floor(pull_budget / (current_active_arms.shape[0] * n_rounds)),
)
),
)
pull_arms(data, current_active_arms, num_pulls_per_arm, estimates, pull_counts)
median = np.median(estimates)
mask = estimates <= median
current_active_arms = current_active_arms[mask]
estimates = estimates[mask]
pull_counts = pull_counts[mask]
return data[current_active_arms[0]]
|