Skip to content

GIDD Sampling

discrete_diffusion.sampling.gidd

GIDD sampler for hybrid diffusion models.

GIDDSampler

Bases: Sampler

Sampler for GIDD (Generalized Iterative Discrete Diffusion) models.

Source code in src/discrete_diffusion/sampling/gidd.py
class GIDDSampler(Sampler):
  """Sampler for GIDD (Generalized Iterative Discrete Diffusion) models."""

  def __init__(self, config, forward_process=None):
    self.config = config

  def compute_posterior(self, model, z_t, t, s):
    """Compute posterior q(z_s | z_t, x_0) for GIDD.

    Args:
      model: The GIDD model instance.
      z_t: Current noisy samples at time t.
      t: Current timestep.
      s: Next (less noisy) timestep.

    Returns:
      Samples from the posterior distribution.
    """
    # Get model prediction (logits)
    sigma_t = model._sigma_from_alphat(model._loglinear.alpha_t(t))
    sigma_t = model._process_sigma(sigma_t)
    logits = model.backbone(z_t, sigma_t)

    # Mask out the mask token from predictions
    logits = logits.clone()
    logits[..., model.mask_id] = model.neg_infinity
    probs = logits.softmax(-1)

    # Get noise schedule values
    q_s = model.hybrid_noise.probs_at_t(probs, s)
    q_t = model.hybrid_noise.probs_at_t(probs, t)
    q_zt = q_t.gather(-1, z_t.unsqueeze(-1))

    alpha_t, beta_pi_t = model.hybrid_noise.get_alpha_betapi(t)
    alpha_s, beta_pi_s = model.hybrid_noise.get_alpha_betapi(s)

    # Compute transition probabilities
    alpha_ts = alpha_t / alpha_s
    beta_pi_ts = beta_pi_t - alpha_t / alpha_s * beta_pi_s

    vz_t = F.one_hot(z_t, num_classes=model.vocab_size)
    beta_pi_ts_at_zt = beta_pi_ts.unsqueeze(1).expand_as(vz_t).gather(
      -1, z_t.unsqueeze(-1))
    q_ts = (alpha_ts * vz_t + beta_pi_ts_at_zt)

    # Compute posterior
    q_st = q_ts * q_s / q_zt

    # Optional: apply minimum probability threshold
    min_p = getattr(self.config.sampling, 'min_p', 0.0)
    if min_p > 0.0:
      is_small = (q_st < min_p).float()
      q_st = (1 - is_small) * q_st
      q_st = q_st / q_st.sum(-1, keepdim=True)

    return sample_categorical(q_st)

  @torch.no_grad()
  def generate(self, model, *, num_samples, num_steps, eps, inject_bos):
    """Generate samples using GIDD reverse diffusion process.

    Args:
      model: The GIDD model instance.
      num_samples: Number of samples to generate.
      num_steps: Number of denoising steps.
      eps: Minimum timestep value (epsilon).
      inject_bos: Whether to inject BOS token at position 0.

    Returns:
      Generated token sequences of shape [num_samples, num_tokens].
    """
    if num_steps is None:
      num_steps = self.config.sampling.steps
    if eps is None:
      eps = getattr(self.config.algo, 't_eps', 1e-4)

    # Sample from the prior (fully masked)
    z_t = model.hybrid_noise.sample_prior(
      (num_samples, model.num_tokens))
    if inject_bos:
      z_t[:, 0] = model.tokenizer.bos_token_id

    # Create timestep schedule from 1-eps to eps
    timesteps = torch.linspace(
      1 - eps, eps, num_steps + 1, device=model.device)

    # Iteratively denoise
    for i in range(num_steps):
      t = timesteps[i] * torch.ones(
        num_samples, device=model.device)
      s = timesteps[i + 1] * torch.ones(
        num_samples, device=model.device)

      z_t = self.compute_posterior(model, z_t, t, s)

    return z_t

compute_posterior(model, z_t, t, s)

Compute posterior q(z_s | z_t, x_0) for GIDD.

Parameters:

Name Type Description Default
model

The GIDD model instance.

required
z_t

Current noisy samples at time t.

required
t

Current timestep.

required
s

Next (less noisy) timestep.

required

Returns:

Type Description

Samples from the posterior distribution.

Source code in src/discrete_diffusion/sampling/gidd.py
def compute_posterior(self, model, z_t, t, s):
  """Compute posterior q(z_s | z_t, x_0) for GIDD.

  Args:
    model: The GIDD model instance.
    z_t: Current noisy samples at time t.
    t: Current timestep.
    s: Next (less noisy) timestep.

  Returns:
    Samples from the posterior distribution.
  """
  # Get model prediction (logits)
  sigma_t = model._sigma_from_alphat(model._loglinear.alpha_t(t))
  sigma_t = model._process_sigma(sigma_t)
  logits = model.backbone(z_t, sigma_t)

  # Mask out the mask token from predictions
  logits = logits.clone()
  logits[..., model.mask_id] = model.neg_infinity
  probs = logits.softmax(-1)

  # Get noise schedule values
  q_s = model.hybrid_noise.probs_at_t(probs, s)
  q_t = model.hybrid_noise.probs_at_t(probs, t)
  q_zt = q_t.gather(-1, z_t.unsqueeze(-1))

  alpha_t, beta_pi_t = model.hybrid_noise.get_alpha_betapi(t)
  alpha_s, beta_pi_s = model.hybrid_noise.get_alpha_betapi(s)

  # Compute transition probabilities
  alpha_ts = alpha_t / alpha_s
  beta_pi_ts = beta_pi_t - alpha_t / alpha_s * beta_pi_s

  vz_t = F.one_hot(z_t, num_classes=model.vocab_size)
  beta_pi_ts_at_zt = beta_pi_ts.unsqueeze(1).expand_as(vz_t).gather(
    -1, z_t.unsqueeze(-1))
  q_ts = (alpha_ts * vz_t + beta_pi_ts_at_zt)

  # Compute posterior
  q_st = q_ts * q_s / q_zt

  # Optional: apply minimum probability threshold
  min_p = getattr(self.config.sampling, 'min_p', 0.0)
  if min_p > 0.0:
    is_small = (q_st < min_p).float()
    q_st = (1 - is_small) * q_st
    q_st = q_st / q_st.sum(-1, keepdim=True)

  return sample_categorical(q_st)

generate(model, *, num_samples, num_steps, eps, inject_bos)

Generate samples using GIDD reverse diffusion process.

Parameters:

Name Type Description Default
model

The GIDD model instance.

required
num_samples

Number of samples to generate.

required
num_steps

Number of denoising steps.

required
eps

Minimum timestep value (epsilon).

required
inject_bos

Whether to inject BOS token at position 0.

required

Returns:

Type Description

Generated token sequences of shape [num_samples, num_tokens].

Source code in src/discrete_diffusion/sampling/gidd.py
@torch.no_grad()
def generate(self, model, *, num_samples, num_steps, eps, inject_bos):
  """Generate samples using GIDD reverse diffusion process.

  Args:
    model: The GIDD model instance.
    num_samples: Number of samples to generate.
    num_steps: Number of denoising steps.
    eps: Minimum timestep value (epsilon).
    inject_bos: Whether to inject BOS token at position 0.

  Returns:
    Generated token sequences of shape [num_samples, num_tokens].
  """
  if num_steps is None:
    num_steps = self.config.sampling.steps
  if eps is None:
    eps = getattr(self.config.algo, 't_eps', 1e-4)

  # Sample from the prior (fully masked)
  z_t = model.hybrid_noise.sample_prior(
    (num_samples, model.num_tokens))
  if inject_bos:
    z_t[:, 0] = model.tokenizer.bos_token_id

  # Create timestep schedule from 1-eps to eps
  timesteps = torch.linspace(
    1 - eps, eps, num_steps + 1, device=model.device)

  # Iteratively denoise
  for i in range(num_steps):
    t = timesteps[i] * torch.ones(
      num_samples, device=model.device)
    s = timesteps[i + 1] * torch.ones(
      num_samples, device=model.device)

    z_t = self.compute_posterior(model, z_t, t, s)

  return z_t