Published: 06.12.2023

Vision Retention Networks

Introduction

Retention is a mechanism recently proposed in Retentive Network: A Successor to Transformer for Large Language Models by Sun et. al. which core idea is to carry out similar computation as attention while being much more computationally efficient. It has now become a recurrent pattern that researchers from other fields of machine learning take inspiration from the progress done in NLP and try to adapt NLP-solutions to a different problem. It was thus only a matter of time before we would have heard of Retention in the field of Computer Vision.

ViR: Vision Retention Networks by Ali Hatamizadeh, Michael Ranzinger, Jan Kautz first applied Retention in a CV model. I recently had a great time re-implementing the paper and digging into Retention, so I thought I would share what I have learned. You can find my re-implementation at brianpulfer/vision-Retention-networks or at brianpulfer/papersreimplementations.

Attention

Before we dig into ViR, we need to learn what Retention is. But before we learn what Retention is, a little recap on Attention.

Attention is a mechanism that allows a model to learn relationships between elements of the input. The meaning of a word can be completely altered based on the surrounding words. A red pixel in an image might come from a tomato, a Ferrari car or a baloon. Only the combination with neighbouring pixels give it a meaning. It is thus important for models to have the ability to learn the interplay of elements in the input sequence. That is where Attention comes in, and this is how it works:

Given an input sequence XRN×DX \in \mathbb{R} ^ {N \times D}, attention computes Queries, Keys and Values for each element of the sequence as follows:

Q=XWqQ = X W_q
K=XWkK = X W_k
V=XWvV = X W_v

Where Wq,Wk,WvRD×DW_q, W_k, W_v \in \mathbb{R} ^ {D \times D} are learnable parameters. The output for each element of the sequence is going to be a weighted sum of the values, where the weights are computed as the dot product between the query and the keys:

Attention(Q,K,V)=softmax(QKTD)V\text{Attention}(Q, K, V) = \text{softmax} \left( \frac{Q K^T}{\sqrt{D}} \right) V

and softmax(x)i=exijexj\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} is applied row-wise.

This mechanism, ever since the Attention Is All You Need paper, has been empirically proven to be very powerful for learning relationships between elements of a sequence. It has been used in virtually all contexts (NLP, CV, TTS, ...), and it has become a de-facto standard for many tasks.

Then why getting rid of it?

There is only one issue that has researchers a bit troubled: the complexity of attention is O(N2)O(N^2) (easily seen when computing QKTQK^T), meaning that for an input sequence twice as long, computing attention takes four times as much time.

Quite some effort went into trying to solve this issue, with various variations like Linear Attention and Efficient Attention trying to replicate the mechanism while being computationally more convenient.

Retention

Retention works recurrently just like recurrent neural networks. At each step, it reads the input to update an inner state matrix, use the inner state to compute an output and pass the inner state onward. Here is the RECURRENT formulation of Retention

sn=αsn1+kntvn\mathbf{s}_n = \alpha \mathbf{s}_{n-1} + \mathbf{k}_n^t \mathbf{v}_n
Retention(x)=on=qnsn\text{Retention}(\mathbf{x})=\mathbf{o}_n = \mathbf{q}_n \mathbf{s}_n

where sn\mathbf{s}_n is the inner state at step nn, kn\mathbf{k}_n is the key and vn\mathbf{v}_n is the value of the current (n-th) element in the sequence (row vectors, so sn,knTvnRD×D\mathbf{s}_n, \mathbf{k_n^Tv_n} \in \mathbb{R}^{D \times D}). Needless to say, qn,kn,vn\mathbf{q}_n, \mathbf{k}_n, \mathbf{v}_n are linear projections of the n-th sequence element xn\mathbf{x}_n. Finally, 0α10 \le \alpha \le 1 is a constant that exponentially decays older key-values products.

Translating into text these equations, the idea is the following: sn\mathbf{s}_n will contain the state in the form of all key-value products. The ouput is obtained by fetching the desired value (mixture of values) by using the current query qn\mathbf{q}_n.

This is literally all there is to Retention! What is so special about it is that it can also be computed using a PARALLEL formulation just like we do for Attention. The formula to compute all outputs at once is the following:

Mi,j={0,i<jαij,ijM_{i,j} = \begin{cases} 0, & i < j \\ \alpha^{i-j}, & i \ge j \end{cases}
Retention(X)=(QKTDM)V\text{Retention}(X) = (\frac{QK^T}{\sqrt{D}} \odot M)V

Looks familiar, right? In fact, we do everything exactly as for Attention, except that we do not apply a row-wise softmax function and always apply MM, a lower-triangular matrix that simultaneously deals with causal masking (take no contribution from future elements in the sequence) and applies the exponential decay given by α\alpha.

The key takeaway here is that if we get rid of the softmax operator we unlock the recurrent formulation, where we can just carry on what we had computed before to compute the next output.

However, processing sequences recurrently sucks! That is exactly the reason why we generally prefer Transformers over RNNs: the Transformer time complexity might be quadratic in sequence length, but at least we can process everything in parallel. With a recurrent formulation, we need to sequentially compute the n-th output before we can compute the n-th + 1 while our GPUs sit quiet.

Then why caring about a recurrent formulation?

The real ✨magic✨ happens when we decide to use a hybrid between parallel and recurrent formulations. In fact, it turns out that we can break the input sequence into multiple chunks, run each chunk in parallel using the parallel formulation, and then aggregate all of the results with a cross-chunk recurrent computation. This means that as soon as the sequence becomes prohibitively long for the parallel formulation (quadratic in NN), we can just split it into chunks of size CC and run those parallelly (quadratic in chunk-size CC only!) and finally combine the cross-chunk information recurrently (linear in NC\frac{N}{C}). The real gain is thus obtained when we have very long sequences.

Here we have the CHUNKWISE RECURRENT formulation of Retention:

Q[i]=QCi:C(i+1),K[i]=KCi:C(i+1),V[i]=VCi:C(i+1)Q_{[i]} = Q_{Ci:C(i+1)}, \quad K_{[i]} = K_{Ci:C(i+1)}, \quad V_{[i]} = V_{Ci:C(i+1)}
Ri=K[i]T(V[i]A)+αCRi1,Aij=αCi1R_i = K^T_{[i]}(V_{[i]} \odot \Alpha) + \alpha^CR_{i-1}, \quad \Alpha_{ij} = \alpha^{C-i-1}
Retention(Xi)=(Q[i]K[i]TM)V[i]Inner-Chunk+(Q[i]Ri1)ξIntra-Chunk,ξij=αi+1\text{Retention}(X_{i}) = \underbrace{(Q_{[i]} K^T_{[i]} \odot M)V_{[i]}}_{\text{Inner-Chunk}} + \underbrace{(Q_{[i]}R_{i-1}) \odot \xi}_{\text{Intra-Chunk}}, \quad \xi_{ij} = \alpha^{i+1}

The math looks scary, but really we are just applying the parallel computation for all chunks and, once we have the Inner-Chunk parts, we can merge them using the recurrent formulation.

Comparison of Attention and Retention

Time complexity

Retention can, given the previous state, compute the next token in O(1)O(1) time complexity, whereas Attention does not have a previous state and it needs to use all O(N)O(N) past keys and queries to predict the next token.

Recurrent formulation

Attention does not need to be formulated recurrently, whereas Retention does. This is perfectly fine for causal decoder transformers, where we don not want current tokens to attend to future tokens anyways. However, in computer vision we mostly use the encoder type of transformer, so it is not completely clear what impact forcing the causal relationship might have in a task where seemingly there is no causal relationship.

Personal observation: Because Retention accumulates all keys and queries, I believe that it is probably not as powerful of a mechanism as Attention. Perhaps this loss of expressivity is not a big deal for text and/or images, especially compared to the gains made in time complexity, but this is still something to keep in mind. It might very well be that Retention fails to become a de-facto standard like other alternatives to Attention before it due to worse performances. What is sure is that Retention enables faster inference and, for very long sequences, even faster training while being quite similar to Attention.

Vision Retention Networks

Vision Retention Networks are a minor yet important variantion from Vision Transformer. I have previously written about how Vision Transformers (ViT) work, but in short, a ViT breaks an image into many distinct non-overlapping patches (typically, 16x16 patches of size 14x14 for images of size 224x224) which are then flattened and treated as a sequence. An encoder transformer is then used to process the sequence without any causal masking and the output is used for down-stream tasks.

The ViT is thus just a stack of encoder blocks, where each block sequentially applies an Attention block and an MLP block. In ViR, we get rid of the Attention block and swap a Retention block in instead.

Personal observation: It must be noted that because Retention works in a recurrent matter by definition, this is a big shift from ViT! While a ViT sees the whole image in one go, a ViR virtually reads the image from left to right from top to bottom. This is potentially a drawback of ViR over ViT, since it might not make sense to introduce causality in images.

Because retention reads the image in sequence, if we want our model to be a classifier, we need to use an output that comes after all tokens have been seen. To do so, we append a learnable [CLS] at the end of the sequence and use the generated output to do classification. Notice than in regular ViT, the CLS token was typically placed at the beginning of the sequence (although for a regular ViT this does not really make a difference).

Implementation

Here is my full re-implementation of a ViR:

1import torch
2import torch.nn as nn
3
4
5class ViRModes:
6    PARALLEL = "parallel"
7    RECURRENT = "recurrent"
8    CHUNKWISE = "chunkwise"
9
10
11class Retention(nn.Module):
12    def __init__(
13        self,
14        embed_dim,
15        max_len,
16        alpha,
17        mode=ViRModes.PARALLEL,
18        chunk_size=20,
19    ):
20        super(Retention, self).__init__()
21        self.dim = embed_dim
22        self.max_len = max_len
23        self.chunk_size = chunk_size
24        self.alpha = alpha
25        self.mode = mode
26
27        # Useful buffers
28        self.register_buffer("dim_sqrt", torch.tensor(embed_dim**0.5))
29
30        indices = torch.arange(max_len).reshape(1, -1)
31        self.register_buffer(
32            "decay_mask",
33            (alpha ** (indices.t() - indices)).tril(),
34        )
35
36        self.register_buffer("causal_mask", torch.ones(max_len, max_len).tril())
37        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
38
39    def forward_parallel(self, x):
40        # Getting queries, keys, values
41        bs, sl, d = x.shape
42        qkv = self.qkv(x)
43        q, k, v = torch.chunk(qkv, 3, dim=-1)
44
45        # Causal and decay masking
46        M = (self.causal_mask[:sl, :sl] * self.decay_mask[:sl, :sl]).repeat(bs, 1, 1)
47
48        # Retention
49        out = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v
50
51        return out
52
53    def forward_recurrent(self, x, state):
54        batch_size, length, dim = x.shape
55
56        all_outputs = []
57        state = torch.zeros(batch_size, dim, dim, device=x.device)
58        for i in range(length):
59            xi = x[:, i]
60            q, k, v = self.qkv(xi).chunk(3, dim=-1)
61
62            state = self.alpha * state + k.unsqueeze(-1) @ v.unsqueeze(1)
63            out = q.unsqueeze(1) @ state / self.dim_sqrt
64            all_outputs.append(out.squeeze())
65
66        x = torch.stack(all_outputs, dim=1)
67        return x
68
69    def forward_chunkwise(self, x, chunk_size=None):
70        # Getting queries, keys, values
71        if chunk_size is None:
72            chunk_size = self.chunk_size
73
74        bs, sl, d = x.shape
75
76        # Adding dummy tokens to make the sequence length divisible by chunk_size
77        if sl % chunk_size != 0:
78            x = torch.cat(
79                [x, torch.zeros(bs, chunk_size - sl % chunk_size, d, device=x.device)],
80                dim=1,
81            )
82        n_chunks = x.shape[1] // chunk_size
83
84        # Running all chunks in parallel
85        x = x.reshape(bs, n_chunks, chunk_size, d)
86        q, k, v = self.qkv(x).chunk(3, dim=-1)
87
88        M = (
89            self.causal_mask[:chunk_size, :chunk_size]
90            * self.decay_mask[:chunk_size, :chunk_size]
91        ).repeat(bs, n_chunks, 1, 1)
92
93        inner_chunk = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v
94
95        # Updating outputs with chunk-wise recurrent
96        retention_mask = (
97            torch.tensor(
98                [self.alpha ** (chunk_size - i - 1) for i in range(chunk_size)],
99                device=x.device,
100            )
101            .repeat(bs, d, 1)
102            .transpose(-1, -2)
103        )
104
105        cross_mask = (
106            torch.tensor(
107                [self.alpha ** (i + 1) for i in range(chunk_size)], device=x.device
108            )
109            .repeat(bs, n_chunks, d, 1)
110            .transpose(-1, -2)
111        )
112
113        states = torch.zeros(bs, n_chunks, d, d, device=x.device)
114        for i in range(1, n_chunks):
115            chunk_state = k[:, i - 1].transpose(-1, -2) @ (v[:, i - 1] * retention_mask)
116            states[:, i] = chunk_state + states[:, i - 1] * self.alpha**chunk_size
117
118        cross_chunk = (q @ states) / self.dim_sqrt * cross_mask
119
120        # Combining inner and cross chunk
121        out = inner_chunk + cross_chunk
122
123        # Removing dummy tokens
124        out = out.flatten(1, 2)[:, :sl]
125        return out
126
127    def forward(self, x, state=None, mode=ViRModes.PARALLEL, chunk_size=None):
128        if mode is None:
129            mode = self.mode
130
131        if mode == ViRModes.PARALLEL:
132            return self.forward_parallel(x)
133        elif mode == ViRModes.RECURRENT:
134            return self.forward_recurrent(x, state)
135        elif mode == ViRModes.CHUNKWISE:
136            return self.forward_chunkwise(x, chunk_size)
137        else:
138            raise ValueError(f"Unknown mode {mode}")
139
140
141class MultiHeadRetention(nn.Module):
142    def __init__(
143        self,
144        heads,
145        embed_dim,
146        max_len,
147        alphas=None,
148        mode=ViRModes.PARALLEL,
149        chunk_size=20,
150    ):
151        super(MultiHeadRetention, self).__init__()
152        self.n_heads = heads
153        self.embed_dim = embed_dim
154        self.max_len = max_len
155        self.alphas = alphas
156        self.head_dim = embed_dim // heads
157        self.mode = mode
158        self.chunk_size = chunk_size
159
160        if alphas is None:
161            alphas = [1 - 2 ** (-5 - i) for i in range(heads)]
162
163        assert len(alphas) == heads, "Number of alphas must match number of heads"
164
165        assert (
166            embed_dim % heads == 0
167        ), "Embedding dimension must be divisible by the number of heads"
168
169        self.heads = nn.ModuleList(
170            [
171                Retention(embed_dim // heads, max_len, alpha, mode, chunk_size)
172                for alpha in alphas
173            ]
174        )
175        self.ln = nn.LayerNorm(embed_dim)
176        self.gelu = nn.GELU()
177        self.linear = nn.Linear(embed_dim, embed_dim)
178
179    def forward(self, x, mode=None, chunk_size=None):
180        if mode is None:
181            mode = self.mode
182
183        if chunk_size is None:
184            chunk_size = self.chunk_size
185
186        out = torch.cat(
187            [
188                head(
189                    x[:, :, i * self.head_dim : (i + 1) * self.head_dim],
190                    mode=mode,
191                    chunk_size=chunk_size,
192                )
193                for i, head in enumerate(self.heads)
194            ],
195            dim=-1,
196        )
197        return self.linear(self.gelu(self.ln(out)))
198
199
200class MLP(nn.Module):
201    def __init__(self, embed_dim, hidden_dim=None):
202        super(MLP, self).__init__()
203
204        if hidden_dim is None:
205            hidden_dim = 4 * embed_dim
206
207        self.linear1 = nn.Linear(embed_dim, hidden_dim)
208        self.linear2 = nn.Linear(hidden_dim, embed_dim)
209        self.gelu = nn.GELU()
210
211    def forward(self, x):
212        return self.linear2(self.gelu(self.linear1(x)))
213
214
215class ViRBlock(nn.Module):
216    def __init__(
217        self,
218        heads,
219        embed_dim,
220        max_len,
221        alphas=None,
222        mode=ViRModes.PARALLEL,
223        chunk_size=20,
224        dropout=0.1,
225    ):
226        super(ViRBlock, self).__init__()
227        self.mode = mode
228        self.chunk_size = chunk_size
229
230        self.ln1 = nn.LayerNorm(embed_dim)
231        self.retention = MultiHeadRetention(
232            heads, embed_dim, max_len, alphas, mode, chunk_size
233        )
234        self.ln2 = nn.LayerNorm(embed_dim)
235        self.mlp = MLP(embed_dim)
236        self.dropout1 = nn.Dropout(dropout)
237        self.dropout2 = nn.Dropout(dropout)
238
239    def forward(self, x, mode=None, chunk_size=None):
240        if mode is None:
241            mode = self.mode
242
243        if chunk_size is None:
244            chunk_size = self.chunk_size
245
246        x = (
247            self.dropout1(self.retention(self.ln1(x), mode=mode, chunk_size=chunk_size))
248            + x
249        )
250        x = self.dropout2(self.mlp(self.ln2(x))) + x
251        return x
252
253
254class ViR(nn.Module):
255    def __init__(
256        self,
257        patch_size=14,
258        depth=12,
259        heads=12,
260        embed_dim=768,
261        max_len=256,
262        alphas=None,
263        mode=ViRModes.CHUNKWISE,
264        chunk_size=256,
265        dropout=0.1,
266    ):
267        super(ViR, self).__init__()
268
269        # Local parameters
270        self.out_dim = 10
271        self.patch_size = patch_size
272        self.depth = depth
273        self.heads = heads
274        self.embed_dim = embed_dim
275        self.max_len = max_len
276        self.alphas = alphas
277        self.mode = mode
278        self.chunk_size = chunk_size
279
280        # Embeddings
281        self.patch_embed = nn.Conv2d(
282            3, embed_dim, (patch_size, patch_size), stride=(patch_size, patch_size)
283        )
284        self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim))
285
286        # ViR blocks
287        self.blocks = nn.ModuleList(
288            [
289                ViRBlock(heads, embed_dim, max_len, alphas, mode, chunk_size, dropout)
290                for _ in range(depth)
291            ]
292        )
293
294        # Head
295        self.ln = nn.LayerNorm(embed_dim)
296
297    def forward(self, x, mode=None, chunk_size=None, reshape=False):
298        if mode is None:
299            mode = self.mode
300
301        if chunk_size is None:
302            chunk_size = self.chunk_size
303
304        # Patch embedding, positional embedding
305        x = self.patch_embed(x).permute(0, 2, 3, 1).flatten(1, 2)
306        bs, sl = x.shape[:2]
307        x = x + self.pos_embed.repeat(bs, 1, 1)[:, :sl]
308
309        # Blocks
310        for block in self.blocks:
311            x = block(x, mode=mode, chunk_size=chunk_size)
312
313        # Layer Norm
314        x = self.ln(x)
315
316        # Reshape
317        if reshape:
318            ps = int(x.shape[1] ** 0.5)
319            x = x.reshape(bs, ps, ps, self.embed_dim).permute(0, 3, 1, 2)
320
321        return x
322
323
324if __name__ == "__main__":
325    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326    x = torch.randn(16, 3, 224, 224).to(device)
327    model = ViR(depth=12, heads=3, embed_dim=192).eval().to(device)
328
329    with torch.no_grad():
330        y1 = model(x, mode=ViRModes.PARALLEL)
331        y2 = model(x, mode=ViRModes.RECURRENT)
332        y3 = model(x, mode=ViRModes.CHUNKWISE, chunk_size=20)
333
334        assert torch.allclose(
335            y1, y2, atol=1e-5
336        ), "Parallel and recurrent modes should give the same output"
337
338        assert torch.allclose(
339            y1, y3, atol=1e-5
340        ), "Parallel and chunkwise modes should give the same output"
341

It feels like I should comment there 300+ lines, but really there is nothing that is not already covered in the formulas. The only thing that I should mention is that the chunk size CC might not entirely devide the sequence length NN, so what one can do is adding some dummy tokens at the end of the sequence such that the sequence is entirely divisible by the chunk size (a sort of padding).

Also, I found it key for performances to actually perform computations for all chunks in parallel, so it is not enough to re-use the forward_parallel function sequentially for each chunk.

Also notice that we use different alphas for each head: some heads with a higher alpha will look further back into the past, other heads with a lower alpha will mostly focus on most recent tokens.

Thank you for reading until here! If you found this helpful / interesting, or have suggestions on how to improve, please do not hesitate to contact me at me@brianpulfer.ch