Source code for itergp.methods.belief_updates.projected_residual

"""Belief update for projected residual observations."""

from __future__ import annotations

from probnum import backend, randvars
from probnum.linalg.solvers import belief_updates, beliefs

from itergp import linops


class ProjectedResidualBeliefUpdate(belief_updates.LinearSolverBeliefUpdate):
    r"""Gaussian belief update given projected residual information.

    Updates the belief over the quantities of interest of a linear system :math:`Ax=b` 
    given a Gaussian prior over the solution :math:`x \sim \mathcal{N}(x_i, A^{-1} - C_i)`, such that
    :math:`x_i = C_ib` and 
    information of the form :math:`s^\top r_i = s^\top (b - Ax_i)=s^\top A(x- x_i)`. The 
    belief update computes the posterior belief about the solution, given by 
    :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, such that

    .. math ::
        \begin{align}
            x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s)^\dagger s^\top r_i\\
            \Sigma_{i+1} &= \Sigma_i - \Sigma_i A^\top s (s^\top A \Sigma_i A s)^\dagger s^\top A \Sigma_i
        \end{align}
    """

[docs] def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> beliefs.LinearSystemBelief: # Search direction A_action = solver_state.problem.A @ solver_state.action Ainv0_A_action = solver_state.belief.Ainv @ A_action search_dir = solver_state.action - Ainv0_A_action # assumes Sigma_0 = A^{-1} A_search_dir = A_action - solver_state.problem.A @ Ainv0_A_action # Normalization constant gram = solver_state.action.T @ A_search_dir gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 return self.updated_linsys_belief( search_dir=search_dir, A_search_dir=A_search_dir, gram_pinv=gram_pinv, solver_state=solver_state, )
[docs] def Ainv_update( self, normalized_search_dir: backend.Array, solver_state: "probnum.linalg.solvers.LinearSolverState", ) -> linops.LowRankMatrix: r"""Update the system matrix approximation. Parameters ---------- normalized_search_dir Normalized search direction :math:`\fraction{1}{\sqrt{s_i^\top A \Sigma_{i-1} A s_i}}As_i`. solver_state State of the linear solver. """ if backend.ndim(normalized_search_dir) == 1: normalized_search_dir = backend.reshape(normalized_search_dir, (-1, 1)) if solver_state.step == 0: Ainv_update = linops.LowRankMatrix(U=normalized_search_dir) else: Ainv_update = linops.LowRankMatrix( U=backend.hstack( [solver_state.belief.Ainv._summands[1].U, normalized_search_dir] ) ) return Ainv_update
[docs] def A_update( self, A_normalized_search_dir: backend.Array, solver_state: "probnum.linalg.solvers.LinearSolverState", ) -> linops.LowRankMatrix: r"""Update the system matrix approximation. Parameters ---------- A_normalized_search_dir :math:`A` multiplied normalized search direction :math:`\fraction{1}{\sqrt{s_i^\top A \Sigma_{i-1} A s_i}}As_i`. solver_state State of the linear solver. """ if backend.ndim(A_normalized_search_dir) == 1: A_normalized_search_dir = backend.reshape(A_normalized_search_dir, (-1, 1)) if solver_state.step == 0: A_update = linops.LowRankMatrix(U=A_normalized_search_dir) else: A_update = linops.LowRankMatrix( U=backend.hstack( [solver_state.belief.A._summands[1].U, A_normalized_search_dir] ) ) return A_update
[docs] def updated_linsys_belief( self, search_dir: backend.Array, A_search_dir: backend.Array, gram_pinv: backend.Scalar, solver_state: "probnum.linalg.solvers.LinearSolverState", ) -> beliefs.LinearSystemBelief: r"""Update the belief over the quantities of interest. Parameters ---------- search_dir Search direction :math:`\Sigma_{i-1}As_i`. A_search_dir :math:`A`-multiplied search direction :math:`A\Sigma_{i-1}As_i`. gram_pinv Pseudo inverse of the Gramian, i.e. the normalization constant :math:`\fraction{1}{\sqrt{s_i^\top A \Sigma_{i-1} A s_i}}` solver_state State of the linear solver. """ # Update belief about inverse matrix Ainv_update = self.Ainv_update( normalized_search_dir=search_dir * backend.sqrt(gram_pinv), solver_state=solver_state, ) # Update belief about matrix A_update = self.A_update( A_normalized_search_dir=A_search_dir * backend.sqrt(gram_pinv), solver_state=solver_state, ) # Update belief about solution x = randvars.Normal( mean=solver_state.belief.x.mean + solver_state.observation * gram_pinv * search_dir, cov=solver_state.prior.x.cov - Ainv_update, ) return beliefs.LinearSystemBelief( x=x, A=solver_state.prior.A + A_update, Ainv=solver_state.prior.Ainv + Ainv_update, b=solver_state.belief.b, )