Score-based diffusion model code example for stochastic differential equations

Score-Based Generative Modeling through Stochastic Differential Equations

The score-based diffusion model is a method for estimating the gradient of data distribution. It can generate images of the same high quality as GAN without the need for adversarial training. From the article: Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. “Score-Based Generative Modeling through Stochastic Differential Equations.” International Conference on Learning Representations, 2021

Score-based diffusion is another milestone work after the diffusion model became popular, unifying the diffusion model and the score generation model. The original diffusion model also has the disadvantage that it is slow to sample, often requiring thousands of evaluation steps to draw a single sample. The score-based diffusion model can complete sampling in a shorter time.

There are many introductions to the principle of score-based diffusion, application cases, etc. on the Internet, as well as article interpretations, which you can refer to. However, few code introductions are provided, so here is a simple runnable code example of a score-based diffusion model.

1. Define time-dependent score-based model

Import related modules

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import functools
from torch.optim import Adam
from import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm

1.1 The projection layer embedding time t

In fact, there is no such thing as a projection layer. The purpose here is to describe the process of randomly initializing the sampling weights at time t (time step), and then using [sin(2πωt);cos(2πωt)] to generate the corresponding Gaussian random feature vector. Note that the parameters inside are not trainable.

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""
  def __init__(self, embed_dim, scale=30.):
    # Randomly sample weights during initialization. These weights are fixed
    # During optimization and not trainable
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return[torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

The projection layer that embeds time t appears because the training process of the score-based diffusion model is different from that of the normal diffusion model. During the training process of the score-based diffusion model, the neural network accepts x with random noise, then random time information t is added to x, and then x and t are used as inputs to calculate the model loss.

Dimension conversion fully connected layer:

class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)[..., None, None]

1.2 Time-dependent score-based Unet model

(time-dependent score-based model) Time-dependent, score-related Unet model, in the forward function, the input is not only x, but also time t. Time t is embedded into the model after GaussianFourierProjection embedding, and then the result of marginal_prob_std regularization is output.

class ScoreNet(nn.Module):
  """Initialize a time-dependent score-based Unet network."""

  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):

      marginal_prob_std: A function that takes a time t and gives the standard deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
      channels: The number of channels of each resolution feature map.
      embed_dim: The dimension of Gaussian random feature embedding, the same as GaussianFourierProjection in 1.1.
    # Gaussian random feature embedding layer at time t
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, embed_dim))
    # Encoding layers where the resolution decreases
    self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
    self.dense1 = Dense(embed_dim, channels[0])
    self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
    self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
    self.dense2 = Dense(embed_dim, channels[1])
    self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
    self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
    self.dense3 = Dense(embed_dim, channels[2])
    self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
    self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
    self.dense4 = Dense(embed_dim, channels[3])
    self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

    # Decoding layer with increased resolution
    self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
    self.dense5 = Dense(embed_dim, channels[2])
    self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
    self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
    self.dense6 = Dense(embed_dim, channels[1])
    self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
    self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
    self.dense7 = Dense(embed_dim, channels[0])
    self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
    self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
    # Swish activation function
    self.act = lambda x: x * torch.sigmoid(x)
    self.marginal_prob_std = marginal_prob_std
  def forward(self, x, t):
    embed = self.act(self.embed(t))
    #Encoding path
    h1 = self.conv1(x)
    ## Merge information from t
    h1 + = self.dense1(embed)
    ## Group normalization
    h1 = self.gnorm1(h1)
    h1 = self.act(h1)
    h2 = self.conv2(h1)
    h2 + = self.dense2(embed)
    h2 = self.gnorm2(h2)
    h2 = self.act(h2)
    h3 = self.conv3(h2)
    h3 + = self.dense3(embed)
    h3 = self.gnorm3(h3)
    h3 = self.act(h3)
    h4 = self.conv4(h3)
    h4 + = self.dense4(embed)
    h4 = self.gnorm4(h4)
    h4 = self.act(h4)

    #Decoding path
    h = self.tconv4(h4)
    ## Skip concatenation from encoding path
    h + = self.dense5(embed)
    h = self.tgnorm4(h)
    h = self.act(h)
    h = self.tconv3([h, h3], dim=1))
    h + = self.dense6(embed)
    h = self.tgnorm3(h)
    h = self.act(h)
    h = self.tconv2([h, h2], dim=1))
    h + = self.dense7(embed)
    h = self.tgnorm2(h)
    h = self.act(h)
    h = self.tconv1([h, h1], dim=1))

    #Normalize output regularize output
    h = h / self.marginal_prob_std(t)[:, None, None, None]
    return h

2. Set up SDE

SDE is used to perturb P_0 to P_T, which contains two important functions: the previously mentioned marginal_prob_std and diffusion coefficient diffusion_coeff marginal_prob_std, calculating the mean and standard deviation of p_{0t}(x(t) | x(0)) ; diffusion_coeff, calculates the diffusion coefficient of SDE.

device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}

def marginal_prob_std(t, sigma):
  """Calculate the mean and standard deviation of p_{0t}(x(t) | x(0)).

    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
    standard deviation.
  t = torch.tensor(t, device=device)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """Calculate the diffusion coefficient of SDE.

    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
    Diffusion coefficient vector.
  return torch.tensor(sigma**t, device=device)
sigma = 25.0 #@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

3. Define loss function

The loss function is a complex formula, but the specific form is fixed. The code is as follows:

def loss_fn(model, x, marginal_prob_std, eps=1e-5):
  """The loss function for training score-based generative models.

    model: Time-dependent, score-based PyTorch model.
    x: A mini-batch of training data.
    marginal_prob_std: A function that gives the standard deviation of
      the perturbation kernel.
    eps: A tolerance value for numerical stability.
  random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
  z = torch.randn_like(x)
  std = marginal_prob_std(random_t)
  perturbed_x = x + z * std[:, None, None, None]
  score = model(perturbed_x, random_t)
  loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
  return loss

4. Training model

Similar to the normal training model, call the model, establish the optimizer, reverse the loss, etc.; the code is as follows:

score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model =

n_epochs = 50#@param {'type':'integer'}
## size of a mini-batch
batch_size = 32 #@param {'type':'integer'}
## learning rate
lr=1e-4 #@param {'type':'number'}

dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.notebook.trange(n_epochs)
for epoch in tqdm_epoch:
  avg_loss = 0.
  num_items = 0
  for x, y in data_loader:
    x =
    loss = loss_fn(score_model, x, marginal_prob_std_fn)
    avg_loss + = loss.item() * x.shape[0]
    num_items + = x.shape[0]
  # Print the averaged training loss so far.
  tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
  # Update the checkpoint after each epoch of training., 'ckpt.pth')

The output of the training process is as follows:

5. Sampler/Solver

There are multiple solvers for score-based diffusion models,

5.1 Euler-Maruyama sampler/solver (Euler-Maruyama sampler)

The Euler-Maruyama sampling method is a numerical SDE solution method. It is based on the scores predicted by the neural network and uses the reverse-time SDE numerical solution for sampling.

## Number of sampling steps
num_steps = 500 #@param {'type':'integer'}
def Euler_Maruyama_sampler(score_model,
  """Generate samples from a fraction-based model using the Euler-Maruyama solver.

    score_model: Time-dependent, score-based PyTorch model.
    marginal_prob_std: A function that gives the standard deviation of
      the perturbation kernel.
    diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
    batch_size: batch size.
    num_steps: Number of sampling steps, equivalent to the number of discrete time steps.
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    eps: Minimum time step for numerical stability.
    Sampling samples.
  t = torch.ones(batch_size, device=device)
  init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
    * marginal_prob_std(t)[:, None, None, None]
  time_steps = torch.linspace(1., eps, num_steps, device=device)
  step_size = time_steps[0] - time_steps[1]
  x = init_x
  with torch.no_grad():
    for time_step in tqdm.notebook.tqdm(time_steps):
      batch_time_step = torch.ones(batch_size, device=device) * time_step
      g = diffusion_coeff(batch_time_step)
      mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
      x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
  # Do not include any noise in the last sampling step.
  return mean_x

5.2 Prediction-test sampler

The predictive corrector sampler combines a numerical solver for reverse-time SDE with the Langevin MCMC method. Specifically, we first apply a step of the numerical SDE solver to obtain xt?Δt from xt, which is called the “predictor” step. Next, we apply several steps of Langevin MCMC to refine xt so that xt becomes a more accurate sample of pt?Δt(x). This is the “corrector” step because MCMC helps reduce the error of the numerical SDE solver.

signal_to_noise_ratio = 0.16 #@param {'type':'number'}

## The number of sampling steps.
num_steps = 500#@param {'type':'integer'}
def pc_sampler(score_model,
  Generate samples from a score-based model using a predict-correct method.

    score_model: Time-dependent, score-based PyTorch model.
    marginal_prob_std: A function that gives the standard deviation
      of the perturbation kernel.
    diffusion_coeff: A function that gives the diffusion coefficient
      of the SDE.
    batch_size: batch size.
    num_steps: Number of sampling steps, equivalent to the number of discrete time steps.
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    eps: Minimum time step for numerical stability.
    Sampling samples.
  t = torch.ones(batch_size, device=device)
  init_x = torch.randn(batch_size, 1, 28, 28, device=device) * marginal_prob_std(t)[:, None, None, None]
  time_steps = np.linspace(1., eps, num_steps)
  step_size = time_steps[0] - time_steps[1]
  x = init_x
  with torch.no_grad():
    for time_step in tqdm.notebook.tqdm(time_steps):
      batch_time_step = torch.ones(batch_size, device=device) * time_step
      # Verifier step (Langevin MCMC)
      grad = score_model(x, batch_time_step)
      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
      noise_norm = np.sqrt([1:]))
      langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
      x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)

      # Predictor step (Euler-Maruyama)
      g = diffusion_coeff(batch_time_step)
      x_mean = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
      x = x_mean + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)
    # The last step does not include any noise
    return x_mean

5.3 ODE numerical solver

Each SDE corresponds to an ODE. By solving this ODE in the reverse time direction, we can sample from the same distribution as solving the reverse time SDE. We call this ODE a probabilistic flow ODE. This can be done using the many black-box ODE solvers provided by packages such as scipy.

from scipy import integrate

## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5 #@param {'type': 'number'}
def ode_sampler(score_model,
  """Generate samples from score-based models with black-box ODE solvers.

    score_model: A PyTorch model that represents the time-dependent score-based model.
    marginal_prob_std: A function that returns the standard deviation
      of the perturbation kernel.
    diffusion_coeff: A function that returns the diffusion coefficient of the SDE.
    batch_size: The number of samplers to generate by calling this function once.
    atol: Tolerance of absolute errors.
    rtol: Tolerance of relative errors.
    device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
    z: The latent code that governs the final sample. If None, we start from p_1;
      otherwise, we start from the given z.
    eps: The smallest time step for numerical stability.
  t = torch.ones(batch_size, device=device)
  # Create the latent code
  if z is None:
    init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
      * marginal_prob_std(t)[:, None, None, None]
    init_x = z
  shape = init_x.shape

  def score_eval_wrapper(sample, time_steps):
    """A wrapper of the score-based model for use by the ODE solver."""
    sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
    time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
    with torch.no_grad():
      score = score_model(sample, time_steps)
    return score.cpu().numpy().reshape((-1,)).astype(np.float64)
  def ode_func(t, x):
    """The ODE function for use by the ODE solver."""
    time_steps = np.ones((shape[0],)) * t
    g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
    return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
  # Run the black-box ODE solver.
  res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
  print(f"Number of function evaluations: {res.nfev}")
  x = torch.tensor(res.y[:, -1], device=device).reshape(shape)

  return x

6. Sampling

from torchvision.utils import make_grid

## Load the pre-trained checkpoint from disk.
device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
ckpt = torch.load('ckpt.pth', map_location=device)

sample_batch_size = 64 #@param {'type':'integer'}
# Sampler configuration
sampler = ode_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}

## Generate samples using the specified sampler.
samples = sampler(score_model,

## Sample visualization.
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))

plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)

The output is as follows:

You can try other samplers to see the difference in the output results of different samplers.

7. Likelihood Computation

A by-product of the probabilistic flow ODE formulation is the likelihood calculation.

def prior_likelihood(z, sigma):
  """The likelihood of a Gaussian distribution with mean zero and
      standard deviation sigma."""
  shape = z.shape
  N =[1:])
  return -N / 2. * torch.log(2*np.pi*sigma**2) - torch.sum(z**2, dim=(1,2,3)) / (2 * sigma**2 )

def ode_likelihood(x,
  """Compute the likelihood with probability flow ODE.
    x: Input data.
    score_model: A PyTorch model representing the score-based model.
    marginal_prob_std: A function that gives the standard deviation of the
      perturbation kernel.
    diffusion_coeff: A function that gives the diffusion coefficient of the
      forward SDE.
    batch_size: The batch size. Equals to the dimension leading of `x`.
    device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
    eps: A `float` number. The smallest time step for numerical stability.

    z: The latent code for `x`.
    bpd: The log-likelihoods in bits/dim.

  # Draw the random Gaussian sample for Skilling-Hutchinson's estimator.
  epsilon = torch.randn_like(x)
  def divergence_eval(sample, time_steps, epsilon):
    """Compute the divergence of the score-based model with Skilling-Hutchinson."""
    with torch.enable_grad():
      score_e = torch.sum(score_model(sample, time_steps) * epsilon)
      grad_score_e = torch.autograd.grad(score_e, sample)[0]
    return torch.sum(grad_score_e * epsilon, dim=(1, 2, 3))
  shape = x.shape

  def score_eval_wrapper(sample, time_steps):
    """A wrapper for evaluating the score-based model for the black-box ODE solver."""
    sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
    time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
    with torch.no_grad():
      score = score_model(sample, time_steps)
    return score.cpu().numpy().reshape((-1,)).astype(np.float64)
  def divergence_eval_wrapper(sample, time_steps):
    """A wrapper for evaluating the divergence of score for the black-box ODE solver."""
    with torch.no_grad():
      # Obtain x(t) by solving the probability flow ODE.
      sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
      time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
      # Compute likelihood.
      div = divergence_eval(sample, time_steps, epsilon)
      return div.cpu().numpy().reshape((-1,)).astype(np.float64)
  def ode_func(t, x):
    """The ODE function for the black-box solver."""
    time_steps = np.ones((shape[0],)) * t
    sample = x[:-shape[0]]
    logp = x[-shape[0]:]
    g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
    sample_grad = -0.5 * g**2 * score_eval_wrapper(sample, time_steps)
    logp_grad = -0.5 * g**2 * divergence_eval_wrapper(sample, time_steps)
    return np.concatenate([sample_grad, logp_grad], axis=0)

  init = np.concatenate([x.cpu().numpy().reshape((-1,)), np.zeros((shape[0],))], axis=0)
  # Black-box ODE solver
  res = integrate.solve_ivp(ode_func, (eps, 1.), init, rtol=1e-5, atol=1e-5, method='RK45')
  zp = torch.tensor(res.y[:, -1], device=device)
  z = zp[:-shape[0]].reshape(shape)
  delta_logp = zp[-shape[0]:].reshape(shape[0])
  sigma_max = marginal_prob_std(1.)
  prior_logp = prior_likelihood(z, sigma_max)
  bpd = -(prior_logp + delta_logp) / np.log(2)
  N =[1:])
  bpd = bpd / N + 8.
  return z, bpd

Calculate the likelihood of a data set:
batch_size = 32 #@param {'type':'integer'}

dataset = MNIST('.', train=False, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

ckpt = torch.load('ckpt.pth', map_location=device)

all_bpds = 0.
all_items = 0
  tqdm_data = tqdm.notebook.tqdm(data_loader)
  for x, _ in tqdm_data:
    x =
    # uniform dequantization
    x = (x * 255. + torch.rand_like(x)) / 256.
    _, bpd = ode_likelihood(x, score_model, marginal_prob_std_fn,
                            x.shape[0], device=device, eps=1e-5)
    all_bpds + = bpd.sum()
    all_items + = bpd.shape[0]
    tqdm_data.set_description("Average bits/dim: {:5f}".format(all_bpds / all_items))

except KeyboardInterrupt:
  # Remove the error message when interrupted by keyboard or GUI.

8. Summary of my own experience:

(1) The fractional diffusion model of stochastic differential equations requires a time-dependent fraction-based neural network;

(2) Time-dependent score-based neural network forward function, the input is the perturbed x, t, and the output is the score, which is different from the traditional diffusion model; the input of the traditional diffusion model neural network is the perturbed x, and then Output x or noise without noise;

(3) In the time-dependent score-based neural network forward function, several important support functions are needed: GaussianFourierProjection: input time t, output Gaussian random feature vector, so that t can be integrated into x; marginal_prob_std: calculate time step The variance of t is used to normalize the neural network output score; * means that the score-based diffusion model requires rewriting the model architecture

(4) The loss function of the score-based diffusion model is very simple, as: loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1 ,2,3)))

(5) There are many methods for fractional sampling based on neural network output, including: Euler-Maruyama sampler, prediction-test sampler, ODE numerical solver;

(6) Each sampler needs to set up SDE first. An important function in it is diffusion_coeff_fn, which is used to calculate the diffusion coefficient of SDE.

(7) Each sampler has a fixed form and can be used directly;

At the end of the writing, I have not introduced the principles of the score-based diffusion model here. Because now there are many blogs or public accounts, and the videos have detailed introductions, including detailed consensus derivation. In addition, I am not a mathematics major, and many of the mathematical principles in it are half-understood, so I won’t detain everyone. You can view relevant information.

Code for the above part:

Extraction code: m5o9

Regarding the principle, if you have any ideas or want to understand it in a simple way, you can leave a message, and you can consider coming up with one to talk about it specifically.