Skip to content

Encoder-Decoder

discrete_diffusion.models.encoder_decoder

Decoder

Bases: Module

Source code in src/discrete_diffusion/models/encoder_decoder.py
class Decoder(nn.Module):
  def __init__(self, n_blocks, dim, n_heads, cond_dim, 
               mlp_ratio, dropout, adaLN, model_length, 
               swap_pre_query_mode, swap_query_process_mode, 
               swap_normalize_mode):
    super().__init__()

    self.hidden_dim = dim
    self.n_heads = n_heads
    self.cond_dim = cond_dim
    self.mlp_ratio = mlp_ratio
    self.adaLN = adaLN
    self.model_length = model_length
    self.dropout = dropout
    self.n_blocks = n_blocks
    self.swap_pre_query_mode = swap_pre_query_mode
    self.group_swap = GroupSwapLayer(dim, n_heads, 
                swap_pre_query_mode, swap_query_process_mode, 
                model_length, swap_normalize_mode)

    self.layers = nn.ModuleList([self._make_cross_attn_block() 
                                for _ in range(self.n_blocks)])

  def _make_cross_attn_block(self):
    return CrossAttnDDiTBlock(
      self.hidden_dim, self.n_heads, self.adaLN, self.cond_dim, 
      self.mlp_ratio, self.dropout)

  def forward(
    self, 
    encoder_output,
    t_cond, 
    rotary_cos_sin_queries, 
    rotary_cos_sin_keys, 
    self_attn_mask,
    # Training
    group_idxs, 
    # Inference
    position_queries,
    concrete_lengths_keys, 
    use_inference_mode,
  ):
    """
    1. Apply GroupSwap -> prepare cross attention mask
    2. Apply layers
    """
    if not use_inference_mode:  # Training / Valid
      cross_attn_mask = make_group_cross_attn_mask(group_idxs)
      q_len = self.model_length
    else:  # Sampling
      # TODO: Make sure we don't attend to pad tokens
      q_len = position_queries.shape[1]
      kv_len = encoder_output.shape[1]
      cross_attn_mask = make_inference_cross_attn_mask(
        kv_len, q_len, concrete_lengths_keys)
      # IMPORTANT NOTE: during inference, the self attention 
      #  mask is different than during training, since the
      #  decoder input has a different shape than the encoder
      #  input.
      del self_attn_mask  # will not be used

    x = self.group_swap(encoder_output, rotary_cos_sin_queries,
      rotary_cos_sin_keys, group_idxs, position_queries, 
      concrete_lengths_keys, cross_attn_mask, use_inference_mode)

    for layer in self.layers:
      if isinstance(layer, DDiTBlock):  # self attention
        x = layer(x, t_cond, rotary_cos_sin_queries, 
                  self_attn_mask)
      else:  # cross attention
        x = layer(
          q_x=x,
          kv_x=encoder_output,
          t_cond=t_cond, 
          rotary_cos_sin_queries=rotary_cos_sin_queries,
          rotary_cos_sin_keys=rotary_cos_sin_keys,
          attn_mask=cross_attn_mask)
    return x

forward(encoder_output, t_cond, rotary_cos_sin_queries, rotary_cos_sin_keys, self_attn_mask, group_idxs, position_queries, concrete_lengths_keys, use_inference_mode)

  1. Apply GroupSwap -> prepare cross attention mask
  2. Apply layers
Source code in src/discrete_diffusion/models/encoder_decoder.py
def forward(
  self, 
  encoder_output,
  t_cond, 
  rotary_cos_sin_queries, 
  rotary_cos_sin_keys, 
  self_attn_mask,
  # Training
  group_idxs, 
  # Inference
  position_queries,
  concrete_lengths_keys, 
  use_inference_mode,
):
  """
  1. Apply GroupSwap -> prepare cross attention mask
  2. Apply layers
  """
  if not use_inference_mode:  # Training / Valid
    cross_attn_mask = make_group_cross_attn_mask(group_idxs)
    q_len = self.model_length
  else:  # Sampling
    # TODO: Make sure we don't attend to pad tokens
    q_len = position_queries.shape[1]
    kv_len = encoder_output.shape[1]
    cross_attn_mask = make_inference_cross_attn_mask(
      kv_len, q_len, concrete_lengths_keys)
    # IMPORTANT NOTE: during inference, the self attention 
    #  mask is different than during training, since the
    #  decoder input has a different shape than the encoder
    #  input.
    del self_attn_mask  # will not be used

  x = self.group_swap(encoder_output, rotary_cos_sin_queries,
    rotary_cos_sin_keys, group_idxs, position_queries, 
    concrete_lengths_keys, cross_attn_mask, use_inference_mode)

  for layer in self.layers:
    if isinstance(layer, DDiTBlock):  # self attention
      x = layer(x, t_cond, rotary_cos_sin_queries, 
                self_attn_mask)
    else:  # cross attention
      x = layer(
        q_x=x,
        kv_x=encoder_output,
        t_cond=t_cond, 
        rotary_cos_sin_queries=rotary_cos_sin_queries,
        rotary_cos_sin_keys=rotary_cos_sin_keys,
        attn_mask=cross_attn_mask)
  return x

make_inference_cross_attn_mask(keys_tensor_length, queries_tensor_length, concrete_lengths_keys)

Queries positions == noisy positions Key positions == denoised positions Concrete length: number of denoised tokens in each element of the batch.

Source code in src/discrete_diffusion/models/encoder_decoder.py
def make_inference_cross_attn_mask(
    keys_tensor_length, 
    queries_tensor_length,
    concrete_lengths_keys,):
  """
  Queries positions == noisy positions
  Key positions == denoised positions
  Concrete length: number of denoised tokens in each 
                        element of the batch.
  """
  arrange = torch.arange(keys_tensor_length, device=concrete_lengths_keys.device)
  mask = arrange[None] < concrete_lengths_keys[:, None]  # BS x KV_LEN
  mask = mask[:, None, :]  # BS x 1 x KV_LEN
  mask = mask.repeat(1, queries_tensor_length, 1)  # BS x Q_LEN x KV_LEN
  return mask