Spaces:
Sleeping
Sleeping
"""Locally Optimal Block Preconditioned Conjugate Gradient methods. | |
""" | |
# Author: Pearu Peterson | |
# Created: February 2020 | |
from typing import Dict, Optional, Tuple | |
import torch | |
from torch import Tensor | |
from . import _linalg_utils as _utils | |
from .overrides import handle_torch_function, has_torch_function | |
__all__ = ["lobpcg"] | |
def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U): | |
# compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0 | |
F = D.unsqueeze(-2) - D.unsqueeze(-1) | |
F.diagonal(dim1=-2, dim2=-1).fill_(float("inf")) | |
F.pow_(-1) | |
# A.grad = U (D.grad + (U^T U.grad * F)) U^T | |
Ut = U.mT.contiguous() | |
res = torch.matmul( | |
U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut) | |
) | |
return res | |
def _polynomial_coefficients_given_roots(roots): | |
""" | |
Given the `roots` of a polynomial, find the polynomial's coefficients. | |
If roots = (r_1, ..., r_n), then the method returns | |
coefficients (a_0, a_1, ..., a_n (== 1)) so that | |
p(x) = (x - r_1) * ... * (x - r_n) | |
= x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0 | |
Note: for better performance requires writing a low-level kernel | |
""" | |
poly_order = roots.shape[-1] | |
poly_coeffs_shape = list(roots.shape) | |
# we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0, | |
# so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)}, | |
# but we insert one extra coefficient to enable better vectorization below | |
poly_coeffs_shape[-1] += 2 | |
poly_coeffs = roots.new_zeros(poly_coeffs_shape) | |
poly_coeffs[..., 0] = 1 | |
poly_coeffs[..., -1] = 1 | |
# perform the Horner's rule | |
for i in range(1, poly_order + 1): | |
# note that it is computationally hard to compute backward for this method, | |
# because then given the coefficients it would require finding the roots and/or | |
# calculating the sensitivity based on the Vieta's theorem. | |
# So the code below tries to circumvent the explicit root finding by series | |
# of operations on memory copies imitating the Horner's method. | |
# The memory copies are required to construct nodes in the computational graph | |
# by exploting the explicit (not in-place, separate node for each step) | |
# recursion of the Horner's method. | |
# Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. | |
poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs | |
out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1) | |
out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow( | |
-1, poly_order - i + 1, i + 1 | |
) | |
poly_coeffs = poly_coeffs_new | |
return poly_coeffs.narrow(-1, 1, poly_order + 1) | |
def _polynomial_value(poly, x, zero_power, transition): | |
""" | |
A generic method for computing poly(x) using the Horner's rule. | |
Args: | |
poly (Tensor): the (possibly batched) 1D Tensor representing | |
polynomial coefficients such that | |
poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and | |
poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n | |
x (Tensor): the value (possible batched) to evalate the polynomial `poly` at. | |
zero_power (Tensor): the representation of `x^0`. It is application-specific. | |
transition (Callable): the function that accepts some intermediate result `int_val`, | |
the `x` and a specific polynomial coefficient | |
`poly[..., k]` for some iteration `k`. | |
It basically performs one iteration of the Horner's rule | |
defined as `x * int_val + poly[..., k] * zero_power`. | |
Note that `zero_power` is not a parameter, | |
because the step `+ poly[..., k] * zero_power` depends on `x`, | |
whether it is a vector, a matrix, or something else, so this | |
functionality is delegated to the user. | |
""" | |
res = zero_power.clone() | |
for k in range(poly.size(-1) - 2, -1, -1): | |
res = transition(res, x, poly[..., k]) | |
return res | |
def _matrix_polynomial_value(poly, x, zero_power=None): | |
""" | |
Evaluates `poly(x)` for the (batched) matrix input `x`. | |
Check out `_polynomial_value` function for more details. | |
""" | |
# matrix-aware Horner's rule iteration | |
def transition(curr_poly_val, x, poly_coeff): | |
res = x.matmul(curr_poly_val) | |
res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1)) | |
return res | |
if zero_power is None: | |
zero_power = torch.eye( | |
x.size(-1), x.size(-1), dtype=x.dtype, device=x.device | |
).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1)) | |
return _polynomial_value(poly, x, zero_power, transition) | |
def _vector_polynomial_value(poly, x, zero_power=None): | |
""" | |
Evaluates `poly(x)` for the (batched) vector input `x`. | |
Check out `_polynomial_value` function for more details. | |
""" | |
# vector-aware Horner's rule iteration | |
def transition(curr_poly_val, x, poly_coeff): | |
res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val) | |
return res | |
if zero_power is None: | |
zero_power = x.new_ones(1).expand(x.shape) | |
return _polynomial_value(poly, x, zero_power, transition) | |
def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): | |
# compute a projection operator onto an orthogonal subspace spanned by the | |
# columns of U defined as (I - UU^T) | |
Ut = U.mT.contiguous() | |
proj_U_ortho = -U.matmul(Ut) | |
proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1) | |
# compute U_ortho, a basis for the orthogonal complement to the span(U), | |
# by projecting a random [..., m, m - k] matrix onto the subspace spanned | |
# by the columns of U. | |
# | |
# fix generator for determinism | |
gen = torch.Generator(A.device) | |
# orthogonal complement to the span(U) | |
U_ortho = proj_U_ortho.matmul( | |
torch.randn( | |
(*A.shape[:-1], A.size(-1) - D.size(-1)), | |
dtype=A.dtype, | |
device=A.device, | |
generator=gen, | |
) | |
) | |
U_ortho_t = U_ortho.mT.contiguous() | |
# compute the coefficients of the characteristic polynomial of the tensor D. | |
# Note that D is diagonal, so the diagonal elements are exactly the roots | |
# of the characteristic polynomial. | |
chr_poly_D = _polynomial_coefficients_given_roots(D) | |
# the code belows finds the explicit solution to the Sylvester equation | |
# U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U | |
# and incorporates it into the whole gradient stored in the `res` variable. | |
# | |
# Equivalent to the following naive implementation: | |
# res = A.new_zeros(A.shape) | |
# p_res = A.new_zeros(*A.shape[:-1], D.size(-1)) | |
# for k in range(1, chr_poly_D.size(-1)): | |
# p_res.zero_() | |
# for i in range(0, k): | |
# p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2) | |
# res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t()) | |
# | |
# Note that dX is a differential, so the gradient contribution comes from the backward sensitivity | |
# Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g, | |
# and we need to compute g(U_grad, A, U, D) | |
# | |
# The naive implementation is based on the paper | |
# Hu, Qingxi, and Daizhan Cheng. | |
# "The polynomial solution to the Sylvester matrix equation." | |
# Applied mathematics letters 19.9 (2006): 859-864. | |
# | |
# We can modify the computation of `p_res` from above in a more efficient way | |
# p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2) | |
# + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2) | |
# + ... | |
# + A.matrix_power(k - 1) U_grad * chr_poly_D[k] | |
# Note that this saves us from redundant matrix products with A (elimination of matrix_power) | |
U_grad_projected = U_grad | |
series_acc = U_grad_projected.new_zeros(U_grad_projected.shape) | |
for k in range(1, chr_poly_D.size(-1)): | |
poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D) | |
series_acc += U_grad_projected * poly_D.unsqueeze(-2) | |
U_grad_projected = A.matmul(U_grad_projected) | |
# compute chr_poly_D(A) which essentially is: | |
# | |
# chr_poly_D_at_A = A.new_zeros(A.shape) | |
# for k in range(chr_poly_D.size(-1)): | |
# chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k) | |
# | |
# Note, however, for better performance we use the Horner's rule | |
chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A) | |
# compute the action of `chr_poly_D_at_A` restricted to U_ortho_t | |
chr_poly_D_at_A_to_U_ortho = torch.matmul( | |
U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho) | |
) | |
# we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its | |
# Cholesky decomposition and then use `torch.cholesky_solve` for better stability. | |
# Cholesky decomposition requires the input to be positive-definite. | |
# Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if | |
# 1. `largest` == False, or | |
# 2. `largest` == True and `k` is even | |
# under the assumption that `A` has distinct eigenvalues. | |
# | |
# check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite | |
chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1 | |
chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky( | |
chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho | |
) | |
# compute the gradient part in span(U) | |
res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) | |
# incorporate the Sylvester equation solution into the full gradient | |
# it resides in span(U_ortho) | |
res -= U_ortho.matmul( | |
chr_poly_D_at_A_to_U_ortho_sign | |
* torch.cholesky_solve( | |
U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L | |
) | |
).matmul(Ut) | |
return res | |
def _symeig_backward(D_grad, U_grad, A, D, U, largest): | |
# if `U` is square, then the columns of `U` is a complete eigenspace | |
if U.size(-1) == U.size(-2): | |
return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U) | |
else: | |
return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest) | |
class LOBPCGAutogradFunction(torch.autograd.Function): | |
def forward( # type: ignore[override] | |
ctx, | |
A: Tensor, | |
k: Optional[int] = None, | |
B: Optional[Tensor] = None, | |
X: Optional[Tensor] = None, | |
n: Optional[int] = None, | |
iK: Optional[Tensor] = None, | |
niter: Optional[int] = None, | |
tol: Optional[float] = None, | |
largest: Optional[bool] = None, | |
method: Optional[str] = None, | |
tracker: None = None, | |
ortho_iparams: Optional[Dict[str, int]] = None, | |
ortho_fparams: Optional[Dict[str, float]] = None, | |
ortho_bparams: Optional[Dict[str, bool]] = None, | |
) -> Tuple[Tensor, Tensor]: | |
# makes sure that input is contiguous for efficiency. | |
# Note: autograd does not support dense gradients for sparse input yet. | |
A = A.contiguous() if (not A.is_sparse) else A | |
if B is not None: | |
B = B.contiguous() if (not B.is_sparse) else B | |
D, U = _lobpcg( | |
A, | |
k, | |
B, | |
X, | |
n, | |
iK, | |
niter, | |
tol, | |
largest, | |
method, | |
tracker, | |
ortho_iparams, | |
ortho_fparams, | |
ortho_bparams, | |
) | |
ctx.save_for_backward(A, B, D, U) | |
ctx.largest = largest | |
return D, U | |
def backward(ctx, D_grad, U_grad): | |
A_grad = B_grad = None | |
grads = [None] * 14 | |
A, B, D, U = ctx.saved_tensors | |
largest = ctx.largest | |
# lobpcg.backward has some limitations. Checks for unsupported input | |
if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]): | |
raise ValueError( | |
"lobpcg.backward does not support sparse input yet." | |
"Note that lobpcg.forward does though." | |
) | |
if ( | |
A.dtype in (torch.complex64, torch.complex128) | |
or B is not None | |
and B.dtype in (torch.complex64, torch.complex128) | |
): | |
raise ValueError( | |
"lobpcg.backward does not support complex input yet." | |
"Note that lobpcg.forward does though." | |
) | |
if B is not None: | |
raise ValueError( | |
"lobpcg.backward does not support backward with B != I yet." | |
) | |
if largest is None: | |
largest = True | |
# symeig backward | |
if B is None: | |
A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest) | |
# A has index 0 | |
grads[0] = A_grad | |
# B has index 2 | |
grads[2] = B_grad | |
return tuple(grads) | |
def lobpcg( | |
A: Tensor, | |
k: Optional[int] = None, | |
B: Optional[Tensor] = None, | |
X: Optional[Tensor] = None, | |
n: Optional[int] = None, | |
iK: Optional[Tensor] = None, | |
niter: Optional[int] = None, | |
tol: Optional[float] = None, | |
largest: Optional[bool] = None, | |
method: Optional[str] = None, | |
tracker: None = None, | |
ortho_iparams: Optional[Dict[str, int]] = None, | |
ortho_fparams: Optional[Dict[str, float]] = None, | |
ortho_bparams: Optional[Dict[str, bool]] = None, | |
) -> Tuple[Tensor, Tensor]: | |
"""Find the k largest (or smallest) eigenvalues and the corresponding | |
eigenvectors of a symmetric positive definite generalized | |
eigenvalue problem using matrix-free LOBPCG methods. | |
This function is a front-end to the following LOBPCG algorithms | |
selectable via `method` argument: | |
`method="basic"` - the LOBPCG method introduced by Andrew | |
Knyazev, see [Knyazev2001]. A less robust method, may fail when | |
Cholesky is applied to singular input. | |
`method="ortho"` - the LOBPCG method with orthogonal basis | |
selection [StathopoulosEtal2002]. A robust method. | |
Supported inputs are dense, sparse, and batches of dense matrices. | |
.. note:: In general, the basic method spends least time per | |
iteration. However, the robust methods converge much faster and | |
are more stable. So, the usage of the basic method is generally | |
not recommended but there exist cases where the usage of the | |
basic method may be preferred. | |
.. warning:: The backward method does not support sparse and complex inputs. | |
It works only when `B` is not provided (i.e. `B == None`). | |
We are actively working on extensions, and the details of | |
the algorithms are going to be published promptly. | |
.. warning:: While it is assumed that `A` is symmetric, `A.grad` is not. | |
To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric | |
in first-order optimization routines, prior to running `lobpcg` | |
we do the following symmetrization map: `A -> (A + A.t()) / 2`. | |
The map is performed only when the `A` requires gradients. | |
Args: | |
A (Tensor): the input tensor of size :math:`(*, m, m)` | |
B (Tensor, optional): the input tensor of size :math:`(*, m, | |
m)`. When not specified, `B` is interpreted as | |
identity matrix. | |
X (tensor, optional): the input tensor of size :math:`(*, m, n)` | |
where `k <= n <= m`. When specified, it is used as | |
initial approximation of eigenvectors. X must be a | |
dense tensor. | |
iK (tensor, optional): the input tensor of size :math:`(*, m, | |
m)`. When specified, it will be used as preconditioner. | |
k (integer, optional): the number of requested | |
eigenpairs. Default is the number of :math:`X` | |
columns (when specified) or `1`. | |
n (integer, optional): if :math:`X` is not specified then `n` | |
specifies the size of the generated random | |
approximation of eigenvectors. Default value for `n` | |
is `k`. If :math:`X` is specified, the value of `n` | |
(when specified) must be the number of :math:`X` | |
columns. | |
tol (float, optional): residual tolerance for stopping | |
criterion. Default is `feps ** 0.5` where `feps` is | |
smallest non-zero floating-point number of the given | |
input tensor `A` data type. | |
largest (bool, optional): when True, solve the eigenproblem for | |
the largest eigenvalues. Otherwise, solve the | |
eigenproblem for smallest eigenvalues. Default is | |
`True`. | |
method (str, optional): select LOBPCG method. See the | |
description of the function above. Default is | |
"ortho". | |
niter (int, optional): maximum number of iterations. When | |
reached, the iteration process is hard-stopped and | |
the current approximation of eigenpairs is returned. | |
For infinite iteration but until convergence criteria | |
is met, use `-1`. | |
tracker (callable, optional) : a function for tracing the | |
iteration process. When specified, it is called at | |
each iteration step with LOBPCG instance as an | |
argument. The LOBPCG instance holds the full state of | |
the iteration process in the following attributes: | |
`iparams`, `fparams`, `bparams` - dictionaries of | |
integer, float, and boolean valued input | |
parameters, respectively | |
`ivars`, `fvars`, `bvars`, `tvars` - dictionaries | |
of integer, float, boolean, and Tensor valued | |
iteration variables, respectively. | |
`A`, `B`, `iK` - input Tensor arguments. | |
`E`, `X`, `S`, `R` - iteration Tensor variables. | |
For instance: | |
`ivars["istep"]` - the current iteration step | |
`X` - the current approximation of eigenvectors | |
`E` - the current approximation of eigenvalues | |
`R` - the current residual | |
`ivars["converged_count"]` - the current number of converged eigenpairs | |
`tvars["rerr"]` - the current state of convergence criteria | |
Note that when `tracker` stores Tensor objects from | |
the LOBPCG instance, it must make copies of these. | |
If `tracker` sets `bvars["force_stop"] = True`, the | |
iteration process will be hard-stopped. | |
ortho_iparams, ortho_fparams, ortho_bparams (dict, optional): | |
various parameters to LOBPCG algorithm when using | |
`method="ortho"`. | |
Returns: | |
E (Tensor): tensor of eigenvalues of size :math:`(*, k)` | |
X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)` | |
References: | |
[Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal | |
Preconditioned Eigensolver: Locally Optimal Block Preconditioned | |
Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2), | |
517-541. (25 pages) | |
https://epubs.siam.org/doi/abs/10.1137/S1064827500366124 | |
[StathopoulosEtal2002] Andreas Stathopoulos and Kesheng | |
Wu. (2002) A Block Orthogonalization Procedure with Constant | |
Synchronization Requirements. SIAM J. Sci. Comput., 23(6), | |
2165-2182. (18 pages) | |
https://epubs.siam.org/doi/10.1137/S1064827500370883 | |
[DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming | |
Gu. (2018) A Robust and Efficient Implementation of LOBPCG. | |
SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages) | |
https://epubs.siam.org/doi/abs/10.1137/17M1129830 | |
""" | |
if not torch.jit.is_scripting(): | |
tensor_ops = (A, B, X, iK) | |
if not set(map(type, tensor_ops)).issubset( | |
(torch.Tensor, type(None)) | |
) and has_torch_function(tensor_ops): | |
return handle_torch_function( | |
lobpcg, | |
tensor_ops, | |
A, | |
k=k, | |
B=B, | |
X=X, | |
n=n, | |
iK=iK, | |
niter=niter, | |
tol=tol, | |
largest=largest, | |
method=method, | |
tracker=tracker, | |
ortho_iparams=ortho_iparams, | |
ortho_fparams=ortho_fparams, | |
ortho_bparams=ortho_bparams, | |
) | |
if not torch._jit_internal.is_scripting(): | |
if A.requires_grad or (B is not None and B.requires_grad): | |
# While it is expected that `A` is symmetric, | |
# the `A_grad` might be not. Therefore we perform the trick below, | |
# so that `A_grad` becomes symmetric. | |
# The symmetrization is important for first-order optimization methods, | |
# so that (A - alpha * A_grad) is still a symmetric matrix. | |
# Same holds for `B`. | |
A_sym = (A + A.mT) / 2 | |
B_sym = (B + B.mT) / 2 if (B is not None) else None | |
return LOBPCGAutogradFunction.apply( | |
A_sym, | |
k, | |
B_sym, | |
X, | |
n, | |
iK, | |
niter, | |
tol, | |
largest, | |
method, | |
tracker, | |
ortho_iparams, | |
ortho_fparams, | |
ortho_bparams, | |
) | |
else: | |
if A.requires_grad or (B is not None and B.requires_grad): | |
raise RuntimeError( | |
"Script and require grads is not supported atm." | |
"If you just want to do the forward, use .detach()" | |
"on A and B before calling into lobpcg" | |
) | |
return _lobpcg( | |
A, | |
k, | |
B, | |
X, | |
n, | |
iK, | |
niter, | |
tol, | |
largest, | |
method, | |
tracker, | |
ortho_iparams, | |
ortho_fparams, | |
ortho_bparams, | |
) | |
def _lobpcg( | |
A: Tensor, | |
k: Optional[int] = None, | |
B: Optional[Tensor] = None, | |
X: Optional[Tensor] = None, | |
n: Optional[int] = None, | |
iK: Optional[Tensor] = None, | |
niter: Optional[int] = None, | |
tol: Optional[float] = None, | |
largest: Optional[bool] = None, | |
method: Optional[str] = None, | |
tracker: None = None, | |
ortho_iparams: Optional[Dict[str, int]] = None, | |
ortho_fparams: Optional[Dict[str, float]] = None, | |
ortho_bparams: Optional[Dict[str, bool]] = None, | |
) -> Tuple[Tensor, Tensor]: | |
# A must be square: | |
assert A.shape[-2] == A.shape[-1], A.shape | |
if B is not None: | |
# A and B must have the same shapes: | |
assert A.shape == B.shape, (A.shape, B.shape) | |
dtype = _utils.get_floating_dtype(A) | |
device = A.device | |
if tol is None: | |
feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype] | |
tol = feps**0.5 | |
m = A.shape[-1] | |
k = (1 if X is None else X.shape[-1]) if k is None else k | |
n = (k if n is None else n) if X is None else X.shape[-1] | |
if m < 3 * n: | |
raise ValueError( | |
f"LPBPCG algorithm is not applicable when the number of A rows (={m})" | |
f" is smaller than 3 x the number of requested eigenpairs (={n})" | |
) | |
method = "ortho" if method is None else method | |
iparams = { | |
"m": m, | |
"n": n, | |
"k": k, | |
"niter": 1000 if niter is None else niter, | |
} | |
fparams = { | |
"tol": tol, | |
} | |
bparams = {"largest": True if largest is None else largest} | |
if method == "ortho": | |
if ortho_iparams is not None: | |
iparams.update(ortho_iparams) | |
if ortho_fparams is not None: | |
fparams.update(ortho_fparams) | |
if ortho_bparams is not None: | |
bparams.update(ortho_bparams) | |
iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3) | |
iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3) | |
fparams["ortho_tol"] = fparams.get("ortho_tol", tol) | |
fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol) | |
fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol) | |
bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False) | |
if not torch.jit.is_scripting(): | |
LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[method-assign] | |
if len(A.shape) > 2: | |
N = int(torch.prod(torch.tensor(A.shape[:-2]))) | |
bA = A.reshape((N,) + A.shape[-2:]) | |
bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None | |
bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None | |
bE = torch.empty((N, k), dtype=dtype, device=device) | |
bXret = torch.empty((N, m, k), dtype=dtype, device=device) | |
for i in range(N): | |
A_ = bA[i] | |
B_ = bB[i] if bB is not None else None | |
X_ = ( | |
torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i] | |
) | |
assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n)) | |
iparams["batch_index"] = i | |
worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker) | |
worker.run() | |
bE[i] = worker.E[:k] | |
bXret[i] = worker.X[:, :k] | |
if not torch.jit.is_scripting(): | |
LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] | |
return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k)) | |
X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X | |
assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n)) | |
worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker) | |
worker.run() | |
if not torch.jit.is_scripting(): | |
LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[method-assign] | |
return worker.E[:k], worker.X[:, :k] | |
class LOBPCG: | |
"""Worker class of LOBPCG methods.""" | |
def __init__( | |
self, | |
A: Optional[Tensor], | |
B: Optional[Tensor], | |
X: Tensor, | |
iK: Optional[Tensor], | |
iparams: Dict[str, int], | |
fparams: Dict[str, float], | |
bparams: Dict[str, bool], | |
method: str, | |
tracker: None, | |
) -> None: | |
# constant parameters | |
self.A = A | |
self.B = B | |
self.iK = iK | |
self.iparams = iparams | |
self.fparams = fparams | |
self.bparams = bparams | |
self.method = method | |
self.tracker = tracker | |
m = iparams["m"] | |
n = iparams["n"] | |
# variable parameters | |
self.X = X | |
self.E = torch.zeros((n,), dtype=X.dtype, device=X.device) | |
self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device) | |
self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device) | |
self.tvars: Dict[str, Tensor] = {} | |
self.ivars: Dict[str, int] = {"istep": 0} | |
self.fvars: Dict[str, float] = {"_": 0.0} | |
self.bvars: Dict[str, bool] = {"_": False} | |
def __str__(self): | |
lines = ["LOPBCG:"] | |
lines += [f" iparams={self.iparams}"] | |
lines += [f" fparams={self.fparams}"] | |
lines += [f" bparams={self.bparams}"] | |
lines += [f" ivars={self.ivars}"] | |
lines += [f" fvars={self.fvars}"] | |
lines += [f" bvars={self.bvars}"] | |
lines += [f" tvars={self.tvars}"] | |
lines += [f" A={self.A}"] | |
lines += [f" B={self.B}"] | |
lines += [f" iK={self.iK}"] | |
lines += [f" X={self.X}"] | |
lines += [f" E={self.E}"] | |
r = "" | |
for line in lines: | |
r += line + "\n" | |
return r | |
def update(self): | |
"""Set and update iteration variables.""" | |
if self.ivars["istep"] == 0: | |
X_norm = float(torch.norm(self.X)) | |
iX_norm = X_norm**-1 | |
A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm | |
B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm | |
self.fvars["X_norm"] = X_norm | |
self.fvars["A_norm"] = A_norm | |
self.fvars["B_norm"] = B_norm | |
self.ivars["iterations_left"] = self.iparams["niter"] | |
self.ivars["converged_count"] = 0 | |
self.ivars["converged_end"] = 0 | |
if self.method == "ortho": | |
self._update_ortho() | |
else: | |
self._update_basic() | |
self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1 | |
self.ivars["istep"] = self.ivars["istep"] + 1 | |
def update_residual(self): | |
"""Update residual R from A, B, X, E.""" | |
mm = _utils.matmul | |
self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E | |
def update_converged_count(self): | |
"""Determine the number of converged eigenpairs using backward stable | |
convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018]. | |
Users may redefine this method for custom convergence criteria. | |
""" | |
# (...) -> int | |
prev_count = self.ivars["converged_count"] | |
tol = self.fparams["tol"] | |
A_norm = self.fvars["A_norm"] | |
B_norm = self.fvars["B_norm"] | |
E, X, R = self.E, self.X, self.R | |
rerr = ( | |
torch.norm(R, 2, (0,)) | |
* (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1 | |
) | |
converged = rerr < tol | |
count = 0 | |
for b in converged: | |
if not b: | |
# ignore convergence of following pairs to ensure | |
# strict ordering of eigenpairs | |
break | |
count += 1 | |
assert ( | |
count >= prev_count | |
), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease" | |
self.ivars["converged_count"] = count | |
self.tvars["rerr"] = rerr | |
return count | |
def stop_iteration(self): | |
"""Return True to stop iterations. | |
Note that tracker (if defined) can force-stop iterations by | |
setting ``worker.bvars['force_stop'] = True``. | |
""" | |
return ( | |
self.bvars.get("force_stop", False) | |
or self.ivars["iterations_left"] == 0 | |
or self.ivars["converged_count"] >= self.iparams["k"] | |
) | |
def run(self): | |
"""Run LOBPCG iterations. | |
Use this method as a template for implementing LOBPCG | |
iteration scheme with custom tracker that is compatible with | |
TorchScript. | |
""" | |
self.update() | |
if not torch.jit.is_scripting() and self.tracker is not None: | |
self.call_tracker() | |
while not self.stop_iteration(): | |
self.update() | |
if not torch.jit.is_scripting() and self.tracker is not None: | |
self.call_tracker() | |
def call_tracker(self): | |
"""Interface for tracking iteration process in Python mode. | |
Tracking the iteration process is disabled in TorchScript | |
mode. In fact, one should specify tracker=None when JIT | |
compiling functions using lobpcg. | |
""" | |
# do nothing when in TorchScript mode | |
pass | |
# Internal methods | |
def _update_basic(self): | |
""" | |
Update or initialize iteration variables when `method == "basic"`. | |
""" | |
mm = torch.matmul | |
ns = self.ivars["converged_end"] | |
nc = self.ivars["converged_count"] | |
n = self.iparams["n"] | |
largest = self.bparams["largest"] | |
if self.ivars["istep"] == 0: | |
Ri = self._get_rayleigh_ritz_transform(self.X) | |
M = _utils.qform(_utils.qform(self.A, self.X), Ri) | |
E, Z = _utils.symeig(M, largest) | |
self.X[:] = mm(self.X, mm(Ri, Z)) | |
self.E[:] = E | |
np = 0 | |
self.update_residual() | |
nc = self.update_converged_count() | |
self.S[..., :n] = self.X | |
W = _utils.matmul(self.iK, self.R) | |
self.ivars["converged_end"] = ns = n + np + W.shape[-1] | |
self.S[:, n + np : ns] = W | |
else: | |
S_ = self.S[:, nc:ns] | |
Ri = self._get_rayleigh_ritz_transform(S_) | |
M = _utils.qform(_utils.qform(self.A, S_), Ri) | |
E_, Z = _utils.symeig(M, largest) | |
self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc])) | |
self.E[nc:] = E_[: n - nc] | |
P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc])) | |
np = P.shape[-1] | |
self.update_residual() | |
nc = self.update_converged_count() | |
self.S[..., :n] = self.X | |
self.S[:, n : n + np] = P | |
W = _utils.matmul(self.iK, self.R[:, nc:]) | |
self.ivars["converged_end"] = ns = n + np + W.shape[-1] | |
self.S[:, n + np : ns] = W | |
def _update_ortho(self): | |
""" | |
Update or initialize iteration variables when `method == "ortho"`. | |
""" | |
mm = torch.matmul | |
ns = self.ivars["converged_end"] | |
nc = self.ivars["converged_count"] | |
n = self.iparams["n"] | |
largest = self.bparams["largest"] | |
if self.ivars["istep"] == 0: | |
Ri = self._get_rayleigh_ritz_transform(self.X) | |
M = _utils.qform(_utils.qform(self.A, self.X), Ri) | |
E, Z = _utils.symeig(M, largest) | |
self.X = mm(self.X, mm(Ri, Z)) | |
self.update_residual() | |
np = 0 | |
nc = self.update_converged_count() | |
self.S[:, :n] = self.X | |
W = self._get_ortho(self.R, self.X) | |
ns = self.ivars["converged_end"] = n + np + W.shape[-1] | |
self.S[:, n + np : ns] = W | |
else: | |
S_ = self.S[:, nc:ns] | |
# Rayleigh-Ritz procedure | |
E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest) | |
# Update E, X, P | |
self.X[:, nc:] = mm(S_, Z[:, : n - nc]) | |
self.E[nc:] = E_[: n - nc] | |
P = mm( | |
S_, | |
mm( | |
Z[:, n - nc :], | |
_utils.basis(_utils.transpose(Z[: n - nc, n - nc :])), | |
), | |
) | |
np = P.shape[-1] | |
# check convergence | |
self.update_residual() | |
nc = self.update_converged_count() | |
# update S | |
self.S[:, :n] = self.X | |
self.S[:, n : n + np] = P | |
W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np]) | |
ns = self.ivars["converged_end"] = n + np + W.shape[-1] | |
self.S[:, n + np : ns] = W | |
def _get_rayleigh_ritz_transform(self, S): | |
"""Return a transformation matrix that is used in Rayleigh-Ritz | |
procedure for reducing a general eigenvalue problem :math:`(S^TAS) | |
C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T | |
S^TAS Ri) Z = Z E` where `C = Ri Z`. | |
.. note:: In the original Rayleight-Ritz procedure in | |
[DuerschEtal2018], the problem is formulated as follows:: | |
SAS = S^T A S | |
SBS = S^T B S | |
D = (<diagonal matrix of SBS>) ** -1/2 | |
R^T R = Cholesky(D SBS D) | |
Ri = D R^-1 | |
solve symeig problem Ri^T SAS Ri Z = Theta Z | |
C = Ri Z | |
To reduce the number of matrix products (denoted by empty | |
space between matrices), here we introduce element-wise | |
products (denoted by symbol `*`) so that the Rayleight-Ritz | |
procedure becomes:: | |
SAS = S^T A S | |
SBS = S^T B S | |
d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector | |
dd = d d^T # this is 2-d matrix | |
R^T R = Cholesky(dd * SBS) | |
Ri = R^-1 * d # broadcasting | |
solve symeig problem Ri^T SAS Ri Z = Theta Z | |
C = Ri Z | |
where `dd` is 2-d matrix that replaces matrix products `D M | |
D` with one element-wise product `M * dd`; and `d` replaces | |
matrix product `D M` with element-wise product `M * | |
d`. Also, creating the diagonal matrix `D` is avoided. | |
Args: | |
S (Tensor): the matrix basis for the search subspace, size is | |
:math:`(m, n)`. | |
Returns: | |
Ri (tensor): upper-triangular transformation matrix of size | |
:math:`(n, n)`. | |
""" | |
B = self.B | |
mm = torch.matmul | |
SBS = _utils.qform(B, S) | |
d_row = SBS.diagonal(0, -2, -1) ** -0.5 | |
d_col = d_row.reshape(d_row.shape[0], 1) | |
# TODO use torch.linalg.cholesky_solve once it is implemented | |
R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True) | |
return torch.linalg.solve_triangular( | |
R, d_row.diag_embed(), upper=True, left=False | |
) | |
def _get_svqb( | |
self, U: Tensor, drop: bool, tau: float # Tensor # bool # float | |
) -> Tensor: | |
"""Return B-orthonormal U. | |
.. note:: When `drop` is `False` then `svqb` is based on the | |
Algorithm 4 from [DuerschPhD2015] that is a slight | |
modification of the corresponding algorithm | |
introduced in [StathopolousWu2002]. | |
Args: | |
U (Tensor) : initial approximation, size is (m, n) | |
drop (bool) : when True, drop columns that | |
contribution to the `span([U])` is small. | |
tau (float) : positive tolerance | |
Returns: | |
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size | |
is (m, n1), where `n1 = n` if `drop` is `False, | |
otherwise `n1 <= n`. | |
""" | |
if torch.numel(U) == 0: | |
return U | |
UBU = _utils.qform(self.B, U) | |
d = UBU.diagonal(0, -2, -1) | |
# Detect and drop exact zero columns from U. While the test | |
# `abs(d) == 0` is unlikely to be True for random data, it is | |
# possible to construct input data to lobpcg where it will be | |
# True leading to a failure (notice the `d ** -0.5` operation | |
# in the original algorithm). To prevent the failure, we drop | |
# the exact zero columns here and then continue with the | |
# original algorithm below. | |
nz = torch.where(abs(d) != 0.0) | |
assert len(nz) == 1, nz | |
if len(nz[0]) < len(d): | |
U = U[:, nz[0]] | |
if torch.numel(U) == 0: | |
return U | |
UBU = _utils.qform(self.B, U) | |
d = UBU.diagonal(0, -2, -1) | |
nz = torch.where(abs(d) != 0.0) | |
assert len(nz[0]) == len(d) | |
# The original algorithm 4 from [DuerschPhD2015]. | |
d_col = (d**-0.5).reshape(d.shape[0], 1) | |
DUBUD = (UBU * d_col) * _utils.transpose(d_col) | |
E, Z = _utils.symeig(DUBUD) | |
t = tau * abs(E).max() | |
if drop: | |
keep = torch.where(E > t) | |
assert len(keep) == 1, keep | |
E = E[keep[0]] | |
Z = Z[:, keep[0]] | |
d_col = d_col[keep[0]] | |
else: | |
E[(torch.where(E < t))[0]] = t | |
return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5) | |
def _get_ortho(self, U, V): | |
"""Return B-orthonormal U with columns are B-orthogonal to V. | |
.. note:: When `bparams["ortho_use_drop"] == False` then | |
`_get_ortho` is based on the Algorithm 3 from | |
[DuerschPhD2015] that is a slight modification of | |
the corresponding algorithm introduced in | |
[StathopolousWu2002]. Otherwise, the method | |
implements Algorithm 6 from [DuerschPhD2015] | |
.. note:: If all U columns are B-collinear to V then the | |
returned tensor U will be empty. | |
Args: | |
U (Tensor) : initial approximation, size is (m, n) | |
V (Tensor) : B-orthogonal external basis, size is (m, k) | |
Returns: | |
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`) | |
such that :math:`V^T B U=0`, size is (m, n1), | |
where `n1 = n` if `drop` is `False, otherwise | |
`n1 <= n`. | |
""" | |
mm = torch.matmul | |
mm_B = _utils.matmul | |
m = self.iparams["m"] | |
tau_ortho = self.fparams["ortho_tol"] | |
tau_drop = self.fparams["ortho_tol_drop"] | |
tau_replace = self.fparams["ortho_tol_replace"] | |
i_max = self.iparams["ortho_i_max"] | |
j_max = self.iparams["ortho_j_max"] | |
# when use_drop==True, enable dropping U columns that have | |
# small contribution to the `span([U, V])`. | |
use_drop = self.bparams["ortho_use_drop"] | |
# clean up variables from the previous call | |
for vkey in list(self.fvars.keys()): | |
if vkey.startswith("ortho_") and vkey.endswith("_rerr"): | |
self.fvars.pop(vkey) | |
self.ivars.pop("ortho_i", 0) | |
self.ivars.pop("ortho_j", 0) | |
BV_norm = torch.norm(mm_B(self.B, V)) | |
BU = mm_B(self.B, U) | |
VBU = mm(_utils.transpose(V), BU) | |
i = j = 0 | |
stats = "" | |
for i in range(i_max): | |
U = U - mm(V, VBU) | |
drop = False | |
tau_svqb = tau_drop | |
for j in range(j_max): | |
if use_drop: | |
U = self._get_svqb(U, drop, tau_svqb) | |
drop = True | |
tau_svqb = tau_replace | |
else: | |
U = self._get_svqb(U, False, tau_replace) | |
if torch.numel(U) == 0: | |
# all initial U columns are B-collinear to V | |
self.ivars["ortho_i"] = i | |
self.ivars["ortho_j"] = j | |
return U | |
BU = mm_B(self.B, U) | |
UBU = mm(_utils.transpose(U), BU) | |
U_norm = torch.norm(U) | |
BU_norm = torch.norm(BU) | |
R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype) | |
R_norm = torch.norm(R) | |
# https://github.com/pytorch/pytorch/issues/33810 workaround: | |
rerr = float(R_norm) * float(BU_norm * U_norm) ** -1 | |
vkey = f"ortho_UBUmI_rerr[{i}, {j}]" | |
self.fvars[vkey] = rerr | |
if rerr < tau_ortho: | |
break | |
VBU = mm(_utils.transpose(V), BU) | |
VBU_norm = torch.norm(VBU) | |
U_norm = torch.norm(U) | |
rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1 | |
vkey = f"ortho_VBU_rerr[{i}]" | |
self.fvars[vkey] = rerr | |
if rerr < tau_ortho: | |
break | |
if m < U.shape[-1] + V.shape[-1]: | |
# TorchScript needs the class var to be assigned to a local to | |
# do optional type refinement | |
B = self.B | |
assert B is not None | |
raise ValueError( | |
"Overdetermined shape of U:" | |
f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold" | |
) | |
self.ivars["ortho_i"] = i | |
self.ivars["ortho_j"] = j | |
return U | |
# Calling tracker is separated from LOBPCG definitions because | |
# TorchScript does not support user-defined callback arguments: | |
LOBPCG_call_tracker_orig = LOBPCG.call_tracker | |
def LOBPCG_call_tracker(self): | |
self.tracker(self) | |