Skip to content

22. Diffusion Transformers

Diffusion Transformers (DiT)

In 2023, Peebles & Xie1 demonstrated something simple and impactful: the U-Net is not necessary for diffusion models. By replacing it with pure Transformer blocks, the model not only maintained quality — it started to scale predictably with more parameters and data, exactly like language models.

Today, all state-of-the-art image and video generation uses DiT:

Model Architecture Objective
FLUX.1 DiT (dual-stream) Flow Matching
Stable Diffusion 3 MMDiT Flow Matching
Sora (OpenAI) Spacetime DiT Diffusion
Movie Gen (Meta) DiT Flow Matching
CogVideoX DiT 3D Flow Matching

From U-Net to Transformer

The classic U-Net uses convolutions with hierarchical skip connections — good for capturing local details, but difficult to scale. DiT replaces all of this with global attention blocks.


Step 1 — Patchify: Images as Token Sequences

Just like ViT divides images into patches, DiT operates in the latent space (after the VAE encoder). A latent of shape \(H \times W \times C\) is divided into patches of size \(p \times p\):

\[ \text{Number of tokens: } N = \frac{H}{p} \times \frac{W}{p} \]

Each patch is flattened and projected to dimension \(d_{\text{model}}\) — becoming a "visual token".


Step 2 — DiT Block with AdaLN

DiT uses Adaptive Layer Normalization (AdaLN) to inject timestep and class/text information directly into the normalization parameters:

\[ \text{AdaLN}(h, c) = \gamma(c) \cdot \frac{h - \mu}{\sigma} + \beta(c) \]

where \(c = \text{MLP}(\text{emb}(t) + \text{emb}(\text{class}))\) is the conditioning vector.

The parameters \(\gamma\) and \(\beta\) are predicted — not statically learned — making normalization sensitive to the diffusion step and the prompt.


Step 3 — MMDiT: Multi-Modal Bidirectional Attention

MMDiT (SD3, FLUX) goes beyond cross-attention conditioning. Text and image participate in the same attention operation:

\[ [Q_{img} \| Q_{txt}] \cdot [K_{img} \| K_{txt}]^\top \]

Image tokens see text tokens and vice versa — much richer conditioning than injecting text only via cross-attention.

FLUX uses a "dual stream" design: separate weights for image and text in Q/K/V/FFN blocks, but shared attention:

Img stream:  x_img → W_q^img·x  ─┐
                                   ├─→ concat → Attention(Q,K,V) → split
Txt stream:  x_txt → W_q^txt·x  ─┘

Visualization: Complete Generation Process


Why Does DiT Scale Better?

U-Nets have pooling and upsample operations that destroy global information. Skip connections help, but there is still a hierarchical bottleneck. In DiTs, each token sees all other tokens from the very first block — \(O(N^2)\) attention, but with full global access.

This means that when doubling the parameters (more blocks, more dimensions), the DiT makes full use of the extra capacity, while the U-Net hits diminishing returns faster.

Model Parameters Tokens d_model Blocks
DiT-XL/2 (original) 675M 256 1152 28
SD3 MMDiT 2B 1024 1536 38
FLUX.1-dev 12B 4096 3072 57

Simplified Implementation

import torch
import torch.nn as nn

class AdaLN(nn.Module):
    def __init__(self, d_model, d_cond):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
        self.proj = nn.Linear(d_cond, 2 * d_model)  # → γ, β

    def forward(self, x, c):
        gamma, beta = self.proj(c).chunk(2, dim=-1)
        return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)

class DiTBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, d_cond):
        super().__init__()
        self.adaln1 = AdaLN(d_model, d_cond)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.adaln2 = AdaLN(d_model, d_cond)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))

    def forward(self, x, c):
        h = self.adaln1(x, c)
        x = x + self.attn(h, h, h)[0]      # self-attention
        x = x + self.ff(self.adaln2(x, c)) # FFN
        return x

class DiT(nn.Module):
    def __init__(self, in_channels, patch_size, d_model, n_heads, d_ff, n_layers, d_cond):
        super().__init__()
        self.patch_size = patch_size
        p = patch_size
        self.patchify = nn.Conv2d(in_channels, d_model, p, stride=p)
        self.blocks = nn.ModuleList([DiTBlock(d_model, n_heads, d_ff, d_cond) for _ in range(n_layers)])
        self.norm_out = nn.LayerNorm(d_model)
        self.depatchify = nn.Linear(d_model, p*p*in_channels)

    def forward(self, x, t_emb, cond):
        # x: (B, C, H, W) noisy latent
        B, C, H, W = x.shape
        tokens = self.patchify(x)                       # (B, d, H/p, W/p)
        tokens = tokens.flatten(2).transpose(1, 2)      # (B, N, d)
        c = t_emb + cond                                # combine conditioning
        for block in self.blocks:
            tokens = block(tokens, c)
        tokens = self.norm_out(tokens)
        patches = self.depatchify(tokens)               # (B, N, p*p*C)
        # reshape back to (B, C, H, W)
        p = self.patch_size
        patches = patches.view(B, H//p, W//p, p, p, C).permute(0,5,1,3,2,4).reshape(B,C,H,W)
        return patches  # predicted velocity field v_θ(x_t, t)



  1. Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023. 

  2. Esser, P. et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis (SD3)

  3. Black Forest Labs. (2024). FLUX.1

  4. Dosovitskiy, A. et al. (2021). An Image is Worth 16×16 Words (ViT)