Skip to content

UDLM

discrete_diffusion.algorithms.udlm

UDLM

Bases: Diffusion

Uniform Discrete Latent Model (UDLM).

  • Forward process: with prob (1 - alpha_t), replace token with a uniform token over the vocabulary; otherwise keep it unchanged.
  • Parameterization: reuse 'subs' head; logits are turned into log-probabilities with log_softmax and no special mask handling.
Source code in src/discrete_diffusion/algorithms/udlm.py
class UDLM(trainer_base.Diffusion):
  """Uniform Discrete Latent Model (UDLM).

  - Forward process: with prob (1 - alpha_t), replace token with a uniform token
    over the vocabulary; otherwise keep it unchanged.
  - Parameterization: reuse 'subs' head; logits are turned into log-probabilities
    with log_softmax and no special mask handling.
  """

  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    # Limiting distribution π(y) = 1 / V for all tokens y
    self.register_buffer(
      'limiting_distribution',
      torch.full((self.vocab_size,), 1.0 / float(self.vocab_size)))
    # Config for whether to include reconstruction loss (default: False for UDLM)
    self.zero_recon_loss = getattr(config.algo, 'zero_recon_loss', True)
    self._validate_configuration()
    # Build a forward process using Hydra instantiation. If the config doesn't specify
    # one, default UDLM to the 'uniform' forward process to match semantics.
    try:
      fp = hydra.utils.instantiate(
        self.config.algo.forward_process,
        tokenizer=self.tokenizer,
        schedule=self.noise
      )
    except Exception:
      fp = None
    self._forward_process = (
      fp if isinstance(fp, UniformForwardProcess)
      else UniformForwardProcess(tokenizer=self.tokenizer, schedule=self.noise, name='uniform')
    )

  def _validate_configuration(self):
    super()._validate_configuration()
    # UDLM uses no time-conditioning by default and subs parameterization
    if self.time_conditioning:
      raise ValueError('UDLM expects algo.time_conditioning=False')
    # Only log-linear noise is supported for UDLM currently
    # This constraint exists because UDLM's loss computation hardcodes the log-linear
    # schedule form (see nll_per_token method). Other schedules may be supported in the future.
    if not isinstance(self.noise, LogLinear):
      raise ValueError(
        'UDLM currently supports only LogLinear noise schedule. '
        'Set config.algo.noise_schedule.name=log_linear')

  def prior_sample(self, *batch_dims):
    # Uniform prior over [0, vocab_size)
    return torch.randint(
      low=0,
      high=self.vocab_size,
      size=batch_dims,
      device=self.device,
      dtype=torch.int64,
    )

  def q_xt(self, x, t, sampling_eps_min=None, sampling_eps_max=None):
    del sampling_eps_min, sampling_eps_max
    # Route through the forward-process registry (uniform replacement).
    out = self._forward_process(x, t)
    xt = out[0] if isinstance(out, (tuple, list)) else out
    if getattr(self, 'ignore_bos', False):
      xt[:, 0] = x[:, 0]
    return xt

  def _process_model_output(self, model_output, xt, sigma):
    # No mask handling; UDLM uses plain log-probabilities over vocab
    return torch.log_softmax(model_output, dim=-1)

  def nll_per_token(self, log_x_theta, xt, x0, alpha_t, dalpha_t, low_var=False):
    del low_var, dalpha_t
    # Shapes
    #  log_x_theta: (B, L, V)
    #  xt, x0: (B, L)
    #  alpha_t: (B, 1) from our LogLinear schedule: alpha_t = 1 - (1 - eps) * t
    B, L, V = log_x_theta.shape
    vocab_size = V

    # Hardcode loglinear continuous-time forms (match guidance repo):
    #   alpha_t = 1 - t
    #   alpha_t' = -1
    # Recover t from our alpha_t definition: alpha_t = 1 - (1 - eps) * t
    eps = getattr(self.noise, 'eps', 1e-3)
    t = (1 - alpha_t.to(log_x_theta.dtype)) / (1 - eps)
    alpha_t_prime = -1.
    alpha_t = 1. - t[..., None]  # B, 1, 1
    x_bar = vocab_size * alpha_t * F.one_hot(x0, self.vocab_size).float() + 1 - alpha_t
    x_bar_theta = vocab_size * alpha_t * log_x_theta.exp() + 1 - alpha_t

    # α_t' / (N*α_t) with α_t' = -1
    coeff = alpha_t_prime / (vocab_size * alpha_t)

    # Term 1: indices where z_t = 1
    x_bar_zt = torch.gather(x_bar, -1, xt[..., None])  # (B, L, 1)
    x_bar_theta_zt = torch.gather(x_bar_theta, -1, xt[..., None])  # (B, L, 1)
    term1 = (vocab_size / x_bar_zt) - (vocab_size / x_bar_theta_zt)  # (B, L, 1)

    # Term 2: indices where z_t = 0
    term2 = (
      (x_bar / x_bar_zt) * (
        x_bar_theta_zt.log() - x_bar_theta.log() +
        x_bar.log() - x_bar_zt.log()
      )
    ).sum(dim=-1, keepdim=True)  # (B, L, 1)

    diffusion_loss = (coeff * (term1 - term2)).squeeze(-1)  # (B, L)

    # Optionally include reconstruction term at t=0 based on config
    if self.zero_recon_loss:
      # For UDLM with log-linear schedule, we only return the diffusion loss.
      # This is the correct formulation for continuous-time UDLM (equivalent to
      # zero_recon_loss=True in the discrete-diffusion-guidance implementation).
      return diffusion_loss
    else:
      # Include reconstruction loss (used for discrete-time or other variants)
      reconstruction_loss = self._reconstruction_loss(x0)  # (B, L)
      return diffusion_loss + reconstruction_loss