skip to content
Site header image Jinsung Lee’s Personal Homepage
Email copied to clipboard

Towards 2-Dimensional State-Space Models (2): WHippo


This post continues from the previous post: Towards 2-Dimensional State-Space Models (1): Intro.

Once we accept the separation of the sequence construction order tt and the spatial coordinates (h,w)(h,w), we can consider wide range of image transformations that align much better with inherent image properties.

Few transformations that would come to your mind would be:

Mean (un)pooling
Mean (un)pooling
Gaussian (de)blurring
Gaussian (de)blurring
Gaussian (de)noising
Gaussian (de)noising

Then, what’s next? Can we construct SSMs on this?

A generalized view of SSMs

I spent quite a lot of time thinking about ways to derive a valid SSM on this concept. I started from HiPPO (Gu et al., NeurIPS 2020), since that’s where the 1D SSMs are built upon.

HiPPO derives how orthogonal polynomial coefficients change every time a new token is appended to the end of the input sequence.
HiPPO derives how orthogonal polynomial coefficients change every time a new token is appended to the end of the input sequence.

To briefly explain, HiPPO manually derived the dynamics of orthogonal polynomial coefficients that reflects the newly added token utu_t, and let the RNN’s hidden update follow this dynamics, as described in the figure above. I provide you more detailed explanation in my previous post.

At the end, I figured that SSMs are implicitly nudged to follow designated compression function (or basis projection, for mathematical ease) that we choose in the first place. Let me explain this in a more principled way.

In my opinion, the mechanisms of existing SSMs can be characterized by two components:

  1. a compression function c:LNNRLRN\mathbf{c}: \bigcup\limits_{L\in \mathbb{N}_{\scriptscriptstyle{\geq N}}} \mathbb{R}^{L} \rightarrow \mathbb{R}^{N},
  2. an input transformation θ:It1It\theta : \mathbf{I}_{t-1} \mapsto \mathbf{I}_{t}, where It\mathbf{I}_{t} denotes the input at time tt.

With this setup, we can explain pretty much of the existing SSMs, either 1D or 2D.

For example, HiPPO chose the compression function c()\mathbf{c}(\cdot) as a projection using 1D orthogonal polynomials, while choosing the input transformation θ:It1It\theta : \mathbf{I}_{t-1} \mapsto \mathbf{I}_{t} as a concatenation of the tt-th token next to the 1D input at time t1t-1.

Letting c()\mathbf{c}(\cdot) be the projection using 2D orthogonal polynomials and {It}\{\mathbf{I}_t\} be the corner-based construction using 2-dimensional sweeping will coincide with the formulation introduced in S4ND (Nguyen et al., NeurIPS 2022) or 2D-SSM (Baron et al., ICLR 2024).

In fact, we can regard that the hidden states of SSMs are trained to resemble the behavior of compression function c()\mathbf{c}(\cdot) in an indirect way; by mimicking its dynamics ddtc(It)\frac{\mathrm{d}}{\mathrm{d}t}\mathbf{c}(\mathbf{I}_t) on the predefined input transformation θ\theta, rather than mimicking its actual outputs.

This somehow explains where the compactness—the well-known ability of SSMs—emerges from: it may be due to its inherent design pushing the hidden state to resemble the compression function c()\mathbf{c}(\cdot). To step a little further, if you choose a typical basis function projection as c()\mathbf{c}(\cdot), you can expect the hidden state to capture spectral components of the input, as one of the properties that a lot of basis projections have is the decomposition of the input into spectral components. With this approach, you can distill favorable properties of c()\mathbf{c}(\cdot) to the hidden state, which has been an effective strategy for SSMs.

Thus, if we choose a combination of compression function and input transformation (c(),θ())\bigl(\mathbf{c}(\cdot), \theta(\cdot)\bigr) and enforce a feature to follow the update rule induced by c(It1)c(It)\mathbf{c}(\mathbf{I}_{t-1}) \mapsto \mathbf{c}(\mathbf{I}_{t}), we can come up with a novel SSM formulation that inherits the SSM’s core strategy.

WHippo: finding a good (c(),θ())\bigl(\mathbf{c}(\cdot), \theta(\cdot)\bigr) combination

<Vision Mamba>
 \theta(\cdot) : 1D scanning
 \mathbf{c}(\cdot) : 1D OP projection
<Vision Mamba>
θ()\theta(\cdot): 1D scanning
c()\mathbf{c}(\cdot): 1D OP projection
<S4ND, 2D-SSM>
 \theta(\cdot) : 2D scanning
 \mathbf{c}(\cdot) : 2D OP projection
<S4ND, 2D-SSM>
θ()\theta(\cdot): 2D scanning
c()\mathbf{c}(\cdot): 2D OP projection
<Ours (WHippo)>
 \theta(\cdot) : Gaussian blurring
 \mathbf{c}(\cdot) : 2D OP projection
<Ours (WHippo)>
θ()\theta(\cdot): Gaussian blurring
c()\mathbf{c}(\cdot): 2D OP projection

The illustration above shows the choice of θ()\theta(\cdot) with designated compression function c()\mathbf{c}(\cdot).

One thing to note when choosing c()\mathbf{c}(\cdot) and θ()\theta(\cdot) is to ensure that the update c(It1)c(It)\mathbf{c}(\mathbf{I}_{t-1}) \mapsto \mathbf{c}(\mathbf{I}_t) is tractable: the dynamics of compressed representation c()\mathbf{c}(\cdot) must be computable without having to decompress the input. This is related to the derive-ability of the derivative dc(It)dt\frac{\mathrm{d}\mathbf{c}(\mathbf{I}_t)}{\mathrm{d}t}, and explains why it is convenient to use basis projection as the compression function c()\mathbf{c}(\cdot); basis projection is often smooth and differentiable, compared to other compression formats such as PNG or Huffman Coding.

WHippo chooses  \mathbf{c}(\cdot)  as 2D basis projection, and  \theta(\cdot)  as Gaussian blurring
WHippo chooses c()\mathbf{c}(\cdot) as 2D basis projection, and θ()\theta(\cdot) as Gaussian blurring

My friend Jaemin and I found that choosing c()\mathbf{c}(\cdot) as 2D orthogonal polynomial basis projection and θ()\theta(\cdot) as Gaussian blurring results in a quite simple update rule, and we name this framework, WHippo.

Why named WHippo?


Source:  Reddit
Source: Reddit


  1. The biological classification of hippo is called Whippomorpha, which is a class of mammals that includes whales and hippos. So, Whippo can be viewed as a superclass (or generalization) of hippo.
  2. This idea is a continuation of HiPPO (NeurIPS 2020), which was the starting point of the state-space model, and we are trying to extend HiPPO to 2D, which previously only worked in 1D. Therefore, our HiPPO has spatial dimensions (Width, Height).
Why Gaussian blurring?

TL;DR, the Gaussian kernel provides a lot of really nice properties that align with our visual perception process, and reflects inherent property of natural images as well.

Scale-space theory: A basic tool for analyzing structures at different scales (Tony Lindeberg, 1994; Journal of Applied Statistics) talks about scale-space in images.

Thinking about what is a good way to analyze a given input at different scales is one of the oldest challenges in computer vision, and this paper is one of the classic theories on this subject. Here they study blurring techniques that allow us to obtain good scale representations and introduce the resulting method.

... This chapter gives a tutorial review of a special type of multi-scale representation, linear scale-space representation, which has been developed by the computer vision community in order to handle image structures at different scales in a consistent manner...
The main result we will arrive at is that if rather general conditions are posed on the types of computations that are to be performed at the first stages of visual processing, then the Gaussian kernel and its derivatives are singled out as the only possible smoothing kernels.

In our case, the Gaussian kernel provides a number of good properties in other means as well. The Gaussian kernel is defined in a continuous manner, is deterministic, and preserves linearity (applying two Gaussian kernels of σ1\sigma_1 and σ2\sigma_2 results in applying a single Gaussian kernel of (σ1+σ2)(\sigma_1 + \sigma_2)). This largely helps our derivation of dc(It)dt\frac{\mathrm{d}\mathbf{c}(\mathbf{I}_t)}{\mathrm{d}t} later.

Derivation of the WHippo update rule

Hence, our next step would be to derive the update rule c(It1)c(It)\mathbf{c}(\mathbf{I}_{t-1}) \mapsto \mathbf{c}(\mathbf{I}_{t}) based on the chosen (c(),θ())\bigl(\mathbf{c}(\cdot), \theta(\cdot)\bigr).

The figure aboveShow information for the linked content already spoils how the update rule looks like: it actually ends up inducing a pretty simple rule:

dctdt=Actdiscretize (e.g., euler)ct+1 =(1+ΔA)ct:=Act\begin{align} \frac{\mathrm{d}c_t}{\mathrm{d}t} = \mathbf{A}c_t \quad \xRightarrow[]{\text{discretize (e.g., euler)}} \quad c_{t+1} \ &= (\mathbf{1}+\Delta \mathbf{A})c_t \\ & := \overline{\mathbf{A}}c_t \nonumber \end{align}

with a structured matrix A\mathbf{A}.

Below I provide the derivation process of the update rule. Hopefully, its derivation process is not as hard as that of HiPPO’s.

Let’s assume an image with only one channel IRH×W\mathbf{I} \in \mathbb{R}^{H \times W}and derive the expression.
(although the illustrations show an RGB image)

Here, we assume the image I\mathbf{I} as a surface function I(x,y)\mathbf{I}(x,y) rather than a set of discrete points, which makes our formulation perfectly analogous to the original 1D SSM’s derivation.

Let a continuous sequence of images with increasingly blurred images be {It}t[0,T]\{ \mathbf{I}_t\}_{t \in [0, T]} (I=I0\mathbf{I} = \mathbf{I}_0), where the Gaussian kernel GtG_t applied at tt-th timestep is

Gt(x,y)=12πtexp(x2+y22t).G_t(x,y) = \frac{1}{2\pi t}\exp(-\frac{x^2 + y^2}{2t}).

Then the image It\mathbf{I}_t can be viewed I0Gt\mathbf{I}_0 *G_t, where * is the convolution operation. We can obtain the coefficients c(t)=[c1(t)c2(t) cHW(t)]T\mathbf{c}(t) = \begin{bmatrix} c_1(t) & c_2(t) & \cdots &\ c_{HW}(t) \end{bmatrix}^T corresponding to the 2D basis functions {ϕk}k=1HW\{\phi_k\}_{k=1}^{HW} by projecting an image onto the basis functions, where

ck(t)=GtI,ϕk=[0,H]×[0,W](GtI)(x,y)ϕk(x,y)dydx.c_k (t) = \langle G_t * \mathbf{I}, \phi_k \rangle = \iint_{[0,H]\times [0,W]}(G_t * \mathbf{I})(x,y)\cdot \phi_k(x,y)\mathrm{d}y\mathrm{d}x.

We are interested in the dynamics of ck(t)c_k(t) with respect to tt, so differentiating both sides with respect to tt gives:

dck(t)dt=ddtGtI,ϕk=t(GtI),ϕk.\begin{align}\frac{\mathrm{d}c_k(t)}{\mathrm{d}t} = \frac{\mathrm{d}}{\mathrm{d}t}\langle G_t * \mathbf{I}, \phi_k \rangle = \Bigl\langle \frac{\partial}{\partial t}(G_t * \mathbf{I}), \phi_k \Bigr\rangle. \end{align}
Note that convolution with a Gaussian filter is a well known solution of heat equations.

What does this mean?
Assume a trivariate function T(x,y,t)T(x,y,t), which models how the temperature at each (x,y)(x,y) coordinate (on [0,H]×[0,W][0,H] \times [0,W] as defined above) changes over time, then the PDE Heat equation (or, diffusion equation) describes the amount of temperature change at each point as follows:

Tt=α(2Tx2+2Ty2):=α2T.\frac{\partial T}{\partial t}= \alpha \Bigl(\frac{\partial^2 T}{\partial x^2} + \frac{\partial^2 T}{\partial y^2}\Bigr) := \alpha \nabla^2 T.

If T(x,y,0)T(x,y,0), the temperature distribution at time t=0t=0 (initial value condition), and derivatives at (x,y)(x,y) boundaries, Txx{0,H}\frac{\partial T}{\partial x}|_{x\in \{0, H\}}, Tyy{0,W}\frac{\partial T}{\partial y}|_{y\in \{0, W\}} (boundary conditions; we usually let them be zero, following standard Neumann boundary condition) are given, we can find the solution of the above equation. The known solution is u(x,y,t):=GtIu(x,y,t):=G_t * \mathbf{I}. More precisely, it is satisfied when the standard deviation of the Gaussian filter GtG_t is 2αt\sqrt{2\alpha t}.

In other words, if we consider the initial value I0\mathbf{I}_0 in the image sequence aboveShow information for the linked content to be the temperature rather than the pixel intensity, then the pixels It\mathbf{I}_t becomes equivalent to representing the temperature after tt.

(GtI)(G_t * \mathbf{I}) is the solution to the heat equation Tt=12(2Tx2+2Ty2):=122T\frac{\partial T}{\partial t}= \frac{1}{2} \Bigl(\frac{\partial^2 T}{\partial x^2} + \frac{\partial^2 T}{\partial y^2}\Bigr) := \frac{1}{2} \nabla^2 T, so the partial differential term appearing in (2) can be summarized as follows:

t(GtI),ϕk=122(GtI),ϕk.\Bigl\langle \frac{\partial}{\partial t}(G_t * \mathbf{I}), \phi_k \Bigr\rangle = \Bigl\langle \frac{1}{2}\nabla^2(G_t * \mathbf{I}), \phi_k \Bigr\rangle.

On the other hand, before computing the above, we can reorganize terms cleaner if we first represent the image (GtI)(G_t * \mathbf{I})) of each step as a sum of basis functions ϕk\phi_k. Since we already have defined the coefficients c(t)\mathbf{c}(t) of the basis functions, we can easily use them to express (GtI)(x,y)=c(t)Tϕ(x,y)=i=1HWci(t)ϕi(x,y)(G_t *\mathbf{I})(x,y) = \mathbf{c}(t)^T \mathbf{\phi}(x,y)= \sum_{i=1}^{HW}c_i(t)\phi_i(x,y) and substitute them:

122(GtI),ϕk=12i=1HWci(t)2ϕi,ϕk=12i=1HWci(t)2ϕi,ϕk.\begin{align*} \Bigl\langle \frac{1}{2}\nabla^2(G_t * \mathbf{I}), \phi_k \Bigr\rangle &= \frac{1}{2} \Bigl\langle \sum_{i=1}^{HW}c_i(t)\nabla^2\phi_i, \phi_k \Bigr\rangle \\ &= \frac{1}{2} \sum_{i=1}^{HW}c_i(t) \Bigl\langle \nabla^2\phi_i, \phi_k \Bigr\rangle. \end{align*}

From here, the derivation can either be easier or harder depending on which basis function ϕ\phi you choose.

In general, good basis functions assume orthogonal properties between each basis function, so the basis function that allows 2ϕk\nabla^2 \phi_kto maximize the use of ϕk\phi_kwill be the easiest to derive.
(For example, in the case of Fourier, 2ϕk\nabla^2 \phi_k is a scaled ϕk\phi_k term, so we can expect a very clean expression!)

The final expression:

dck(t)dt=12i=1HWci(t)ϕi,2ϕk()\frac{\mathrm{d}c_k(t)}{\mathrm{d}t} = \frac{1}{2} \sum_{i=1}^{HW}c_i(t) \Bigl\langle \phi_i, \nabla^2\phi_k \Bigr\rangle \qquad ---- \quad (*)

The good news is that we actually already have an expression that looks like a simple form of the SSM. Since the change in ckc_k is expressed as a linear combination of different ckc_k’s, this can be understood as a matrix operation:

ddtc(t)=12[ϕ1,2ϕ1ϕ1,2ϕ2ϕ1,2ϕHWϕ2,2ϕ1ϕ2,2ϕ2ϕ2,2ϕHWϕHW,2ϕ1ϕHW,2ϕ2ϕHW,2ϕHW]c(t):=Ac(t).\begin{align} \frac{\mathrm{d}}{\mathrm{d}t}\mathbf{c}(t) &= \frac{1}{2} \begin{bmatrix} \langle \phi_1, \nabla^2\phi_1 \rangle & \langle \phi_1, \nabla^2\phi_2 \rangle & \cdots & \langle \phi_{1}, \nabla^2\phi_{HW} \rangle \\ \langle \phi_2, \nabla^2\phi_1 \rangle & \langle \phi_2, \nabla^2\phi_2 \rangle & \cdots & \langle \phi_{2}, \nabla^2\phi_{HW} \rangle \\ \vdots & \vdots & \ddots & \vdots \\ \langle \phi_{HW}, \nabla^2\phi_1 \rangle & \langle \phi_{HW}, \nabla^2\phi_2 \rangle & \cdots & \langle \phi_{HW}, \nabla^2\phi_{HW} \rangle \\ \end{bmatrix} \mathbf{c}(t) \nonumber \\ &:= \mathbf{A}\mathbf{c}(t). \nonumber \end{align}

The structure of A\mathbf{A}

The structure of A\mathbf{A} is solely dependent on the choice of the 2D basis function ϕ\phi.

I’ll skip the derivation of A\mathbf{A} matrices of different bases, but I think it is worth sharing how they look like. Below are figures of those matrices, and you can see their sparse structures.

\mathbf{A}  of 2D Fourier bases
A\mathbf{A} of 2D Fourier bases
\mathbf{A}  of 2D Legendre polynomial bases
A\mathbf{A} of 2D Legendre polynomial bases
\mathbf{A}  of 2D Chebyshev polynomial bases
A\mathbf{A} of 2D Chebyshev polynomial bases
\mathbf{A}  of 2D Hermite polynomial bases. Due to its exponential values, log values are displayed. Zero values marked white.
A\mathbf{A} of 2D Hermite polynomial bases. Due to its exponential values, log values are displayed. Zero values marked white.

So, what’s next? Any use of this?

Now we derived a novel 2D SSM formulation that is more natural in terms of image processing.

Then, how can this be implemented in neural network training?

I’ll explain its intriguing usage in the next post. To tease a little bit, there was an unexpected property that this framework provides to diffusion modeling!



To be continued…