Skip to content

CANDI Hybrid Process

discrete_diffusion.forward_process.candi_hybrid

Hybrid Noising process for CANDI.

HybridForwardCANDI

Bases: ForwardProcess

Hybrid noising kernel for CANDI. https://arxiv.org/pdf/2510.22510

Selects positions with probability (1 - alpha_t) to add Gaussian noise, with variance sigma_t. Leaves others the same.

Source code in src/discrete_diffusion/forward_process/candi_hybrid.py
class HybridForwardCANDI(ForwardProcess):
    """Hybrid noising kernel for CANDI.
    https://arxiv.org/pdf/2510.22510 

    Selects positions with probability `(1 - alpha_t)` to add Gaussian noise, with variance sigma_t.
    Leaves others the same.

    """

    def __init__(
        self, tokenizer, schedule: NoiseSchedule, name: str | None = None
    ) -> None:
        assert hasattr(schedule, "r_t") and hasattr(schedule, "sigma_t"), (
            "CANDI schedule must implement r_t and sigma_t methods."
        )

        super().__init__(tokenizer=tokenizer, schedule=schedule, name=name)
        self.mask_id = _mask_token_id(tokenizer)
        self.vocab_size = len(tokenizer)
        self._corruption_vocab_size = len(tokenizer) + 1

    def _discrete_noising(self, x, alpha_t):
        """Computes the noisy discrete sample."""
        p_mask = (1.0 - alpha_t).to(dtype=torch.float32)
        move_indices = torch.rand(*x.shape, device=x.device) < p_mask
        uniform_tensor = torch.randint(0, self._corruption_vocab_size, x.shape, device=x.device)
        xt = torch.where(move_indices, uniform_tensor, x)
        return xt

    @torch.no_grad()
    def forward(self, input_ids: torch.Tensor, t: torch.Tensor):
        """Applies hybrid noising process to input_ids at time t. Implements Equation 11 of CANDI paper."""
        alpha_t = self.schedule.alpha_t(t).view(-1, 1)
        dalpha_t = self.schedule.alpha_prime_t(t).view(-1, 1)
        sigma_t = self.schedule.sigma_t(t).view(-1, 1, 1)

        disc_xt = self._discrete_noising(input_ids, alpha_t)
        reveal_mask = (disc_xt == input_ids).float()

        X_0 = F.one_hot(input_ids, num_classes=self.vocab_size).to(input_ids.device)
        X_t_prime = X_0 + torch.randn_like(X_0, dtype=torch.float32) * sigma_t
        X_t = X_0 * reveal_mask.unsqueeze(-1) + X_t_prime * (1 - reveal_mask).unsqueeze(-1)

        return {
            "xt": X_t, 
            "reveal_mask": reveal_mask, 
            "continuous_noise": sigma_t.squeeze(), 
            "discrete_noise": (1 - alpha_t).squeeze(),
            "alpha_t": alpha_t, 
            "dalpha_t": dalpha_t
        }  

forward(input_ids, t)

Applies hybrid noising process to input_ids at time t. Implements Equation 11 of CANDI paper.

Source code in src/discrete_diffusion/forward_process/candi_hybrid.py
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, t: torch.Tensor):
    """Applies hybrid noising process to input_ids at time t. Implements Equation 11 of CANDI paper."""
    alpha_t = self.schedule.alpha_t(t).view(-1, 1)
    dalpha_t = self.schedule.alpha_prime_t(t).view(-1, 1)
    sigma_t = self.schedule.sigma_t(t).view(-1, 1, 1)

    disc_xt = self._discrete_noising(input_ids, alpha_t)
    reveal_mask = (disc_xt == input_ids).float()

    X_0 = F.one_hot(input_ids, num_classes=self.vocab_size).to(input_ids.device)
    X_t_prime = X_0 + torch.randn_like(X_0, dtype=torch.float32) * sigma_t
    X_t = X_0 * reveal_mask.unsqueeze(-1) + X_t_prime * (1 - reveal_mask).unsqueeze(-1)

    return {
        "xt": X_t, 
        "reveal_mask": reveal_mask, 
        "continuous_noise": sigma_t.squeeze(), 
        "discrete_noise": (1 - alpha_t).squeeze(),
        "alpha_t": alpha_t, 
        "dalpha_t": dalpha_t
    }