Skip to content

Uniform Sampling

discrete_diffusion.sampling.uniform

Uniform (UDLM) sampler abstraction.

UniformSampler

Bases: Sampler

Sampler whose updates mirror UDLM's uniform transition.

Source code in src/discrete_diffusion/sampling/uniform.py
class UniformSampler(Sampler):
  """Sampler whose updates mirror UDLM's uniform transition."""

  def __init__(self, config, forward_process=None):
    self.config = config

  def compute_posterior(self, model, x, t, dt, p_x0=None,
                        noise_removal_step=False):
    alpha_t = model.noise.alpha_t(t)
    if noise_removal_step:
      alpha_s = torch.ones_like(alpha_t)
    else:
      alpha_s = model.noise.alpha_t(t - dt)
    if p_x0 is None:
      sigma_t = model._sigma_from_alphat(alpha_t)
      log_p_x0 = model.forward(xt=x, sigma=sigma_t)
      if self.config.sampling.use_float64:
        log_p_x0 = log_p_x0.to(torch.float64)
      p_x0 = log_p_x0.exp()

    V = model.vocab_size
    alpha_t3 = alpha_t[..., None]
    alpha_s3 = alpha_s[..., None]
    alpha_ts = alpha_t3 / alpha_s3
    xt_one_hot = F.one_hot(x, V)
    limiting = model.limiting_distribution.view(1, 1, -1)

    numerator = (
      (alpha_t3 * V * p_x0 * xt_one_hot)
      + ((alpha_ts - alpha_t3) * xt_one_hot)
      + ((alpha_s3 - alpha_t3) * p_x0)
      + ((1 - alpha_ts) * (1 - alpha_s3) * limiting)
    )
    denom = (
      (alpha_t3 * V * torch.gather(p_x0, -1, x[..., None]))
      + (1 - alpha_t3)
    )
    q_xs = numerator / denom

    xs = sample_categorical(q_xs)
    return p_x0, xs

  @torch.no_grad()
  def generate(self, model, *, num_samples, num_steps, eps, inject_bos):
    if num_steps is None:
      num_steps = self.config.sampling.steps
    x = model.prior_sample(num_samples, model.num_tokens)
    if inject_bos is None:
      inject_bos = self.config.sampling.inject_bos
    if inject_bos:
      x[:, 0] = model.tokenizer.bos_token_id

    timesteps = torch.linspace(
      1, eps, num_steps + 1, device=model.device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None
    predictor = self.config.sampling.predictor

    for i in range(num_steps):
      t = timesteps[i] * torch.ones(
        x.shape[0], 1, device=model.device)
      if predictor == 'ddpm':
        _, x = self.compute_posterior(
          model=model, x=x, t=t, dt=dt, p_x0=None)
      elif predictor == 'ddpm_cache':
        p_x0_cache, x_next = self.compute_posterior(
          model=model, x=x, t=t, dt=dt, p_x0=p_x0_cache)
        if (not torch.allclose(x_next, x)
            or model.time_conditioning):
          p_x0_cache = None
        x = x_next
      else:
        raise ValueError(
          f'Uniform sampler only supports ddpm predictors, got {predictor}')

    t0 = timesteps[-1] * torch.ones(x.shape[0], 1,
                                    device=model.device)
    _, x = self.compute_posterior(
      model=model, x=x, t=t0, dt=None, p_x0=p_x0_cache,
      noise_removal_step=True)
    return x