Skip to content

Metrics

discrete_diffusion.evaluations.metrics

BD3Metrics

Bases: Metrics

Extension of Metrics with BD3-specific variance tracking.

Source code in src/discrete_diffusion/evaluations/metrics.py
class BD3Metrics(Metrics):
  """Extension of Metrics with BD3-specific variance tracking."""

  def __init__(self, config) -> None:
    super().__init__()
    self.config = config
    self.block_size = getattr(config, 'block_size', config.model.length)
    self.nfes = NFEs()
    self.gen_entropy = NLL()
    self.gen_nfes: List[float] = []
    self.gen_entropies: List[float] = []
    self.gen_lengths: List[int] = []

    self.sampling_eps = config.training.sampling_eps
    self.clip_search_delta = getattr(config.algo, 'clip_search_delta', None)
    self.valid_vars: Dict[Tuple[float, float], List[torch.Tensor]] = {
      (self.sampling_eps, 1.0): []
    }
    if getattr(config.algo, 'var_min', None):
      self.init_valid_vars()

  def init_valid_vars(self):
    eps = self.sampling_eps
    if self.block_size > 1:
      self.valid_vars = {(eps, 1): []}
      for width in self.config.algo.clip_search_widths:
        for i in torch.arange(0, 1 - width + self.clip_search_delta,
                              self.clip_search_delta):
          eps_min = torch.clamp(i, min=self.sampling_eps).item()
          eps_max = torch.clamp(i + width, min=self.sampling_eps).item()
          self.valid_vars[(eps_min, eps_max)] = []
    else:
      self.valid_vars = {
        (eps, 1): [],
        (1, 1): []
      }

  def update_train(self,
                   nll_sum: torch.Tensor,
                   num_tokens: torch.Tensor):
    self.train_nlls.update(nll_sum, num_tokens)

  def update_valid(self,
                   nll_sum: torch.Tensor,
                   num_tokens: torch.Tensor):
    self.valid_nlls.update(nll_sum, num_tokens)

  def to(self, *args, **kwargs):
    super().to(*args, **kwargs)
    self.nfes = self.nfes.to(*args, **kwargs)
    self.gen_entropy = self.gen_entropy.to(*args, **kwargs)

  def reset(self):
    super().reset()
    self.gen_nfes, self.gen_entropies, self.gen_lengths = [], [], []
    self.nfes.reset()
    self.gen_entropy.reset()
    if getattr(self.config.algo, 'var_min', None):
      self.init_valid_vars()

  @torch.no_grad()
  def record_entropy(self, tokens):
    for sample in tokens:
      entropy = _token_entropy(sample)
      self._record_entropy_value(entropy)

  def _record_entropy_value(self, entropy: float) -> None:
    self.sample_entropy.update(entropy)
    self.gen_entropies.append(entropy)
    self.gen_entropy.update(entropy, 1)

BPD

Bases: NLL

Source code in src/discrete_diffusion/evaluations/metrics.py
class BPD(NLL):
  def compute(self) -> torch.Tensor:
    """Computes the bits per dimension.

    Returns:
      bpd
    """
    return self.mean_value / self.weight / LOG2

compute()

Computes the bits per dimension.

Returns:

Type Description
Tensor

bpd

Source code in src/discrete_diffusion/evaluations/metrics.py
def compute(self) -> torch.Tensor:
  """Computes the bits per dimension.

  Returns:
    bpd
  """
  return self.mean_value / self.weight / LOG2

NFEs

Bases: MeanMetric

Average number of function evaluations per sample.

Source code in src/discrete_diffusion/evaluations/metrics.py
class NFEs(torchmetrics.aggregation.MeanMetric):
  """Average number of function evaluations per sample."""

NLL

Bases: MeanMetric

Source code in src/discrete_diffusion/evaluations/metrics.py
class NLL(torchmetrics.aggregation.MeanMetric):
  def update(self,
             value: Value,
             weight: Value = 1.0) -> None:
    """Update state with data.

    Args:
      value: Either a float or tensor containing data.
        Additional tensor dimensions will be flattened
      weight: Either a float or tensor containing weights
        for calculating the average. Shape of weight should
        be able to broadcast with the shape of `value`.
        Default to `1.0` corresponding to simple harmonic
        average.
    """
    # broadcast weight to value shape
    if not isinstance(value, torch.Tensor):
      value = torch.as_tensor(value,
                              dtype=self.dtype,
                              device=self.device)
    else:
      value = value.to(dtype=self.dtype, device=self.device)

    if (weight is not None and
        not isinstance(weight, torch.Tensor)):
      weight = torch.as_tensor(weight,
                               dtype=self.dtype,
                               device=self.device)
    else:
      weight = weight.to(dtype=self.dtype, device=self.device)

    # Handle edge case where torch.compile infers scalar value but sees tensor inputs
    if value.ndim == 0 and weight.ndim > 0:
      weight = weight.squeeze()
    if value.ndim > 0:
      weight = torch.broadcast_to(weight, value.shape)

    if value.numel() == 0:
      return
    self.mean_value += value.sum()
    self.weight += weight.sum()

update(value, weight=1.0)

Update state with data.

Parameters:

Name Type Description Default
value Value

Either a float or tensor containing data. Additional tensor dimensions will be flattened

required
weight Value

Either a float or tensor containing weights for calculating the average. Shape of weight should be able to broadcast with the shape of value. Default to 1.0 corresponding to simple harmonic average.

1.0
Source code in src/discrete_diffusion/evaluations/metrics.py
def update(self,
           value: Value,
           weight: Value = 1.0) -> None:
  """Update state with data.

  Args:
    value: Either a float or tensor containing data.
      Additional tensor dimensions will be flattened
    weight: Either a float or tensor containing weights
      for calculating the average. Shape of weight should
      be able to broadcast with the shape of `value`.
      Default to `1.0` corresponding to simple harmonic
      average.
  """
  # broadcast weight to value shape
  if not isinstance(value, torch.Tensor):
    value = torch.as_tensor(value,
                            dtype=self.dtype,
                            device=self.device)
  else:
    value = value.to(dtype=self.dtype, device=self.device)

  if (weight is not None and
      not isinstance(weight, torch.Tensor)):
    weight = torch.as_tensor(weight,
                             dtype=self.dtype,
                             device=self.device)
  else:
    weight = weight.to(dtype=self.dtype, device=self.device)

  # Handle edge case where torch.compile infers scalar value but sees tensor inputs
  if value.ndim == 0 and weight.ndim > 0:
    weight = weight.squeeze()
  if value.ndim > 0:
    weight = torch.broadcast_to(weight, value.shape)

  if value.numel() == 0:
    return
  self.mean_value += value.sum()
  self.weight += weight.sum()

Perplexity

Bases: NLL

Source code in src/discrete_diffusion/evaluations/metrics.py
class Perplexity(NLL):
  def compute(self) -> torch.Tensor:
    """Computes the Perplexity.

    Returns:
     Perplexity
    """
    return torch.exp(self.mean_value / self.weight)

compute()

Computes the Perplexity.

Returns:

Type Description
Tensor

Perplexity

Source code in src/discrete_diffusion/evaluations/metrics.py
def compute(self) -> torch.Tensor:
  """Computes the Perplexity.

  Returns:
   Perplexity
  """
  return torch.exp(self.mean_value / self.weight)