Ir para o conteúdo

14. Vision Transformers

Vision Transformers (ViT)

Na aula de Transformers construímos uma arquitetura para sequências de tokens de texto. Em 2021, Dosovitskiy et al.1 fizeram uma pergunta provocadora: e se alimentássemos uma imagem nesse mesmo encoder? A resposta — "An Image is Worth 16×16 Words" — mostrou que, com dados suficientes, um Transformer quase sem modificações pode superar as Redes Neurais Convolucionais na classificação de imagens, sem nenhuma convolução.

O ViT é a ponte entre o mundo convolucional (aula 11) e o mundo dos Transformers (aula 13). Uma vez que você o entende, os codificadores de imagem dentro do CLIP, do Stable Diffusion e dos Diffusion Transformers deixam de ser caixas-pretas.


O trade-off do viés indutivo

Uma CNN traz dois fortes vieses indutivos embutidos na arquitetura:

  • Localidade — um kernel convolucional olha apenas para uma pequena vizinhança de pixels.
  • Equivariância à translação — o mesmo filtro desliza por toda a imagem, então uma feature é detectada independentemente de onde ela aparece.

Esses priors são exatamente o motivo pelo qual CNNs aprendem com tanta eficiência a partir de datasets pequenos: a arquitetura já "sabe" que pixels próximos importam e que objetos podem se deslocar.

Uma camada de self-attention pura não tem nenhum desses priors. Cada patch pode atender a todos os outros patches desde a camada 1, então o modelo precisa aprender as relações espaciais a partir dos dados. Isso é uma faca de dois gumes:

Fraqueza

  • Precisa de muitos dados para aprender o que a CNN ganha de graça
  • Só no ImageNet-1k, uma CNN de tamanho similar vence
Força

  • Campo receptivo global desde a primeira camada
  • Com dados suficientes, supera as CNNs — o viés era um teto, não só um piso

O pipeline do ViT, passo a passo

O truque é transformar uma imagem em uma sequência de tokens para que o encoder Transformer que já conhecemos possa consumi-la. Percorra o pipeline abaixo:


Patch embedding — a única peça realmente nova

Tudo depois do primeiro passo é o encoder que você já conhece. O único mecanismo novo é como a imagem vira tokens.

Uma imagem \(x \in \mathbb{R}^{H \times W \times C}\) é remodelada em uma sequência de \(N\) patches achatados e, em seguida, cada patch é projetado por uma única camada linear compartilhada:

\[ x \in \mathbb{R}^{H \times W \times C} \;\xrightarrow{\text{patchify}}\; x_p \in \mathbb{R}^{N \times (P^2 C)} \;\xrightarrow{\;E\;}\; z \in \mathbb{R}^{N \times d} \]

onde \(P\) é o tamanho do patch, \(N = HW/P^2\) é o número de patches e \(E \in \mathbb{R}^{(P^2 C) \times d}\) é a matriz de patch embedding. Para uma imagem \(224 \times 224\) com \(P = 16\), isso dá \(N = 196\) tokens.

Um token [CLS] aprendível \(z_{\text{cls}}\) é adicionado ao início e positional embeddings aprendíveis \(E_{\text{pos}}\) são somados (a atenção sozinha é invariante a permutação, então sem posições o modelo não conseguiria distinguir um patch no canto superior esquerdo de um no canto inferior direito):

\[ z_0 = [\, z_{\text{cls}};\; x_p^1 E;\; x_p^2 E;\; \dots;\; x_p^N E \,] + E_{\text{pos}} \]

Nota. Uma projeção linear sobre patches \(P \times P\) não-sobrepostos é matematicamente idêntica a uma Conv2d com kernel_size = stride = P. É exatamente assim que se implementa na prática — uma única convolução faz o patchify e a projeção em uma só operação.


O encoder e a cabeça de classificação

A sequência \(z_0\) passa por \(L\) blocos encoder Transformer idênticos4 — os mesmos blocos MHSA + Add&Norm + FFN da aula de Transformers (o ViT usa a variante pre-norm e GELU):

\[ z'_\ell = \text{MHSA}(\text{LN}(z_{\ell-1})) + z_{\ell-1}, \qquad z_\ell = \text{FFN}(\text{LN}(z'_\ell)) + z'_\ell \]

Para classificação, apenas o estado final do token [CLS] é lido e passado por uma pequena cabeça MLP:

\[ y = \text{cabeça MLP}\big(\text{LN}(z_L^{0})\big) \]

É o modelo inteiro. Sem convoluções, sem pirâmides de pooling — apenas patchify e, em seguida, um encoder Transformer padrão.


Fome de dados: por que o ViT precisa de pré-treinamento

Como lhe faltam os priors convolucionais, o ViT só brilha em escala. O artigo original tornou isso concreto:

  • Treinado apenas no ImageNet-1k (~1,3M imagens), o ViT fica abaixo de uma ResNet comparável.
  • Pré-treinado no ImageNet-21k (~14M) ele se iguala.
  • Pré-treinado no JFT-300M (~300M) ele supera as melhores CNNs e transfere muito bem para tarefas posteriores.

Essa é exatamente a receita de pré-treinar e depois ajustar (finetune) da aula de Transfer Learning: pré-treinar o encoder em um dataset enorme e depois ajustar a cabeça MLP barata (ou o modelo inteiro com taxa de aprendizado baixa) na sua tarefa. É também por isso que o CLIP pôde treinar um encoder de imagem ViT em 400M pares imagem-texto — nessa escala, o viés indutivo fraco vira vantagem.

Duas linhas de pesquisa amenizam essa fome de dados. Receitas cuidadosas de augmentation e regularização permitem treinar ViTs de forma competitiva só com o ImageNet-1k, sem um dataset privado gigante5. E o pré-treinamento auto-supervisionado aprende features fortes do ViT sem nenhum rótulo — reconstruindo patches mascarados (MAE6) ou por auto-destilação (DINO7, cujos mapas de atenção segmentam objetos de graça).


CNN vs. ViT em um relance

CNN Vision Transformer
Operação central Convolução (local) Self-attention (global)
Viés indutivo Forte (localidade, equivariância) Fraco — aprendido dos dados
Campo receptivo Cresce com a profundidade Global desde a camada 1
Eficiência em dados Forte em datasets pequenos Precisa de pré-treinamento em larga escala
Custo computacional \(O(N)\) em pixels \(O(N^2)\) em patches
Escala com dados Satura mais cedo Continua melhorando

Híbridos e sucessores. Várias variantes reintroduzem algum viés espacial para obter o melhor dos dois mundos: o DeiT2 (treinamento eficiente em dados com destilação, sem precisar do JFT) e o Swin Transformer3 (atenção em janelas com estrutura hierárquica, tipo pirâmide, que devolve a localidade e torna o ViT prático para detecção e segmentação). Por outro lado, o ConvNeXt8 modernizou uma CNN pura para igualar os ViTs de igual para igual — evidência de que a receita de treinamento e a escala importam tanto quanto a escolha entre convolução e atenção em si.


Referência de implementação

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
    """Imagem -> sequência de tokens (Conv2d faz patchify + projeção)."""
    def __init__(self, img_size=224, patch=16, in_ch=3, dim=768):
        super().__init__()
        self.n_patches = (img_size // patch) ** 2
        self.proj = nn.Conv2d(in_ch, dim, kernel_size=patch, stride=patch)

    def forward(self, x):                  # x: (B, C, H, W)
        x = self.proj(x)                   # (B, dim, H/p, W/p)
        return x.flatten(2).transpose(1, 2)  # (B, N, dim)
class ViT(nn.Module):
    def __init__(self, dim=768, depth=12, heads=12, n_classes=1000):
        super().__init__()
        self.patch_embed = PatchEmbed(dim=dim)
        n = self.patch_embed.n_patches
        self.cls = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos = nn.Parameter(torch.zeros(1, n + 1, dim))
        layer = nn.TransformerEncoderLayer(dim, heads, dim * 4,
                                           activation='gelu', norm_first=True,
                                           batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, depth)
        self.head = nn.Linear(dim, n_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)                       # (B, N, dim)
        cls = self.cls.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1) + self.pos     # prepor CLS + somar posições
        x = self.encoder(x)
        return self.head(x[:, 0])                      # classifica a partir do token CLS
import timm
# carrega pesos pré-treinados ImageNet-21k -> 1k e ajusta a cabeça
model = timm.create_model('vit_base_patch16_224', pretrained=True,
                          num_classes=10)