Skip to content

Base Sampling

discrete_diffusion.sampling.base

Sampler interface for discrete diffusion generation routines.

Sampler

Bases: ABC

Base interface defining hooks used by all samplers.

Samplers orchestrate the iterative generation process, managing the transition from noise to clean data.

Source code in src/discrete_diffusion/sampling/base.py
class Sampler(ABC):
  """Base interface defining hooks used by all samplers.

  Samplers orchestrate the iterative generation process, managing the
  transition from noise to clean data.
  """

  @abstractmethod
  def generate(self, model: Any, *, num_samples: int, num_steps: int, eps: float,
               inject_bos: bool) -> Any:
    """Generate new samples from the provided model.

    Args:
        model: The trained model to sample from.
        num_samples: Number of samples to generate.
        num_steps: Number of sampling steps.
        eps: Small epsilon for numerical stability or time bounds.
        inject_bos: Whether to inject a Beginning-Of-Sequence token.

    Returns:
        Tensor: Generated samples.
    """
    raise NotImplementedError

  def compute_posterior(self, x: Any, t: Any, dt: Any, p_x0_cache: Optional[Any]) -> Any:
    """Optional posterior computation hook for samplers that need incremental steps."""
    raise NotImplementedError

  def step_analytic(self, x: Any, t: Any, dt: Any) -> Any:
    """Optional analytic update hook for samplers that support closed-form steps."""
    raise NotImplementedError

  def denoise(self, x: Any, t: Any) -> Any:
    """Optional denoiser update hook for samplers that clean up predictions."""
    raise NotImplementedError

compute_posterior(x, t, dt, p_x0_cache)

Optional posterior computation hook for samplers that need incremental steps.

Source code in src/discrete_diffusion/sampling/base.py
def compute_posterior(self, x: Any, t: Any, dt: Any, p_x0_cache: Optional[Any]) -> Any:
  """Optional posterior computation hook for samplers that need incremental steps."""
  raise NotImplementedError

denoise(x, t)

Optional denoiser update hook for samplers that clean up predictions.

Source code in src/discrete_diffusion/sampling/base.py
def denoise(self, x: Any, t: Any) -> Any:
  """Optional denoiser update hook for samplers that clean up predictions."""
  raise NotImplementedError

generate(model, *, num_samples, num_steps, eps, inject_bos) abstractmethod

Generate new samples from the provided model.

Parameters:

Name Type Description Default
model Any

The trained model to sample from.

required
num_samples int

Number of samples to generate.

required
num_steps int

Number of sampling steps.

required
eps float

Small epsilon for numerical stability or time bounds.

required
inject_bos bool

Whether to inject a Beginning-Of-Sequence token.

required

Returns:

Name Type Description
Tensor Any

Generated samples.

Source code in src/discrete_diffusion/sampling/base.py
@abstractmethod
def generate(self, model: Any, *, num_samples: int, num_steps: int, eps: float,
             inject_bos: bool) -> Any:
  """Generate new samples from the provided model.

  Args:
      model: The trained model to sample from.
      num_samples: Number of samples to generate.
      num_steps: Number of sampling steps.
      eps: Small epsilon for numerical stability or time bounds.
      inject_bos: Whether to inject a Beginning-Of-Sequence token.

  Returns:
      Tensor: Generated samples.
  """
  raise NotImplementedError

step_analytic(x, t, dt)

Optional analytic update hook for samplers that support closed-form steps.

Source code in src/discrete_diffusion/sampling/base.py
def step_analytic(self, x: Any, t: Any, dt: Any) -> Any:
  """Optional analytic update hook for samplers that support closed-form steps."""
  raise NotImplementedError