Skip to content

Absorbing Process

discrete_diffusion.forward_process.absorbing

Absorbing-state forward processes.

AbsorbingForwardProcess

Bases: ForwardProcess

Absorbing-state forward process.

Replaces tokens with the mask token with probability (1 - alpha_t). Returns the noised ids and the per-position mask probability p_mask.

Source code in src/discrete_diffusion/forward_process/absorbing.py
class AbsorbingForwardProcess(ForwardProcess):
  """Absorbing-state forward process.

  Replaces tokens with the mask token with probability `(1 - alpha_t)`.
  Returns the noised ids and the per-position mask probability `p_mask`.
  """

  def __init__(self, tokenizer, schedule: NoiseSchedule, name: str | None = None) -> None:
    super().__init__(tokenizer=tokenizer, schedule=schedule, name=name)
    self.mask_id = _mask_token_id(tokenizer)

  @torch.no_grad()
  def forward(self, input_ids: torch.Tensor, t: torch.Tensor):
    alpha_t = self.schedule.alpha_t(t).view(-1, 1)
    p_mask = (1.0 - alpha_t).to(dtype=torch.float32)
    move_mask = (torch.rand_like(input_ids, dtype=torch.float32) < p_mask).to(torch.bool)
    xt = torch.where(move_mask, torch.tensor(self.mask_id, device=input_ids.device, dtype=input_ids.dtype), input_ids)
    return xt, p_mask