3 Commits

Author SHA1 Message Date
  Leo Dong 0a5339f4cb
[FIRST] Fix softcap scoremod kwargs typo. (#2072) 14 hours ago
  Driss Guessous 179f793bbc
[CUTE] Seeing if tvvm reduces cpu overhead (#2042) 17 hours ago
  Reuben Stern fd8d5eb363
[Cute,Fwd] Extend score_mod to variable sequence length (#2043) 19 hours ago
12 changed files with 2399 additions and 499 deletions
Split View
  1. +30
    -6
      flash_attn/cute/block_sparsity.py
  2. +41
    -5
      flash_attn/cute/flash_fwd.py
  3. +29
    -2
      flash_attn/cute/flash_fwd_sm100.py
  4. +234
    -180
      flash_attn/cute/interface.py
  5. +3
    -1
      flash_attn/cute/pyproject.toml
  6. +12
    -1
      flash_attn/cute/seqlen_info.py
  7. +13
    -9
      flash_attn/cute/softmax.py
  8. +1
    -1
      flash_attn/cute/utils.py
  9. +591
    -0
      tests/cute/score_mod_definitions.py
  10. +1
    -0
      tests/cute/test_flash_attn.py
  11. +396
    -294
      tests/cute/test_score_mod.py
  12. +1048
    -0
      tests/cute/test_score_mod_varlen.py

+ 30
- 6
flash_attn/cute/block_sparsity.py View File

@@ -14,6 +14,10 @@ import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack


def ceildiv(a: int, b: int) -> int:
return (a + b - 1) // b


# placeholder
Config = type("Config", (), {})

@@ -78,6 +82,26 @@ def _check_and_expand_block(
return expanded_cnt, expanded_idx


def get_block_sparse_expected_shapes(
batch_size: int,
num_head: int,
seqlen_q: int,
seqlen_k: int,
m_block_size: int,
n_block_size: int,
compute_capability: int,
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:
"""Return (expected_count_shape, expected_index_shape) for block sparse normalization."""
# TODO: This multiplier should really be q_stage, wire up in later PR
# 1 cta handles 2*tile_m rows on SM100
m_block_size_effective = 2 * m_block_size if compute_capability == 10 else m_block_size
expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective)
expected_n_blocks = ceildiv(seqlen_k, n_block_size)
expected_count_shape = (batch_size, num_head, expected_m_blocks)
expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)
return expected_count_shape, expected_index_shape


def normalize_block_sparse_tensors(
tensors: BlockSparseTensorsTorch,
*,
@@ -205,8 +229,8 @@ def _compute_sparsity(
config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes block sparsity for fixed-length sequences."""
n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m
n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n
n_blocks_q = ceildiv(config.seqlen_q, config.tile_m)
n_blocks_k = ceildiv(config.seqlen_k, config.tile_n)

# Pre-allocate output tensors
full_block_cnt = torch.zeros(
@@ -325,12 +349,12 @@ def _compute_varlen_sparsity(
max_m_blocks = 0
for seq_idx in range(config.batch_size):
seq_len_q = (cu_seqlens_q[seq_idx + 1] - cu_seqlens_q[seq_idx]).item()
n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m
n_blocks_q = ceildiv(seq_len_q, config.tile_m)
max_m_blocks = max(max_m_blocks, n_blocks_q)

# The number of K blocks is determined by the total length of all sequences.
total_k_len = cu_seqlens_k[-1].item()
max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n
max_n_blocks = ceildiv(total_k_len, config.tile_n)

# Pre-allocate padded output tensors
full_block_cnt = torch.zeros(
@@ -360,8 +384,8 @@ def _compute_varlen_sparsity(
seq_end_k = cu_seqlens_k[seq_idx + 1].item()
seq_len_k = seq_end_k - seq_start_k

n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m
n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n
n_blocks_q = ceildiv(seq_len_q, config.tile_m)
n_blocks_k = ceildiv(seq_len_k, config.tile_n)

# Global block indices are relative to the start of the entire batch tensor
first_m_block_global = seq_start_q // config.tile_m


+ 41
- 5
flash_attn/cute/flash_fwd.py View File

@@ -1050,6 +1050,7 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
batch_idx: cutlass.Int32,
head_idx: cutlass.Int32,
m_block: cutlass.Int32,
seqlen: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=None,
mask_fn: Optional[Callable] = None,
@@ -1105,6 +1106,7 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
m_block,
acc_S,
n_block,
seqlen,
softmax_scale=softmax.softmax_scale,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
@@ -1502,7 +1504,11 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
seqlen_q = cute.size(mQ.shape[0]) // (
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
)
seqlen_k = cute.size(mK.shape[0])
seqlen_k = (
cute.size(mK.shape[0])
if const_expr(mPageTable is None)
else mK.shape[0] * mPageTable.shape[1]
)
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
@@ -1982,6 +1988,25 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
# shape: (atom_v_m * rest_m)
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
seqlen = SeqlenInfoCls(batch_idx)

# Recompute fastdiv_mods if necessary for varlen with aux_tensors
recompute_fastdiv_mods_q = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
)
recompute_fastdiv_mods_k = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
)
if cutlass.const_expr(fastdiv_mods is not None):
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
fastdiv_mods = (
seqlen_q_divmod
if not recompute_fastdiv_mods_q
else FastDivmodDivisor(seqlen.seqlen_q),
seqlen_k_divmod
if not recompute_fastdiv_mods_k
else FastDivmodDivisor(seqlen.seqlen_k),
)

mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k)
mask_fn = partial(
mask.apply_mask,
@@ -2046,6 +2071,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
if const_expr(self.intra_wg_overlap):
kv_consumer_state = process_first_half_block(
n_block=n_block_max - 1,
seqlen=seqlen,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
score_mod_fn=score_mod_fn,
@@ -2058,6 +2084,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=True),
is_first_n_block=True,
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
@@ -2077,6 +2104,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1 - n_tile,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
)
@@ -2091,6 +2119,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1 - n_tile,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
)
@@ -2102,6 +2131,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1 - n_tile,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
)
@@ -2195,6 +2225,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
tOrP: cute.Tensor,
smem_copy_params: SimpleNamespace,
softmax: Softmax,
seqlen: SeqlenInfoQK,
mask_fn: Callable = None,
score_mod_fn: Optional[Callable] = None,
is_first_block: bool = False,
@@ -2207,7 +2238,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):

# Apply score modification if present
if const_expr(score_mod_fn is not None):
score_mod_fn(acc_S, n_block=n_block)
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)

# Apply mask; mask_seqlen always True for first block
# Caveat: if full block further right than mask block, seqlen masking is redundant;
@@ -2267,6 +2298,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
tOrP: cute.Tensor,
smem_copy_params: SimpleNamespace,
softmax: Softmax,
seqlen: SeqlenInfoQK,
score_mod_fn: Optional[Callable] = None,
mask_fn: Optional[Callable] = None,
is_first_n_block: cutlass.Constexpr = False,
@@ -2281,7 +2313,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):

# handle score mods and masking
if const_expr(score_mod_fn is not None):
score_mod_fn(acc_S, n_block=n_block)
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
if const_expr(mask_fn is not None):
mask_fn(acc_S=acc_S, n_block=n_block)

@@ -2326,6 +2358,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
tOrP: cute.Tensor,
smem_copy_params: SimpleNamespace,
softmax: Softmax,
seqlen: SeqlenInfoQK,
score_mod_fn: Optional[Callable] = None,
mask_fn: Optional[Callable] = None,
check_inf: cutlass.Constexpr = True,
@@ -2345,7 +2378,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):

# handle score mods and masking
if const_expr(score_mod_fn is not None):
score_mod_fn(acc_S, n_block=n_block)
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
if const_expr(mask_fn is not None):
mask_fn(acc_S=acc_S, n_block=n_block)
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
@@ -2392,6 +2425,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
acc_S,
n_block,
softmax_scale,
seqlen,
aux_tensors: Optional[list] = None,
fastdiv_mods=None,
):
@@ -2411,6 +2445,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info=seqlen,
constant_q_idx=None,
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)
@@ -2436,4 +2471,5 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
cute.arch.barrier_arrive(
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
number_of_threads=2 * self.num_threads_per_warp_group,
)
)


+ 29
- 2
flash_attn/cute/flash_fwd_sm100.py View File

@@ -658,7 +658,11 @@ class FlashAttentionForwardSm100:
seqlen_q = cute.size(mQ.shape[0]) // (
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
)
seqlen_k = cute.size(mK.shape[0])
seqlen_k = (
cute.size(mK.shape[0])
if const_expr(mPageTable is None)
else mK.shape[0] * mPageTable.shape[1]
)
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
@@ -1624,6 +1628,26 @@ class FlashAttentionForwardSm100:
head_idx=head_idx,
aux_tensors=aux_tensors,
)

# Recompute fastdiv_mods if necessary
recompute_fastdiv_mods_q = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
)
recompute_fastdiv_mods_k = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
)

if cutlass.const_expr(fastdiv_mods is not None):
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
fastdiv_mods = (
seqlen_q_divmod
if not recompute_fastdiv_mods_q
else FastDivmodDivisor(seqlen.seqlen_q),
seqlen_k_divmod
if not recompute_fastdiv_mods_k
else FastDivmodDivisor(seqlen.seqlen_k),
)

mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None
mask_fn = partial(
mask.apply_mask_sm100,
@@ -1874,6 +1898,7 @@ class FlashAttentionForwardSm100:
m_block,
n_block,
softmax,
seqlen,
aux_tensors,
fastdiv_mods,
)
@@ -2369,7 +2394,7 @@ class FlashAttentionForwardSm100:
self.check_hdim_v_oob,
self.qhead_per_kvhead,
)
# load acc O from smem to rmem for wider vectorization
tOrO = cute.make_fragment_like(tOsO, self.o_dtype)
cute.autovec_copy(tOsO, tOrO)
@@ -2637,6 +2662,7 @@ class FlashAttentionForwardSm100:
m_block,
n_block,
softmax,
seqlen: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=(None, None),
):
@@ -2673,6 +2699,7 @@ class FlashAttentionForwardSm100:
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info=seqlen,
constant_q_idx=q_idx_logical,
qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
)

+ 234
- 180
flash_attn/cute/interface.py View File

@@ -1,7 +1,5 @@
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.

# Supported features:
# - BF16 & FP16 dtype
@@ -22,10 +20,17 @@
# - bwd pass optimized for Hopper/Blackwell

import math
from functools import lru_cache
from typing import Optional, Tuple, Callable

import torch


@lru_cache(maxsize=None)
def _get_device_capability():
"""Cached device capability check."""
return torch.cuda.get_device_capability()[0]

import cuda.bindings.driver as cuda

import cutlass
@@ -46,6 +51,7 @@ from flash_attn.cute.block_sparsity import (
BlockSparseTensorsTorch,
to_cute_block_sparse_tensors,
normalize_block_sparse_tensors,
get_block_sparse_expected_shapes,
)

def maybe_contiguous(x):
@@ -58,6 +64,15 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):
assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}"
assert t.is_cuda, f"{name} must be on CUDA"

def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False):
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=True)
if fully_dynamic:
return tensor.mark_layout_dynamic()
if leading_dim == -1:
leading_dim = t.ndim - 1
return tensor.mark_layout_dynamic(leading_dim=leading_dim)


torch2cute_dtype_map = {
torch.float16: cutlass.Float16,
@@ -114,7 +129,7 @@ def _flash_attn_fwd(
...
score_mod: A callable that takes the attention scores and applies a modification.
mask_mod: A callable that takes token position information and selectively masks
block_sparse_tensors: A tuple of tensors used for block sparsity.
block_sparse_tensors: A tuple of tensors used for block sparsity.
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
out: Optional pre-allocated output tensor. If None, will be allocated internally.
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
@@ -230,51 +245,15 @@ def _flash_attn_fwd(
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)

dtype = torch2cute_dtype_map[q.dtype]
(
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
learnable_sink_tensor,
) = [
from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
if t is not None
else None
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
]
page_table_tensor = (
from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1)
if page_table is not None
else None
)
compute_capability = (
torch.cuda.get_device_capability()[0]
_get_device_capability()
if _compute_capability is None
else _compute_capability
)

assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x"


sparse_tensors = None
if block_sparse_tensors is not None:
if seqlen_q is None:
raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
m_block_size_block = m_block_size
if compute_capability == 10:
# TODO: This multiplier should really be q_stage, wire up in later PR
# 1 cta handles 2*tile_m row
m_block_size_block = 2 * m_block_size
expected_m_blocks = (seqlen_q + m_block_size_block - 1) // m_block_size_block
expected_n_blocks = (seqlen_k + n_block_size - 1) // n_block_size
block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=(batch_size, num_head, expected_m_blocks),
expected_index_shape=(batch_size, num_head, expected_m_blocks, expected_n_blocks),
)
sparse_tensors = to_cute_block_sparse_tensors(block_sparse_tensors)

use_block_sparsity = sparse_tensors is not None
use_block_sparsity = block_sparse_tensors is not None

if mask_mod is None:
if causal:
@@ -294,6 +273,7 @@ def _flash_attn_fwd(
if compute_capability == 9: # TODO: tune block size according to hdim.
if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity:
n_block_size = 192

if compute_capability == 10:
# TODO: fix the varlen case
if (
@@ -326,17 +306,6 @@ def _flash_attn_fwd(
out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)
lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)

q_tensor, k_tensor, v_tensor, o_tensor = [
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1)
for t in (q, k, v, out if not is_split_kv else out_partial)
]
if is_split_kv:
lse_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 1)
elif lse is not None:
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1)
else:
lse_tensor = None

# hash score and mask mods for compile cache
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
@@ -351,11 +320,6 @@ def _flash_attn_fwd(
or seqused_q is not None
or seqused_k is not None
)
if score_mod is not None:
if is_varlen:
raise NotImplementedError(
"score_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
)

if mask_mod is not None:
if is_varlen:
@@ -381,10 +345,6 @@ def _flash_attn_fwd(
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
)

cute_aux_tensors = None
if aux_tensors is not None:
cute_aux_tensors = [from_dlpack(buf).mark_layout_dynamic() for buf in aux_tensors]

compile_key = (
dtype,
head_dim,
@@ -413,6 +373,52 @@ def _flash_attn_fwd(
page_size not in [None, 128], # paged KV non-TMA
)
if compile_key not in _flash_attn_fwd.compile_cache:
(
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
learnable_sink_tensor,
) = [
to_cute_tensor(t, assumed_align=4, leading_dim=0)
if t is not None
else None
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
]
page_table_tensor = (
to_cute_tensor(page_table, assumed_align=4, leading_dim=1)
if page_table is not None
else None
)
q_tensor, k_tensor, v_tensor, o_tensor = [
to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial)
]
if is_split_kv:
lse_tensor = to_cute_tensor(lse_partial, assumed_align=4)
elif lse is not None:
lse_tensor = to_cute_tensor(lse, assumed_align=4)
else:
lse_tensor = None

sparse_tensors = None
if block_sparse_tensors is not None:
if seqlen_q is None:
raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, compute_capability,
)
compile_time_normalized = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)
sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized)

cute_aux_tensors = None
if aux_tensors is not None:
cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]

if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0"
assert not is_split_kv, "SplitKV not supported on SM 9.0"
@@ -484,25 +490,40 @@ def _flash_attn_fwd(
learnable_sink_tensor,
sparse_tensors,
cute_aux_tensors,
options="--enable-tvm-ffi",
)

# Expand block sparse tensors to match actual head count (may be broadcast from 1)
normalized_block_sparse_tensors = None
if block_sparse_tensors is not None:
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, compute_capability,
)
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)

_flash_attn_fwd.compile_cache[compile_key](
q_tensor,
k_tensor,
v_tensor,
o_tensor,
lse_tensor,
q,
k,
v,
out if not is_split_kv else out_partial,
lse_partial if is_split_kv else lse,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
page_table_tensor,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
page_table,
window_size_left,
window_size_right,
learnable_sink_tensor,
sparse_tensors,
cute_aux_tensors,
learnable_sink,
normalized_block_sparse_tensors,
aux_tensors,
)
if is_split_kv:
_flash_attn_fwd_combine(
@@ -553,7 +574,7 @@ def _flash_attn_bwd(
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
compute_capability = torch.cuda.get_device_capability()[0]
compute_capability = _get_device_capability()
assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x"

if compute_capability == 9:
@@ -751,28 +772,8 @@ def _flash_attn_bwd(
)

dtype = torch2cute_dtype_map[q.dtype]
q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1)
for t in (q, k, v, out, dout, dq, dk, dv)
]
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=lse.ndim - 1
)
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1)
for t in (dq_accum, dpsum, lse_log2)
]
if qhead_per_kvhead > 1:
dk_accum_tensor, dv_accum_tensor = [
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1)
for t in (dk_accum, dv_accum)
]
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [
from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim - 1)
if t is not None
else None
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
]
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)

if deterministic:
dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda")
else:
@@ -784,16 +785,19 @@ def _flash_attn_bwd(
else:
dK_semaphore = None
dV_semaphore = None
dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [
utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order())
if t is not None else None
for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
]
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)

# Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads)
if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
]
lse_tensor = to_cute_tensor(lse, assumed_align=4)
cu_seqlens_q_tensor, seqused_q_tensor = [
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_q, seqused_q)
]
fa_bwd_pre = FlashAttentionBackwardPreprocess(
dtype,
head_dim_v,
@@ -812,16 +816,17 @@ def _flash_attn_bwd(
cu_seqlens_q_tensor,
seqused_q_tensor,
current_stream,
options="--enable-tvm-ffi",
)
_flash_attn_bwd.compile_cache_pre[compile_key_pre](
o_tensor,
do_tensor,
dpsum_tensor,
lse_tensor,
lse_log2_tensor,
dq_accum_tensor,
cu_seqlens_q_tensor,
seqused_q_tensor,
out,
dout,
dpsum,
lse,
lse_log2,
dq_accum,
cu_seqlens_q,
seqused_q,
current_stream,
)

@@ -869,6 +874,25 @@ def _flash_attn_bwd(
)
num_threads = 384
if compile_key not in _flash_attn_bwd.compile_cache:
q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)
]
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
]
if qhead_per_kvhead > 1:
dk_accum_tensor, dv_accum_tensor = [
to_cute_tensor(t) for t in (dk_accum, dv_accum)
]
cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
]
dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [
utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order())
if t is not None else None
for t in (dQ_semaphore, dK_semaphore, dV_semaphore)
]
fa_bwd_sm80 = FlashAttentionBackwardSm80(
dtype,
head_dim,
@@ -941,39 +965,48 @@ def _flash_attn_bwd(
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
window_size_left=window_size_left,
window_size_right=window_size_right,
mdQ_semaphore=dQ_semaphore_tensor,
mdK_semaphore=dK_semaphore_tensor,
mdV_semaphore=dV_semaphore_tensor,
None, # softcap - not yet supported in backward
window_size_left,
window_size_right,
dQ_semaphore_tensor,
dK_semaphore_tensor,
dV_semaphore_tensor,
options="--enable-tvm-ffi",
)
_flash_attn_bwd.compile_cache[compile_key](
q_tensor,
k_tensor,
v_tensor,
do_tensor,
lse_log2_tensor,
dpsum_tensor,
dq_accum_tensor,
dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor,
dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor,
q,
k,
v,
dout,
lse_log2,
dpsum,
dq_accum,
dk if qhead_per_kvhead == 1 else dk_accum,
dv if qhead_per_kvhead == 1 else dv_accum,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
cu_seqlens_k_tensor,
seqused_q_tensor,
seqused_k_tensor,
window_size_left=window_size_left,
window_size_right=window_size_right,
mdQ_semaphore=dQ_semaphore_tensor,
mdK_semaphore=dK_semaphore_tensor,
mdV_semaphore=dV_semaphore_tensor,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
None, # softcap - not yet supported in backward
window_size_left,
window_size_right,
dQ_semaphore,
dK_semaphore,
dV_semaphore,
)

num_threads = 256 if compute_capability == 9 else 128
# Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
dq_accum_tensor = to_cute_tensor(dq_accum)
dq_tensor = to_cute_tensor(dq)
cu_seqlens_q_tensor, seqused_q_tensor = [
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_q, seqused_q)
]
arch = compute_capability * 10
fa_bwd_post = FlashAttentionBackwardPostprocess(
dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB
@@ -987,13 +1020,14 @@ def _flash_attn_bwd(
cu_seqlens_q_tensor,
seqused_q_tensor,
current_stream,
options="--enable-tvm-ffi",
)
_flash_attn_bwd.compile_cache_post[compile_key_post](
dq_accum_tensor,
dq_tensor,
dq_accum,
dq,
softmax_scale,
cu_seqlens_q_tensor,
seqused_q_tensor,
cu_seqlens_q,
seqused_q,
current_stream,
)

@@ -1001,6 +1035,12 @@ def _flash_attn_bwd(
# Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
compile_key_post = (dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
dk_accum_tensor = to_cute_tensor(dk_accum)
dk_tensor = to_cute_tensor(dk)
cu_seqlens_k_tensor, seqused_k_tensor = [
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_k, seqused_k)
]
fa_bwd_post = FlashAttentionBackwardPostprocess(
dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
)
@@ -1013,13 +1053,14 @@ def _flash_attn_bwd(
cu_seqlens_k_tensor,
seqused_k_tensor,
current_stream,
options="--enable-tvm-ffi",
)
_flash_attn_bwd.compile_cache_post[compile_key_post](
dk_accum_tensor,
dk_tensor,
dk_accum,
dk,
softmax_scale,
cu_seqlens_k_tensor,
seqused_k_tensor,
cu_seqlens_k,
seqused_k,
current_stream,
)
compile_key_post = (
@@ -1031,6 +1072,12 @@ def _flash_attn_bwd(
dKV_swapAB,
)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
dv_accum_tensor = to_cute_tensor(dv_accum)
dv_tensor = to_cute_tensor(dv)
cu_seqlens_k_tensor, seqused_k_tensor = [
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_k, seqused_k)
]
fa_bwd_post = FlashAttentionBackwardPostprocess(
dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
)
@@ -1043,13 +1090,14 @@ def _flash_attn_bwd(
cu_seqlens_k_tensor,
seqused_k_tensor,
current_stream,
options="--enable-tvm-ffi",
)
_flash_attn_bwd.compile_cache_post[compile_key_post](
dv_accum_tensor,
dv_tensor,
cutlass.Float32(1.0),
cu_seqlens_k_tensor,
seqused_k_tensor,
dv_accum,
dv,
1.0,
cu_seqlens_k,
seqused_k,
current_stream,
)

@@ -1154,6 +1202,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
score_mod: Optional[Callable] = None,
aux_tensors: Optional[list] = None,
):
out, lse = _flash_attn_fwd(
q,
@@ -1172,6 +1222,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
score_mod=score_mod,
aux_tensors=aux_tensors,
)
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.softmax_scale = softmax_scale
@@ -1261,6 +1313,8 @@ def flash_attn_varlen_func(
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
score_mod: Optional[Callable] = None,
aux_tensors: Optional[list] = None,
):
return FlashAttnVarlenFunc.apply(
q,
@@ -1279,6 +1333,8 @@ def flash_attn_varlen_func(
num_splits,
pack_gqa,
deterministic,
score_mod,
aux_tensors,
)


@@ -1360,30 +1416,6 @@ def _flash_attn_fwd_combine(
# TODO: we can deal w this by using 128 threads instead
log_max_splits = max(log_max_splits, 5)

# Convert to cute tensors (using kernel-formatted tensors)
out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic(
leading_dim=4 if not is_varlen else 3
)
lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=lse_partial.ndim - 2
)
out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3 if not is_varlen else 2)
lse_tensor = (
from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2)
if lse is not None
else None
)

optional_tensors = [
from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
if t is not None
else None
for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
]
cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
optional_tensors
)

current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)

# Create combine kernel configuration
@@ -1403,6 +1435,28 @@ def _flash_attn_fwd_combine(
)

if compile_key not in _flash_attn_fwd_combine.compile_cache:
out_partial_tensor = to_cute_tensor(
out_partial, leading_dim=4 if not is_varlen else 3
)
lse_partial_tensor = to_cute_tensor(
lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2
)
out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2)
lse_tensor = (
to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2)
if lse is not None
else None
)

optional_tensors = [
to_cute_tensor(t, assumed_align=4, leading_dim=0)
if t is not None
else None
for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset)
]
cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = (
optional_tensors
)
fa_combine = FlashAttentionForwardCombine(
dtype=dtype,
dtype_partial=dtype_partial,
@@ -1437,17 +1491,17 @@ def _flash_attn_fwd_combine(
num_splits_dynamic_tensor,
semaphore_tensor,
current_stream,
options="--enable-tvm-ffi",
)

_flash_attn_fwd_combine.compile_cache[compile_key](
out_partial_tensor,
lse_partial_tensor,
out_tensor,
lse_tensor,
cu_seqlens_tensor,
seqused_tensor,
num_splits_dynamic_tensor,
semaphore_tensor,
out_partial,
lse_partial,
out,
lse,
cu_seqlens,
seqused,
num_splits_dynamic_ptr,
semaphore_to_reset,
current_stream,
)



+ 3
- 1
flash_attn/cute/pyproject.toml View File

@@ -22,10 +22,12 @@ classifiers = [
]

dependencies = [
"nvidia-cutlass-dsl==4.3.0",
"nvidia-cutlass-dsl==4.3.3",
"torch",
"einops",
"typing_extensions",
"apache-tvm-ffi>=0.1.5,<0.2",
"torch-c-dlpack-ext",
]

[project.optional-dependencies]


+ 12
- 1
flash_attn/cute/seqlen_info.py View File

@@ -42,6 +42,8 @@ class SeqlenInfoQK:
seqlen_k: cutlass.Int32
has_cu_seqlens_q: cutlass.Constexpr[bool]
has_cu_seqlens_k: cutlass.Constexpr[bool]
has_seqused_q: cutlass.Constexpr[bool]
has_seqused_k: cutlass.Constexpr[bool]

@staticmethod
def create(
@@ -73,8 +75,17 @@ class SeqlenInfoQK:
)
has_cu_seqlens_q: int = mCuSeqlensQ is not None
has_cu_seqlens_k: int = mCuSeqlensK is not None
has_seqused_q: int = mSeqUsedQ is not None
has_seqused_k: int = mSeqUsedK is not None
return SeqlenInfoQK(
offset_q, offset_k, seqlen_q, seqlen_k, has_cu_seqlens_q, has_cu_seqlens_k
offset_q,
offset_k,
seqlen_q,
seqlen_k,
has_cu_seqlens_q,
has_cu_seqlens_k,
has_seqused_q,
has_seqused_k,
)

def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor:


+ 13
- 9
flash_attn/cute/softmax.py View File

@@ -11,6 +11,7 @@ from cutlass import Float32

import flash_attn.cute.utils as utils
from flash_attn.cute.cute_dsl_utils import ParamsBase
from flash_attn.cute.seqlen_info import SeqlenInfoQK


@dataclass
@@ -29,8 +30,8 @@ class Softmax(ParamsBase):
arch: cutlass.Constexpr[int] = 80,
softmax_scale: Float32 | None = None,
):
row_max = cute.make_fragment(num_rows, Float32)
row_sum = cute.make_fragment(num_rows, Float32)
row_max = cute.make_rmem_tensor(num_rows, Float32)
row_sum = cute.make_rmem_tensor(num_rows, Float32)
return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)

def reset(self) -> None:
@@ -168,8 +169,8 @@ class SoftmaxSm100(Softmax):
):
num_rows = 1
arch = 100
row_max = cute.make_fragment(num_rows, Float32)
row_sum = cute.make_fragment(num_rows, Float32)
row_max = cute.make_rmem_tensor(num_rows, Float32)
row_sum = cute.make_rmem_tensor(num_rows, Float32)
return SoftmaxSm100(
scale_log2,
num_rows,
@@ -339,6 +340,7 @@ def apply_score_mod_inner(
qk_acc_dtype: cutlass.Constexpr,
aux_tensors,
fastdiv_mods,
seqlen_info: SeqlenInfoQK,
constant_q_idx: cutlass.Constexpr,
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
):
@@ -355,25 +357,26 @@ def apply_score_mod_inner(
qk_acc_dtype: Data type for accumulator
aux_tensors: Optional aux_tensors for FlexAttention
fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping
seqlen_info: Sequence length info
constant_q_idx: If provided, use this constant for all q_idx values
If None, compute q_idx per-element
If None, compute q_idx per-element
qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this
when greater than 1 so score mods see logical heads.
"""
n_vals = cutlass.const_expr(cute.size(score_tensor.shape))
score_vec = cute.make_fragment(vec_size, qk_acc_dtype)
kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype)
kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)

# SSA values for batch (constant across all elements)
batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))

# Handle q_idx based on whether it's constant
q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)

# For Pack-GQA with non-constant q_idx, we need per-element head indices
# since a thread my process multiple query head indices
if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):
head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)
head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)

for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):
for j in cutlass.range(vec_size, unroll_full=True):
@@ -431,6 +434,7 @@ def apply_score_mod_inner(
head_idx_ssa,
q_idx=q_idx_ssa,
kv_idx=kv_idx_ssa,
seqlen_info=seqlen_info,
aux_tensors=aux_args,
)



+ 1
- 1
flash_attn/cute/utils.py View File

@@ -67,7 +67,7 @@ def create_softcap_scoremod(softcap_val):
inv_softcap = 1.0 / softcap_val

@cute.jit
def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, buffers):
def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
scores = acc_S_SSA * inv_softcap
return scores * cute.math.tanh(scores, fastmath=True)



+ 591
- 0
tests/cute/score_mod_definitions.py View File

@@ -0,0 +1,591 @@
import torch
import cutlass
import cutlass.cute as cute
from cutlass._mlir.dialects import math as mlir_math
import operator

# =============================================================================
# Score_mod functions that don't use global indices
# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors)
# =============================================================================


@cute.jit
def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
return tSrS_ssa


@cute.jit
def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
mask = operator.ge(q_idx, kv_idx)
return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf")))


@cute.jit
def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
diff = q_idx - kv_idx
abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype)
return tSrS_ssa + abs_diff.to(cutlass.Float32)


@cute.jit
def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
diff = q_idx - kv_idx
abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype)
scaled = abs_diff * cute.full_like(abs_diff, 2)
return tSrS_ssa + scaled.to(cutlass.Float32)


@cute.jit
def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
return tSrS_ssa * cute.full_like(tSrS_ssa, 2)


@cute.jit
def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
score = tSrS_ssa.to(cutlass.Float32)
slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8)
slope = cute.math.exp2(
slope_exp.to(cutlass.Float32)
* cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634)
)
diff = q_idx - kv_idx
abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32)
return score - slope * abs_diff


@cute.jit
def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
diff = q_idx - kv_idx
abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype)
mask = operator.le(abs_diff, cute.full_like(abs_diff, 256))
return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf")))


@cute.jit
def score_mod_block_diagonal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
q_block = q_idx // 64
kv_block = kv_idx // 64
mask = operator.eq(q_block, kv_block)
return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf")))


@cute.jit
def score_mod_causal_v2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
diff = q_idx - kv_idx
mask = operator.ge(diff, cute.full_like(diff, 0))
return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf")))


@cute.jit
def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
batch_bias = aux_tensors[0]
dtype = batch_bias.element_type
b_frag = cute.make_fragment(1, cutlass.Int32)
b_frag.store(b_idx)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = batch_bias[b_frag[0]]
bias_val = (bias_frag.load()).to(cutlass.Float32)
return tSrS_ssa + bias_val


@cute.jit
def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
head_bias = aux_tensors[0]
pos_bias = aux_tensors[1]
dtype = head_bias.element_type

h_frag = cute.make_fragment(1, cutlass.Int32)
h_frag.store(h_idx)
head_val_frag = cute.make_fragment(1, dtype)
head_val_frag[0] = head_bias[h_frag[0]]
head_val = (head_val_frag.load()).to(cutlass.Float32)

q_frag = cute.make_fragment(1, cutlass.Int32)
q_frag.store(q_idx)
pos_val_frag = cute.make_fragment(1, dtype)
pos_val_frag[0] = pos_bias[q_frag[0]]
pos_val = (pos_val_frag.load()).to(cutlass.Float32)

return tSrS_ssa + head_val + pos_val


# =============================================================================
# Score_mod functions that use global indices
# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors)
# Global indices computed as: q_idx_global = q_idx + seqlen_info.offset_q (and similarly for kv)
# =============================================================================


@cute.jit
def score_mod_global_kv_bias(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Per-token bias using global kv index."""
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
token_bias = aux_tensors[0]
dtype = token_bias.element_type
kv_frag = cute.make_fragment(1, cutlass.Int32)
kv_frag.store(kv_idx_global)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = token_bias[kv_frag[0]]

return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32)


@cute.jit
def score_mod_global_q_bias(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Per-token bias using global q index."""
offset_q = seqlen_info.offset_q
q_idx_global = q_idx + offset_q
token_bias = aux_tensors[0]
dtype = token_bias.element_type
q_frag = cute.make_fragment(1, cutlass.Int32)
q_frag.store(q_idx_global)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = token_bias[q_frag[0]]
return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32)


@cute.jit
def score_mod_global_rel_plus_kv_bias(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Relative position (logical) + per-token bias (global kv)."""
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
token_bias = aux_tensors[0]
dtype = token_bias.element_type

rel_pos = q_idx - kv_idx
rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype)
rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1)

kv_frag = cute.make_fragment(1, cutlass.Int32)
kv_frag.store(kv_idx_global)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = token_bias[kv_frag[0]]

return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32)


@cute.jit
def score_mod_global_q_and_kv_bias(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Both q and kv global indices."""
offset_q = seqlen_info.offset_q
q_idx_global = q_idx + offset_q
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
q_bias = aux_tensors[0]
kv_bias = aux_tensors[1]
dtype = q_bias.element_type

q_frag = cute.make_fragment(1, cutlass.Int32)
q_frag.store(q_idx_global)
q_bias_frag = cute.make_fragment(1, dtype)
q_bias_frag[0] = q_bias[q_frag[0]]

kv_frag = cute.make_fragment(1, cutlass.Int32)
kv_frag.store(kv_idx_global)
kv_bias_frag = cute.make_fragment(1, dtype)
kv_bias_frag[0] = kv_bias[kv_frag[0]]

return (
tSrS_ssa
+ (q_bias_frag.load()).to(cutlass.Float32)
+ (kv_bias_frag.load()).to(cutlass.Float32)
)


@cute.jit
def score_mod_global_logical_rel_plus_kv_bias(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Logical relative + global-indexed per-token bias."""
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
token_bias = aux_tensors[0]
dtype = token_bias.element_type

rel_pos = q_idx - kv_idx
rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype)
rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.01)

kv_frag = cute.make_fragment(1, cutlass.Int32)
kv_frag.store(kv_idx_global)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = token_bias[kv_frag[0]]

return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32)


# "Stress tests" - score_mods with complex global index usage

@cute.jit
def score_mod_stress_complex_arithmetic(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""All indices in complex arithmetic."""
offset_q = seqlen_info.offset_q
q_idx_global = q_idx + offset_q
bias = aux_tensors[0]
dtype = bias.element_type

# Use absolute value instead of squaring to avoid overflow with large sequences
rel_pos = q_idx - kv_idx
rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype)
rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001)

q_frag = cute.make_fragment(1, cutlass.Int32)
q_frag.store(q_idx_global)
bias_q_frag = cute.make_fragment(1, dtype)
bias_q_frag[0] = bias[q_frag[0]]
bias_q = (bias_q_frag.load()).to(cutlass.Float32)

scale = (b_idx + cute.full_like(b_idx, 1)) * (h_idx + cute.full_like(h_idx, 1))
scale_f32 = scale.to(cutlass.Float32) * 0.001

result = tSrS_ssa + rel_bias + bias_q * scale_f32
return result


@cute.jit
def score_mod_stress_conditional_mask(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Conditional masking with global vs logical."""
offset_q = seqlen_info.offset_q
q_idx_global = q_idx + offset_q
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
token_bias = aux_tensors[0]
dtype = token_bias.element_type

kv_frag = cute.make_fragment(1, cutlass.Int32)
kv_frag.store(kv_idx_global)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = token_bias[kv_frag[0]]
bias_val = (bias_frag.load()).to(cutlass.Float32)

is_causal = operator.ge(q_idx, kv_idx)

global_diff = q_idx_global - kv_idx_global
is_nearby = operator.le(
cute.TensorSSA(mlir_math.absi(global_diff), global_diff.shape, global_diff.dtype),
cute.full_like(global_diff, 512),
)

both_conditions = is_causal & is_nearby
return cute.where(both_conditions, tSrS_ssa + bias_val, cute.full_like(tSrS_ssa, float("-inf")))


@cute.jit
def score_mod_stress_multi_buffer(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Multiple aux tensors with different indexing."""
offset_q = seqlen_info.offset_q
q_idx_global = q_idx + offset_q
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
batch_bias = aux_tensors[0]
head_scale = aux_tensors[1]
q_pos_bias = aux_tensors[2]
kv_pos_bias = aux_tensors[3]
rel_pos_scale = aux_tensors[4]

dtype = batch_bias.element_type

b_frag = cute.make_fragment(1, cutlass.Int32)
b_frag.store(b_idx)
bb_frag = cute.make_fragment(1, dtype)
bb_frag[0] = batch_bias[b_frag[0]]
bb_val = (bb_frag.load()).to(cutlass.Float32)

h_frag = cute.make_fragment(1, cutlass.Int32)
h_frag.store(h_idx)
hs_frag = cute.make_fragment(1, dtype)
hs_frag[0] = head_scale[h_frag[0]]
hs_val = (hs_frag.load()).to(cutlass.Float32)

qg_frag = cute.make_fragment(1, cutlass.Int32)
qg_frag.store(q_idx_global)
qpb_frag = cute.make_fragment(1, dtype)
qpb_frag[0] = q_pos_bias[qg_frag[0]]
qpb_val = (qpb_frag.load()).to(cutlass.Float32)

kvg_frag = cute.make_fragment(1, cutlass.Int32)
kvg_frag.store(kv_idx_global)
kvpb_frag = cute.make_fragment(1, dtype)
kvpb_frag[0] = kv_pos_bias[kvg_frag[0]]
kvpb_val = (kvpb_frag.load()).to(cutlass.Float32)

rel_idx = q_idx - kv_idx + cute.full_like(q_idx, 512)
rel_idx_clamped = cute.where(
operator.lt(rel_idx, cute.full_like(rel_idx, 0)), cute.full_like(rel_idx, 0), rel_idx
)
rel_idx_clamped = cute.where(
operator.gt(rel_idx_clamped, cute.full_like(rel_idx_clamped, 1024)),
cute.full_like(rel_idx_clamped, 1024),
rel_idx_clamped,
)
ri_frag = cute.make_fragment(1, cutlass.Int32)
ri_frag.store(rel_idx_clamped)
rps_frag = cute.make_fragment(1, dtype)
rps_frag[0] = rel_pos_scale[ri_frag[0]]
rps_val = (rps_frag.load()).to(cutlass.Float32)

return tSrS_ssa * hs_val + bb_val + qpb_val + kvpb_val + rps_val * cute.full_like(tSrS_ssa, 0.1)


@cute.jit
def score_mod_stress_global_offset(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""Verify global - logical = offset."""
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
token_bias = aux_tensors[0]
dtype = token_bias.element_type

kv_frag = cute.make_fragment(1, cutlass.Int32)
kv_frag.store(kv_idx_global)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = token_bias[kv_frag[0]]

return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32)


@cute.jit
def score_mod_stress_xor_pattern(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
"""XOR-based pattern using index bits."""
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
token_bias = aux_tensors[0]
dtype = token_bias.element_type

xor_logical = q_idx ^ kv_idx
pattern_logical = xor_logical & cute.full_like(xor_logical, 0xFF)
pattern_bias = pattern_logical.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001)

kv_frag = cute.make_fragment(1, cutlass.Int32)
kv_frag.store(kv_idx_global)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = token_bias[kv_frag[0]]

return (
tSrS_ssa
+ pattern_bias
+ (bias_frag.load()).to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1)
)


@cute.jit
def score_mod_debug_global_idx(
tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors
):
# Don't read from aux_tensors at all - just add the global index as bias
offset_k = seqlen_info.offset_k
kv_idx_global = kv_idx + offset_k
bias = kv_idx_global.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001)
return tSrS_ssa + bias


# =============================================================================
# Eager reference functions
# =============================================================================


def identity_eager(score, b, h, q_idx, kv_idx):
return score


def causal_eager(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))


def rel_bias_eager(score, b, h, q_idx, kv_idx):
return score + torch.abs(q_idx - kv_idx)


def rel_bias_x2_eager(score, b, h, q_idx, kv_idx):
return score + 2 * torch.abs(q_idx - kv_idx)


def times_two_eager(score, b, h, q_idx, kv_idx):
return score * 2


def alibi_eager(score, b, h, q_idx, kv_idx):
slope = 2 ** (-8 * (h + 1) / 8)
return score - slope * torch.abs(q_idx - kv_idx)


def sliding_window_eager(score, b, h, q_idx, kv_idx):
return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf"))


def block_diagonal_eager(score, b, h, q_idx, kv_idx):
return torch.where(q_idx // 64 == kv_idx // 64, score, float("-inf"))


def causal_v2_eager(score, b, h, q_idx, kv_idx):
return torch.where(q_idx - kv_idx >= 0, score, float("-inf"))


def batch_bias_factory(bias_tensor):
def mod(score, b, h, q_idx, kv_idx):
return score + bias_tensor[b]

return mod


def dual_buffer_factory(head_bias, pos_bias):
def mod(score, b, h, q_idx, kv_idx):
return score + head_bias[h] + pos_bias[q_idx]

return mod


def packed_kv_bias_factory(bias_tensor, cu_seqlens_k):
def mod(score, b, h, q_idx, kv_idx):
# Calculate valid length for this sequence
start = cu_seqlens_k[b]
seq_len = cu_seqlens_k[b+1] - start

# Clamp kv_idx.
safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1)

return score + bias_tensor[start + safe_kv_idx]
return mod


def packed_q_bias_factory(bias_tensor, cu_seqlens_q):
def mod(score, b, h, q_idx, kv_idx):
start = cu_seqlens_q[b]
seq_len = cu_seqlens_q[b+1] - start

# Clamp q_idx
safe_q_idx = torch.clamp(q_idx, max=seq_len - 1)

return score + bias_tensor[start + safe_q_idx]
return mod


def packed_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k):
def mod(score, b, h, q_idx, kv_idx):
start = cu_seqlens_k[b]
seq_len = cu_seqlens_k[b+1] - start

# Clamp kv_idx
safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1)

rel_bias = torch.abs(q_idx - kv_idx).float() * 0.1
return score + rel_bias + bias_tensor[start + safe_kv_idx]

return mod


def packed_q_and_kv_bias_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k):
def mod(score, b, h, q_idx, kv_idx):
# Handle Q bounds
q_start = cu_seqlens_q[b]
q_len = cu_seqlens_q[b+1] - q_start
safe_q_idx = torch.clamp(q_idx, max=q_len - 1)

# Handle KV bounds
kv_start = cu_seqlens_k[b]
kv_len = cu_seqlens_k[b+1] - kv_start
safe_kv_idx = torch.clamp(kv_idx, max=kv_len - 1)

return score + q_bias[q_start + safe_q_idx] + kv_bias[kv_start + safe_kv_idx]

return mod


def packed_logical_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k):
def mod(score, b, h, q_idx, kv_idx):
rel_bias = torch.abs(q_idx - kv_idx).float() * 0.01
return score + rel_bias + bias_tensor[cu_seqlens_k[b] + kv_idx]

return mod


def stress_complex_arithmetic_factory(bias, cu_seqlens_q):
def mod(score, b, h, q_idx, kv_idx):
# Use absolute value instead of squaring to avoid overflow with large sequences
rel_pos_abs = torch.abs(q_idx - kv_idx)
q_global = cu_seqlens_q[b] + q_idx
bias_q = bias[q_global]
scale = (b + 1) * (h + 1) * 0.001
rel_bias = rel_pos_abs * 0.001
return score + rel_bias + bias_q * scale

return mod


def stress_conditional_mask_factory(token_bias, cu_seqlens_q, cu_seqlens_k):
def mod(score, b, h, q_idx, kv_idx):
kv_global = cu_seqlens_k[b] + kv_idx
bias_val = token_bias[kv_global]
is_causal = q_idx >= kv_idx
q_global = cu_seqlens_q[b] + q_idx
global_diff = q_global - kv_global
is_nearby = torch.abs(global_diff) <= 512
both_conditions = is_causal & is_nearby
return torch.where(both_conditions, score + bias_val, float("-inf"))

return mod


def stress_multi_buffer_factory(
batch_bias,
head_scale,
q_pos_bias,
kv_pos_bias,
rel_pos_scale,
cu_seqlens_q,
cu_seqlens_k,
max_rel_pos=512,
):
def mod(score, b, h, q_idx, kv_idx):
bb_val = batch_bias[b]
hs_val = head_scale[h]
qpb_val = q_pos_bias[cu_seqlens_q[b] + q_idx]
kvpb_val = kv_pos_bias[cu_seqlens_k[b] + kv_idx]
rel_idx = (q_idx - kv_idx + max_rel_pos).clamp(0, max_rel_pos * 2)
rps_val = rel_pos_scale[rel_idx]
return score * hs_val + bb_val + qpb_val + kvpb_val + rps_val * 0.1

return mod


def stress_global_offset_factory(token_bias, cu_seqlens_k):
def mod(score, b, h, q_idx, kv_idx):
return score + token_bias[cu_seqlens_k[b] + kv_idx]

return mod


def stress_xor_pattern_factory(token_bias, cu_seqlens_q, cu_seqlens_k):
def mod(score, b, h, q_idx, kv_idx):
xor_logical = q_idx ^ kv_idx
pattern_bias = (xor_logical & 0xFF).float() * 0.001
kv_global = cu_seqlens_k[b] + kv_idx
return score + pattern_bias + token_bias[kv_global] * 0.1

return mod

def debug_global_idx_factory(bias, cu_seqlens_k):
offsets = cu_seqlens_k.tolist()
def mod(score, b, h, q_idx, kv_idx):
global_kv = offsets[b] + kv_idx
return score + global_kv.float() * 0.001
return mod

+ 1
- 0
tests/cute/test_flash_attn.py View File

@@ -274,6 +274,7 @@ def test_flash_attn_output(
and dv == d
and learnable_sink is None
# and False
and not ((causal or local) and seqlen_k < seqlen_q)
):
g = torch.randn_like(out)
# do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)


+ 396
- 294
tests/cute/test_score_mod.py View File

@@ -6,218 +6,34 @@ from cutlass._mlir.dialects import math as mlir_math
import operator
from torch.nn.attention.flex_attention import flex_attention
from flash_attn.cute.interface import _flash_attn_fwd


@cute.jit
def score_mod_1(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = tSrS_ssa
tSrS_ssa = tmp0
return tSrS_ssa


@cute.jit
def score_mod_2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = q_idx
tmp1 = kv_idx
tmp2 = operator.ge(tmp0, tmp1)
tmp3 = tSrS_ssa
tmp4 = cute.where(tmp2, tmp3, cute.full_like(tmp3, float("-inf")))
tSrS_ssa = tmp4
return tSrS_ssa


@cute.jit
def score_mod_3(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = tSrS_ssa
tmp1 = q_idx
tmp2 = kv_idx
tmp3 = tmp1 - tmp2
tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype)
tmp5 = tmp4.to(cutlass.Float32)
tmp6 = tmp0 + tmp5
tSrS_ssa = tmp6
return tSrS_ssa


@cute.jit
def score_mod_4(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = tSrS_ssa
tmp1 = q_idx
tmp2 = kv_idx
tmp3 = tmp1 - tmp2
tmp4 = cute.TensorSSA(mlir_math.absi(tmp3), tmp3.shape, tmp3.dtype)
tmp5 = tmp4 * cute.full_like(tmp4, 2)
tmp6 = tmp5.to(cutlass.Float32)
tmp7 = tmp0 + tmp6
tSrS_ssa = tmp7
return tSrS_ssa


@cute.jit
def score_mod_5(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = tSrS_ssa
tmp1 = tmp0 * cute.full_like(tmp0, 2)
tSrS_ssa = tmp1
return tSrS_ssa


@cute.jit
def score_mod_6(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = tSrS_ssa
tmp1 = tmp0.to(cutlass.Float32)
tmp2 = h_idx
tmp3 = tmp2 + cute.full_like(tmp2, 1)
tmp4 = tmp3 * cute.full_like(tmp3, -8)
tmp5 = tmp4.to(cutlass.Float32)
tmp6 = tmp5 * cute.full_like(tmp5, 0.125)
tmp7 = tmp6 * cute.full_like(tmp6, 0.6931471805599453)
tmp8 = cute.math.exp2(tmp7 * 1.4426950408889634)
tmp9 = q_idx
tmp10 = kv_idx
tmp11 = tmp9 - tmp10
tmp12 = cute.TensorSSA(mlir_math.absi(tmp11), tmp11.shape, tmp11.dtype)
tmp13 = tmp12.to(cutlass.Float32)
tmp14 = tmp8 * tmp13
tmp15 = tmp1 - tmp14
tSrS_ssa = tmp15
return tSrS_ssa


@cute.jit
def score_mod_7(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = q_idx
tmp1 = kv_idx
tmp2 = tmp0 - tmp1
tmp3 = cute.TensorSSA(mlir_math.absi(tmp2), tmp2.shape, tmp2.dtype)
tmp4 = operator.le(tmp3, cute.full_like(tmp3, 256))
tmp5 = tSrS_ssa
tmp6 = cute.where(tmp4, tmp5, cute.full_like(tmp5, float("-inf")))
tSrS_ssa = tmp6
return tSrS_ssa


@cute.jit
def score_mod_8(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = q_idx
tmp1 = kv_idx
tmp2 = tSrS_ssa
tmp3 = cute.where(
operator.eq(tmp0 // 64, tmp1 // 64), tmp2, cute.full_like(tmp2, float("-inf"))
)
tSrS_ssa = tmp3
return tSrS_ssa


@cute.jit
def score_mod_9(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
tmp0 = q_idx
tmp1 = kv_idx
tmp2 = tmp0 - tmp1
tmp3 = operator.ge(tmp2, cute.full_like(tmp2, 0))
tmp4 = tSrS_ssa
tmp5 = cute.where(tmp3, tmp4, cute.full_like(tmp4, float("-inf")))
tSrS_ssa = tmp5
return tSrS_ssa


@cute.jit
def score_mod_10(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
batch_bias = aux_tensors[0]

# Detect dtype from buffer element type
dtype = batch_bias.element_type

b_frag = cute.make_fragment(1, cutlass.Int32)
b_frag.store(b_idx)
bias_frag = cute.make_fragment(1, dtype)
bias_frag[0] = batch_bias[b_frag[0]]
bias_val = (bias_frag.load()).to(cutlass.Float32)

return tSrS_ssa + bias_val


@cute.jit
def score_mod_11(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):
head_bias = aux_tensors[0]
pos_bias = aux_tensors[1]

# Detect dtype from buffer element type
dtype = head_bias.element_type

h_frag = cute.make_fragment(1, cutlass.Int32)
h_frag.store(h_idx)
head_val_frag = cute.make_fragment(1, dtype)
head_val_frag[0] = head_bias[h_frag[0]]
head_val = (head_val_frag.load()).to(cutlass.Float32)

q_frag = cute.make_fragment(1, cutlass.Int32)
q_frag.store(q_idx)
pos_val_frag = cute.make_fragment(1, dtype)
pos_val_frag[0] = pos_bias[q_frag[0]]
pos_val = (pos_val_frag.load()).to(cutlass.Float32)

return tSrS_ssa + head_val + pos_val


# Eager reference functions for comparison
def identity_eager(score, b, h, q_idx, kv_idx):
return score


def causal_mask_eager(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))


def relative_bias_eager(score, b, h, q_idx, kv_idx):
return score + torch.abs(q_idx - kv_idx)


def relative_bias_v2_eager(score, b, h, q_idx, kv_idx):
return score + 2 * torch.abs(q_idx - kv_idx)


def times_two_eager(score, b, h, q_idx, kv_idx):
return score * 2


def alibi_bias_eager(score, b, h, q_idx, kv_idx):
slope = 2 ** (-8 * (h + 1) / 8)
return score - slope * torch.abs(q_idx - kv_idx)


def sliding_window_eager(score, b, h, q_idx, kv_idx):
return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf"))


def block_diagonal_eager(score, b, h, q_idx, kv_idx):
q_block = q_idx // 64
kv_block = kv_idx // 64
return torch.where(q_block == kv_block, score, float("-inf"))


def causal_mask_v2_eager(score, b, h, q_idx, kv_idx):
return torch.where(q_idx - kv_idx >= 0, score, float("-inf"))


def batch_bias(bias_tensor):
"""Per-batch bias (tests batch indexing)."""

def batch_bias_mod(score, b, h, q_idx, kv_idx):
return score + bias_tensor[b]

return batch_bias_mod


def dual_buffer_bias(head_bias, pos_scale):
"""Dual buffer loading (tests loading from 2 separate tensors)."""

def dual_buffer_mod(score, b, h, q_idx, kv_idx):
head_component = head_bias[h]
pos_component = pos_scale[q_idx]
return score + pos_component + head_component

return dual_buffer_mod

from score_mod_definitions import (
# TensorSSA-based score mods
score_mod_identity as score_mod_1,
score_mod_causal as score_mod_2,
score_mod_rel_bias as score_mod_3,
score_mod_rel_bias_x2 as score_mod_4,
score_mod_times_two as score_mod_5,
score_mod_alibi as score_mod_6,
score_mod_sliding_window as score_mod_7,
score_mod_block_diagonal as score_mod_8,
score_mod_causal_v2 as score_mod_9,
score_mod_batch_bias as score_mod_10,
score_mod_dual_buffer as score_mod_11,
) # isort: split
from score_mod_definitions import (
# Eager (torch) reference score mods
identity_eager,
causal_eager as causal_mask_eager,
rel_bias_eager as relative_bias_eager,
rel_bias_x2_eager as relative_bias_v2_eager,
times_two_eager,
alibi_eager as alibi_bias_eager,
sliding_window_eager,
block_diagonal_eager,
causal_v2_eager as causal_mask_v2_eager,
batch_bias_factory as batch_bias,
dual_buffer_factory as dual_buffer_bias,
)

# Test pairs: (cute_jit_function, eager_reference_function)
TEST_PAIRS = [
@@ -238,6 +54,29 @@ TEST_PAIRS_WITH_AUX_TENSORS = [
(score_mod_11, dual_buffer_bias),
]

SEQLEN_CONFIGS = [
(1, 1),
(64, 128),
(128, 192),
(256, 256),
(239, 1),
(799, 3),
(113, 203),
(113, 128),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(384, 256),
(640, 128),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(4096, 4096),
(4224, 4224),
]


def create_tensors(
batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16
@@ -277,31 +116,7 @@ def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor:
)


@pytest.mark.parametrize(
"seqlen_q,seqlen_kv",
[
(1, 1),
(64, 128),
(128, 192),
(256, 256),
(239, 1),
(799, 3),
(113, 203),
(113, 128),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(384, 256),
(640, 128),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(4096, 4096),
(4224, 4224),
],
)
@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS)
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS)
@@ -354,31 +169,7 @@ def test_cute_vs_flex_attention(
)


@pytest.mark.parametrize(
"seqlen_q,seqlen_kv",
[
(1, 1),
(64, 128),
(128, 192),
(256, 256),
(239, 1),
(799, 3),
(113, 203),
(113, 128),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(384, 256),
(640, 128),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(4096, 4096),
(4224, 4224),
],
)
@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS)
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS)
@@ -451,48 +242,359 @@ def test_cute_vs_flex_attention_with_aux_tensors(
)


@pytest.mark.xfail(
raises=NotImplementedError, reason="Varlen with score_mod not yet supported"
def _generate_block_kvcache(
seqlen_k, page_size, batch_size, nheads_k, d, device, dtype
):
import math
from einops import rearrange

num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3
k_cache_paged = torch.randn(
num_blocks, page_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, page_size, nheads_k, d, device=device, dtype=dtype
)
page_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache_bshd = rearrange(
k_cache_paged[page_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache_bshd = rearrange(
v_cache_paged[page_table.flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
k_cache = k_cache_bshd.transpose(1, 2)
v_cache = v_cache_bshd.transpose(1, 2)
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("page_size", [None, 1, 4, 128])
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)])
@pytest.mark.parametrize(
"seqlen_q,seqlen_kv",
[
(1, 128),
(64, 256),
(64, 800),
(256, 256),
(113, 203),
],
)
def test_varlen_with_score_mod():
"""Test that varlen (variable length sequences) works with score_mod.
@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS)
def test_score_mod_with_paged_kvcache(
seqlen_q,
seqlen_kv,
qhead_per_kvhead,
num_kv_heads,
page_size,
dtype,
score_mod_pair,
):
if page_size is not None and seqlen_kv % page_size != 0:
pytest.skip()

For varlen, tokens from different sequences should not attend to each other.
Without proper index mapping, the causal mask will be applied to the global
indices instead of per-sequence logical indices.
"""
torch.random.manual_seed(42)
cute_score_mod, eager_score_mod = score_mod_pair

batch_size = 2
num_q_heads = num_kv_heads * qhead_per_kvhead
pack_gqa = qhead_per_kvhead > 1
dim = 128
device = "cuda"

q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype)

if page_size is None:
k_cache = torch.randn(
batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype
)
v_cache = torch.randn(
batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype
)
page_table = None
k_cache_paged = None
v_cache_paged = None
else:
(
k_cache,
v_cache,
page_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype
)

cache_seqlens = torch.randint(
1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device
)

from einops import rearrange

arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded

if pack_gqa:
k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1)
v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1)
else:
k_cache_rep = k_cache
v_cache_rep = v_cache

seqlens = [64, 56, 128]
total_seq = sum(seqlens)
num_heads = 4
dtype = torch.bfloat16
def make_masked_score_mod(base_score_mod, seqused_k_tensor):
seqused_k_dev = seqused_k_tensor

cu_seqlens = torch.tensor(
[0] + list(torch.tensor(seqlens).cumsum(0).tolist()),
device="cuda",
dtype=torch.int32,
def masked_score_mod(score, b, h, q_idx, kv_idx):
if base_score_mod is not None:
score = base_score_mod(score, b, h, q_idx, kv_idx)
seqlen_limit = torch.gather(seqused_k_dev, 0, b.long())
valid_mask = kv_idx < seqlen_limit
return torch.where(valid_mask, score, torch.full_like(score, float("-inf")))

return masked_score_mod

masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens)
masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens)

out_ref_fp32 = run_flex_reference(
q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32
)
q = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype)
k = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype)
v = torch.randn(total_seq, num_heads, 128, device="cuda", dtype=dtype)
out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod)

q_bshd = q.transpose(1, 2)
out_cute = torch.empty_like(q_bshd)

if page_size is None:
k_bshd = k_cache.transpose(1, 2)
v_bshd = v_cache.transpose(1, 2)
_flash_attn_fwd(
q_bshd,
k_bshd,
v_bshd,
seqused_k=cache_seqlens,
return_lse=True,
score_mod=cute_score_mod,
out=out_cute,
lse=None,
pack_gqa=pack_gqa,
)
else:
_flash_attn_fwd(
q_bshd,
k_cache_paged,
v_cache_paged,
seqused_k=cache_seqlens,
page_table=page_table,
return_lse=True,
score_mod=cute_score_mod,
out=out_cute,
lse=None,
pack_gqa=pack_gqa,
)

out_cute = out_cute.transpose(1, 2)

out_cute = torch.empty_like(q)
assert out_cute.shape == out_ref_fp32.shape == out_pt.shape
assert not torch.isnan(out_cute).any()
assert not torch.isnan(out_ref_fp32).any()
assert not torch.isnan(out_pt).any()
assert torch.isfinite(out_cute).all()
assert torch.isfinite(out_ref_fp32).all()
assert torch.isfinite(out_pt).all()

_flash_attn_fwd(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
return_lse=True,
score_mod=score_mod_2,
out=out_cute,
lse=None,
fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()
rtol = 2

pt_error = (out_pt - out_ref_fp32).abs().max().item()
cute_error = (out_cute - out_ref_fp32).abs().max().item()

print(
f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):"
)
print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}")
print(f" CuTE vs FP32 ref max error: {cute_error:.2e}")
print(f" Dynamic absolute tolerance: {fwd_atol:.2e}")
print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}")

assert cute_error <= rtol * pt_error + fwd_atol, (
f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}"
)


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("page_size", [None, 128])
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)])
@pytest.mark.parametrize(
"seqlen_q,seqlen_kv",
[
(64, 128),
(128, 256),
(256, 256),
],
)
@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS)
def test_score_mod_with_paged_kvcache_aux_tensors(
seqlen_q,
seqlen_kv,
qhead_per_kvhead,
num_kv_heads,
page_size,
dtype,
score_mod_pair,
):
if page_size is not None and seqlen_kv % page_size != 0:
pytest.skip()

torch.random.manual_seed(42)
cute_score_mod, eager_score_mod_factory = score_mod_pair

batch_size = 2
num_q_heads = num_kv_heads * qhead_per_kvhead
pack_gqa = qhead_per_kvhead > 1
dim = 128
device = "cuda"

q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype)

if page_size is None:
k_cache = torch.randn(
batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype
)
v_cache = torch.randn(
batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype
)
page_table = None
k_cache_paged = None
v_cache_paged = None
else:
(
k_cache,
v_cache,
page_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype
)

cache_seqlens = torch.randint(
1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device
)

if cute_score_mod == score_mod_10:
buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1
aux_tensors = [buffer]
eager_score_mod = eager_score_mod_factory(buffer)
elif cute_score_mod == score_mod_11:
head_bias = torch.randn(num_q_heads, device=device, dtype=dtype) * 0.2
pos_scale = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01
aux_tensors = [head_bias, pos_scale]
eager_score_mod = eager_score_mod_factory(head_bias, pos_scale)

from einops import rearrange

arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded

if pack_gqa:
k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1)
v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1)
else:
k_cache_rep = k_cache
v_cache_rep = v_cache

def make_masked_score_mod(base_score_mod, seqused_k_tensor):
seqused_k_dev = seqused_k_tensor

def masked_score_mod(score, b, h, q_idx, kv_idx):
if base_score_mod is not None:
score = base_score_mod(score, b, h, q_idx, kv_idx)
seqlen_limit = torch.gather(seqused_k_dev, 0, b.long())
valid_mask = kv_idx < seqlen_limit
return torch.where(valid_mask, score, torch.full_like(score, float("-inf")))

return masked_score_mod

masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens)
masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens)

out_ref_fp32 = run_flex_reference(
q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32
)
out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod)

q_bshd = q.transpose(1, 2)
out_cute = torch.empty_like(q_bshd)

if page_size is None:
k_bshd = k_cache.transpose(1, 2)
v_bshd = v_cache.transpose(1, 2)
_flash_attn_fwd(
q_bshd,
k_bshd,
v_bshd,
seqused_k=cache_seqlens,
return_lse=True,
score_mod=cute_score_mod,
out=out_cute,
lse=None,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
)
else:
_flash_attn_fwd(
q_bshd,
k_cache_paged,
v_cache_paged,
seqused_k=cache_seqlens,
page_table=page_table,
return_lse=True,
score_mod=cute_score_mod,
out=out_cute,
lse=None,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
)

out_cute = out_cute.transpose(1, 2)

assert out_cute.shape == out_ref_fp32.shape == out_pt.shape
assert not torch.isnan(out_cute).any()
assert not torch.isnan(out_ref_fp32).any()
assert not torch.isnan(out_pt).any()
assert torch.isfinite(out_cute).all()
assert torch.isfinite(out_ref_fp32).all()
assert torch.isfinite(out_pt).all()

fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()
rtol = 2

pt_error = (out_pt - out_ref_fp32).abs().max().item()
cute_error = (out_cute - out_ref_fp32).abs().max().item()

print(
f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):"
)
print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}")
print(f" CuTE vs FP32 ref max error: {cute_error:.2e}")
print(f" Dynamic absolute tolerance: {fwd_atol:.2e}")
print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}")

assert not torch.isnan(out_cute).any(), "Output contains NaN values"
assert torch.isfinite(out_cute).all(), "Output contains infinite values"
assert cute_error <= rtol * pt_error + fwd_atol, (
f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}"
)


if __name__ == "__main__":


+ 1048
- 0
tests/cute/test_score_mod_varlen.py View File

@@ -0,0 +1,1048 @@
import pytest
import torch
from torch.nn.attention.flex_attention import flex_attention
from flash_attn.cute.interface import _flash_attn_fwd
from test_score_mod import _generate_block_kvcache
from score_mod_definitions import (
# TensorSSA-based score mods
score_mod_alibi,
score_mod_batch_bias,
score_mod_block_diagonal,
score_mod_causal,
score_mod_causal_v2,
score_mod_debug_global_idx,
score_mod_dual_buffer,
score_mod_global_kv_bias,
score_mod_global_logical_rel_plus_kv_bias,
score_mod_global_q_and_kv_bias,
score_mod_global_q_bias,
score_mod_global_rel_plus_kv_bias,
score_mod_identity,
score_mod_rel_bias,
score_mod_rel_bias_x2,
score_mod_sliding_window,
score_mod_stress_complex_arithmetic,
score_mod_stress_conditional_mask,
score_mod_stress_global_offset,
score_mod_stress_multi_buffer,
score_mod_stress_xor_pattern,
score_mod_times_two,
) # isort: split
from score_mod_definitions import (
# Eager (torch) reference score mods
identity_eager,
causal_eager,
rel_bias_eager,
rel_bias_x2_eager,
times_two_eager,
alibi_eager,
sliding_window_eager,
block_diagonal_eager,
causal_v2_eager,
batch_bias_factory,
dual_buffer_factory,
packed_kv_bias_factory,
packed_q_bias_factory,
packed_rel_plus_kv_bias_factory,
packed_q_and_kv_bias_factory,
packed_logical_rel_plus_kv_bias_factory,
stress_complex_arithmetic_factory,
stress_conditional_mask_factory,
stress_multi_buffer_factory,
stress_global_offset_factory,
stress_xor_pattern_factory,
debug_global_idx_factory,
)

# =============================================================================
# Test pairs
# =============================================================================

# (cute_score_mod, eager_factory_or_fn, aux_type)
# aux_type: None, "batch", "dual_buffer"
# All score_mods use 7-arg signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors)
TEST_PAIRS_NO_GLOBAL = [
(score_mod_identity, identity_eager, None),
(score_mod_causal, causal_eager, None),
(score_mod_rel_bias, rel_bias_eager, None),
(score_mod_rel_bias_x2, rel_bias_x2_eager, None),
(score_mod_times_two, times_two_eager, None),
(score_mod_alibi, alibi_eager, None),
(score_mod_sliding_window, sliding_window_eager, None),
(score_mod_block_diagonal, block_diagonal_eager, None),
(score_mod_causal_v2, causal_v2_eager, None),
(score_mod_batch_bias, batch_bias_factory, "batch"),
(score_mod_dual_buffer, dual_buffer_factory, "dual_buffer"),
]

# (cute_score_mod, eager_factory, aux_type, requires_global)
# aux_type: "kv", "q", "q_and_kv", "q_concat", "kv_with_cu", "multi_buffer"
# requires_global: "q" (needs varlen_q), "kv" (needs varlen_k), "both" (needs both)
# All score_mods use 7-arg signature and compute global indices from seqlen_info
TEST_PAIRS_WITH_GLOBAL = [
(score_mod_global_kv_bias, packed_kv_bias_factory, "kv", "kv"),
(score_mod_global_q_bias, packed_q_bias_factory, "q", "q"),
(score_mod_global_rel_plus_kv_bias, packed_rel_plus_kv_bias_factory, "kv", "kv"),
(score_mod_global_q_and_kv_bias, packed_q_and_kv_bias_factory, "q_and_kv", "both"),
(
score_mod_global_logical_rel_plus_kv_bias,
packed_logical_rel_plus_kv_bias_factory,
"kv",
"kv",
),
(
score_mod_stress_complex_arithmetic,
stress_complex_arithmetic_factory,
"q_concat",
"q",
),
(
score_mod_stress_conditional_mask,
stress_conditional_mask_factory,
"kv_with_cu",
"both",
),
(
score_mod_stress_multi_buffer,
stress_multi_buffer_factory,
"multi_buffer",
"both",
),
(score_mod_stress_global_offset, stress_global_offset_factory, "kv", "kv"),
(score_mod_stress_xor_pattern, stress_xor_pattern_factory, "kv_with_cu", "kv"),
(score_mod_debug_global_idx, debug_global_idx_factory, "kv", "kv"),
]

SEQLEN_CONFIGS = [
([1], [1]),
([1, 1], [1, 1]),
([2, 3], [2, 3]),
([8, 16], [8, 16]),
([32, 32], [32, 32]),
([64, 128], [64, 128]),
([64, 56, 128], [64, 56, 128]),
([256, 512], [256, 512]),
([113, 203], [113, 203]),
([239, 1], [239, 1]),
([64], [64]),
([128, 128], [128, 128]),
([32, 32, 32, 32], [32, 32, 32, 32]),
([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]),
([1, 1024], [1, 1024]),
([1024, 1], [1024, 1]),
([1, 256, 1], [1, 256, 1]),
([256, 1, 256], [256, 1, 256]),
([17, 33, 65], [17, 33, 65]),
([64, 128], [32, 64]),
([100, 100], [50, 50]),
([256, 512, 256], [128, 256, 128]),
([2, 1], [16384, 32 * 1024]),
([1, 1], [128 * 1024] * 2),
([2, 1], [8192, 8192]),
([1, 3], [8192, 8192]),
([3, 3], [8192, 8192]),
([128, 128], [8192, 8192]),
([2, 2, 2], [8 * 1024] * 3),
([2, 1], [1024 * 32, 16384]),
([1, 2], [1024 * 32, 16384]),
([1, 1, 1], [128 * 1024] * 3),
([1, 1, 1], [256 * 1024] * 3),
]

# =============================================================================
# Helper functions
# =============================================================================


def run_cute_flash(
q,
k,
v,
score_mod,
aux_tensors=None,
pack_gqa=False,
cu_seqlens_q=None,
cu_seqlens_k=None,
page_table=None,
seqused_k=None,
):
"""Run CuTE flash attention."""
if cu_seqlens_q is not None or cu_seqlens_k is not None:
out = torch.empty_like(q)
_flash_attn_fwd(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_k=seqused_k,
page_table=page_table,
return_lse=True,
score_mod=score_mod,
out=out,
lse=None,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
)
return out

out = torch.empty_like(q)
_flash_attn_fwd(
q,
k,
v,
seqused_k=seqused_k,
page_table=page_table,
return_lse=True,
score_mod=score_mod,
out=out,
lse=None,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
)
return out


def run_flex_varlen_ref(q, k, v, cu_seqlens_q, cu_seqlens_k, score_mod, dtype=None):
"""Run flex_attention per-sequence for varlen reference."""
if cu_seqlens_q is not None:
num_batches = len(cu_seqlens_q) - 1
else:
num_batches = len(cu_seqlens_k) - 1

results = []
for i in range(num_batches):
# Get Q slice
if cu_seqlens_q is not None:
q_slice = (
q[cu_seqlens_q[i] : cu_seqlens_q[i + 1]].unsqueeze(0).transpose(1, 2)
)
else:
q_slice = q[i : i + 1].transpose(1, 2)

# Get K/V slices
if cu_seqlens_k is not None:
k_slice = (
k[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2)
)
v_slice = (
v[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2)
)
else:
k_slice = k[i : i + 1].transpose(1, 2)
v_slice = v[i : i + 1].transpose(1, 2)

if dtype is not None:
q_slice, k_slice, v_slice = (
q_slice.to(dtype),
k_slice.to(dtype),
v_slice.to(dtype),
)

def wrapped_mod(score, b, h, q_idx, kv_idx):
return score_mod(score, i, h, q_idx, kv_idx)

out = flex_attention(
q_slice,
k_slice,
v_slice,
score_mod=wrapped_mod,
enable_gqa=q_slice.shape[1] != k_slice.shape[1],
)
results.append(out.transpose(1, 2).squeeze(0))

return torch.cat(results, dim=0)


def setup_tensors(seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype):
"""Create Q, K, V tensors and cu_seqlens based on varlen flags."""
batch_size = len(seqlens_q)

if varlen_q:
total_q = sum(seqlens_q)
q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype)
cu_seqlens_q = torch.tensor(
[0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()),
device="cuda",
dtype=torch.int32,
)
else:
seqlen_q = seqlens_q[0] # All sequences have the same length for non-varlen
q = torch.randn(
batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype
)
cu_seqlens_q = None

if varlen_k:
total_k = sum(seqlens_k)
k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype)
v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype)
cu_seqlens_k = torch.tensor(
[0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()),
device="cuda",
dtype=torch.int32,
)
else:
seqlen_k = seqlens_k[0] # All sequences have the same length for non-varlen
k = torch.randn(
batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype
)
v = torch.randn(
batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype
)
cu_seqlens_k = None

return q, k, v, cu_seqlens_q, cu_seqlens_k


def prepare_ref_tensors(
q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q
):
"""Prepare tensors for flex_attention reference (handle mixed varlen formats)."""
num_heads = q.shape[1] if varlen_q else q.shape[2]

if not varlen_q and varlen_k:
seqlen_q = q.shape[1]
q_packed = q.reshape(-1, num_heads, q.shape[-1])
ref_cu_seqlens_q = torch.tensor(
[seqlen_q * i for i in range(batch_size + 1)],
device="cuda",
dtype=torch.int32,
)
return q_packed, k, v, ref_cu_seqlens_q, cu_seqlens_k

if varlen_q and not varlen_k:
return q, k, v, cu_seqlens_q, None

return q, k, v, cu_seqlens_q, cu_seqlens_k


def check_results(
out_cute,
out_ref_fp32,
out_pt,
test_name,
rtol=2,
extra_atol=1e-4,
seqlens_q=None,
cu_seqlens_q=None,
):
"""Compare CuTE output against references."""
assert not torch.isnan(out_cute).any(), f"{test_name}: NaN in output"
assert torch.isfinite(out_cute).all(), f"{test_name}: Inf in output"

varlen_q = cu_seqlens_q is not None

if varlen_q:
# Unpack and compare per-sequence
assert seqlens_q is not None, "varlen_q requires use of seqlens_q"
num_seqs = len(seqlens_q)
max_cute_error = 0.0
max_pt_error = 0.0

for i in range(num_seqs):
# Extract sequences using cu_seqlens (all outputs are in packed format)
start_q = cu_seqlens_q[i]
end_q = cu_seqlens_q[i + 1]
cute_seq = out_cute[start_q:end_q]
ref_seq = out_ref_fp32[start_q:end_q]
pt_seq = out_pt[start_q:end_q]

max_cute_error = max(
max_cute_error, (cute_seq - ref_seq).abs().max().item()
)
max_pt_error = max(max_pt_error, (pt_seq - ref_seq).abs().max().item())

cute_error = max_cute_error
pt_error = max_pt_error
else:
# Direct comparison
pt_error = (out_pt - out_ref_fp32).abs().max().item()
cute_error = (out_cute - out_ref_fp32).abs().max().item()

fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()

print(f"\n{test_name}:")
print(f" PyTorch vs FP32 ref: {pt_error:.2e}")
print(f" CuTE vs FP32 ref: {cute_error:.2e}")

tol = rtol * pt_error + fwd_atol + extra_atol
assert cute_error <= tol, (
f"{test_name}: CuTE error {cute_error:.2e} exceeds tolerance {tol:.2e}"
)


# =============================================================================
# Tests
# =============================================================================


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("varlen_q", [True, False])
@pytest.mark.parametrize("varlen_k", [True, False])
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)])
@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS)
@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL)
def test_varlen_with_score_mod(
seqlens_q,
seqlens_k,
varlen_q,
varlen_k,
qhead_per_kvhead,
num_kv_heads,
dtype,
score_mod_tuple,
):
"""Test varlen attention with score_mod functions that don't use global indices.

Covers: both varlen, varlen Q only, varlen K only.
Skips: neither varlen
"""
if not varlen_q and not varlen_k:
pytest.skip(
"At least one of varlen_q or varlen_k must be True for varlen tests"
)

# For non-varlen dimension, all sequences must have same length
if not varlen_q:
seqlens_q = [seqlens_q[0]] * len(seqlens_q)
if not varlen_k:
seqlens_k = [seqlens_k[0]] * len(seqlens_k)

torch.random.manual_seed(42)
cute_score_mod, eager_factory, aux_type = score_mod_tuple

num_heads = num_kv_heads * qhead_per_kvhead
pack_gqa = qhead_per_kvhead > 1
head_dim = 128
batch_size = len(seqlens_q)

q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors(
seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype
)

if pack_gqa:
if varlen_k:
k = k[:, :num_kv_heads, :].clone()
v = v[:, :num_kv_heads, :].clone()
else:
k = k[:, :, :num_kv_heads, :].clone()
v = v[:, :, :num_kv_heads, :].clone()

aux_tensors = None
if aux_type == "batch":
bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias)
elif aux_type == "dual_buffer":
seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q)
head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2
pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01
aux_tensors = [head_bias, pos_bias]
eager_score_mod = eager_factory(head_bias, pos_bias)
else:
eager_score_mod = eager_factory

# Prepare reference tensors
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(
q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q
)

out_ref_fp32 = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32
)
out_pt = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype
)
out_cute = run_cute_flash(
q,
k,
v,
cute_score_mod,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)

if not varlen_q and varlen_k:
seqlen_q = q.shape[1]
out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim)
out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim)

assert out_cute.shape == out_ref_fp32.shape, (
f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}"
)

test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k})"
extra_atol = 2e-3
check_results(
out_cute,
out_ref_fp32,
out_pt,
test_name,
extra_atol=extra_atol,
seqlens_q=seqlens_q if varlen_q else None,
cu_seqlens_q=cu_seqlens_q if varlen_q else None,
)


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("varlen_q", [True, False])
@pytest.mark.parametrize("varlen_k", [True, False])
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)])
@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS)
@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL)
def test_varlen_with_global_idx_score_mod(
seqlens_q,
seqlens_k,
varlen_q,
varlen_k,
qhead_per_kvhead,
num_kv_heads,
dtype,
score_mod_tuple,
):
"""Test varlen attention with score_mod functions that use global indices.

These score_mods compute q_idx_global and/or kv_idx_global from seqlen_info for packed tensor indexing.
Skips tests where required global indices aren't available.
"""
if not varlen_q and not varlen_k:
pytest.skip(
"At least one of varlen_q or varlen_k must be True for varlen tests"
)

cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple

# Skip if score_mod requires global indices we can't provide
if requires_global == "q" and not varlen_q:
pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global")
if requires_global == "kv" and not varlen_k:
pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global")
if requires_global == "both" and (not varlen_q or not varlen_k):
pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k")

# For non-varlen dimension, all sequences must have same length
if not varlen_q:
seqlens_q = [seqlens_q[0]] * len(seqlens_q)
if not varlen_k:
seqlens_k = [seqlens_k[0]] * len(seqlens_k)

torch.random.manual_seed(42)

num_heads = num_kv_heads * qhead_per_kvhead
pack_gqa = qhead_per_kvhead > 1
head_dim = 128
batch_size = len(seqlens_q)
max_rel_pos = 512

total_q = sum(seqlens_q)
total_k = sum(seqlens_k)

cu_seqlens_q = torch.tensor(
[0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()),
device="cuda",
dtype=torch.int32,
)
cu_seqlens_k = torch.tensor(
[0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()),
device="cuda",
dtype=torch.int32,
)

if varlen_q:
q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype)
else:
seqlen_q = seqlens_q[0]
q = torch.randn(
batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype
)

if varlen_k:
k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype)
v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype)
else:
seqlen_k = seqlens_k[0]
k = torch.randn(
batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype
)
v = torch.randn(
batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype
)

if pack_gqa:
if varlen_k:
k = k[:, :num_kv_heads, :].clone()
v = v[:, :num_kv_heads, :].clone()
else:
k = k[:, :, :num_kv_heads, :].clone()
v = v[:, :, :num_kv_heads, :].clone()

# Setup aux tensors based on indexing type
if aux_type == "kv":
bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias, cu_seqlens_k)
elif aux_type == "q":
bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias, cu_seqlens_q)
elif aux_type == "q_and_kv":
q_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1
kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1
aux_tensors = [q_bias, kv_bias]
eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k)
elif aux_type == "q_concat":
bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias, cu_seqlens_q)
elif aux_type == "kv_with_cu":
kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1
aux_tensors = [kv_bias]
eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k)
elif aux_type == "multi_buffer":
batch_bias = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1
head_scale = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.1 + 1.0
q_pos_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1
kv_pos_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1
rel_pos_scale = (
torch.randn(max_rel_pos * 2 + 1, device="cuda", dtype=dtype) * 0.1
)
aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale]
eager_score_mod = eager_factory(
batch_bias,
head_scale,
q_pos_bias,
kv_pos_bias,
rel_pos_scale,
cu_seqlens_q,
cu_seqlens_k,
max_rel_pos,
)
else:
raise ValueError(f"Unknown aux_type: {aux_type}")

# Prepare reference tensors for flex_attention
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(
q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q
)

out_ref_fp32 = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32
)
out_pt = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype
)

kernel_cu_seqlens_q = cu_seqlens_q if varlen_q else None
kernel_cu_seqlens_k = cu_seqlens_k if varlen_k else None
out_cute = run_cute_flash(
q,
k,
v,
cute_score_mod,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
cu_seqlens_q=kernel_cu_seqlens_q,
cu_seqlens_k=kernel_cu_seqlens_k,
)

if varlen_q:
out_ref_final = out_ref_fp32
out_pt_final = out_pt
out_cute_final = out_cute
else:
seqlen_q = seqlens_q[0]
out_ref_final = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim)
out_pt_final = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim)
out_cute_final = out_cute

assert out_cute_final.shape == out_ref_final.shape, (
f"Shape mismatch: {out_cute_final.shape} vs {out_ref_final.shape}"
)

test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, {aux_type})"

check_results(
out_cute_final,
out_ref_final,
out_pt_final,
test_name,
extra_atol=1e-3,
seqlens_q=seqlens_q if varlen_q else None,
cu_seqlens_q=cu_seqlens_q if varlen_q else None,
)


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("page_size", [None, 128])
@pytest.mark.parametrize("varlen_q", [True, False])
@pytest.mark.parametrize("varlen_k", [True, False])
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)])
@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS)
@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL)
def test_varlen_score_mod_kvcache(
seqlens_q,
seqlens_k,
varlen_q,
varlen_k,
qhead_per_kvhead,
num_kv_heads,
page_size,
dtype,
score_mod_tuple,
):
"""Test varlen attention with score_mod and paged KV cache."""
if not varlen_q and not varlen_k:
pytest.skip(
"At least one of varlen_q or varlen_k must be True for varlen tests"
)

if page_size is not None and varlen_k:
pytest.skip("Paged KV requires batched (non-varlen) K")

if not varlen_q:
seqlens_q = [seqlens_q[0]] * len(seqlens_q)
if not varlen_k:
seqlens_k = [seqlens_k[0]] * len(seqlens_k)

# Skip if page_size doesn't divide seqlens evenly (for simplicity)
if page_size is not None and not varlen_k:
if seqlens_k[0] % page_size != 0:
pytest.skip("page_size must divide seqlen_k")

torch.random.manual_seed(42)
cute_score_mod, eager_factory, aux_type = score_mod_tuple

num_heads = num_kv_heads * qhead_per_kvhead
pack_gqa = qhead_per_kvhead > 1
head_dim = 128
batch_size = len(seqlens_q)
device = "cuda"

# Setup tensors
q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors(
seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype
)

if pack_gqa:
if varlen_k:
k = k[:, :num_kv_heads, :].clone()
v = v[:, :num_kv_heads, :].clone()
else:
k = k[:, :, :num_kv_heads, :].clone()
v = v[:, :, :num_kv_heads, :].clone()

page_table = None
k_cache_paged = None
v_cache_paged = None
k_cache = k
v_cache = v

if page_size is not None:
seqlen_k = seqlens_k[0]
(
k_cache_bhsd,
v_cache_bhsd,
page_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype
)
k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD
v_cache = v_cache_bhsd.transpose(1, 2)
seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device)
else:
seqused_k = None

# Setup aux tensors and eager score_mod
aux_tensors = None
if aux_type == "batch":
bias = torch.zeros(batch_size, device=device, dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias)
elif aux_type == "dual_buffer":
seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q)
head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2
pos_bias = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01
aux_tensors = [head_bias, pos_bias]
eager_score_mod = eager_factory(head_bias, pos_bias)
else:
eager_score_mod = eager_factory

# Prepare reference tensors
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(
q,
k_cache,
v_cache,
cu_seqlens_q,
cu_seqlens_k,
varlen_q,
varlen_k,
batch_size,
seqlens_q,
)

out_ref_fp32 = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32
)
out_pt = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype
)

k_input = k_cache_paged if page_size is not None else k_cache
v_input = v_cache_paged if page_size is not None else v_cache

out_cute = run_cute_flash(
q,
k_input,
v_input,
cute_score_mod,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
cu_seqlens_q=cu_seqlens_q if varlen_q else None,
cu_seqlens_k=cu_seqlens_k if (varlen_k and page_size is None) else None,
page_table=page_table if page_size is not None else None,
seqused_k=seqused_k if page_size is not None else None,
)

if not varlen_q and varlen_k:
seqlen_q = q.shape[1]
out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim)
out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim)

assert out_cute.shape == out_ref_fp32.shape, (
f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}"
)

test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, paged={page_size is not None})"
extra_atol = 2e-3
check_results(
out_cute,
out_ref_fp32,
out_pt,
test_name,
extra_atol=extra_atol,
seqlens_q=seqlens_q if varlen_q else None,
cu_seqlens_q=cu_seqlens_q if varlen_q else None,
)


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("page_size", [None, 128])
@pytest.mark.parametrize("varlen_q", [True, False])
@pytest.mark.parametrize("varlen_k", [True, False])
@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)])
@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS)
@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL)
def test_varlen_score_mod_with_paged_kvcache_global(
seqlens_q,
seqlens_k,
varlen_q,
varlen_k,
qhead_per_kvhead,
num_kv_heads,
page_size,
dtype,
score_mod_tuple,
):
"""Test varlen attention with global idx score_mod and paged KV cache."""
if page_size is not None and varlen_k:
pytest.skip("Paged KV cache requires batched (non-varlen) K")

if not varlen_q and not varlen_k:
pytest.skip(
"At least one of varlen_q or varlen_k must be True for varlen tests"
)

if not varlen_q:
seqlens_q = [seqlens_q[0]] * len(seqlens_q)
if not varlen_k:
seqlens_k = [seqlens_k[0]] * len(seqlens_k)

if page_size is not None and not varlen_k:
if seqlens_k[0] % page_size != 0:
pytest.skip("page_size must divide seqlen_k")

cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple

if requires_global == "q" and not varlen_q:
pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global")
if requires_global == "kv" and not varlen_k:
pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global")
if requires_global == "both" and (not varlen_q or not varlen_k):
pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k")

torch.random.manual_seed(42)

num_heads = num_kv_heads * qhead_per_kvhead
pack_gqa = qhead_per_kvhead > 1
head_dim = 128
batch_size = len(seqlens_q)
max_rel_pos = 512
device = "cuda"

total_q = sum(seqlens_q)
total_k = sum(seqlens_k)

cu_seqlens_q = torch.tensor(
[0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()),
device=device,
dtype=torch.int32,
)
cu_seqlens_k = torch.tensor(
[0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()),
device=device,
dtype=torch.int32,
)
cu_seqlens_k_for_kernel = cu_seqlens_k if varlen_k else None

q = torch.randn(total_q, num_heads, head_dim, device=device, dtype=dtype)
if varlen_k:
k = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype)
v = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype)
else:
seqlen_k = seqlens_k[0]
k = torch.randn(
batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype
)
v = torch.randn(
batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype
)

if pack_gqa:
if varlen_k:
k = k[:, :num_kv_heads, :].clone()
v = v[:, :num_kv_heads, :].clone()
else:
k = k[:, :, :num_kv_heads, :].clone()
v = v[:, :, :num_kv_heads, :].clone()

page_table = None
k_cache_paged = None
v_cache_paged = None
k_cache = k
v_cache = v

if page_size is not None:
seqlen_k = seqlens_k[0]
(
k_cache_bhsd,
v_cache_bhsd,
page_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype
)
k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD
v_cache = v_cache_bhsd.transpose(1, 2)
seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device)
else:
seqused_k = None

if aux_type == "kv":
bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias, cu_seqlens_k)
elif aux_type == "q":
bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias, cu_seqlens_q)
elif aux_type == "q_and_kv":
q_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1
kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1
aux_tensors = [q_bias, kv_bias]
eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k)
elif aux_type == "q_concat":
bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1
aux_tensors = [bias]
eager_score_mod = eager_factory(bias, cu_seqlens_q)
elif aux_type == "kv_with_cu":
kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1
aux_tensors = [kv_bias]
eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k)
elif aux_type == "multi_buffer":
batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1
head_scale = torch.randn(num_heads, device=device, dtype=dtype) * 0.1 + 1.0
q_pos_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1
kv_pos_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1
rel_pos_scale = (
torch.randn(max_rel_pos * 2 + 1, device=device, dtype=dtype) * 0.1
)
aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale]
eager_score_mod = eager_factory(
batch_bias,
head_scale,
q_pos_bias,
kv_pos_bias,
rel_pos_scale,
cu_seqlens_q,
cu_seqlens_k,
max_rel_pos,
)
else:
raise ValueError(f"Unknown aux_type: {aux_type}")

q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(
q,
k_cache,
v_cache,
cu_seqlens_q,
cu_seqlens_k,
True,
varlen_k,
batch_size,
seqlens_q,
)

out_ref_fp32 = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32
)
out_pt = run_flex_varlen_ref(
q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype
)

# Run CuTE
k_input = k_cache_paged if page_size is not None else k_cache
v_input = v_cache_paged if page_size is not None else v_cache

out_cute = torch.empty_like(q)
_flash_attn_fwd(
q,
k_input,
v_input,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k_for_kernel if page_size is None else None,
seqused_k=seqused_k if page_size is not None else None,
page_table=page_table,
return_lse=True,
score_mod=cute_score_mod,
out=out_cute,
lse=None,
aux_tensors=aux_tensors,
pack_gqa=pack_gqa,
)

assert out_cute.shape == out_ref_fp32.shape, (
f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}"
)

test_name = f"{cute_score_mod.__name__} (paged={page_size is not None}, {aux_type})"
check_results(
out_cute,
out_ref_fp32,
out_pt,
test_name,
extra_atol=1e-3,
seqlens_q=seqlens_q,
cu_seqlens_q=cu_seqlens_q,
)


if __name__ == "__main__":
pytest.main([__file__, "-v"])

Loading…
Cancel
Save
Baidu
map