# Authors: The scikit-learn developers | |
# SPDX-License-Identifier: BSD-3-Clause | |
from cython cimport floating | |
from libc.math cimport fabs | |
def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt, | |
floating[:, :] XHt, Py_ssize_t[::1] permutation): | |
cdef: | |
floating violation = 0 | |
Py_ssize_t n_components = W.shape[1] | |
Py_ssize_t n_samples = W.shape[0] # n_features for H update | |
floating grad, pg, hess | |
Py_ssize_t i, r, s, t | |
with nogil: | |
for s in range(n_components): | |
t = permutation[s] | |
for i in range(n_samples): | |
# gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt | |
grad = -XHt[i, t] | |
for r in range(n_components): | |
grad += HHt[t, r] * W[i, r] | |
# projected gradient | |
pg = min(0., grad) if W[i, t] == 0 else grad | |
violation += fabs(pg) | |
# Hessian | |
hess = HHt[t, t] | |
if hess != 0: | |
W[i, t] = max(W[i, t] - grad / hess, 0.) | |
return violation | |