Skip to content

Block DiT

discrete_diffusion.models.block_dit

Block-diffusion aware DiT backbone.

BlockDiT

Bases: Module, PyTorchModelHubMixin

DiT backbone extended with block-diffusion attention masks.

Source code in src/discrete_diffusion/models/block_dit.py
class BlockDiT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
  """DiT backbone extended with block-diffusion attention masks."""

  def __init__(self, config, vocab_size: int):
    super().__init__()
    if isinstance(config, dict):
      config = omegaconf.OmegaConf.create(config)
    self.config = config
    self.n = config.model.length
    self.causal = getattr(config.model, 'causal_attention', config.algo.parameterization == 'ar')
    self.adaLN = (not self.causal) or getattr(config.model, 'adaln', False)
    self.vocab_size = vocab_size
    self.block_size = getattr(config, 'block_size', config.model.length)
    dim = config.model.hidden_size
    cond_dim = config.model.cond_dim
    self.n_heads = config.model.n_heads
    self.vocab_embed = EmbeddingLayer(dim, vocab_size)
    if self.adaLN or not self.causal:
      self.sigma_map = TimestepEmbedder(cond_dim)
    self.rotary_emb = Rotary(dim // config.model.n_heads)
    self.attn_backend = getattr(config.model, 'attn_backend', 'flash_attn')
    self.max_seqlen = 1024

    blocks = []
    for _ in range(config.model.n_blocks):
      if self.causal:
        block = DDiTBlockCausal(
          n=config.model.length,
          dim=dim,
          n_heads=config.model.n_heads,
          dropout=config.model.dropout,
          max_batch_size=config.loader.eval_batch_size,
          adaLN=self.adaLN,
          cond_dim=cond_dim,
          attn_backend=self.attn_backend)
      else:
        block = DDiTBlock(
          n=config.model.length,
          dim=dim,
          n_heads=config.model.n_heads,
          cond_dim=cond_dim,
          adaLN=self.adaLN,
          dropout=config.model.dropout,
          block_size=self.block_size,
          attn_backend=self.attn_backend,
          max_seqlen=self.max_seqlen)
      blocks.append(block)
    self.blocks = nn.ModuleList(blocks)
    self.output_layer = DDiTFinalLayer(
      hidden_size=dim,
      out_channels=vocab_size,
      cond_dim=cond_dim,
      adaLN=self.adaLN,
      tie_word_embeddings=getattr(config.model, 'tie_word_embeddings', False))
    # Tie output projection to input embeddings if requested
    if getattr(config.model, 'tie_word_embeddings', False):
      self.output_layer.linear.weight = self.vocab_embed.embedding
    if getattr(config.algo, 'cross_attn', False):
      self.gen_mask(config.model.length, self.block_size, self.attn_backend)

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    return bias_dropout_add_scale_fused_inference

  def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
    if attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
      assert create_block_mask is not None
      self.block_diff_mask = create_block_mask(
        partial(block_diff_mask, block_size=block_size, n=seqlen),
        B=None, H=None, Q_LEN=seqlen * 2, KV_LEN=seqlen * 2)
    elif attn_backend == 'sdpa':
      self.block_diff_mask = block_diff_mask(
        b=None, h=None, q_idx=torch.arange(seqlen * 2)[:, None],
        kv_idx=torch.arange(seqlen * 2)[None, :],
        block_size=block_size, n=seqlen)
    else:
      raise ValueError('Unknown attention backend')

  def reset_kv_cache(self):
    for block in self.blocks:
      block.kv_cache = torch.zeros(
        self.config.loader.eval_batch_size,
        self.max_seqlen,
        self.config.model.hidden_size * 3,
        device='cuda',
        dtype=torch.bfloat16)
      block.cache_idx = 0

  def forward(self, indices, sigma, sample_mode=False, store_kv=False):
    x = self.vocab_embed(indices)
    if sigma is None:
      t_cond = None
    else:
      t_cond = F.silu(self.sigma_map(sigma))

    cross_attn = hasattr(self, 'block_diff_mask')
    if cross_attn:
      mask = self.block_diff_mask
      if sample_mode:
        if getattr(self.config.sampling, 'kv_cache', False):
          mask = None
          accum_length = self.blocks[0].cache_idx + self.block_size
          x_full = torch.zeros((
            x.shape[0], accum_length, x.shape[2]), device=x.device)
          rotary_cos_sin = self.rotary_emb(x_full)
        else:
          mask = mask[
            self.n:self.n + x.shape[1], self.n:self.n + x.shape[1]]
          rotary_cos_sin = self.rotary_emb(x)
      else:
        rotary_cos_sin = self.rotary_emb(x[:, :self.n])
    else:
      rotary_cos_sin = self.rotary_emb(x)
      mask = None

    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
      for block in self.blocks:
        x = block(
          x,
          rotary_cos_sin,
          c=t_cond,
          causal=self.causal,
          sample_mode=sample_mode,
          mask=mask,
          store_kv=store_kv)
      x = self.output_layer(x, t_cond)
    if cross_attn and not sample_mode:
      x = x[:, :self.n]
    return x

block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None)

Construct the block diffusion attention mask.

Line-by-line match to upstream BD3-LM implementation.

Source code in src/discrete_diffusion/models/block_dit.py
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):  # noqa: D401, ignored args
  """Construct the block diffusion attention mask.

  Line-by-line match to upstream BD3-LM implementation.
  """

  x0_flag_q = (q_idx >= n)
  x0_flag_kv = (kv_idx >= n)

  block_q = torch.where(
    x0_flag_q == 1,
    (q_idx - n) // block_size,
    q_idx // block_size)
  block_kv = torch.where(
    x0_flag_kv == 1,
    (kv_idx - n) // block_size,
    kv_idx // block_size)

  block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
  offset_block_causal = (
    (block_q > block_kv)
    & (x0_flag_kv == 1)
    & (x0_flag_q == 0))
  block_causal = (
    (block_q >= block_kv)
    & (x0_flag_kv == 1)
    & (x0_flag_q == 1))
  return block_diagonal | offset_block_causal | block_causal