22. Diffusion Transformers
Diffusion Transformers (DiT)
In 2023, Peebles & Xie1 demonstrated something simple and impactful: the U-Net is not necessary for diffusion models. By replacing it with pure Transformer blocks, the model not only maintained quality — it started to scale predictably with more parameters and data, exactly like language models.
Today, all state-of-the-art image and video generation uses DiT:
| Model | Architecture | Objective |
|---|---|---|
| FLUX.1 | DiT (dual-stream) | Flow Matching |
| Stable Diffusion 3 | MMDiT | Flow Matching |
| Sora (OpenAI) | Spacetime DiT | Diffusion |
| Movie Gen (Meta) | DiT | Flow Matching |
| CogVideoX | DiT 3D | Flow Matching |
From U-Net to Transformer
The classic U-Net uses convolutions with hierarchical skip connections — good for capturing local details, but difficult to scale. DiT replaces all of this with global attention blocks.
Step 1 — Patchify: Images as Token Sequences
Just like ViT divides images into patches, DiT operates in the latent space (after the VAE encoder). A latent of shape \(H \times W \times C\) is divided into patches of size \(p \times p\):
Each patch is flattened and projected to dimension \(d_{\text{model}}\) — becoming a "visual token".
Step 2 — DiT Block with AdaLN
DiT uses Adaptive Layer Normalization (AdaLN) to inject timestep and class/text information directly into the normalization parameters:
where \(c = \text{MLP}(\text{emb}(t) + \text{emb}(\text{class}))\) is the conditioning vector.
The parameters \(\gamma\) and \(\beta\) are predicted — not statically learned — making normalization sensitive to the diffusion step and the prompt.
Step 3 — MMDiT: Multi-Modal Bidirectional Attention
MMDiT (SD3, FLUX) goes beyond cross-attention conditioning. Text and image participate in the same attention operation:
Image tokens see text tokens and vice versa — much richer conditioning than injecting text only via cross-attention.
FLUX uses a "dual stream" design: separate weights for image and text in Q/K/V/FFN blocks, but shared attention:
Img stream: x_img → W_q^img·x ─┐
├─→ concat → Attention(Q,K,V) → split
Txt stream: x_txt → W_q^txt·x ─┘
Visualization: Complete Generation Process
Why Does DiT Scale Better?
U-Nets have pooling and upsample operations that destroy global information. Skip connections help, but there is still a hierarchical bottleneck. In DiTs, each token sees all other tokens from the very first block — \(O(N^2)\) attention, but with full global access.
This means that when doubling the parameters (more blocks, more dimensions), the DiT makes full use of the extra capacity, while the U-Net hits diminishing returns faster.
| Model | Parameters | Tokens | d_model | Blocks |
|---|---|---|---|---|
| DiT-XL/2 (original) | 675M | 256 | 1152 | 28 |
| SD3 MMDiT | 2B | 1024 | 1536 | 38 |
| FLUX.1-dev | 12B | 4096 | 3072 | 57 |
Simplified Implementation
import torch
import torch.nn as nn
class AdaLN(nn.Module):
def __init__(self, d_model, d_cond):
super().__init__()
self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
self.proj = nn.Linear(d_cond, 2 * d_model) # → γ, β
def forward(self, x, c):
gamma, beta = self.proj(c).chunk(2, dim=-1)
return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)
class DiTBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, d_cond):
super().__init__()
self.adaln1 = AdaLN(d_model, d_cond)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.adaln2 = AdaLN(d_model, d_cond)
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
def forward(self, x, c):
h = self.adaln1(x, c)
x = x + self.attn(h, h, h)[0] # self-attention
x = x + self.ff(self.adaln2(x, c)) # FFN
return x
class DiT(nn.Module):
def __init__(self, in_channels, patch_size, d_model, n_heads, d_ff, n_layers, d_cond):
super().__init__()
self.patch_size = patch_size
p = patch_size
self.patchify = nn.Conv2d(in_channels, d_model, p, stride=p)
self.blocks = nn.ModuleList([DiTBlock(d_model, n_heads, d_ff, d_cond) for _ in range(n_layers)])
self.norm_out = nn.LayerNorm(d_model)
self.depatchify = nn.Linear(d_model, p*p*in_channels)
def forward(self, x, t_emb, cond):
# x: (B, C, H, W) noisy latent
B, C, H, W = x.shape
tokens = self.patchify(x) # (B, d, H/p, W/p)
tokens = tokens.flatten(2).transpose(1, 2) # (B, N, d)
c = t_emb + cond # combine conditioning
for block in self.blocks:
tokens = block(tokens, c)
tokens = self.norm_out(tokens)
patches = self.depatchify(tokens) # (B, N, p*p*C)
# reshape back to (B, C, H, W)
p = self.patch_size
patches = patches.view(B, H//p, W//p, p, p, C).permute(0,5,1,3,2,4).reshape(B,C,H,W)
return patches # predicted velocity field v_θ(x_t, t)
-
Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023. ↩
-
Esser, P. et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis (SD3). ↩
-
Dosovitskiy, A. et al. (2021). An Image is Worth 16×16 Words (ViT). ↩