Skip to content

EMA

discrete_diffusion.models.ema

ExponentialMovingAverage

Maintains (exponential) moving average of a set of parameters.

Source code in src/discrete_diffusion/models/ema.py
class ExponentialMovingAverage:
  """
  Maintains (exponential) moving average of a set of parameters.
  """

  def __init__(self, parameters, decay, use_num_updates=True):
    """
    Args:
        parameters: Iterable of `torch.nn.Parameter`; usually the result of
            `model.parameters()`.
        decay: The exponential decay.
        use_num_updates: Whether to use number of updates when computing
            averages.
    """
    if decay < 0.0 or decay > 1.0:
      raise ValueError('Decay must be between 0 and 1')
    self.decay = decay
    self.num_updates = 0 if use_num_updates else None
    self.shadow_params = [p.clone().detach()
                          for p in parameters if p.requires_grad]
    self.collected_params = []

  def move_shadow_params_to_device(self, device):
    self.shadow_params = [i.to(device) for i in self.shadow_params]

  def update(self, parameters):
    """
    Update currently maintained parameters.

    Call this every time the parameters are updated, such as the result of
    the `optimizer.step()` call.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; usually the same set of
            parameters used to initialize this object.
    """
    decay = self.decay
    if self.num_updates is not None:
      self.num_updates += 1
      decay = min(decay, (1 + self.num_updates) /
                  (10 + self.num_updates))
    one_minus_decay = 1.0 - decay
    with torch.no_grad():
      parameters = [p for p in parameters if p.requires_grad]
      for s_param, param in zip(self.shadow_params, parameters):
        s_param.sub_(one_minus_decay * (s_param - param))

  def copy_to(self, parameters):
    """
    Copy current parameters into given collection of parameters.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored moving averages.
    """
    parameters = [p for p in parameters if p.requires_grad]
    for s_param, param in zip(self.shadow_params, parameters):
      if param.requires_grad:
        param.data.copy_(s_param.data)

  def store(self, parameters):
    """
    Save the current parameters for restoring later.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
    """
    self.collected_params = [param.clone() for param in parameters]

  def restore(self, parameters):
    """
    Restore the parameters stored with the `store` method.
    Useful to validate the model with EMA parameters without affecting the
    original optimization process. Store the parameters before the
    `copy_to` method. After validation (or model saving), use this to
    restore the former parameters.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
    """
    for c_param, param in zip(self.collected_params, parameters):
      param.data.copy_(c_param.data)

  def state_dict(self):
    return dict(decay=self.decay,
                num_updates=self.num_updates,
                shadow_params=self.shadow_params)

  def load_state_dict(self, state_dict):
    self.decay = state_dict['decay']
    self.num_updates = state_dict['num_updates']
    self.shadow_params = state_dict['shadow_params']

__init__(parameters, decay, use_num_updates=True)

Parameters:

Name Type Description Default
parameters

Iterable of torch.nn.Parameter; usually the result of model.parameters().

required
decay

The exponential decay.

required
use_num_updates

Whether to use number of updates when computing averages.

True
Source code in src/discrete_diffusion/models/ema.py
def __init__(self, parameters, decay, use_num_updates=True):
  """
  Args:
      parameters: Iterable of `torch.nn.Parameter`; usually the result of
          `model.parameters()`.
      decay: The exponential decay.
      use_num_updates: Whether to use number of updates when computing
          averages.
  """
  if decay < 0.0 or decay > 1.0:
    raise ValueError('Decay must be between 0 and 1')
  self.decay = decay
  self.num_updates = 0 if use_num_updates else None
  self.shadow_params = [p.clone().detach()
                        for p in parameters if p.requires_grad]
  self.collected_params = []

copy_to(parameters)

Copy current parameters into given collection of parameters.

Parameters:

Name Type Description Default
parameters

Iterable of torch.nn.Parameter; the parameters to be updated with the stored moving averages.

required
Source code in src/discrete_diffusion/models/ema.py
def copy_to(self, parameters):
  """
  Copy current parameters into given collection of parameters.

  Args:
      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
          updated with the stored moving averages.
  """
  parameters = [p for p in parameters if p.requires_grad]
  for s_param, param in zip(self.shadow_params, parameters):
    if param.requires_grad:
      param.data.copy_(s_param.data)

restore(parameters)

Restore the parameters stored with the store method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the copy_to method. After validation (or model saving), use this to restore the former parameters.

Parameters:

Name Type Description Default
parameters

Iterable of torch.nn.Parameter; the parameters to be updated with the stored parameters.

required
Source code in src/discrete_diffusion/models/ema.py
def restore(self, parameters):
  """
  Restore the parameters stored with the `store` method.
  Useful to validate the model with EMA parameters without affecting the
  original optimization process. Store the parameters before the
  `copy_to` method. After validation (or model saving), use this to
  restore the former parameters.

  Args:
      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
          updated with the stored parameters.
  """
  for c_param, param in zip(self.collected_params, parameters):
    param.data.copy_(c_param.data)

store(parameters)

Save the current parameters for restoring later.

Parameters:

Name Type Description Default
parameters

Iterable of torch.nn.Parameter; the parameters to be temporarily stored.

required
Source code in src/discrete_diffusion/models/ema.py
def store(self, parameters):
  """
  Save the current parameters for restoring later.

  Args:
      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
          temporarily stored.
  """
  self.collected_params = [param.clone() for param in parameters]

update(parameters)

Update currently maintained parameters.

Call this every time the parameters are updated, such as the result of the optimizer.step() call.

Parameters:

Name Type Description Default
parameters

Iterable of torch.nn.Parameter; usually the same set of parameters used to initialize this object.

required
Source code in src/discrete_diffusion/models/ema.py
def update(self, parameters):
  """
  Update currently maintained parameters.

  Call this every time the parameters are updated, such as the result of
  the `optimizer.step()` call.

  Args:
      parameters: Iterable of `torch.nn.Parameter`; usually the same set of
          parameters used to initialize this object.
  """
  decay = self.decay
  if self.num_updates is not None:
    self.num_updates += 1
    decay = min(decay, (1 + self.num_updates) /
                (10 + self.num_updates))
  one_minus_decay = 1.0 - decay
  with torch.no_grad():
    parameters = [p for p in parameters if p.requires_grad]
    for s_param, param in zip(self.shadow_params, parameters):
      s_param.sub_(one_minus_decay * (s_param - param))

create_ema(parameters, decay)

Construct the EMA helper for the provided parameter iterable.

Source code in src/discrete_diffusion/models/ema.py
def create_ema(parameters, decay: float):
  """Construct the EMA helper for the provided parameter iterable."""
  return ExponentialMovingAverage(parameters, decay=decay)