class BlockDiT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
"""DiT backbone extended with block-diffusion attention masks."""
def __init__(self, config, vocab_size: int):
super().__init__()
if isinstance(config, dict):
config = omegaconf.OmegaConf.create(config)
self.config = config
self.n = config.model.length
self.causal = getattr(config.model, 'causal_attention', config.algo.parameterization == 'ar')
self.adaLN = (not self.causal) or getattr(config.model, 'adaln', False)
self.vocab_size = vocab_size
self.block_size = getattr(config, 'block_size', config.model.length)
dim = config.model.hidden_size
cond_dim = config.model.cond_dim
self.n_heads = config.model.n_heads
self.vocab_embed = EmbeddingLayer(dim, vocab_size)
if self.adaLN or not self.causal:
self.sigma_map = TimestepEmbedder(cond_dim)
self.rotary_emb = Rotary(dim // config.model.n_heads)
self.attn_backend = getattr(config.model, 'attn_backend', 'flash_attn')
self.max_seqlen = 1024
blocks = []
for _ in range(config.model.n_blocks):
if self.causal:
block = DDiTBlockCausal(
n=config.model.length,
dim=dim,
n_heads=config.model.n_heads,
dropout=config.model.dropout,
max_batch_size=config.loader.eval_batch_size,
adaLN=self.adaLN,
cond_dim=cond_dim,
attn_backend=self.attn_backend)
else:
block = DDiTBlock(
n=config.model.length,
dim=dim,
n_heads=config.model.n_heads,
cond_dim=cond_dim,
adaLN=self.adaLN,
dropout=config.model.dropout,
block_size=self.block_size,
attn_backend=self.attn_backend,
max_seqlen=self.max_seqlen)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.output_layer = DDiTFinalLayer(
hidden_size=dim,
out_channels=vocab_size,
cond_dim=cond_dim,
adaLN=self.adaLN,
tie_word_embeddings=getattr(config.model, 'tie_word_embeddings', False))
# Tie output projection to input embeddings if requested
if getattr(config.model, 'tie_word_embeddings', False):
self.output_layer.linear.weight = self.vocab_embed.embedding
if getattr(config.algo, 'cross_attn', False):
self.gen_mask(config.model.length, self.block_size, self.attn_backend)
def _get_bias_dropout_scale(self):
if self.training:
return bias_dropout_add_scale_fused_train
return bias_dropout_add_scale_fused_inference
def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
if attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
assert create_block_mask is not None
self.block_diff_mask = create_block_mask(
partial(block_diff_mask, block_size=block_size, n=seqlen),
B=None, H=None, Q_LEN=seqlen * 2, KV_LEN=seqlen * 2)
elif attn_backend == 'sdpa':
self.block_diff_mask = block_diff_mask(
b=None, h=None, q_idx=torch.arange(seqlen * 2)[:, None],
kv_idx=torch.arange(seqlen * 2)[None, :],
block_size=block_size, n=seqlen)
else:
raise ValueError('Unknown attention backend')
def reset_kv_cache(self):
for block in self.blocks:
block.kv_cache = torch.zeros(
self.config.loader.eval_batch_size,
self.max_seqlen,
self.config.model.hidden_size * 3,
device='cuda',
dtype=torch.bfloat16)
block.cache_idx = 0
def forward(self, indices, sigma, sample_mode=False, store_kv=False):
x = self.vocab_embed(indices)
if sigma is None:
t_cond = None
else:
t_cond = F.silu(self.sigma_map(sigma))
cross_attn = hasattr(self, 'block_diff_mask')
if cross_attn:
mask = self.block_diff_mask
if sample_mode:
if getattr(self.config.sampling, 'kv_cache', False):
mask = None
accum_length = self.blocks[0].cache_idx + self.block_size
x_full = torch.zeros((
x.shape[0], accum_length, x.shape[2]), device=x.device)
rotary_cos_sin = self.rotary_emb(x_full)
else:
mask = mask[
self.n:self.n + x.shape[1], self.n:self.n + x.shape[1]]
rotary_cos_sin = self.rotary_emb(x)
else:
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
else:
rotary_cos_sin = self.rotary_emb(x)
mask = None
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
for block in self.blocks:
x = block(
x,
rotary_cos_sin,
c=t_cond,
causal=self.causal,
sample_mode=sample_mode,
mask=mask,
store_kv=store_kv)
x = self.output_layer(x, t_cond)
if cross_attn and not sample_mode:
x = x[:, :self.n]
return x