22. Diffusion Transformers
Diffusion Transformers (DiT)
Em 2023, Peebles & Xie1 demonstraram algo simples e impactante: a U-Net não é necessária para modelos de difusão. Substituindo-a por blocos Transformer puros, o modelo não só manteve a qualidade — ele passou a escalar previsivelmente com mais parâmetros e dados, exatamente como modelos de linguagem.
Hoje, toda geração de imagem e vídeo de ponta usa DiT:
| Modelo | Arquitetura | Objetivo |
|---|---|---|
| FLUX.1 | DiT (duplo-stream) | Flow Matching |
| Stable Diffusion 3 | MMDiT | Flow Matching |
| Sora (OpenAI) | Spacetime DiT | Difusão |
| Movie Gen (Meta) | DiT | Flow Matching |
| CogVideoX | DiT 3D | Flow Matching |
De U-Net para Transformer
A U-Net clássica usa convoluções com skip connections hierárquicas — boa para capturar detalhes locais, mas difícil de escalar. O DiT substitui tudo isso por blocos de atenção global.
Passo 1 — Patchify: Imagens como Sequências de Tokens
Assim como o ViT divide imagens em patches, o DiT opera no espaço latente (após o encoder VAE). Um latente de forma \(H \times W \times C\) é dividido em patches de tamanho \(p \times p\):
Cada patch é linearizado e projetado para a dimensão \(d_{\text{model}}\) — tornando-se um "token visual".
Passo 2 — Bloco DiT com AdaLN
O DiT usa Adaptive Layer Normalization (AdaLN) para injetar a informação de timestep e classe/texto diretamente nos parâmetros de normalização:
onde \(c = \text{MLP}(\text{emb}(t) + \text{emb}(\text{classe}))\) é o vetor de condicionamento.
Os parâmetros \(\gamma\) e \(\beta\) são preditos — não aprendidos estaticamente — tornando a normalização sensível ao passo de difusão e ao prompt.
Passo 3 — MMDiT: Atenção Bidirecional Multi-Modal
MMDiT (SD3, FLUX) vai além do condicionamento por cross-attention. Texto e imagem participam da mesma operação de atenção:
Os tokens de imagem veem os tokens de texto e vice-versa — condicionamento muito mais rico do que injetar texto apenas via cross-attention.
O FLUX usa um design de "duplo stream": pesos separados para imagem e texto nos blocos Q/K/V/FFN, mas atenção compartilhada:
Stream img: x_img → W_q^img·x ─┐
├─→ concat → Atenção(Q,K,V) → separar
Stream txt: x_txt → W_q^txt·x ─┘
Visualização: Processo Completo de Geração
Por Que DiT Escala Melhor?
U-Nets têm operações de pooling e upsample que destroem informação global. Skip connections ajudam, mas ainda há um gargalo hierárquico. Em DiTs, cada token vê todos os outros tokens desde o primeiro bloco — atenção \(O(N^2)\), mas com acesso global completo.
Isso significa que ao dobrar os parâmetros (mais blocos, mais dimensões), o DiT aproveita toda a capacidade extra, enquanto a U-Net tem retornos decrescentes mais rápidos.
| Modelo | Parâmetros | Tokens | d_model | Blocos |
|---|---|---|---|---|
| DiT-XL/2 (original) | 675M | 256 | 1152 | 28 |
| SD3 MMDiT | 2B | 1024 | 1536 | 38 |
| FLUX.1-dev | 12B | 4096 | 3072 | 57 |
Implementação Simplificada
import torch
import torch.nn as nn
class AdaLN(nn.Module):
def __init__(self, d_model, d_cond):
super().__init__()
self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
self.proj = nn.Linear(d_cond, 2 * d_model) # → γ, β
def forward(self, x, c):
gamma, beta = self.proj(c).chunk(2, dim=-1)
return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)
class DiTBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, d_cond):
super().__init__()
self.adaln1 = AdaLN(d_model, d_cond)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.adaln2 = AdaLN(d_model, d_cond)
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
def forward(self, x, c):
h = self.adaln1(x, c)
x = x + self.attn(h, h, h)[0] # self-attention
x = x + self.ff(self.adaln2(x, c)) # FFN
return x
class DiT(nn.Module):
def __init__(self, in_channels, patch_size, d_model, n_heads, d_ff, n_layers, d_cond):
super().__init__()
self.patch_size = patch_size
p = patch_size
self.patchify = nn.Conv2d(in_channels, d_model, p, stride=p)
self.blocks = nn.ModuleList([DiTBlock(d_model, n_heads, d_ff, d_cond) for _ in range(n_layers)])
self.norm_out = nn.LayerNorm(d_model)
self.depatchify = nn.Linear(d_model, p*p*in_channels)
def forward(self, x, t_emb, cond):
# x: (B, C, H, W) latente ruidoso
B, C, H, W = x.shape
tokens = self.patchify(x) # (B, d, H/p, W/p)
tokens = tokens.flatten(2).transpose(1, 2) # (B, N, d)
c = t_emb + cond # combinar condicionamento
for block in self.blocks:
tokens = block(tokens, c)
tokens = self.norm_out(tokens)
patches = self.depatchify(tokens) # (B, N, p*p*C)
# reformatar para (B, C, H, W)
p = self.patch_size
patches = patches.view(B, H//p, W//p, p, p, C).permute(0,5,1,3,2,4).reshape(B,C,H,W)
return patches # campo de velocidade predito v_θ(x_t, t)
-
Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023. ↩
-
Esser, P. et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis (SD3). ↩
-
Dosovitskiy, A. et al. (2021). An Image is Worth 16×16 Words (ViT). ↩