Ir para o conteúdo

22. Diffusion Transformers

Diffusion Transformers (DiT)

Em 2023, Peebles & Xie1 demonstraram algo simples e impactante: a U-Net não é necessária para modelos de difusão. Substituindo-a por blocos Transformer puros, o modelo não só manteve a qualidade — ele passou a escalar previsivelmente com mais parâmetros e dados, exatamente como modelos de linguagem.

Hoje, toda geração de imagem e vídeo de ponta usa DiT:

Modelo Arquitetura Objetivo
FLUX.1 DiT (duplo-stream) Flow Matching
Stable Diffusion 3 MMDiT Flow Matching
Sora (OpenAI) Spacetime DiT Difusão
Movie Gen (Meta) DiT Flow Matching
CogVideoX DiT 3D Flow Matching

De U-Net para Transformer

A U-Net clássica usa convoluções com skip connections hierárquicas — boa para capturar detalhes locais, mas difícil de escalar. O DiT substitui tudo isso por blocos de atenção global.


Passo 1 — Patchify: Imagens como Sequências de Tokens

Assim como o ViT divide imagens em patches, o DiT opera no espaço latente (após o encoder VAE). Um latente de forma \(H \times W \times C\) é dividido em patches de tamanho \(p \times p\):

\[ \text{Número de tokens: } N = \frac{H}{p} \times \frac{W}{p} \]

Cada patch é linearizado e projetado para a dimensão \(d_{\text{model}}\) — tornando-se um "token visual".


Passo 2 — Bloco DiT com AdaLN

O DiT usa Adaptive Layer Normalization (AdaLN) para injetar a informação de timestep e classe/texto diretamente nos parâmetros de normalização:

\[ \text{AdaLN}(h, c) = \gamma(c) \cdot \frac{h - \mu}{\sigma} + \beta(c) \]

onde \(c = \text{MLP}(\text{emb}(t) + \text{emb}(\text{classe}))\) é o vetor de condicionamento.

Os parâmetros \(\gamma\) e \(\beta\) são preditos — não aprendidos estaticamente — tornando a normalização sensível ao passo de difusão e ao prompt.


Passo 3 — MMDiT: Atenção Bidirecional Multi-Modal

MMDiT (SD3, FLUX) vai além do condicionamento por cross-attention. Texto e imagem participam da mesma operação de atenção:

\[ [Q_{img} \| Q_{txt}] \cdot [K_{img} \| K_{txt}]^\top \]

Os tokens de imagem veem os tokens de texto e vice-versa — condicionamento muito mais rico do que injetar texto apenas via cross-attention.

O FLUX usa um design de "duplo stream": pesos separados para imagem e texto nos blocos Q/K/V/FFN, mas atenção compartilhada:

Stream img:  x_img → W_q^img·x  ─┐
                                   ├─→ concat → Atenção(Q,K,V) → separar
Stream txt:  x_txt → W_q^txt·x  ─┘

Visualização: Processo Completo de Geração


Por Que DiT Escala Melhor?

U-Nets têm operações de pooling e upsample que destroem informação global. Skip connections ajudam, mas ainda há um gargalo hierárquico. Em DiTs, cada token vê todos os outros tokens desde o primeiro bloco — atenção \(O(N^2)\), mas com acesso global completo.

Isso significa que ao dobrar os parâmetros (mais blocos, mais dimensões), o DiT aproveita toda a capacidade extra, enquanto a U-Net tem retornos decrescentes mais rápidos.

Modelo Parâmetros Tokens d_model Blocos
DiT-XL/2 (original) 675M 256 1152 28
SD3 MMDiT 2B 1024 1536 38
FLUX.1-dev 12B 4096 3072 57

Implementação Simplificada

import torch
import torch.nn as nn

class AdaLN(nn.Module):
    def __init__(self, d_model, d_cond):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
        self.proj = nn.Linear(d_cond, 2 * d_model)  # → γ, β

    def forward(self, x, c):
        gamma, beta = self.proj(c).chunk(2, dim=-1)
        return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)

class DiTBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, d_cond):
        super().__init__()
        self.adaln1 = AdaLN(d_model, d_cond)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.adaln2 = AdaLN(d_model, d_cond)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))

    def forward(self, x, c):
        h = self.adaln1(x, c)
        x = x + self.attn(h, h, h)[0]      # self-attention
        x = x + self.ff(self.adaln2(x, c)) # FFN
        return x

class DiT(nn.Module):
    def __init__(self, in_channels, patch_size, d_model, n_heads, d_ff, n_layers, d_cond):
        super().__init__()
        self.patch_size = patch_size
        p = patch_size
        self.patchify = nn.Conv2d(in_channels, d_model, p, stride=p)
        self.blocks = nn.ModuleList([DiTBlock(d_model, n_heads, d_ff, d_cond) for _ in range(n_layers)])
        self.norm_out = nn.LayerNorm(d_model)
        self.depatchify = nn.Linear(d_model, p*p*in_channels)

    def forward(self, x, t_emb, cond):
        # x: (B, C, H, W) latente ruidoso
        B, C, H, W = x.shape
        tokens = self.patchify(x)                       # (B, d, H/p, W/p)
        tokens = tokens.flatten(2).transpose(1, 2)      # (B, N, d)
        c = t_emb + cond                                # combinar condicionamento
        for block in self.blocks:
            tokens = block(tokens, c)
        tokens = self.norm_out(tokens)
        patches = self.depatchify(tokens)               # (B, N, p*p*C)
        # reformatar para (B, C, H, W)
        p = self.patch_size
        patches = patches.view(B, H//p, W//p, p, p, C).permute(0,5,1,3,2,4).reshape(B,C,H,W)
        return patches  # campo de velocidade predito v_θ(x_t, t)



  1. Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023. 

  2. Esser, P. et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis (SD3)

  3. Black Forest Labs. (2024). FLUX.1

  4. Dosovitskiy, A. et al. (2021). An Image is Worth 16×16 Words (ViT)