Source code for itergp.methods.policies._mixed

"""Policy returning actions from multiple different policies."""

from __future__ import annotations

import bisect
from typing import Optional, Tuple

from probnum import backend
from probnum.backend.random import RNGState
from probnum.linalg.solvers import policies


class MixedPolicy(policies.LinearSolverPolicy):
    """Mixed policy.

    Policy which chooses actions based on a set of base policies.

    Parameters
    ----------
    base_policies
        Policies which make up the :class:`MixedPolicy`.
    iters
        Until which iteration (non-inclusive) to use the policy in the corresponding
        position in ``base_policies``. Assumed to be sorted in increasing order. If
        ``iters`` has one fewer entry than ``base_policies``, the last policy is used
        for all remaining iterations.
    """

    def __init__(
        self, base_policies: Tuple[policies.LinearSolverPolicy], iters: Tuple[int]
    ) -> None:
        self._base_policies = base_policies
        self._iters = iters
        super().__init__()

[docs] def __call__( self, solver_state: "probnum.linalg.solvers.LinearSolverState", rng: Optional[RNGState] = None, ) -> backend.Array: """Return an action for a given solver state. Parameters ---------- solver_state Current state of the linear solver. rng Random number generator. Returns ------- action Next action to take. """ policy_idx = bisect.bisect_right(self._iters, solver_state.step) return self._base_policies[policy_idx](solver_state=solver_state, rng=rng)