Ir para o conteúdo

21. Flow-Matching

Flow Matching

Flow Matching1 é um framework de geração que treina um campo de velocidade \(v_\theta\) para transformar uma distribuição simples (ruído Gaussiano \(p_0\)) em uma distribuição de dados complexa (\(p_1\)) ao longo de uma trajetória contínua no tempo.

Comparado a modelos de difusão (DDPM), o Flow Matching tem trajetórias mais retas, requerendo menos passos de integração na inferência, sendo ao mesmo tempo mais simples de treinar.


Intuição: Mover Partículas

Imagine que você tem partículas espalhadas segundo uma Gaussiana \(p_0 = \mathcal{N}(0, I)\). Você quer movê-las para que, ao final do tempo \(t=1\), estejam distribuídas como seus dados de treino \(p_1\).

O Flow Matching aprende um campo vetorial \(v_\theta(x, t)\) que "empurra" cada partícula na direção certa a cada instante \(t \in [0, 1]\).


Formulação Matemática

O objetivo é aprender um campo de velocidade \(v_\theta : \mathbb{R}^d \times [0,1] \to \mathbb{R}^d\) tal que integrar a ODE:

\[ \frac{dx}{dt} = v_\theta(x, t), \quad x_0 \sim p_0 \]

produza \(x_1 \sim p_1\).

Conditional Flow Matching (CFM)

Dado um caminho condicional \(x_t = (1-t)x_0 + t x_1\) (interpolação linear — "caminho de Transporte Ótimo"), o campo de velocidade condicional é simplesmente:

\[ u_t(x \mid x_0, x_1) = x_1 - x_0 \]

A loss do CFM:

\[ \mathcal{L}_{\text{CFM}} = \mathbb{E}_{t, p(x_0), p(x_1)} \left[ \left\| v_\theta(x_t, t) - (x_1 - x_0) \right\|^2 \right] \]

onde \(x_t = (1-t)x_0 + t x_1\). A loss é simplesmente o MSE entre o campo predito e a direção de interpolação linear — sem cronograma de ruído complexo.


Flow Matching vs. Difusão

Aspecto DDPM Flow Matching
Trajetória Curva (ruído incremental) Reta (caminho OT)
Passos de inferência 50–1000 10–50
Loss Predição de ruído \(\epsilon\) Predição de velocidade \(v\)
Cronograma \(\beta_t\) complexo \(t\) uniforme \([0,1]\)
Velocidade de inferência Mais lento 2–10× mais rápido

FLUX.1 — Estado da Arte (2024)

FLUX.12 (Black Forest Labs) usa Flow Matching com Diffusion Transformers (DiT) — substituindo a U-Net por blocos Transformer puros:

flowchart LR
    A["Texto\n(prompt)"] --> B["Encoder de Texto\n(CLIP + T5-XXL)"]
    N["Ruído\nz₀ ~ N(0,I)"] --> C
    B --> C["Diffusion Transformer\n12B params\n(Flow Matching)"]
    C -->|"ODE: 20-50 passos"| D["Latente z₁"]
    D --> E["Decoder VAE"]
    E --> F["Imagem 1024x1024"]
  • 12B parâmetros (FLUX.1-dev, open-source)
  • Suporta múltiplas proporções de aspecto nativamente
  • Qualidade superior ao SDXL e SD3 em benchmarks

Inferência: Resolvendo a ODE

import torch

def sample_flow_matching(model, n_samples, n_steps=50, device='cuda'):
    dt = 1.0 / n_steps
    x = torch.randn(n_samples, *data_shape, device=device)  # z0 ~ N(0,I)

    for i in range(n_steps):
        t = torch.full((n_samples,), i * dt, device=device)
        v = model(x, t)          # campo de velocidade predito
        x = x + dt * v           # integração Euler simples

    return x  # z1 ~ p_data

# Com solucionador Heun (2ª ordem, melhor qualidade):
def sample_heun(model, n_samples, n_steps=20, device='cuda'):
    dt = 1.0 / n_steps
    x = torch.randn(n_samples, *data_shape, device=device)
    for i in range(n_steps):
        t = torch.full((n_samples,), i*dt, device=device)
        v1 = model(x, t)
        x_pred = x + dt * v1
        t2 = torch.full((n_samples,), (i+1)*dt, device=device)
        v2 = model(x_pred, t2)
        x = x + dt * (v1 + v2) / 2  # média de Heun
    return x



  1. Lipman, Y. et al. (2022). Flow Matching for Generative Modeling

  2. Black Forest Labs. (2024). FLUX.1: State-of-the-art text-to-image generation

  3. Liu, X. et al. (2022). Flow Straight and Fast: Rectified Flow