Skip to content

14. Vision Transformers

Vision Transformers (ViT)

In the Transformers class we built an architecture for sequences of text tokens. In 2021, Dosovitskiy et al.1 asked a provocative question: what if we feed an image to that exact same encoder? Their answer β€” "An Image is Worth 16Γ—16 Words" β€” showed that, with enough data, a near-vanilla Transformer can beat Convolutional Neural Networks at image classification, without any convolutions at all.

The ViT is the bridge between the convolutional world (class 11) and the Transformer world (class 13). Once you understand it, the image encoders inside CLIP, Stable Diffusion and Diffusion Transformers stop being black boxes.


The inductive bias trade-off

A CNN comes with two strong inductive biases baked into its architecture:

  • Locality β€” a convolutional kernel only looks at a small neighborhood of pixels.
  • Translation equivariance β€” the same filter slides across the whole image, so a feature is detected regardless of where it appears.

These priors are exactly why CNNs learn so efficiently from small datasets: the architecture already "knows" that nearby pixels matter and that objects can move around.

A pure self-attention layer has none of these priors. Every patch can attend to every other patch from layer 1, so the model must learn spatial relationships from data. This is a double-edged sword:

Weakness

  • Needs lots of data to learn what a CNN gets for free
  • On ImageNet-1k alone, a CNN of similar size wins
Strength

  • Global receptive field from the very first layer
  • Given enough data, it surpasses CNNs β€” the bias was a ceiling, not just a floor

The ViT pipeline, step by step

The trick is to turn an image into a sequence of tokens so the Transformer encoder we already know can consume it. Step through the pipeline below:


Patch embedding β€” the only genuinely new piece

Everything after the first step is the encoder you already know. The one new mechanism is how the image becomes tokens.

An image \(x \in \mathbb{R}^{H \times W \times C}\) is reshaped into a sequence of \(N\) flattened patches, then each patch is projected by a single shared linear layer:

\[ x \in \mathbb{R}^{H \times W \times C} \;\xrightarrow{\text{patchify}}\; x_p \in \mathbb{R}^{N \times (P^2 C)} \;\xrightarrow{\;E\;}\; z \in \mathbb{R}^{N \times d} \]

where \(P\) is the patch size, \(N = HW/P^2\) is the number of patches, and \(E \in \mathbb{R}^{(P^2 C) \times d}\) is the patch embedding matrix. For a \(224 \times 224\) image with \(P = 16\), that is \(N = 196\) tokens.

A learnable [CLS] token \(z_{\text{cls}}\) is prepended, and learnable positional embeddings \(E_{\text{pos}}\) are added (attention alone is permutation-invariant, so without positions the model could not tell a patch in the top-left from one in the bottom-right):

\[ z_0 = [\, z_{\text{cls}};\; x_p^1 E;\; x_p^2 E;\; \dots;\; x_p^N E \,] + E_{\text{pos}} \]

Note. A linear projection over non-overlapping \(P \times P\) patches is mathematically identical to a Conv2d with kernel_size = stride = P. That is exactly how it is implemented in practice β€” one convolution does the patchifying and the projection in a single op.


The encoder and the classification head

The sequence \(z_0\) goes through \(L\) identical Transformer encoder blocks4 β€” the same MHSA + Add&Norm + FFN blocks from the Transformers class (ViT uses the pre-norm variant and GELU):

\[ z'_\ell = \text{MHSA}(\text{LN}(z_{\ell-1})) + z_{\ell-1}, \qquad z_\ell = \text{FFN}(\text{LN}(z'_\ell)) + z'_\ell \]

For classification, only the final state of the [CLS] token is read out and passed through a small MLP head:

\[ y = \text{MLP head}\big(\text{LN}(z_L^{0})\big) \]

That is the whole model. No convolutions, no pooling pyramids β€” just patchify, then a standard Transformer encoder.


Data hunger: why ViT needs pretraining

Because it lacks the convolutional priors, ViT only shines at scale. The original paper made this concrete:

  • Trained on ImageNet-1k only (~1.3M images), ViT underperforms a comparable ResNet.
  • Pre-trained on ImageNet-21k (~14M) it draws level.
  • Pre-trained on JFT-300M (~300M) it surpasses the best CNNs and transfers beautifully to downstream tasks.

This is precisely the pretrain-then-finetune recipe of the Transfer Learning class: pretrain the encoder on a huge dataset, then fine-tune the cheap MLP head (or the whole model with a low learning rate) on your task. It is also why CLIP could train a ViT image encoder on 400M image-text pairs β€” at that scale, the weak inductive bias becomes an advantage.

Two lines of work soften this data hunger. Careful augmentation and regularization recipes let ViTs train competitively on ImageNet-1k alone, without a giant private dataset5. And self-supervised pretraining learns strong ViT features without any labels β€” by reconstructing masked patches (MAE6) or by self-distillation (DINO7, whose attention maps segment objects for free).


CNN vs. ViT at a glance

CNN Vision Transformer
Core operation Convolution (local) Self-attention (global)
Inductive bias Strong (locality, equivariance) Weak β€” learned from data
Receptive field Grows with depth Global from layer 1
Data efficiency Strong on small datasets Needs large-scale pretraining
Compute \(O(N)\) in pixels \(O(N^2)\) in patches
Scales with data Saturates earlier Keeps improving

Hybrids and successors. Several variants reintroduce some spatial bias to get the best of both worlds: DeiT2 (data-efficient training with distillation, no JFT needed), and Swin Transformer3 (windowed attention with a hierarchical, pyramid-like structure that brings back locality and makes ViT practical for detection and segmentation). Conversely, ConvNeXt8 modernized a pure CNN to match ViTs head-to-head β€” evidence that the training recipe and scale matter as much as the convolution-vs-attention choice itself.


Implementation reference

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """Image -> sequence of patch tokens (Conv2d does patchify + projection)."""
    def __init__(self, img_size=224, patch=16, in_ch=3, dim=768):
        super().__init__()
        self.n_patches = (img_size // patch) ** 2
        self.proj = nn.Conv2d(in_ch, dim, kernel_size=patch, stride=patch)

    def forward(self, x):                  # x: (B, C, H, W)
        x = self.proj(x)                   # (B, dim, H/p, W/p)
        return x.flatten(2).transpose(1, 2)  # (B, N, dim)
class ViT(nn.Module):
    def __init__(self, dim=768, depth=12, heads=12, n_classes=1000):
        super().__init__()
        self.patch_embed = PatchEmbed(dim=dim)
        n = self.patch_embed.n_patches
        self.cls = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos = nn.Parameter(torch.zeros(1, n + 1, dim))
        layer = nn.TransformerEncoderLayer(dim, heads, dim * 4,
                                           activation='gelu', norm_first=True,
                                           batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, depth)
        self.head = nn.Linear(dim, n_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)                       # (B, N, dim)
        cls = self.cls.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1) + self.pos     # prepend CLS + add positions
        x = self.encoder(x)
        return self.head(x[:, 0])                      # classify from CLS token
import timm
# load ImageNet-21k -> 1k pretrained weights and fine-tune the head
model = timm.create_model('vit_base_patch16_224', pretrained=True,
                          num_classes=10)