Diffusion & Flow Matching Part 6: Conditional and Marginal Score Functions

diffusion
flow-matching
generative-models
score-function
Author

Hujie Wang

Published

December 2, 2025

We have constructed the training target for the flow matching model, now we will extend this reasoning to SDEs.

Let us start with the SDE extension trick.

TipTheorem: SDE Extension Trick

Define the conditional and marginal vector fields \(u_t^{target}(x | z)\) and \(u_t^{target}(x)\) as before. For diffusion coefficient \(\sigma_t \geq 0\), we can construct an SDE that follows the marginal probability path \(p_t(x)\):

\[ \begin{aligned} dX_t &= [u_t^{target}(X_t) + \frac{\sigma_t^2}{2} \nabla \log p_t(X_t)] \, dt + \sigma_t \, dW_t \\ X_0 &\sim p_{\text{init}} \\ \implies X_t &\sim p_t \quad \text{for all } 0 \leq t \leq 1 \end{aligned} \]

where \(\nabla \log p_t(x)\) is the marginal score function of \(p_t\).

The same identity holds if we replace the marginal probability \(p_t(x)\) and vector field \(u_t^{target}(x)\) with the conditional probability \(p_t(x | z)\) and vector field \(u_t^{target}(x | z)\):

We can express the marginal score function via the conditional score function \(\nabla \log p_t(x | z)\):

\[ \begin{aligned} \nabla \log p_t(x) &= \frac{\nabla p_t(x)}{p_t(x)} && \text{(definition of } \nabla \log) \\[0.5em] & = \frac{\nabla \int p_t(x | z) p_{data}(z) dz}{p_t(x)} && \text{(marginalization)} \\[0.5em] & = \frac{\int \nabla p_t(x | z) p_{data}(z) dz}{p_t(x)} && \text{(Leibniz rule}^* \text{)} \\[0.5em] & = \int \nabla \log p_t(x | z) \frac{p_t(x | z) p_{data}(z)}{p_t(x)} dz && \text{(see below)} \end{aligned} \]

\(^*\) Swapping \(\nabla\) and \(\int\) requires regularity conditions (the Leibniz integral rule). This holds when the integrand is smooth and its derivative is bounded by an integrable function — standard assumptions for “nice” probability densities.

The key is to rewrite \(\nabla p_t(x|z)\) using the chain rule. For gradients, the chain rule gives us: \[ \nabla_x \log f(x) = \frac{1}{f(x)} \nabla_x f(x) \] (This is the multivariate version of \(\frac{d}{dx}\log f = \frac{1}{f}\frac{df}{dx}\).)

Rearranging: \[ \nabla p_t(x|z) = p_t(x|z) \cdot \nabla \log p_t(x|z) \]

Substituting this into the integral: \[ \frac{\int \nabla p_t(x | z) p_{data}(z) dz}{p_t(x)} = \frac{\int p_t(x|z) \cdot \nabla \log p_t(x | z) \cdot p_{data}(z) dz}{p_t(x)} \]

Since \(p_t(x)\) doesn’t depend on \(z\), we can move it inside the integral: \[ = \int \nabla \log p_t(x | z) \cdot \frac{p_t(x | z) p_{data}(z)}{p_t(x)} dz \]

Notice that \(\frac{p_t(x | z) p_{data}(z)}{p_t(x)} = p(z | x, t)\) by Bayes’ theorem, so this is an expectation over the posterior: \[ \nabla \log p_t(x) = \mathbb{E}_{z \sim p(z | x, t)}[\nabla \log p_t(x | z)] \]

The conditional score function \(\nabla \log p_t(x | z)\) is usually something we can calculate analytically (e.g., for Gaussian conditionals).

NoteExample: Gaussian Conditional Score Function

For a Gaussian conditional \(p_t(x | z) = \mathcal{N}(x; \alpha_t z, \beta_t^2 I_d)\) where \(x \in \mathbb{R}^d\), the conditional score function is: \[ \nabla_x \log p_t(x | z) = -\frac{x - \alpha_t z}{\beta_t^2} \]

Derivation: The multivariate Gaussian PDF with mean \(\mu\) and covariance \(\Sigma\) is: \[ p(x) = \frac{1}{(2\pi)^{d/2} |\Sigma|^{1/2}} \exp\left(-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)\right) \]

For isotropic covariance \(\Sigma = \beta_t^2 I_d\), we have \(|\Sigma|^{1/2} = \beta_t^d\) and \(\Sigma^{-1} = \frac{1}{\beta_t^2}I_d\), so: \[ p_t(x | z) = \frac{1}{(2\pi \beta_t^2)^{d/2}} \exp\left(-\frac{\|x - \alpha_t z\|^2}{2\beta_t^2}\right) \]

Taking the log: \[ \log p_t(x | z) = \underbrace{-\frac{d}{2}\log(2\pi \beta_t^2)}_{\text{constant in } x} - \frac{\|x - \alpha_t z\|^2}{2\beta_t^2} \]

Taking the gradient with respect to \(x \in \mathbb{R}^d\) (the constant term vanishes): \[ \nabla_x \log p_t(x | z) = -\frac{1}{2\beta_t^2} \nabla_x \|x - \alpha_t z\|^2 = -\frac{1}{2\beta_t^2} \cdot 2(x - \alpha_t z) = -\frac{x - \alpha_t z}{\beta_t^2} \]

(We used \(\nabla_x \|x - a\|^2 = 2(x - a)\), which follows from \(\|x-a\|^2 = (x-a)^T(x-a)\).)