Algorithm Illustration¶
We will visualize the different components of IterGP to better understand how it performs GP inference.
[1]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='svg'
from probnum import backend, randvars
from probnum.randprocs import kernels, mean_fns
from itergp import GaussianProcess, datasets, methods
Dataset¶
We consider a synthetically generated dataset without noise such that we can evaluate the label function \(y(x) = sin(\pi x^\top \mathbf{1})\).
[2]:
# Generate data
rng_state = backend.random.rng_state(42)
num_data = 6
input_shape = ()
output_shape = ()
rng_state, rng_state_data = backend.random.split(rng_state, num=2)
data = datasets.SyntheticDataset(
rng_state=rng_state,
size=(num_data, num_data),
input_shape=input_shape,
output_shape=output_shape,
noise_var=0.0,
)
X = data.train.X
y = data.train.y
Gaussian Process Model¶
[3]:
# Model
mean_fn = mean_fns.Zero(input_shape=input_shape, output_shape=output_shape)
kernel = kernels.ExpQuad(input_shape=input_shape, lengthscale=0.2)
sigma_sq = 0.0
noise = randvars.Normal(
mean=backend.zeros(y.shape), cov=sigma_sq * backend.eye(y.shape[0])
)
gp = GaussianProcess(mean_fn, kernel)
IterGP Inference¶
[4]:
# Approximation methods
approx_methods = ["IterGP-Chol", "IterGP-CG", "IterGP-PI"]
pseudo_inputs = backend.linspace(-1, 1, len(X))
[5]:
fig, axs = plt.subplots(
nrows=3, ncols=len(approx_methods), sharex="col", figsize=(12, 6), sharey="row"
)
fig.patch.set_alpha(0.0) # set figure background opacity to 0
plt.close()
def animate(idxiter):
for ax in axs.flatten():
ax.cla()
for idxmethod, approx_method in enumerate(approx_methods):
# Latent function
Xnew = backend.linspace(-1, 1, 1000)
axs[2, idxmethod].plot(
Xnew, data.fun(Xnew), linestyle="--", color="black", lw=0.75
)
# Training Data
data_range = X.max() - X.min()
data_width = 0.05 * data_range
for i in range(3):
for x in X:
axs[i, idxmethod].axvspan(
xmin=x - 0.25 * data_width,
xmax=x + 0.25 * data_width,
color="gray",
alpha=0.25,
zorder=-10,
lw=0.0,
)
# Gaussian process approximation
if approx_method == "IterGP-Chol":
ameth = methods.Cholesky(maxrank=idxiter)
elif approx_method == "IterGP-CG":
ameth = methods.CG(maxiter=idxiter)
elif approx_method == "IterGP-PI":
ameth = methods.PseudoInput(pseudo_inputs=pseudo_inputs[0:idxiter])
else:
raise NotImplementedError
gp_post = gp.condition_on_data(X, y, b=noise, approx_method=ameth)
gp_post.plot(X=Xnew, data=(X, y), ax=axs[2, idxmethod])
axs[2, idxmethod].set(ylabel="Prediction", ylim=(-2.01, 2.01))
# Residual
residual_fn = lambda x: data.fun(x) - gp_post.mean(x)
residual_global = residual_fn(Xnew)
residual = residual_fn(X)
residual_color = "C3"
axs[1, idxmethod].fill_between(
x=Xnew,
y1=backend.zeros_like(Xnew),
y2=residual_global,
alpha=0.2,
lw=0.0,
color=residual_color,
)
axs[1, idxmethod].plot(Xnew, residual_global, color=residual_color)
axs[1, idxmethod].scatter(X, residual, color=residual_color, marker=".")
axs[1, idxmethod].axhline(y=0.0, color="black", linestyle="--", lw=0.5)
ymin, ymax = axs[1, idxmethod].get_ylim()
for i, x in enumerate(X):
axs[1, idxmethod].axvspan(
xmin=x - 0.25 * data_width,
xmax=x + 0.25 * data_width,
ymin=(0.0 - ymin) / (ymax - ymin),
ymax=(backend.to_numpy(residual[i]) - ymin) / (ymax - ymin),
color=residual_color,
# alpha=0.5,
zorder=-5,
lw=0.0,
)
axs[1, idxmethod].set(ylabel="Residual", ylim=(-1.01, 1.01))
# Action
if idxiter < len(X):
if approx_method == "IterGP-Chol":
action_fn = lambda x: x == X[idxiter]
elif approx_method == "IterGP-CG":
action_fn = residual_fn
elif approx_method == "IterGP-PI":
action_fn = lambda x: kernel(pseudo_inputs[idxiter], x)
else:
raise NotImplementedError
action_global = action_fn(Xnew)
action = action_fn(X)
action_color = "C4"
axs[0, idxmethod].scatter(X, action, color=action_color, marker=".")
axs[0, idxmethod].plot(Xnew, action_global, color=action_color)
axs[0, idxmethod].fill_between(
x=Xnew,
y1=backend.zeros_like(Xnew),
y2=action_global,
color=action_color,
lw=0.0,
alpha=0.2,
)
axs[0, idxmethod].axhline(y=0.0, color="black", linestyle="--", lw=0.5)
ymin, ymax = axs[0, idxmethod].get_ylim()
for i, x in enumerate(X):
axs[0, idxmethod].axvspan(
xmin=x - 0.25 * data_width,
xmax=x + 0.25 * data_width,
ymin=(0.0 - ymin) / (ymax - ymin),
ymax=(backend.to_numpy(action[i]) - ymin) / (ymax - ymin),
color=action_color,
# alpha=0.5,
zorder=-5,
lw=0.0,
)
axs[0, idxmethod].set(xlabel="Input Space", ylabel="Action")
fig.align_ylabels()
from IPython.display import HTML
from matplotlib import animation
# Create animation
anim = animation.FuncAnimation(
fig, func=animate, frames=num_data + 1, interval=1250, repeat_delay=4000, blit=False
)
# Create interactive plot
HTML(anim.to_jshtml())
[KeOps] Generating code for formula Sum_Reduction(Exp(-Var(0,1,2)*((Var(1,1,0)-Var(2,1,1))*Sum(Var(1,1,0)-Var(2,1,1))))*Var(3,2,1),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(Exp(-Var(0,1,2)*((Var(1,1,0)-Var(2,1,1))*Sum(Var(1,1,0)-Var(2,1,1))))*Var(3,3,1),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(Exp(-Var(0,1,2)*((Var(1,1,0)-Var(2,1,1))*Sum(Var(1,1,0)-Var(2,1,1))))*Var(3,4,1),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(Exp(-Var(0,1,2)*((Var(1,1,0)-Var(2,1,1))*Sum(Var(1,1,0)-Var(2,1,1))))*Var(3,5,1),0) ... OK
[KeOps] Generating code for formula Sum_Reduction(Exp(-Var(0,1,2)*((Var(1,1,0)-Var(2,1,1))*Sum(Var(1,1,0)-Var(2,1,1))))*Var(3,6,1),0) ... OK
[5]:
Observations¶
Notice how the action can be interpreted as targeting computation towards certain datapoints. Different instances of IterGP differ by how the computation is targeted during a run of the algorithm. After \(n\) iterations all combined posteriors are identical to the mathematical GP posterior. Since we assumed no observation noise, the residual is zero at the datapoints.