Source code for itergp.methods.policies._auto_preconditioned_gradient

"""Policy returning preconditioned gradients as actions from a self-constructed 
preconditioner."""

from __future__ import annotations

from typing import Optional, Tuple

from probnum import backend, linalg
from probnum.backend.random import RNGState

from itergp.methods import preconditioners

from ._gradient import GradientPolicy
from ._unit_vector import UnitVectorPolicy


class AutoPreconditionedGradientPolicy:
    """Policy which first constructs a preconditioner on the fly and then returns
    preconditioned gradient actions.

    This policy constructs a diagonal-plus-low-rank preconditioner by first using a
    ``precond_policy`` for ``precond_size`` steps. Afterwards the policy returns
    preconditioned gradients as actions using the previously constructed preconditioner.

    Parameters
    ----------
    precond_iter
        Number of iterations to use for preconditioner construction. Equivalently, the
        rank of the low-rank component of the preconditioner (inverse).
    precond_policy
        Policy to use for the first ``precond_size`` iterations, which determines the
        constructed preconditioner. Defaults to a
        :class:`~itergp.methods.policies.UnitVectorPolicy`, which is equivalent to
        preconditioning with a partial Cholesky factorization.
    """

    def __init__(
        self,
        precond_iter: int,
        precond_policy: Tuple[
            linalg.solvers.policies.LinearSolverPolicy
        ] = UnitVectorPolicy(),
    ) -> None:

        self._precond_policy = precond_policy
        self._gradient_policy = None
        self._precond_iter = precond_iter

[docs] def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState", rng: Optional[RNGState], ) -> backend.Array: """Return an action for a given solver state. Parameters ---------- solver_state Current state of the linear solver. rng Random number generator. Returns ------- action Next action to take. """ if solver_state.step < self._precond_iter: return self._precond_policy(solver_state=solver_state, rng=rng) elif solver_state.step == self._precond_iter: if self._precond_iter == 0: precond_inv = None else: precond = ( preconditioners.DiagonalPlusLowRank.from_kernel_matrix_linear_solve( solver_state=solver_state ) ) precond_inv = precond.inv() self._gradient_policy = GradientPolicy(precond_inv=precond_inv) return self._gradient_policy(solver_state=solver_state, rng=rng)