24 Commits

Author SHA1 Message Date
  fems14 b662d914a4
[bugfix] [main] Fix KV cache query inconsistency across different TP ranks in the KV Pool (#5030) 9 hours ago
  Jade Zheng c064d11fd7
[Cleanup] Remove unused attn_metadata parameter from Proposer classes (#4862) 9 hours ago
  whx a9625851ef
[Attention] Temporarily add back pa for small batch sizes. (#4765) 10 hours ago
  baxingpiaochong 95e6400128
[KVPool]Fix PP get bug (#5007) 10 hours ago
  InSec a5cb8e40f5
[doc]Modify quantization tutorials (#5026) 11 hours ago
  zhangyiming e90e8afc94
[E2E] Collect test run time. (#5018) 11 hours ago
  zhangxinyuehfad 019c8e03c2
[CI] Delete deepseek3.2-exp nightly test (#5028) 11 hours ago
  Li Wang 8d2998d0e4
[Misc] Upgrade vllm hash to 12_14 (#5000) 11 hours ago
  wangx700 3b7eb5179f
[Bugfix] fix the incorrect use of python's sum on tensors. (#4655) 11 hours ago
  zengzengran 6029bea480
[UT]add pcp dcp ut (#4949) 12 hours ago
  Icey 5fae65f3a8
[Graph][Fusion] Add AddRMSNorm(with bias) and Quant Fusion Pattern (#5011) 12 hours ago
  fluctlux 6de4bedd04
update release note for suffix decoding (#5009) 13 hours ago
  Levi df7e0fe916
[Bugfix] qwen3-vl-235b-w8a8 load weight ERROR when start service (#4292) 14 hours ago
  knight0528 e25c57b346
[Bugfix] Add support for PP intermediate value types in graph mode (#4902) 14 hours ago
  zzhxxx e16444f21f
[Bugfix] Fix the bug in initializing the shared_weight communication domain in sfa-cp, and fix the mtp weight load in pp>1 situation (#4913) 14 hours ago
  SILONG ZENG 70606e0bb9
[Test]update accuracy test of models (#4911) 16 hours ago
  Chao Lei b75bfc58f6
[Doc ] Supplement kvpool user guide (#5013) 16 hours ago
  Chen Chen aa02a85e4d
[bugfix] Fix dummy-run and multi-node issues in MoE routing and MTP (#4947) 17 hours ago
  dependabot[bot] cc7b302020
Bump actions/upload-artifact from 5 to 6 (#5014) 17 hours ago
  drslark 8fb0ef5ffa
[main][BugFix] Fixed an accuracy bug of Qwen3-next-MTP when batched inferring (#4932) 17 hours ago
  wujinyuan1 545e856971
[Refactor]3/N Refactor mla_v1.py & extract mla_cp (#4933) 18 hours ago
  ming1212 98b9e2e18e
Add Qwen3-Next tutorials (#4607) 19 hours ago
  Mengqing Cao 6beb4434e1
[CI][Bugfix] Fix scheduleroutput has no attr get error in prompt logprobs (#4998) 20 hours ago
  Li Wang 2497bbbaf6
[Misc] Update pooling example (#5002) 22 hours ago
65 changed files with 3019 additions and 2293 deletions
Split View
  1. +1
    -1
      .github/workflows/_e2e_nightly_single_node_models.yaml
  2. +47
    -47
      .github/workflows/_e2e_test.yaml
  3. +3
    -4
      .github/workflows/nightly_test_a2.yaml
  4. +8
    -6
      .github/workflows/nightly_test_a3.yaml
  5. +2
    -2
      .github/workflows/pr_tag_release_code_and_wheel.yml
  6. +1
    -1
      .github/workflows/pr_test_full.yaml
  7. +3
    -3
      .github/workflows/pr_test_light.yaml
  8. +1
    -1
      .github/workflows/schedule_test_benchmarks.yaml
  9. +0
    -1
      csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp
  10. +1
    -1
      docs/source/community/versioning_policy.md
  11. +3
    -1
      docs/source/tutorials/Qwen3-32B-W4A4.md
  12. +2
    -0
      docs/source/tutorials/Qwen3-8B-W4A8.md
  13. +60
    -14
      docs/source/tutorials/Qwen3-Next.md
  14. +2
    -2
      docs/source/tutorials/Qwen3_embedding.md
  15. +8
    -1
      docs/source/user_guide/feature_guide/kv_pool.md
  16. +2
    -1
      docs/source/user_guide/release_notes.md
  17. +1
    -1
      examples/offline_embed.py
  18. +0
    -11
      tests/e2e/models/configs/InternVL2-8B.yaml
  19. +0
    -11
      tests/e2e/models/configs/InternVL2_5-8B.yaml
  20. +0
    -11
      tests/e2e/models/configs/InternVL3-8B.yaml
  21. +3
    -4
      tests/e2e/models/configs/Llama-3.2-3B-Instruct.yaml
  22. +11
    -0
      tests/e2e/models/configs/Qwen3-Omni-30B-A3B-Instruct.yaml
  23. +4
    -6
      tests/e2e/models/configs/accuracy.txt
  24. +1
    -0
      tests/e2e/models/configs/gemma-3-4b-it.yaml
  25. +2
    -3
      tests/e2e/models/configs/llava-onevision-qwen2-0.5b-ov-hf.yaml
  26. +7
    -2
      tests/e2e/multicard/test_qwen3_next.py
  27. +60
    -4
      tests/e2e/singlecard/compile/test_norm_quant_fusion.py
  28. +321
    -0
      tests/ut/attention/test_attention_cp.py
  29. +403
    -0
      tests/ut/attention/test_mla_cp.py
  30. +45
    -10
      tests/ut/attention/test_mla_v1.py
  31. +5
    -2
      tests/ut/compilation/test_acl_graph.py
  32. +4
    -0
      tests/ut/spec_decode/test_eagle_proposer.py
  33. +0
    -2
      tests/ut/spec_decode/test_mtp_proposer.py
  34. +0
    -375
      tests/ut/worker/test_input_batch.py
  35. +7
    -0
      vllm_ascend/ascend_config.py
  36. +71
    -4
      vllm_ascend/attention/attention_v1.py
  37. +1274
    -0
      vllm_ascend/attention/mla_cp.py
  38. +87
    -720
      vllm_ascend/attention/mla_v1.py
  39. +2
    -2
      vllm_ascend/attention/sfa_v1.py
  40. +16
    -0
      vllm_ascend/attention/utils.py
  41. +69
    -2
      vllm_ascend/compilation/acl_graph.py
  42. +60
    -0
      vllm_ascend/compilation/passes/norm_quant_fusion_pass.py
  43. +2
    -5
      vllm_ascend/distributed/kvpool/backend/memcache_backend.py
  44. +1
    -1
      vllm_ascend/distributed/kvpool/config_data.py
  45. +148
    -78
      vllm_ascend/distributed/kvpool/kv_transfer.py
  46. +3
    -3
      vllm_ascend/distributed/kvpool/pool_scheduler.py
  47. +36
    -113
      vllm_ascend/distributed/kvpool/pool_worker.py
  48. +4
    -3
      vllm_ascend/distributed/parallel_state.py
  49. +1
    -1
      vllm_ascend/eplb/utils.py
  50. +16
    -10
      vllm_ascend/ops/fused_moe/fused_moe.py
  51. +8
    -4
      vllm_ascend/patch/__init__.py
  52. +6
    -1
      vllm_ascend/patch/platform/__init__.py
  53. +6
    -7
      vllm_ascend/patch/platform/patch_ec_connector.py
  54. +33
    -0
      vllm_ascend/patch/platform/patch_ec_connector012.py
  55. +2
    -0
      vllm_ascend/patch/worker/patch_module.py
  56. +4
    -0
      vllm_ascend/platform.py
  57. +34
    -0
      vllm_ascend/quantization/quant_config.py
  58. +0
    -1
      vllm_ascend/spec_decode/eagle_proposer.py
  59. +0
    -1
      vllm_ascend/spec_decode/interface.py
  60. +9
    -17
      vllm_ascend/spec_decode/mtp_proposer.py
  61. +0
    -1
      vllm_ascend/spec_decode/ngram_proposer.py
  62. +0
    -1
      vllm_ascend/spec_decode/suffix_proposer.py
  63. +16
    -6
      vllm_ascend/utils.py
  64. +65
    -32
      vllm_ascend/worker/model_runner_v1.py
  65. +28
    -752
      vllm_ascend/worker/npu_input_batch.py

+ 1
- 1
.github/workflows/_e2e_nightly_single_node_models.yaml View File

@@ -223,7 +223,7 @@ jobs:

- name: Upload Report
if: ${{ inputs.upload == true }}
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: report-${{ env.GHA_VLLM_ASCEND_VERSION }}-${{ steps.ts.outputs.artifact_ts }}
path: ./benchmarks/accuracy/


+ 47
- 47
.github/workflows/_e2e_test.yaml View File

@@ -75,10 +75,10 @@ jobs:
PYTORCH_NPU_ALLOC_CONF: max_split_size_mb:256
if: ${{ inputs.type == 'light' }}
run: |
# pytest -sv tests/e2e/singlecard/test_aclgraph_accuracy.py
# pytest -sv tests/e2e/singlecard/test_quantization.py
pytest -sv tests/e2e/singlecard/test_vlm.py::test_multimodal_vl
pytest -sv tests/e2e/singlecard/pooling/test_classification.py::test_classify_correctness
# pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py
# pytest -sv --durations=0 tests/e2e/singlecard/test_quantization.py
pytest -sv --durations=0 tests/e2e/singlecard/test_vlm.py::test_multimodal_vl
pytest -sv --durations=0 tests/e2e/singlecard/pooling/test_classification.py::test_classify_correctness

- name: Run e2e test
env:
@@ -90,25 +90,25 @@ jobs:
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
# the test separately.

pytest -sv tests/e2e/singlecard/test_completion_with_prompt_embeds.py
pytest -sv tests/e2e/singlecard/test_aclgraph_accuracy.py
pytest -sv tests/e2e/singlecard/test_aclgraph_mem.py
pytest -sv tests/e2e/singlecard/test_async_scheduling.py
pytest -sv tests/e2e/singlecard/test_camem.py
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
pytest -sv --durations=0 tests/e2e/singlecard/test_completion_with_prompt_embeds.py
pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py
pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_mem.py
pytest -sv --durations=0 tests/e2e/singlecard/test_async_scheduling.py
pytest -sv --durations=0 tests/e2e/singlecard/test_camem.py
pytest -sv --durations=0 tests/e2e/singlecard/test_guided_decoding.py
# torch 2.8 doesn't work with lora, fix me
#pytest -sv tests/e2e/singlecard/test_ilama_lora.py
pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py
pytest -sv tests/e2e/singlecard/test_quantization.py
pytest -sv tests/e2e/singlecard/test_sampler.py
pytest -sv tests/e2e/singlecard/test_vlm.py
pytest -sv tests/e2e/singlecard/test_xlite.py
pytest -sv tests/e2e/singlecard/pooling/
pytest -sv tests/e2e/singlecard/compile/test_norm_quant_fusion.py
#pytest -sv --durations=0 tests/e2e/singlecard/test_ilama_lora.py
pytest -sv --durations=0 tests/e2e/singlecard/test_profile_execute_duration.py
pytest -sv --durations=0 tests/e2e/singlecard/test_quantization.py
pytest -sv --durations=0 tests/e2e/singlecard/test_sampler.py
pytest -sv --durations=0 tests/e2e/singlecard/test_vlm.py
pytest -sv --durations=0 tests/e2e/singlecard/test_xlite.py
pytest -sv --durations=0 tests/e2e/singlecard/pooling/
pytest -sv --durations=0 tests/e2e/singlecard/compile/test_norm_quant_fusion.py

# ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py

e2e-2-cards:
name: multicard-2
@@ -170,7 +170,7 @@ jobs:
VLLM_USE_MODELSCOPE: True
if: ${{ inputs.type == 'light' }}
run: |
pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_qwen3_moe_distributed_mp_tp2_ep
pytest -sv --durations=0 tests/e2e/multicard/test_qwen3_moe.py::test_qwen3_moe_distributed_mp_tp2_ep

- name: Run vllm-project/vllm-ascend test (full)
env:
@@ -178,30 +178,30 @@ jobs:
VLLM_USE_MODELSCOPE: True
if: ${{ inputs.type == 'full' }}
run: |
pytest -sv tests/e2e/multicard/test_quantization.py
pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py
pytest -sv tests/e2e/multicard/test_external_launcher.py
pytest -sv tests/e2e/multicard/test_single_request_aclgraph.py
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
pytest -sv --durations=0 tests/e2e/multicard/test_quantization.py
pytest -sv --durations=0 tests/e2e/multicard/test_aclgraph_capture_replay.py
pytest -sv --durations=0 tests/e2e/multicard/test_full_graph_mode.py
pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel.py
pytest -sv --durations=0 tests/e2e/multicard/test_expert_parallel.py
pytest -sv --durations=0 tests/e2e/multicard/test_external_launcher.py
pytest -sv --durations=0 tests/e2e/multicard/test_single_request_aclgraph.py
pytest -sv --durations=0 tests/e2e/multicard/test_fused_moe_allgather_ep.py
# torch 2.8 doesn't work with lora, fix me
#pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
#pytest -sv --durations=0 tests/e2e/multicard/test_ilama_lora_tp2.py

# To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_fc2_for_qwen3_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
pytest -sv tests/e2e/multicard/test_prefix_caching.py
pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
pytest -sv tests/e2e/multicard/test_qwen3_moe.py
pytest -sv tests/e2e/multicard/test_offline_weight_load.py
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_sp_for_qwen3_moe
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_fc2_for_qwen3_moe
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_flashcomm_v1
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight
pytest -sv --durations=0 tests/e2e/multicard/test_prefix_caching.py
pytest -sv --durations=0 tests/e2e/multicard/test_pipeline_parallel.py
pytest -sv --durations=0 tests/e2e/multicard/test_qwen3_moe.py
pytest -sv --durations=0 tests/e2e/multicard/test_offline_weight_load.py

e2e-4-cards:
name: multicard-4
@@ -264,10 +264,10 @@ jobs:
VLLM_WORKER_MULTIPROC_METHOD: spawn
VLLM_USE_MODELSCOPE: True
run: |
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Kimi_K2_Thinking_W4A16
pytest -sv tests/e2e/multicard/test_data_parallel_tp2.py
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Kimi_K2_Thinking_W4A16
pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py

- name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct)
shell: bash -l {0}
@@ -283,4 +283,4 @@ jobs:
VLLM_USE_MODELSCOPE: True
run: |
. /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh
pytest -sv tests/e2e/multicard/test_qwen3_next.py
pytest -sv --durations=0 tests/e2e/multicard/test_qwen3_next.py

+ 3
- 4
.github/workflows/nightly_test_a2.yaml View File

@@ -86,15 +86,13 @@ jobs:
- Qwen3-8B-W8A8
- Qwen3-VL-8B-Instruct
- Qwen2.5-Omni-7B
- Meta-Llama-3.1-8B-Instruct
- os: linux-aarch64-a2-1
model_list:
- ERNIE-4.5-21B-A3B-PT
- gemma-3-4b-it
- internlm-7b
- InternVL3_5-8B-hf
- llava-1.5-7b-hf
- Molmo-7B-D-0924
- Llama-3.2-3B-Instruct
- llava-onevision-qwen2-0.5b-ov-hf
- os: linux-aarch64-a2-2
model_list:
- Qwen3-30B-A3B
@@ -103,6 +101,7 @@ jobs:
- os: linux-aarch64-a2-4
model_list:
- Qwen3-Next-80B-A3B-Instruct
- Qwen3-VL-30B-A3B-Instruct
uses: ./.github/workflows/_e2e_nightly_single_node_models.yaml
with:
vllm: v0.12.0


+ 8
- 6
.github/workflows/nightly_test_a3.yaml View File

@@ -61,9 +61,10 @@ jobs:
- name: multi-node-qwenw8a8-2node
config_file_path: Qwen3-235B-W8A8.yaml
size: 2
- name: multi-node-dpsk3.2-exp-2node
config_file_path: DeepSeek-V3_2-Exp-bf16.yaml
size: 2
# TODO: Replace deepseek3.2-exp with deepseek3.2 after nightly tests pass
# - name: multi-node-dpsk3.2-exp-2node
# config_file_path: DeepSeek-V3_2-Exp-bf16.yaml
# size: 2
- name: multi-node-deepseek-r1-w8a8-eplb
config_file_path: DeepSeek-R1-W8A8-EPLB.yaml
size: 4
@@ -128,9 +129,10 @@ jobs:
- name: qwen3-235b-w8a8
os: linux-aarch64-a3-16
tests: tests/e2e/nightly/models/test_qwen3_235b_w8a8.py
- name: deepseek3_2-exp-w8a8
os: linux-aarch64-a3-16
tests: tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py
# TODO: Replace deepseek3.2-exp with deepseek3.2 after nightly tests pass
# - name: deepseek3_2-exp-w8a8
# os: linux-aarch64-a3-16
# tests: tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
with:
vllm: v0.12.0


+ 2
- 2
.github/workflows/pr_tag_release_code_and_wheel.yml View File

@@ -70,7 +70,7 @@ jobs:
ls dist

- name: Archive tar.gz
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: vllm-ascend-src
path: dist/*
@@ -155,7 +155,7 @@ jobs:
done

- name: Archive wheel
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: vllm-ascend-${{ matrix.os }}-py${{ matrix.python-version }}-wheel
path: dist/*


+ 1
- 1
.github/workflows/pr_test_full.yaml View File

@@ -74,7 +74,7 @@ jobs:
name: e2e-full
strategy:
matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0]
vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0]
needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' }}
uses: ./.github/workflows/_e2e_test.yaml


+ 3
- 3
.github/workflows/pr_test_light.yaml View File

@@ -42,7 +42,7 @@ jobs:
lint:
uses: ./.github/workflows/_pre_commit.yml
with:
vllm: ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
vllm: 97f2f160fda2805f9149b0e44da76b5d3b1f7c7e
changes:
runs-on: linux-aarch64-a2-0
outputs:
@@ -90,7 +90,7 @@ jobs:
SOC_VERSION: ascend910b1
strategy:
matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0]
vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0]

steps:
- name: Free up disk space
@@ -154,7 +154,7 @@ jobs:
name: e2e-light
strategy:
matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0]
vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0]
# Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request.


+ 1
- 1
.github/workflows/schedule_test_benchmarks.yaml View File

@@ -134,7 +134,7 @@ jobs:

- name: Upload benchmark artifacts
if: github.event_name != 'schedule' && github.event_name != 'workflow_dispatch'
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: "benchmark-performance-${{ matrix.vllm_branch }}-${{ matrix.vllm_ascend_branch }}-report"
path: ./benchmarks/results/benchmark_results.md


+ 0
- 1
csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp View File

@@ -114,7 +114,6 @@ __aicore__ inline void moe_init_routing_quant_v2(
srcToDstAndGatherOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &srcToDstGatherPipe);
srcToDstAndGatherOp.Process();
srcToDstGatherPipe.Destroy();
return;
}
}



+ 1
- 1
docs/source/community/versioning_policy.md View File

@@ -45,7 +45,7 @@ The table below is the release compatibility matrix for vLLM Ascend release.
For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly.
| vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu |
|-------------|--------------|------------------|-------------|--------------------|
| main | ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 |
| main | 97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 |

## Release cadence



+ 3
- 1
docs/source/tutorials/Qwen3-32B-W4A4.md View File

@@ -55,10 +55,12 @@ cd example/Qwen
MODEL_PATH=/home/models/Qwen3-32B
# Path to save converted weight, Replace with your local path
SAVE_PATH=/home/models/Qwen3-32B-w4a4
# Set two idle NPU cards
export ASCEND_RT_VISIBLE_DEVICES=0,1

python3 w4a4.py --model_path $MODEL_PATH \
--save_directory $SAVE_PATH \
--calib_file ../common/qwen_qwen3_cot_w4a4.json \
--calib_file ./calib_data/qwen3_cot_w4a4.json \
--trust_remote_code True \
--batch_size 1
```


+ 2
- 0
docs/source/tutorials/Qwen3-8B-W4A8.md View File

@@ -47,6 +47,8 @@ cd example/Qwen
MODEL_PATH=/home/models/Qwen3-8B
# Path to save converted weight, Replace with your local path
SAVE_PATH=/home/models/Qwen3-8B-w4a8
# Set an idle NPU card
export ASCEND_RT_VISIBLE_DEVICES=0

python quant_qwen.py \
--model_path $MODEL_PATH \


+ 60
- 14
docs/source/tutorials/Qwen3-Next.md View File

@@ -1,12 +1,25 @@
# Qwen3-Next

```{note}
The Qwen3 Next is using [Triton Ascend](https://gitee.com/ascend/triton-ascend) which is currently experimental. In future versions, there may be behavioral changes related to stability, accuracy, and performance improvement.
```
## Introduction

The Qwen3-Next model is a sparse MoE (Mixture of Experts) model with high sparsity. Compared to the MoE architecture of Qwen3, it has introduced key improvements in aspects such as the hybrid attention mechanism and multi-token prediction mechanism, enhancing the training and inference efficiency of the model under long contexts and large total parameter scales.

This document will present the core verification steps of the model, including supported features, environment preparation, as well as accuracy and performance evaluation. Qwen3 Next is currently using Triton Ascend, which is in the experimental phase. In subsequent versions, its performance related to stability and accuracy may change, and performance will be continuously optimized.

The `Qwen3-Next` model is first supported in `vllm-ascend:v0.10.2rc1`.

## Supported Features

Refer to [supported features](../user_guide/support_matrix/supported_models.md) to get the model's supported feature matrix.

Refer to [feature guide](../user_guide/feature_guide/index.md) to get the feature's configuration.

## Run vllm-ascend on Multi-NPU with Qwen3 Next
## Weight Preparation

Run docker container:
Download Link for the `Qwen3-Next-80B-A3B-Instruct` Model Weights: [Download model weight](https://modelers.cn/models/Modelers_Park/Qwen3-Next-80B-A3B-Instruct/tree/main)

## Deployment
### Run docker container

```{code-block} bash
:substitutions:
@@ -32,12 +45,7 @@ docker run --rm \
-it $IMAGE bash
```

Set up environment variables:

```bash
# Load model from ModelScope to speed up download
export VLLM_USE_MODELSCOPE=True
```
The Qwen3 Next is using [Triton Ascend](https://gitee.com/ascend/triton-ascend) which is currently experimental. In future versions, there may be behavioral changes related to stability, accuracy, and performance improvement.

### Install Triton Ascend

@@ -46,7 +54,7 @@ export VLLM_USE_MODELSCOPE=True

The [Triton Ascend](https://gitee.com/ascend/triton-ascend) is required when you run Qwen3 Next, please follow the instructions below to install it and its dependency.

Install the Ascend BiSheng toolkit:
Source the Ascend BiSheng toolkit, execute the command:

```bash
source /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh
@@ -68,7 +76,7 @@ Coming soon ...
::::
:::::

### Inference on Multi-NPU
### Inference

Please make sure you have already executed the command:

@@ -84,7 +92,7 @@ Run the following script to start the vLLM server on multi-NPU:
For an Atlas A2 with 64 GB of NPU card memory, tensor-parallel-size should be at least 4, and for 32 GB of memory, tensor-parallel-size should be at least 8.

```bash
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 4096 --gpu-memory-utilization 0.7 --enforce-eager
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --tensor-parallel-size 4 --max-model-len 4096 --gpu-memory-utilization 0.85 --compilation-config '{"cudagraph_mode":"FULL_DECODE_ONLY"}'
```

Once your server is started, you can query the model with input prompts.
@@ -152,3 +160,41 @@ Prompt: 'Who are you?', Generated text: ' What do you know about me?\n\nHello! I

::::
:::::

## Accuracy Evaluation

### Using AISBench

1. Refer to [Using AISBench](../developer_guide/evaluation/using_ais_bench.md) for details.

2. After execution, you can get the result, here is the result of `Qwen3-Next-80B-A3B-Instruct` in `vllm-ascend:0.11.0rc3` for reference only.

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k | - | accuracy | gen | 96.3 |

## Performance

### Using AISBench

Refer to [Using AISBench for performance evaluation](../developer_guide/evaluation/using_ais_bench.md#execute-performance-evaluation) for details.

### Using vLLM Benchmark

Run performance evaluation of `Qwen3-Next` as an example.

Refer to [vllm benchmark](https://docs.vllm.ai/en/latest/contributing/benchmarks.html) for more details.

There are three `vllm bench` subcommand:
- `latency`: Benchmark the latency of a single batch of requests.
- `serve`: Benchmark the online serving throughput.
- `throughput`: Benchmark offline inference throughput.

Take the `serve` as an example. Run the code as follows.

```shell
export VLLM_USE_MODELSCOPE=true
vllm bench serve --model Qwen/Qwen3-Next-80B-A3B-Instruct --dataset-name random --random-input 200 --num-prompt 200 --request-rate 1 --save-result --result-dir ./
```

After about several minutes, you can get the performance evaluation result.

+ 2
- 2
docs/source/tutorials/Qwen3_embedding.md View File

@@ -40,7 +40,7 @@ export PYTORCH_NPU_ALLOC_CONF=max_split_size_mb:256
### Online Inference

```bash
vllm serve Qwen/Qwen3-Embedding-8B --task embed
vllm serve Qwen/Qwen3-Embedding-8B --runner pooling
```

Once your server is started, you can query the model with input prompts.
@@ -81,7 +81,7 @@ if __name__=="__main__":
input_texts = queries + documents

model = LLM(model="Qwen/Qwen3-Embedding-8B",
task="embed",
runner="pooling",
distributed_executor_backend="mp")

outputs = model.embed(input_texts)


+ 8
- 1
docs/source/user_guide/feature_guide/kv_pool.md View File

@@ -85,9 +85,16 @@ export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm
export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json"
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
export ACL_OP_INIT_MODE=1
export ASCEND_BUFFER_POOL=4:8
# ASCEND_BUFFER_POOL is the environment variable for configuring the number and size of buffer on NPU Device for aggregation and KV transfer,the value 4:8 means we allocate 4 buffers of size 8MB.
export ASCEND_BUFFER_POOL=4:8

# Unit: ms. The timeout for one-sided communication connection establishment is set to 10 seconds by default (see PR: https://github.com/kvcache-ai/Mooncake/pull/1039). Users can adjust this value based on their specific setup.
# The recommended formula is: ASCEND_CONNECT_TIMEOUT = connection_time_per_card (typically within 500ms) × total_number_of_Decode_cards.
# This ensures that even in the worst-case scenario—where all Decode cards simultaneously attempt to connect to the same Prefill card the connection will not time out.
export ASCEND_CONNECT_TIMEOUT=10000

# Unit: ms. The timeout for one-sided communication transfer is set to 10 seconds by default (see PR: https://github.com/kvcache-ai/Mooncake/pull/1039).
export ASCEND_TRANSFER_TIMEOUT=10000

python3 -m vllm.entrypoints.openai.api_server \


+ 2
- 1
docs/source/user_guide/release_notes.md View File

@@ -14,7 +14,7 @@ This is the first release candidate of v0.12.0 for vLLM Ascend. We landed lots o
- Lots of triton kernel are added. The performance of vLLM Ascend, especially Qwen3-Next and DeepSeek 3.2 is improved. Please note that triton is not installed and enabled by default, but we suggest to enable it in most case. You can download and install it by hand from [package url](https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl). If you're running vLLM Ascend with X86, you need to build triton ascend by yourself from [source](https://gitcode.com/Ascend/triton-ascend)
- Lots of Ascend ops are added to improve the performance. It means that from this release vLLM Ascend only works with custom ops built. So we removed the env `COMPILE_CUSTOM_KERNELS`. You can not set it to 0 now.
- speculative decode method `MTP` is more stable now. It can be enabled with most case and decode token number can be 1,2,3.
- speculative decode method `suffix` is supported now.
- speculative decode method `suffix` is supported now. Thanks for the contribution from China Merchants Bank.
- llm-comppressor quantization tool with W8A8 works now. You can now deploy the model with W8A8 quantization from this tool directly.
- W4A4 quantization works now.
- Support features flashcomm1 and flashcomm2 in paper [flashcomm](https://arxiv.org/pdf/2412.04964) [#3004](https://github.com/vllm-project/vllm-ascend/pull/3004) [#3334](https://github.com/vllm-project/vllm-ascend/pull/3334)
@@ -44,6 +44,7 @@ This is the first release candidate of v0.12.0 for vLLM Ascend. We landed lots o
- DeepSeek 3.2 doesn't work with chat template. It because that vLLM v0.12.0 doesn't support it. We'll support in the next v0.13.0rc1 version.
- DeepSeek 3.2 doesn't work with high concurrency in some case. We'll fix it in next release. [#4996](https://github.com/vllm-project/vllm-ascend/pull/4996)
- We notice that bf16/fp16 model doesn't perform well, it's mainly because that `VLLM_ASCEND_ENABLE_NZ` is enabled by default. Please set `VLLM_ASCEND_ENABLE_NZ=0` to disable it. We'll add the auto detection mechanism in next release.
- speculative decode method `suffix` doesn't work. We'll fix it in next release. You can pick this commit to fix the issue: [#4813](https://github.com/vllm-project/vllm-ascend/pull/4813)

## v0.11.0rc3 - 2025.12.03
This is the third release candidate of v0.11.0 for vLLM Ascend. For quality reasons, we released a new rc before the official release. Thanks for all your feedback. Please follow the [official doc](https://vllm-ascend.readthedocs.io/en/v0.11.0-dev) to get started.


+ 1
- 1
examples/offline_embed.py View File

@@ -44,7 +44,7 @@ def main():
]
input_texts = queries + documents

model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", runner="pooling")

outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])


+ 0
- 11
tests/e2e/models/configs/InternVL2-8B.yaml View File

@@ -1,11 +0,0 @@
model_name: "OpenGVLab/InternVL2-8B"
runner: "linux-aarch64-a2-1"
hardware: "Atlas A2 Series"
model: "vllm-vlm"
tasks:
- name: "mmmu_val"
metrics:
- name: "acc,none"
value: 0.58
max_model_len: 32768
trust_remote_code: True

+ 0
- 11
tests/e2e/models/configs/InternVL2_5-8B.yaml View File

@@ -1,11 +0,0 @@
model_name: "OpenGVLab/InternVL2_5-8B"
runner: "linux-aarch64-a2-1"
hardware: "Atlas A2 Series"
model: "vllm-vlm"
tasks:
- name: "mmmu_val"
metrics:
- name: "acc,none"
value: 0.58
max_model_len: 32768
trust_remote_code: True

+ 0
- 11
tests/e2e/models/configs/InternVL3-8B.yaml View File

@@ -1,11 +0,0 @@
model_name: "OpenGVLab/InternVL3-8B"
runner: "linux-aarch64-a2-1"
hardware: "Atlas A2 Series"
model: "vllm-vlm"
tasks:
- name: "mmmu_val"
metrics:
- name: "acc,none"
value: 0.58
max_model_len: 32768
trust_remote_code: True

tests/e2e/models/configs/Meta-Llama-3.1-8B-Instruct.yaml → tests/e2e/models/configs/Llama-3.2-3B-Instruct.yaml View File

@@ -1,11 +1,10 @@
model_name: "LLM-Research/Meta-Llama-3.1-8B-Instruct"
model_name: "LLM-Research/Llama-3.2-3B-Instruct"
hardware: "Atlas A2 Series"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.82
value: 0.71
- name: "exact_match,flexible-extract"
value: 0.84

value: 0.76
num_fewshot: 5

+ 11
- 0
tests/e2e/models/configs/Qwen3-Omni-30B-A3B-Instruct.yaml View File

@@ -0,0 +1,11 @@
model_name: "Qwen/Qwen3-Omni-30B-A3B-Instruct"
hardware: "Atlas A2 Series"
model: "vllm-vlm"
tasks:
- name: "mmmu_val"
metrics:
- name: "acc,none"
value: 0.52
max_model_len: 8192
tensor_parallel_size: 4
enable_expert_parallel: True

+ 4
- 6
tests/e2e/models/configs/accuracy.txt View File

@@ -5,13 +5,11 @@ Qwen2-Audio-7B-Instruct.yaml
Qwen3-VL-30B-A3B-Instruct.yaml
Qwen3-VL-8B-Instruct.yaml
Qwen2.5-Omni-7B.yaml
Meta-Llama-3.1-8B-Instruct.yaml
InternVL2-8B.yaml
InternVL2_5-8B.yaml
InternVL3-8B.yaml
InternVL3_5-8B.yaml
Qwen3-Omni-30B-A3B-Instruct.yaml
InternVL3_5-8B-hf.yaml
ERNIE-4.5-21B-A3B-PT.yaml
gemma-3-4b-it.yaml
internlm3-8b-instruct.yaml
Molmo-7B-D-0924.yaml
llava-1.5-7b-hf.yaml
llava-onevision-qwen2-0.5b-ov-hf.yaml
Llama-3.2-3B-Instruct.yaml

+ 1
- 0
tests/e2e/models/configs/gemma-3-4b-it.yaml View File

@@ -11,3 +11,4 @@ num_fewshot: 5
apply_chat_template: False
fewshot_as_multiturn: False
gpu_memory_utilization: 0.7
enforce_eager: True

tests/e2e/models/configs/llava-1.5-7b-hf.yaml → tests/e2e/models/configs/llava-onevision-qwen2-0.5b-ov-hf.yaml View File

@@ -1,11 +1,10 @@
model_name: "llava-hf/llava-1.5-7b-hf"
model_name: "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
hardware: "Atlas A2 Series"
model: "vllm-vlm"
tasks:
- name: "ceval-valid"
metrics:
- name: "acc,none"
value: 0.30
value: 0.42
trust_remote_code: True
gpu_memory_utilization: 0.8
dtype: "bfloat16"

+ 7
- 2
tests/e2e/multicard/test_qwen3_next.py View File

@@ -61,9 +61,14 @@ def test_qwen3_next_distributed_mp_full_decode_only_tp4():
del vllm_model


# TODO: Fix the accuary of batch chunked prefill
def test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4():
example_prompts = ["Hello, my name is"]
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

max_tokens = 20

with VllmRunner(


+ 60
- 4
tests/e2e/singlecard/compile/test_norm_quant_fusion.py View File

@@ -29,10 +29,10 @@ from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass


class TestModel(nn.Module):
class TestModelWithoutBias(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm → Quantization
AddRMSNorm → Quantization (without bias)
"""

def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
@@ -75,12 +75,65 @@ class TestModel(nn.Module):
return [torch.ops.npu.npu_add_rms_norm_quant.default]


class TestModelWithBias(nn.Module):
"""
A test model that simulates the pattern:
AddRMSNorm → Add Bias → Quantization (with bias)
"""

def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(
torch.randn(hidden_size, device=device))
self.bias = nn.Parameter(torch.randn(hidden_size, device=device))
self.quant_scale = torch.tensor([1.0], device=device)
self.quant_offset = torch.tensor([0.0], device=device)

def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Add bias
3. Quantize to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)

norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)

# Add bias
norm_output_with_bias = norm_output + self.bias

quantized_output = torch_npu.npu_quantize(norm_output_with_bias,
self.quant_scale,
self.quant_offset,
torch.qint8, -1, False)

return quantized_output, new_residual

def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops.aten.add.Tensor, # Add bias operation
torch.ops.npu.npu_quantize.default
]

def ops_in_model_after(self) -> List[OpOverload]:
"""Return the list of expected operators AFTER successful fusion."""
return [torch.ops.npu.npu_add_rms_norm_quant.default]


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("use_bias", [False, True])
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
num_tokens: int, eps: float):
num_tokens: int, eps: float, use_bias: bool):
"""
End-to-end test for AddRMSNorm+Quantize fusion.
Compares: Operator presence/absence before and after graph transformation
@@ -93,7 +146,10 @@ def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
with vllm.config.set_current_vllm_config(vllm_config):
backend = TestBackend(
custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
model = TestModel(hidden_size, eps, device="npu")
if use_bias:
model = TestModelWithBias(hidden_size, eps, device="npu")
else:
model = TestModelWithoutBias(hidden_size, eps, device="npu")
model = model.to("npu")

x = torch.rand(num_tokens,


+ 321
- 0
tests/ut/attention/test_attention_cp.py View File

@@ -0,0 +1,321 @@
from unittest.mock import MagicMock, patch

import torch
from vllm.distributed.parallel_state import GroupCoordinator

from tests.ut.base import TestBase
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl


class TestAscendAttentionCPImpl(TestBase):

@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group

mock_pcp.world_size = 2
mock_pcp.rank_in_group = 0
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 2
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group

self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0

self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"

self.attn_metadata = MagicMock()
self.attn_metadata.return_value = "1"

self.layer_no_quant = MagicMock(
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
self.layer_no_quant.layer_name = "test_layer"
self.layer_no_quant._k_scale_float = 1.0
self.layer_no_quant._v_scale_float = 1.0

self.impl = AscendAttentionCPImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)

def test_init(self):
self.assertEqual(self.impl.pcp_size, 2)
self.assertEqual(self.impl.pcp_rank, 0)
self.assertEqual(self.impl.dcp_size, 2)
self.assertEqual(self.impl.dcp_rank, 0)

def test_forward_prefill_cp(self):
query = torch.randn(2, 4, 128)
key = torch.randn(4, 1, 128)
value = torch.randn(4, 1, 128)

def mock_attention_with_nomask_and_mask(q, k_mask, **kwargs):
mock_output = torch.randn_like(q)
mock_lse = torch.randn_like(k_mask)
return mock_output, mock_lse

self.impl._attention_with_nomask_and_mask = MagicMock()
self.impl._attention_with_nomask_and_mask.side_effect = mock_attention_with_nomask_and_mask

attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_metadata.q_head_idx = torch.tensor([0])
attn_metadata.prefill.pcp_metadata.q_tail_idx = torch.tensor([1])
attn_metadata.prefill.pcp_metadata.q_full_idx = torch.tensor([0, 1])
attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx = torch.tensor(
[0])
attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx = torch.tensor(
[0])

output, attn_lse = self.impl._forward_prefill_cp(
query, key, value, attn_metadata)

self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)

@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP')
@patch("torch_npu.npu_fused_infer_attention_score")
@patch("torch.distributed.all_gather")
@patch("torch.distributed.all_to_all_single")
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
mock_all_to_all_single, mock_all_gather,
mock_npu_fused_infer_attention_score,
mock_dcp, mock_get_dcp_group):

def mock_dcp_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)

mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
dcp_group.all_gather = mock_dcp_all_gather_func
mock_get_dcp_group.return_value = dcp_group

query = torch.randn(2, 4, 128)
self.impl.key_cache = torch.randn(100, 128, 1, 128)
self.impl.value_cache = torch.randn(100, 128, 1, 128)

def mock_npu_attention_update(attn_out_lse_list):
mock_output = torch.randn(attn_out_lse_list[0].shape[0],
attn_out_lse_list[0].shape[1],
attn_out_lse_list[0].shape[2] - 1)
return mock_output

self.impl._npu_attention_update = MagicMock()
self.impl._npu_attention_update.side_effect = mock_npu_attention_update

mock_get_forward_context.return_value = MagicMock(capturing=False)

mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)

def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()

mock_all_gather.side_effect = mock_all_gather_func

def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
**common_kwargs):
mock_output = torch.randn_like(query)
mock_lse = torch.randn(query.shape[0], query.shape[1], 1)
return mock_output, mock_lse

mock_npu_fused_infer_attention_score.side_effect = mock_npu_fused_infer_attention_score_func

attn_metadata = MagicMock()
attn_metadata.decode_meta = MagicMock()
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
[1, 0], dtype=torch.bool)

output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)

self.assertEqual(output.shape[0], 2)
self.assertEqual(output.shape[1], 4)
self.assertEqual(output.shape[2], 128)

@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP')
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP')
def test_prefill_query_all_gather(self, mock_dcp, mock_get_dcp_group,
mock_pcp, mock_get_pcp_group):
query = torch.randn(2, 4, 128)

def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)

dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.all_gather = mock_all_gather_func
mock_get_dcp_group.return_value = dcp_group

pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = pcp_group

attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.chunked_context = MagicMock()
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk = torch.tensor(
[1, 2, 3, 0])
output = self.impl._prefill_query_all_gather(attn_metadata, query)

self.assertEqual(output.shape[0], 4)
self.assertEqual(output.shape[1], 8)
self.assertEqual(output.shape[2], 128)

@patch('torch.ops.npu.npu_fused_infer_attention_score')
def test_compute_prefill_context(self, mock_npu_attention):

block_num = 100
block_size = 128
kv_num_heads = 1
head_size = 128
kv_cache = (torch.randn(block_num, block_size, kv_num_heads,
head_size),
torch.randn(block_num, block_size, kv_num_heads,
head_size))

batch_size = 1024
self.impl.head_size = head_size
self.impl.num_heads = 4
num_heads = self.impl.num_heads * self.impl.dcp_size
query = torch.randn(batch_size, num_heads, head_size)

attn_metadata = MagicMock()
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.chunked_context = MagicMock()
attn_metadata.prefill.chunked_context.local_context_lens_allranks = torch.tensor(
[[[256, 256], [256, 256]]])
attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint(
0, 2, (1024, ), dtype=torch.bool)

def mock_load_kv_for_chunk(attn_metadata, kv_cache,
local_chunked_kv_lens_rank, query,
total_toks):
return torch.randn(total_toks, kv_num_heads,
head_size), torch.randn(total_toks,
kv_num_heads, head_size)

self.impl._load_kv_for_chunk = MagicMock()
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk

mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
head_size), torch.randn(
batch_size,
num_heads, 1)

result_output, result_lse = self.impl._compute_prefill_context(
query, kv_cache, attn_metadata)

self.assertEqual(result_output.shape[0], batch_size)
self.assertEqual(result_output.shape[1], self.impl.num_heads)
self.assertEqual(result_output.shape[2], head_size)
self.assertEqual(result_lse.shape[0], batch_size)
self.assertEqual(result_lse.shape[1], self.impl.num_heads)
self.assertEqual(result_lse.shape[2], 1)

@patch('torch_npu.atb.npu_paged_cache_load')
def test_load_kv_for_chunk(self, mock_npu_paged_cache_load):
block_num = 100
block_size = 128
num_heads = 1
head_size = 128

kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))
query = torch.randn(4, 8, 128)
total_toks = 256
local_chunked_kv_lens_rank = torch.randn(total_toks)

attn_metadata = MagicMock()

key, value = self.impl._load_kv_for_chunk(attn_metadata, kv_cache,
local_chunked_kv_lens_rank,
query, total_toks)

self.assertEqual(key.shape[0], total_toks)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)
self.assertEqual(value.shape[0], total_toks)
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)

@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP')
@patch('torch_npu._npu_reshape_and_cache')
def test_reshape_and_cache(self, mock_npu_reshape_and_cache, mock_pcp,
mock_get_pcp_group):
num_tokens = 4
block_num = 100
block_size = 128
num_heads = 1
head_size = 128
self.impl.head_size = head_size

kv_cache = (torch.randn(block_num, block_size, num_heads, head_size),
torch.randn(block_num, block_size, num_heads, head_size))

attn_metadata = MagicMock()
attn_metadata.num_decode_tokens = 1
attn_metadata.num_decodes = 1
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = torch.randn(2)
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
attn_metadata.prefill = MagicMock()
attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor(
[0, 3, 1, 2, 0, 0, 0, 0])

key = torch.randn(num_tokens, num_heads, head_size)
value = torch.randn(num_tokens, num_heads, head_size)

def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)

pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = pcp_group

key, value = self.impl.reshape_and_cache(key, value, kv_cache,
attn_metadata)
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(key.shape[1], num_heads)
self.assertEqual(key.shape[2], head_size)
self.assertEqual(value.shape[0], num_tokens * self.impl.pcp_size)
self.assertEqual(value.shape[1], num_heads)
self.assertEqual(value.shape[2], head_size)

+ 403
- 0
tests/ut/attention/test_mla_cp.py View File

@@ -0,0 +1,403 @@
from unittest.mock import MagicMock, patch

import torch
from vllm.distributed.parallel_state import GroupCoordinator

from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl


class TestAscendMLAImpl(TestBase):

@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
mock_tp.world_size = 2
mock_tp.rank_in_group = MagicMock()
mock_tp.device_group = MagicMock()
mock_dcp.world_size = 2
mock_dcp.rank_in_group = MagicMock()
mock_dcp.device_group = MagicMock()
mock_pcp.world_size = 2
mock_pcp.rank_in_group = MagicMock()
mock_pcp.device_group = MagicMock()
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
model_config.dtype = torch.float16
vllm_config.model_config = model_config
get_current_vllm_config.return_value = vllm_config
vllm_config.additional_config = {"refresh": True}
init_ascend_config(vllm_config)

num_heads = 256
head_size = 1024
scale = 0.1
num_kv_heads = 8
kv_cache_dtype = "auto"

kv_a_layernorm = MagicMock()
kv_a_layernorm.weight = torch.randn(96)
kv_a_layernorm.variance_epsilon = 1e-6
kwargs = {
"kv_lora_rank": 32,
"qk_nope_head_dim": 64,
"qk_rope_head_dim": 32,
"qk_head_dim": 96,
"v_head_dim": 128,
"q_lora_rank": 64,
"q_proj": MagicMock(),
"q_b_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
"fused_qkv_a_proj": MagicMock(),
"kv_a_layernorm": kv_a_layernorm,
"rotary_emb": MagicMock(),
}

self.impl = AscendMlaCPImpl(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None,
**kwargs)

def test_init(self):
self.assertEqual(self.impl.num_heads, 256)
self.assertEqual(self.impl.head_size, 1024)
self.assertEqual(self.impl.scale, 0.1)
self.assertEqual(self.impl.num_kv_heads, 8)
self.assertEqual(self.impl.kv_cache_dtype, "auto")
self.assertEqual(self.impl.kv_lora_rank, 32)
self.assertEqual(self.impl.qk_nope_head_dim, 64)
self.assertEqual(self.impl.qk_rope_head_dim, 32)
self.assertEqual(self.impl.qk_head_dim, 96)
self.assertEqual(self.impl.v_head_dim, 128)
self.assertIsNotNone(self.impl.q_proj)
self.assertIsNotNone(self.impl.kv_b_proj)
self.assertIsNotNone(self.impl.o_proj)
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.pcp_size, 2)
self.assertEqual(self.impl.dcp_size, 2)

@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
def test_mla_preprocess_dcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_dcp_group):

self.impl.num_kv_heads = 1
self.impl.num_heads = 16
self.impl.qk_rope_head_dim = 64
self.impl.kv_lora_rank = 512
self.impl.q_lora_rank = 1536
self.impl.dcp_size = 2
self.impl.pcp_size = 2
block_num = 10
block_size = 128
batch_size = 2
hidden_size = 1024
hidden_states = torch.randn(batch_size, hidden_size)

kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.kv_lora_rank)
kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.qk_rope_head_dim)
kv_cache = (kv_cache0, kv_cache1)

mock_dcp_group = MagicMock()

def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)

mock_dcp_group.all_gather = mock_all_gather_func
mock_get_dcp_group.return_value = mock_dcp_group

attn_metadata = MagicMock()
attn_metadata.num_decodes = 2
attn_metadata.num_prefills = 0
attn_metadata.num_prefill_tokens = 0
attn_metadata.num_decode_tokens = 2
attn_metadata.num_actual_tokens = 2
attn_metadata.slot_mapping = torch.arange(4)
attn_metadata.decode.cos = torch.randn(2, 64)
attn_metadata.decode.sin = torch.randn(2, 64)

self.impl.q_a_layernorm = MagicMock()
self.impl.q_a_layernorm.return_value = torch.randn(
attn_metadata.num_actual_tokens, self.impl.q_lora_rank)
self.impl.kv_a_proj_with_mqa = MagicMock()
self.impl.kv_a_proj_with_mqa.return_value = [
torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
]
self.impl.fused_qkv_a_proj = MagicMock()
self.impl.fused_qkv_a_proj.return_value = [
torch.randn(
attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim +
self.impl.kv_lora_rank + self.impl.q_lora_rank)
]

self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]

self.impl._q_proj_and_k_up_proj = MagicMock()
self.impl._q_proj_and_k_up_proj.return_value = [
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.qk_rope_head_dim)
]

magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x

decode_res, prefill_res = self.impl._mla_preprocess(
"mock_layer",
hidden_states,
kv_cache,
attn_metadata,
need_gather_q_kv=False)

self.assertIsNotNone(decode_res)
self.assertIsNone(prefill_res)

@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
def test_mla_preprocess_pcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_pcp_group,
mock_npu_reshape_and_cache):
self.impl.num_kv_heads = 1
self.impl.num_heads = 16
self.impl.qk_rope_head_dim = 64
self.impl.kv_lora_rank = 512
self.impl.q_lora_rank = 1536
self.impl.dcp_size = 2
self.impl.pcp_size = 2
block_num = 10
block_size = 128
batch_size = 2
hidden_size = 1024
hidden_states = torch.randn(batch_size, hidden_size)

kv_cache0 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.kv_lora_rank)
kv_cache1 = torch.randn(block_num, block_size, self.impl.num_kv_heads,
self.impl.qk_rope_head_dim)
kv_cache = (kv_cache0, kv_cache1)

mock_pcp_group = MagicMock()

def mock_all_gather_func(tensor, dim):
return torch.cat([tensor, tensor], dim=dim)

mock_pcp_group.all_gather = mock_all_gather_func
mock_get_pcp_group.return_value = mock_pcp_group

attn_metadata = MagicMock()
attn_metadata.num_decodes = 0
attn_metadata.num_prefills = 2
attn_metadata.num_prefill_tokens = 2
attn_metadata.num_decode_tokens = 0
attn_metadata.num_actual_tokens = 2
attn_metadata.num_actual_tokens_pcp_padded = 4
attn_metadata.prefill.pcp_metadata = MagicMock()
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.arange(
4)
attn_metadata.slot_mapping = torch.arange(4)
attn_metadata.prefill.cos = torch.randn(2, 64)
attn_metadata.prefill.sin = torch.randn(2, 64)

self.impl.q_a_layernorm = MagicMock()
self.impl.q_a_layernorm.return_value = torch.randn(
attn_metadata.num_actual_tokens, self.impl.q_lora_rank)
self.impl.kv_a_proj_with_mqa = MagicMock()
self.impl.kv_a_proj_with_mqa.return_value = [
torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
]
self.impl.fused_qkv_a_proj = MagicMock()
self.impl.fused_qkv_a_proj.return_value = [
torch.randn(
attn_metadata.num_actual_tokens, self.impl.qk_rope_head_dim +
self.impl.kv_lora_rank + self.impl.q_lora_rank)
]

self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]

self.impl._q_proj_and_k_up_proj = MagicMock()
self.impl._q_proj_and_k_up_proj.return_value = [
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(attn_metadata.num_decodes, self.impl.num_heads,
self.impl.qk_rope_head_dim)
]

magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x

self.impl.kv_a_layernorm = MagicMock()
self.impl.kv_a_layernorm.return_value = torch.randn(
attn_metadata.num_prefill_tokens, self.impl.num_kv_heads,
self.impl.kv_lora_rank)

self.impl.q_proj = MagicMock()
self.impl.q_proj.return_value = [
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
self.impl.qk_head_dim)
]
self.impl.kv_b_proj = MagicMock()
self.impl.kv_b_proj.return_value = [
torch.randn(attn_metadata.num_prefill_tokens * self.impl.pcp_size,
self.impl.num_heads,
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
]
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
self.impl.exec_kv_prefill = MagicMock()
self.impl.exec_kv_prefill.return_value = [
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim),
torch.randn(attn_metadata.num_prefill_tokens, self.impl.num_heads,
self.impl.kv_lora_rank)
]

decode_res, prefill_res = self.impl._mla_preprocess(
"mock_layer",
hidden_states,
kv_cache,
attn_metadata,
need_gather_q_kv=False)
self.assertIsNone(decode_res)
self.assertIsNotNone(prefill_res)

@patch("torch.distributed.all_gather")
@patch("torch.distributed.all_to_all_single")
def test_process_attn_out_lse(self, mock_all_to_all_single,
mock_all_gather):
self.impl.dcp_size = 2
self.impl.pcp_size = 2

B = 2
N = self.impl.num_heads
self.impl.kv_lora_rank = 512

attn_output = torch.randn(B, N, self.impl.kv_lora_rank)
softmax_lse = torch.randn(B, N, 1)

mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)

def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()

mock_all_gather.side_effect = mock_all_gather_func

decode_metadata = MagicMock()
decode_metadata.actual_seq_lengths_q = MagicMock()
decode_metadata.seq_lens_list = MagicMock()
decode_metadata.batch_seq_mask = torch.tensor([True, False],
dtype=torch.bool)

result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
decode_metadata)

self.assertEqual(result[0].shape[0], B)
self.assertEqual(result[0].shape[1], N / self.impl.dcp_size)
self.assertEqual(result[0].shape[2], self.impl.kv_lora_rank + 1)

@patch("torch.distributed.all_gather")
@patch("torch.distributed.all_to_all_single")
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch('torch_npu.npu_attention_update')
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention,
mock_get_forward_context,
mock_all_to_all_single, mock_all_gather):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
self.impl.num_kv_heads = 1
self.impl.num_heads = 16
self.impl.kv_lora_rank = 64
self.impl.qk_nope_head_dim = 64
self.impl.spec_token_num = 1
B = 2
N = self.impl.num_heads * self.impl.dcp_size
BS = 128
NB = 100

q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim)

attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
attn_metadata.decode = MagicMock()
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
attn_metadata.decode.seq_lens_list = MagicMock()
attn_metadata.decode.batch_seq_mask = torch.tensor([False, False],
dtype=torch.bool)

self.impl.enable_kv_nz = True

mock_npu_attention_update.return_value = (torch.randn(
B, self.impl.num_heads, self.impl.kv_lora_rank), None)
mock_npu_multi_head_latent_attention.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank),
torch.randn(B, N, 1)
]
mock_get_forward_context.return_value = MagicMock(capturing=False)

mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)

def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()

mock_all_gather.side_effect = mock_all_gather_func

self.impl._v_up_proj = MagicMock()
self.impl._v_up_proj.return_value = torch.randn(
B, self.impl.v_head_dim)

result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe,
BS, attn_metadata)

self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], self.impl.v_head_dim)

+ 45
- 10
tests/ut/attention/test_mla_v1.py View File

@@ -75,6 +75,12 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
padded_chunk_seq_lens_npu = torch.tensor([2, 2])
padded_local_chunk_seq_lens = [[2], [2]]
local_context_lens_allranks = [[1, 1], [1, 1]]
padded_local_cu_seq_lens = torch.tensor([0, 2, 4])
cu_seq_lens_lst = [[0, 2], [2, 4]]
chunk_size = 2

chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
@@ -83,7 +89,13 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens)
chunk_seq_lens_npu=chunk_seq_lens,
padded_chunk_seq_lens_npu=padded_chunk_seq_lens_npu,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens,
local_context_lens_allranks=local_context_lens_allranks,
padded_local_cu_seq_lens=padded_local_cu_seq_lens,
cu_seq_lens_lst=cu_seq_lens_lst,
chunk_size=chunk_size)

metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -106,6 +118,17 @@ class TestAscendMLAPrefillMetadata(TestBase):
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
self.assertIs(metadata.chunked_context.padded_chunk_seq_lens_npu,
padded_chunk_seq_lens_npu)
self.assertEqual(metadata.chunked_context.padded_local_chunk_seq_lens,
padded_local_chunk_seq_lens)
self.assertEqual(metadata.chunked_context.local_context_lens_allranks,
local_context_lens_allranks)
self.assertIs(metadata.chunked_context.padded_local_cu_seq_lens,
padded_local_cu_seq_lens)
self.assertEqual(metadata.chunked_context.cu_seq_lens_lst,
cu_seq_lens_lst)
self.assertEqual(metadata.chunked_context.chunk_size, chunk_size)


class TestAscendMLADecodeMetadata(TestBase):
@@ -117,10 +140,17 @@ class TestAscendMLADecodeMetadata(TestBase):
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None

metadata = AscendMLADecodeMetadata(input_positions, block_table,
seq_lens, max_seq_lens,
seq_lens_list, attn_mask)
cp_seq_len = torch.tensor([2, 3])
batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]])

metadata = AscendMLADecodeMetadata(input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
max_seq_lens=max_seq_lens,
seq_lens_list=seq_lens_list,
attn_mask=attn_mask,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)

self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
@@ -128,6 +158,8 @@ class TestAscendMLADecodeMetadata(TestBase):
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
self.assertIs(metadata.cp_seq_len, cp_seq_len)
self.assertIs(metadata.batch_seq_mask, batch_seq_mask)


class TestAscendMLAMetadata(TestBase):
@@ -200,17 +232,19 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'

mock_dcp.world_size = 1
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group

mock_pcp.world_size = 1
mock_pcp.world_size = 2
mock_pcp.rank_in_group = 0
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 1
pcp_group.world_size = 2
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group

@@ -227,6 +261,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
self.assertEqual(builder.dcp_size, mock_dcp.world_size)
self.assertEqual(builder.pcp_size, mock_pcp.world_size)

@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
@@ -912,7 +948,6 @@ class TestAscendMLAImpl(TestBase):
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)

def test_q_proj_and_k_up_proj(self):
batch_size = 4


+ 5
- 2
tests/ut/compilation/test_acl_graph.py View File

@@ -803,7 +803,9 @@ class TestPCPDCPGraphParams(TestBase):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
scale, num_kv_heads, out, lse))

update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
4)

_mock_graph_task_end.assert_called_once()

@@ -842,6 +844,7 @@ class TestPCPDCPGraphParams(TestBase):
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
out, lse, 2, 0, 0))

update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
with patch("torch_npu._C._npu_setStream", return_value=None):
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)

_mock_graph_task_end.assert_called_once()

+ 4
- 0
tests/ut/spec_decode/test_eagle_proposer.py View File

@@ -95,6 +95,8 @@ class TestEagleProposerLoadModel(TestBase):
mock_model = MagicMock()
mock_model.model.embed_tokens = MagicMock()
mock_model.lm_head = MagicMock()
mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None
mock_get_model.return_value = MagicMock()
self.proposer.name = SpecDcodeType.EAGLE

@@ -117,6 +119,8 @@ class TestEagleProposerLoadModel(TestBase):

mock_model = MagicMock()
original_embed = MagicMock()
mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None
mock_get_model.return_value = MagicMock(model=MagicMock(
embed_tokens=original_embed))



+ 0
- 2
tests/ut/spec_decode/test_mtp_proposer.py View File

@@ -238,7 +238,6 @@ class TestMtpProposer:
proposer.speculative_config = MagicMock(
disable_padded_drafter_batch=False)
proposer.pcp_size = mock_deps.runner.pcp_size
proposer._get_attn_metadata = MagicMock(return_value=MagicMock())
proposer.prepare_next_token_ids_padded = MagicMock(
return_value=(torch.tensor([101, 200, 302]), 3))
proposer.prepare_inputs_padded = MagicMock(
@@ -261,7 +260,6 @@ class TestMtpProposer:

proposer.prepare_next_token_ids_padded.assert_called_once()
proposer.prepare_inputs_padded.assert_called_once()
proposer._get_attn_metadata.assert_called_once()
proposer._propose.assert_called_once()
assert torch.equal(draft_token_ids, proposer._propose.return_value)



+ 0
- 375
tests/ut/worker/test_input_batch.py View File

@@ -1,375 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import inspect
from collections.abc import Sequence
from typing import Optional

import numpy as np
import pytest
import torch
from vllm.sampling_params import SamplingParams
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import CpuGpuBuffer

from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch

VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
MAX_NUM_PROMPT_TOKENS = 64


def _compare_objs(obj1,
obj2,
skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
if attr_name in skip:
continue

a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)

is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
is_same = True
elif isinstance(a, CpuGpuBuffer):
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"


def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns
set of request removed
"""

num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)

req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove


def _construct_expected_sampling_metadata(
reqs: list[CachedRequestState],
req_ids_retained: set[int],
req_id_index_in_input_batch: dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(num_reqs,
VOCAB_SIZE,
dtype=torch.bool,
device=device)
bad_words_token_ids = {}
for req in reqs:
if req.req_id not in req_ids_retained:
continue
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
output_token_ids[index_in_input_batch] = req.output_token_ids
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[index_in_input_batch] = (
req.sampling_params.frequency_penalty)
repetition_penalties[index_in_input_batch] = (
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
if req.sampling_params.bad_words_token_ids:
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids

return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(frequency_penalties,
dtype=torch.float,
device=device),
presence_penalties=torch.tensor(presence_penalties,
dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(repetition_penalties,
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessors(),
)


def _create_sampling_params():
return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)


def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
]
output_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_kwargs=[],
mm_positions=[],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
mm_hashes=None,
)


@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
"""
Tests the logic for managing sampling metadata in the InputBatch.

This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.

Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}

# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids

# Remove some requests
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove

# Compact the input batch
input_batch.condense()

# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()

# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
reqs,
req_ids_retained,
input_batch.req_id_to_index,
device=torch.device(device))

def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))

# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
)
assert torch.allclose(
expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)
assert expected_sampling_metadata.bad_words_token_ids == \
sampling_metadata.bad_words_token_ids


@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1), )])
def test_swap_states_in_input_batch(device: str, batch_size: int,
swap_list: list):
"""
Tests the logic for managing sampling metadata in the InputBatch.

This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.

Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)

reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids

reordered_reqs = reqs.copy()
for swap_pair in swap_list:
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
input_batch.swap_states(swap_pair[0], swap_pair[1])

for req_index in range(batch_size):
req = reordered_reqs[req_index]
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index

input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()

_compare_objs(input_batch, ref_input_batch)

+ 7
- 0
vllm_ascend/ascend_config.py View File

@@ -153,6 +153,13 @@ class AscendConfig:
raise NotImplementedError(
"This feature is still in the experiment and will be supported soon."
)
# We find that _npu_paged_attention still performs better than
# npu_fused_infer_attention_score in some cases. We allow to execute
# _npu_paged_attention in this cases. This should be removed once
# npu_fused_infer_attention_score performs better on all scenarios.
self.pa_shape_list = additional_config.get("pa_shape_list",
[1, 2, 3, 4])

kv_cfg = vllm_config.kv_transfer_config
if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched",
False):


+ 71
- 4
vllm_ascend/attention/attention_v1.py View File

@@ -34,7 +34,8 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
split_decodes_and_prefills,
using_paged_attention)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
@@ -488,6 +489,67 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.handles[num_tokens].append(handle)
return output, num_tokens

def full_graph_attention_with_pa(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
):
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))

# Handle graph capturing mode
stream = torch_npu.npu.current_stream()

event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))

torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output

def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: torch.Tensor):
@@ -701,9 +763,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = self._forward_prefill(query, key, value,
attn_metadata, output)
else:
attn_output, num_tokens = self.full_graph_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
num_tokens = query.shape[0]
if using_paged_attention(num_tokens):
output = self.full_graph_attention_with_pa(
query, attn_metadata, output)
else:
attn_output, num_tokens = self.full_graph_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]

return output



+ 1274
- 0
vllm_ascend/attention/mla_cp.py View File

@@ -0,0 +1,1274 @@
from typing import ClassVar, List, Optional, Tuple, TypeVar

import numpy as np
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec

from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata,
DecodeMLAPreprocessResult,
PrefillMLAPreprocessResult)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, reach_layer_for_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import weak_ref_tensors

MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024

M = TypeVar("M", bound=AscendMLAMetadata)


class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""

def __init__(self,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
metadata_cls)

self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
scheduler_config = vllm_config.scheduler_config
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs *
self.decode_threshold,
dtype=torch.uint8,
device=device)

def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata

if long_seq_metadata is None:
raise AssertionError("long_seq_metadata should not be None.")

num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens

# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device

# If graph_pad_size > -1, mean is running in fullgraph mode.
graph_pad_size = common_attn_metadata.graph_pad_size
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch:
block_table = (
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]

# NOTE: Currently, MTP-fullgraph is incompatibility pcp
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[:
num_actual_tokens_pcp_padded].long(
)

if self.cos_cache is None:
self.cos_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore

query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_lens = query_seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)

prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
pcp_metadata = AscendMLAPrefillMetadata.AscendPCPMetadata(
q_head_idx=long_seq_metadata.q_head_idx_tensor,
q_tail_idx=long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=long_seq_metadata.attn_mask_seqlens,
head_attn_nomask_seqlens=long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=long_seq_metadata.q_full_idx,
pcp_prefill_mask=long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=long_seq_metadata.
pcp_allgather_restore_idx)

reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]

context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)

assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)

local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
).reshape(-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
context_lens_cpu,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
padded_local_max_context_chunk_across_ranks = (cdiv(
max_context_chunk,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
local_chunk_starts = (
torch.arange(num_chunks,
dtype=torch.int32).unsqueeze(1).expand(
-1, num_prefills) *
padded_local_max_context_chunk_across_ranks)
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts +
padded_local_max_context_chunk_across_ranks,
)
padded_local_chunk_seq_lens = (local_chunk_ends -
local_chunk_starts).clamp(min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(
padded_local_chunk_seq_lens,
dim=1,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=local_chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(
),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.
tolist(),
local_context_lens_allranks=local_context_lens_allranks.
tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu
.pin_memory().to(device, non_blocking=True),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)

prefill_input_positions = input_positions[tokens_start:]
assert self.cos_cache is not None
assert self.sin_cache is not None
cos = self.cos_cache[prefill_input_positions].unsqueeze(
1).unsqueeze(2)
sin = self.sin_cache[prefill_input_positions].unsqueeze(
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[reqs_start:].to(torch.int32),
seq_lens=seq_lens,
context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
pcp_metadata=pcp_metadata,
)
if self.pcp_size > 1:
prefill_metadata.block_table = block_table[
num_decodes_flatten:, ...]

decode_metadata = None
if num_decodes > 0:
cos, sin = get_cos_and_sin()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
if self.pcp_size > 1:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
block_table = block_table[:num_decodes_flatten, ...]
else:
block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
block_table = block_table[:graph_pad_size, ...]
seq_lens_list = seq_lens.tolist()

# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:num_decodes *
self.decode_threshold]

cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)

if graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
num_reqs_pad_size = graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
num_decodes)
num_block_pad_size = graph_pad_size - block_table.shape[0]
if num_block_pad_size > 0:
block_table_padding = torch.zeros(
(num_block_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat(
[block_table, block_table_padding], dim=0)
else:
num_token_pad_size = graph_pad_size - num_decode_tokens
num_reqs_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
num_block_table_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req -
num_decodes)
seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size
slot_padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, slot_padding])
block_table_padding = torch.zeros(
(num_block_table_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
position_padding = torch.zeros(
num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
common_attn_metadata)

# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
if cos is None and sin is None:
cos = self.cos_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
1).unsqueeze(2)
sin[:num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(
1).unsqueeze(2)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)

return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
)


class AscendMlaCPImpl(AscendMLAImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
):
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **kwargs)

self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None

self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None

def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x

def _compute_prefill_context(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
rope_dim: int,
attn_metadata: AscendMLAMetadata,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
):
assert len(kv_c_and_k_pe_cache) > 1
prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse

iters = len(prefill_metadata.chunked_context.seq_tot)

current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
# chunk_seq_lens will be padded when pcp&dcp
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(toks,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)

torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)

cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
if self.dcp_size > 1:
cache_kv_c_k_pe = get_dcp_group().all_gather(
cache_kv_c_k_pe, 0)

if self.pcp_size > 1:
cache_kv_c_k_pe = get_pcp_group().all_gather(
cache_kv_c_k_pe, 0)

allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = self._reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
padded_local_chunk_seq_lens_lst=prefill_metadata.
chunked_context.padded_local_chunk_seq_lens[i],
local_context_lens_allranks=prefill_metadata.chunked_context.
local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
[-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)

kv_c_normed = kv_c_normed.squeeze()
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope \
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))

mask = attn_metadata.attn_mask
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
return prefix_output, prefix_lse

def forward(
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
if self.pcp_size > 1:
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
else:
num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
o_proj_input_shape = (get_forward_context().num_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)

# MLA Preprocess
forward_context = get_forward_context()
if (self.enable_mlapo and
(attn_metadata is None or not forward_context.with_prefill)):
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
hidden_states, kv_cache, attn_metadata)
else:
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
layer_name, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv)

if decode_preprocess_res is not None:
# MLA Preprocess for decoding
if self.pcp_size * self.dcp_size > 1:
output_decode = self._forward_decode_pcp_dcp(
decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata,
)
else:
output_decode = self._forward_decode(
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], attn_metadata)

o_proj_input[:num_decode_tokens] = output_decode

if prefill_preprocess_res is not None:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
if self.pcp_size > 1:
output_prefill = self._forward_prefill_cp(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
else:
output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)

o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)

output[...] = self.o_proj(o_proj_input,
is_prefill=(prefill_preprocess_res
is not None))[0]

del o_proj_input

has_prefill = attn_metadata.num_prefills > 0
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded

def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv):
# MLA Preprocess:
# 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
# or
# Perform kv_a_proj_with_mqa to obtain kv_no_split
# 2. If need_gather_q_kv, perform all_gather.
# 3. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 4. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if self.fused_qkv_a_proj is not None:
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
# allgather need contiguous data
kv_no_split = kv_no_split.contiguous()
else:
q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]

# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)

if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)

decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
# Preprocess for decode tokens
if has_decode:
decode_q_c = q_c[:num_decode_tokens]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_q_c)
if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_ql_nope, decode_q_pe],
dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_ql_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens *
self.pcp_size:self.
pcp_size]
decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode(
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
decode_preprocess_res = DecodeMLAPreprocessResult(
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
# Preprocess for prefill tokens
if has_prefill:
if self.pcp_size > 1:
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded
- self.pcp_size * num_decode_tokens
) // self.pcp_size + num_decode_tokens
prefill_kv_no_split = kv_no_split[
num_decode_tokens:num_actual_tokens]
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
prefill_q = self.q_proj(prefill_q_c)[0] \
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.pcp_size > 1:
cos = attn_metadata.prefill.cos[:num_actual_tokens -
num_decode_tokens]
sin = attn_metadata.prefill.sin[:num_actual_tokens -
num_decode_tokens]
else:
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.pcp_size > 1:
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view(
[num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1)
prefill_k_pe = k_pe
prefill_k_pe[
num_decode_tokens:num_actual_tokens] = self.rope_single(
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos,
sin)
prefill_k_c_normed = kv_c_normed[:num_actual_tokens]
prefill_kv_c_k_pe = torch.cat(
[prefill_k_c_normed, prefill_k_pe], dim=-1)
prefill_kv_c_k_pe = get_pcp_group().all_gather(
prefill_kv_c_k_pe, 0)
prefill_kv_c_k_pe = torch.index_select(
prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.
pcp_allgather_restore_idx)
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens *
self.pcp_size:]
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
prefill_k_c_normed = prefill_k_c_normed.squeeze()
slot_mapping = attn_metadata.slot_mapping[self.pcp_size *
num_decode_tokens:]
torch_npu._npu_reshape_and_cache(key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slot_mapping)
else:
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if not self.pcp_size > 1:
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
self.num_kv_heads, -1)
prefill_k_pe = prefill_k_pe.expand(
(*prefill_k_nope.shape[:-1], -1))
prefill_preprocess_res = PrefillMLAPreprocessResult(
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
prefill_value)
return decode_preprocess_res, prefill_preprocess_res

def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
bsz = attn_metadata.num_decode_tokens
hidden_states = hidden_states[:bsz]

cos_shape = attn_metadata.decode.cos.shape
cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1])
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])

decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
decode_q_nope = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0],
decode_k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
decode_q_pe = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0],
decode_k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
cos,
sin,
self.W_UK_T,
decode_k_nope,
decode_k_pe,
attn_metadata.slot_mapping[:bsz].flatten(),
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope,
q_out1=decode_q_pe,
kv_cache_out1=decode_k_pe,
enable_inner_out=False,
inner_out=torch.tensor([], device=hidden_states.device))
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)

if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_q_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

decode_preprocess_res = DecodeMLAPreprocessResult(
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
return decode_preprocess_res, None

def _forward_prefill_cp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
value: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
num_tokens = q_nope.size(0)
# Use precomputed indices from the metadata (already converted to tensors and on device)
q_head_idx = attn_metadata.prefill.pcp_metadata.q_head_idx
q_tail_idx = attn_metadata.prefill.pcp_metadata.q_tail_idx
kv_with_q_head_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_head_mask_idx,
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)

output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_tail_mask_idx,
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=mask)

q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
attn_output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1),
1, q_full_idx)

output, _ = self._compute_prefill_context(q_nope, q_pe,
kv_c_and_k_pe_cache,
self.qk_rope_head_dim,
attn_metadata, attn_output,
attn_lse)

output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])

return output

def _attention_with_mask_and_nomask(
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor,
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor,
mask: torch.Tensor):
attn_output = torch.empty(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
dtype=k_pe.dtype,
device=k_pe.device)
attn_lse = torch.empty(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=k_pe.device)
# mask
k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx)
value_mask = torch.index_select(value, 0, kv_mask_idx)
k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_mask,
k_rope=k_pe_mask,
value=value_mask,
mask=mask,
seqlen=attn_mask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)

# nomask
if kv_nomask_idx.shape[0] == 0:
return attn_output, attn_lse

k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_nomask,
k_rope=k_pe_nomask,
value=value_nomask,
mask=mask,
seqlen=attn_nomask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=attn_output,
prev_lse=attn_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
return attn_output, attn_lse

def _forward_decode_pcp_dcp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
block_size: int,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
if self.dcp_size > 1:
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads

k_nope = k_nope.view(-1, block_size, self.num_kv_heads,
self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads,
self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
seq_len = decode_meta.cp_seq_len

common_kwargs = {
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)

# Update out&lse
attn_out_lse_list = self._process_attn_out_lse(attn_output,
softmax_lse,
decode_meta)
attn_output = self._npu_attention_update(attn_out_lse_list)
return self._v_up_proj(attn_output)

def _npu_attention_update(
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
attn_out_split_cp = []
attn_lse_split_cp = []

for attn_out_lse in attn_out_lse_list:
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
attn_out_split_cp.append(attn_out_allgather)
attn_lse_split_cp.append(attn_lse_allgather)
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
attn_out_split_cp, 0)
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
self.kv_lora_rank)
return attn_out

def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse

def _process_attn_out_lse(
self,
attn_output: torch.Tensor,
softmax_lse: torch.Tensor,
decode_meta: AscendMLADecodeMetadata,
) -> List[torch.Tensor]:
attn_out_lse_list = []
out_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
lse_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(softmax_lse)
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)

softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if self.dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
if self.pcp_size > 1:
attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))

if self.pcp_size > 1:
# AllGather out&lse within PCP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp

return attn_out_lse_list

def _reorg_kvcache(
self,
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
padded_local_chunk_seq_lens_lst: list[int],
local_context_lens_allranks: list[list[int]],
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
e.g.
kv_c_normed in rank0 = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...]
kv_c_normed in rank1 = [T0_4, T0_5, pad, pad, T1_2, pad, ...]
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
T0_4, T0_5, pad, pad, T1_2, pad, ...]
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
T1_0, T1_1, T1_2, ...]
Args:
padded_local_chunk_seq_lens_lst: local chunk context lengths
under current CP rank.
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for padded_local_chunk_seq_len, local_context_lens in zip(
padded_local_chunk_seq_lens_lst, local_context_lens_allranks):
cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens):
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len = min(
max(0, local_context_len - chunk_idx * chunk_size),
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks +
src_token_idx +
local_chunk_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
local_chunk_len]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += local_chunk_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe

+ 87
- 720
vllm_ascend/attention/mla_v1.py View File

@@ -1,27 +1,24 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
Type, TypeVar)
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
TypeVar)

import numpy as np
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
from vllm.distributed import (get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group)
get_pcp_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec

from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
@@ -44,7 +41,7 @@ from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
flashcomm2_o_shared_enabled, is_enable_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import InputBatch
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -53,7 +50,6 @@ MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024


class AscendMLABackend(AttentionBackend):

accept_output_buffer: bool = True

@staticmethod
@@ -62,34 +58,26 @@ class AscendMLABackend(AttentionBackend):

@staticmethod
def get_builder_cls():
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder

@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
return num_blocks, block_size, num_kv_heads, head_size

@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
prefill_config = get_current_vllm_config().parallel_config
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
return AscendMlaCPImpl
return AscendMLAImpl


@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None


@dataclass
class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@@ -113,6 +101,21 @@ class AscendMLAPrefillMetadata:
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None

@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None

attn_mask: torch.Tensor
query_lens: torch.Tensor
seq_lens: list[int]
@@ -148,7 +151,6 @@ class AscendMLADecodeMetadata:
@dataclass
class AscendMLAMetadata:
"""Metadata for MLACommon.

NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
@@ -209,8 +211,8 @@ class AscendMLAMetadataBuilder:
"""

def __init__(self,
kv_cache_spec,
layer_names,
kv_cache_spec: MLAAttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
@@ -251,7 +253,7 @@ class AscendMLAMetadataBuilder:
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
@@ -278,7 +280,7 @@ class AscendMLAMetadataBuilder:
dtype=torch.uint8,
device=device)

def reorder_batch(self, input_batch: "InputBatch",
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
@@ -331,7 +333,7 @@ class AscendMLAMetadataBuilder:
actual_seq_lengths_q,
common_attn_metadata):
"""
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
in order to meet the requirement of npu_fused_infer_attention_score.

In Torchair scenario, the lengths of the queries must be padded to the same length.
@@ -339,18 +341,19 @@ class AscendMLAMetadataBuilder:

For example:
batch_size=36, num_reqs_pad_size=2, num_reqs=16
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
[2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].

However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request.
after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
"""
FIA_SEQ_LEN_LIMIT = 16
need_padding = num_reqs_pad_size != 0 and \
len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \
common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT
len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \
common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[
-1] > FIA_SEQ_LEN_LIMIT
if need_padding:
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
@@ -376,7 +379,7 @@ class AscendMLAMetadataBuilder:
Only use for acl full graph mode.
Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
the num of dimensions equal to the batch_size of main model.
For example:
batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
@@ -408,7 +411,6 @@ class AscendMLAMetadataBuilder:
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata

num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
@@ -428,13 +430,7 @@ class AscendMLAMetadataBuilder:
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]

if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens

@@ -465,30 +461,6 @@ class AscendMLAMetadataBuilder:
chunked_context_metadata = None
if num_prefills > 0:
pcp_metadata = None
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.
attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask
if long_seq_metadata else None,
pcp_allgather_restore_idx=long_seq_metadata.
pcp_allgather_restore_idx if long_seq_metadata else None)

reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
@@ -509,7 +481,7 @@ class AscendMLAMetadataBuilder:
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
@@ -522,82 +494,18 @@ class AscendMLAMetadataBuilder:
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)

if self.dcp_size * self.pcp_size > 1:
if num_computed_tokens_of_pcp_dcp is not None:
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
).reshape(-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
context_lens_cpu,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
padded_local_max_context_chunk_across_ranks = (cdiv(
max_context_chunk,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
local_chunk_starts = (
torch.arange(num_chunks,
dtype=torch.int32).unsqueeze(1).expand(
-1, num_prefills) *
padded_local_max_context_chunk_across_ranks)
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts +
padded_local_max_context_chunk_across_ranks,
)
padded_local_chunk_seq_lens = (local_chunk_ends -
local_chunk_starts).clamp(
min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(
padded_local_chunk_seq_lens,
dim=1,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
chunked_context_metadata = (
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=local_chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(
dim=1).tolist(),
starts=chunk_starts.pin_memory().to(device,
non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
.tolist(),
local_context_lens_allranks=local_context_lens_allranks
.tolist(),
padded_local_cu_seq_lens=
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = (
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(
dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
))
))
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
@@ -620,9 +528,6 @@ class AscendMLAMetadataBuilder:
cos=cos,
pcp_metadata=pcp_metadata,
)
if self.pcp_size > 1:
prefill_metadata.block_table = block_table[
num_decodes_flatten:, ...]

decode_metadata = None
if num_decodes > 0:
@@ -633,12 +538,7 @@ class AscendMLAMetadataBuilder:
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
if self.pcp_size > 1:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
block_table = block_table[:num_decodes_flatten, ...]
else:
block_table = block_table[:num_decodes, ...]
block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
@@ -646,31 +546,14 @@ class AscendMLAMetadataBuilder:
block_table = block_table[:graph_pad_size, ...]
seq_lens_list = seq_lens.tolist()

if num_computed_tokens_of_pcp_dcp is not None:
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:num_decodes *
self.decode_threshold]

cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.
shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
else:
cp_seq_len, batch_seq_mask = None, None
cp_seq_len, batch_seq_mask = None, None

if graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
num_reqs_pad_size = graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
seq_lens_list = seq_lens_list + [0] * (graph_pad_size - \
seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
num_decodes)
num_block_pad_size = graph_pad_size - block_table.shape[0]
if num_block_pad_size > 0:
@@ -833,7 +716,7 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
):
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -870,7 +753,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()

ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
@@ -881,47 +763,19 @@ class AscendMLAImpl(MLAAttentionImpl):
self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO

self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None

self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None

self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_group = get_tp_group(
).device_group if self.tp_size > 1 else None

def _v_up_proj(self, x):
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.dcp_size * self.pcp_size > 1:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x

# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.num_heads, self.qk_head_dim)\
q_nope, q_pe = self.q_proj(x)[0] \
.view(-1, self.num_heads, self.qk_head_dim) \
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

# Convert from (B, N, P) to (N, B, P)
@@ -1137,10 +991,6 @@ class AscendMLAImpl(MLAAttentionImpl):
dtype=q_nope.dtype,
device=q_nope.device)

if self.dcp_size * self.pcp_size > 1:
context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
i]

torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
@@ -1151,38 +1001,10 @@ class AscendMLAImpl(MLAAttentionImpl):
value=k_pe,
)

cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
if self.dcp_size > 1:
cache_kv_c_k_pe = get_dcp_group().all_gather(
cache_kv_c_k_pe, 0)

if self.pcp_size > 1:
cache_kv_c_k_pe = get_pcp_group().all_gather(
cache_kv_c_k_pe, 0)

if self.dcp_size * self.pcp_size > 1:
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = self._reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
padded_local_chunk_seq_lens_lst=prefill_metadata.
chunked_context.padded_local_chunk_seq_lens[i],
local_context_lens_allranks=prefill_metadata.
chunked_context.local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.
cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.
max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)

kv_c_normed = kv_c_normed.squeeze()
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
k_nope, v = kv_nope \
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))

@@ -1248,8 +1070,9 @@ class AscendMLAImpl(MLAAttentionImpl):
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
attn_output, attn_lse = self._compute_prefill_context(
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim,
attn_metadata, attn_output, attn_lse)

attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
@@ -1488,13 +1311,6 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)

if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_q_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

decode_preprocess_res = DecodeMLAPreprocessResult(
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
return decode_preprocess_res, None
@@ -1551,17 +1367,8 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_q_c)
if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_ql_nope, decode_q_pe],
dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_ql_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens *
self.pcp_size:self.
pcp_size]
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1]
decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode(
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
@@ -1569,76 +1376,27 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
# Preprocess for prefill tokens
if has_prefill:
if self.pcp_size > 1:
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded
- self.pcp_size * num_decode_tokens
) // self.pcp_size + num_decode_tokens
prefill_kv_no_split = kv_no_split[
num_decode_tokens:num_actual_tokens]
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
prefill_q = self.q_proj(prefill_q_c)[0]\
prefill_q = self.q_proj(prefill_q_c)[0] \
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
if self.pcp_size > 1:
cos = attn_metadata.prefill.cos[:num_actual_tokens -
num_decode_tokens]
sin = attn_metadata.prefill.sin[:num_actual_tokens -
num_decode_tokens]
else:
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.pcp_size > 1:
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
kv_c, k_pe = prefill_kv_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view(
[num_actual_tokens, self.num_kv_heads, -1])
k_pe = k_pe.unsqueeze(1)
prefill_k_pe = k_pe
prefill_k_pe[
num_decode_tokens:num_actual_tokens] = self.rope_single(
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos,
sin)
prefill_k_c_normed = kv_c_normed[:num_actual_tokens]
prefill_kv_c_k_pe = torch.cat(
[prefill_k_c_normed, prefill_k_pe], dim=-1)
prefill_kv_c_k_pe = get_pcp_group().all_gather(
prefill_kv_c_k_pe, 0)
prefill_kv_c_k_pe = torch.index_select(
prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.
pcp_allgather_restore_idx)
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens *
self.pcp_size:]
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe
prefill_k_c_normed = prefill_k_c_normed.squeeze()
slot_mapping = attn_metadata.slot_mapping[self.pcp_size *
num_decode_tokens:]
torch_npu._npu_reshape_and_cache(key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=slot_mapping)
else:
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if not self.pcp_size > 1:
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
self.num_kv_heads, -1)
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
self.num_kv_heads, -1)
prefill_k_pe = prefill_k_pe.expand(
(*prefill_k_nope.shape[:-1], -1))
prefill_preprocess_res = PrefillMLAPreprocessResult(
@@ -1662,13 +1420,10 @@ class AscendMLAImpl(MLAAttentionImpl):
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
if self.pcp_size > 1:
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
else:
num_actual_tokens = attn_metadata.num_actual_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
@@ -1693,20 +1448,12 @@ class AscendMLAImpl(MLAAttentionImpl):

if decode_preprocess_res is not None:
# MLA Preprocess for decoding
if self.pcp_size * self.dcp_size > 1:
output_decode = self._forward_decode_pcp_dcp(
decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata,
)
else:
output_decode = self._forward_decode(
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], attn_metadata)
output_decode = self._forward_decode(decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata)

o_proj_input[:num_decode_tokens] = output_decode

@@ -1714,16 +1461,10 @@ class AscendMLAImpl(MLAAttentionImpl):
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
if self.pcp_size > 1:
output_prefill = self._forward_prefill_cp(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
else:
output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)

o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
@@ -1743,377 +1484,3 @@ class AscendMLAImpl(MLAAttentionImpl):
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded

def _forward_prefill_cp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
value: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
num_tokens = q_nope.size(0)
# Use precomputed indices from the metadata (already converted to tensors and on device)
q_head_idx = attn_metadata.prefill.pcp_metadata.q_head_idx
q_tail_idx = attn_metadata.prefill.pcp_metadata.q_tail_idx
kv_with_q_head_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_head_mask_idx,
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)

output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
k_nope=k_nope,
k_pe=k_pe,
value=value,
kv_mask_idx=kv_with_q_tail_mask_idx,
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=mask)

q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
attn_output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1),
1, q_full_idx)

output, _ = self._compute_prefill_context( \
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)

output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])

return output

def _attention_with_mask_and_nomask(
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor,
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor,
mask: torch.Tensor):
attn_output = torch.empty(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
dtype=k_pe.dtype,
device=k_pe.device)
attn_lse = torch.empty(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=k_pe.device)
# mask
k_nope_mask = torch.index_select(k_nope, 0, kv_mask_idx)
value_mask = torch.index_select(value, 0, kv_mask_idx)
k_pe_mask = torch.index_select(k_pe, 0, kv_mask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_mask,
k_rope=k_pe_mask,
value=value_mask,
mask=mask,
seqlen=attn_mask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)

# nomask
if kv_nomask_idx.shape[0] == 0:
return attn_output, attn_lse

k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope_nomask,
k_rope=k_pe_nomask,
value=value_nomask,
mask=mask,
seqlen=attn_nomask_seqlens,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=attn_output,
prev_lse=attn_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
return attn_output, attn_lse

def _forward_decode_pcp_dcp(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
block_size: int,
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
num_tokens = q_nope.size(0)
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
if self.dcp_size > 1:
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads

k_nope = k_nope.view(-1, block_size, self.num_kv_heads,
self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads,
self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
seq_len = decode_meta.cp_seq_len

common_kwargs = {
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)

# Update out&lse
attn_out_lse_list = self._process_attn_out_lse(attn_output,
softmax_lse,
decode_meta)
attn_output = self._npu_attention_update(attn_out_lse_list)
return self._v_up_proj(attn_output)

def _npu_attention_update(
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
attn_out_split_cp = []
attn_lse_split_cp = []

for attn_out_lse in attn_out_lse_list:
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
attn_out_split_cp.append(attn_out_allgather)
attn_lse_split_cp.append(attn_lse_allgather)
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
attn_out_split_cp, 0)
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
self.kv_lora_rank)
return attn_out

def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse

def _process_attn_out_lse(
self,
attn_output: torch.Tensor,
softmax_lse: torch.Tensor,
decode_meta: AscendMLADecodeMetadata,
) -> List[torch.Tensor]:
attn_out_lse_list = []
out_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
lse_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(softmax_lse)
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)

softmax_lse = softmax_lse.to(torch.float32)
attn_output = attn_output.to(torch.float32)
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if self.dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
if self.pcp_size > 1:
attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))

if self.pcp_size > 1:
# AllGather out&lse within PCP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp

return attn_out_lse_list

def _reorg_kvcache(
self,
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
padded_local_chunk_seq_lens_lst: list[int],
local_context_lens_allranks: list[list[int]],
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
e.g.
kv_c_normed in rank0 = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...]
kv_c_normed in rank1 = [T0_4, T0_5, pad, pad, T1_2, pad, ...]
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
T0_4, T0_5, pad, pad, T1_2, pad, ...]
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
T1_0, T1_1, T1_2, ...]
Args:
padded_local_chunk_seq_lens_lst: local chunk context lengths
under current CP rank.
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for padded_local_chunk_seq_len, local_context_lens in zip(
padded_local_chunk_seq_lens_lst, local_context_lens_allranks):
cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens):
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len = min(
max(0, local_context_len - chunk_idx * chunk_size),
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks +
src_token_idx +
local_chunk_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
local_chunk_len]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += local_chunk_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe

+ 2
- 2
vllm_ascend/attention/sfa_v1.py View File

@@ -32,7 +32,7 @@ from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
_round_up, dispose_layer, enable_sp,
is_enable_nz, replace_layer)
from vllm_ascend.worker.npu_input_batch import InputBatch
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -149,7 +149,7 @@ class AscendSFAMetadataBuilder:
self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_config, "index_topk")

def reorder_batch(self, input_batch: "InputBatch",
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA
return False


+ 16
- 0
vllm_ascend/attention/utils.py View File

@@ -1,13 +1,29 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context

from vllm_ascend.utils import get_ascend_config


@lru_cache
def using_paged_attention(runtime_shape: int) -> bool:
vllm_config = get_current_vllm_config()
if vllm_config.speculative_config is not None:
return False
from vllm.config.compilation import CUDAGraphMode
if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
return False

return runtime_shape in get_ascend_config().pa_shape_list


@dataclass
# class AscendCommonLongSequenceMetadata:


+ 69
- 2
vllm_ascend/compilation/acl_graph.py View File

@@ -19,6 +19,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.platforms import current_platform

from vllm_ascend.attention.utils import using_paged_attention

from ..utils import weak_ref_tensors


@@ -193,7 +195,65 @@ class ACLGraphWrapper:
return entry.output


def update_attn_params(update_stream, forward_context, runtime_shape):
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens

# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
# might encounter a bigger workspace, while currently we use max_model_len to
# calculate max workspace in capturing. So additional get_workspace is added
# here to avoid such bugs.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
torch.npu.graph_task_update_end(update_stream)

event.record(update_stream)


def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
@@ -236,6 +296,13 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
event.record(update_stream)


def update_attn_params(update_stream, forward_context, runtime_shape):
if using_paged_attention(runtime_shape):
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape)


def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config):
if forward_context.is_mtp_model:
@@ -446,7 +513,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
)


def update_graph_params_workspaces(num_tokens: int, workspace: int):
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)


+ 60
- 0
vllm_ascend/compilation/passes/norm_quant_fusion_pass.py View File

@@ -79,6 +79,64 @@ class AddRMSNormQuantPattern:
pm.fwd_only, pm_pass)


class AddRMSNormQuantPatternWithBias:

def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.eps = eps

def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.tensor([1.0], device="npu")
offset = torch.tensor([0.0], device="npu")
bias = torch.randn(4, device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset, bias]

def register(self, pm_pass: PatternMatcherPass):

def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(
out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1

def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
1. /
scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
offset,
epsilon=self.eps,
beta=bias)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1

pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)


class AddRMSNormQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
@@ -99,6 +157,8 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)

def __call__(self, graph: torch.fx.Graph):
self.begin()


+ 2
- 5
vllm_ascend/distributed/kvpool/backend/memcache_backend.py View File

@@ -19,7 +19,7 @@ class MemcacheBackend(Backend):

def __init__(self, parallel_config: ParallelConfig):
try:
from memcache import DistributedObjectStore # type: ignore
from memcache_hybrid import DistributedObjectStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install memcache by following the instructions at "
@@ -43,10 +43,7 @@ class MemcacheBackend(Backend):
torch.npu.set_device(device)

def register_buffer(self, ptrs: list[int], sizes: list[int]):
for ptr, size in zip(ptrs, sizes):
ret_value = self.store.register_buffer(ptr, size)
if ret_value != 0:
raise RuntimeError("Memcache memory registration failed.")
pass

def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)


+ 1
- 1
vllm_ascend/distributed/kvpool/config_data.py View File

@@ -374,4 +374,4 @@ class LasyerMultiBlockReqMeta:
ends: list[int]
block_ids: list[int]
layer_id: int
is_last_chunk: bool = True
is_last_chunk: Optional[bool] = True

+ 148
- 78
vllm_ascend/distributed/kvpool/kv_transfer.py View File

@@ -1,29 +1,31 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
from typing import Any

import torch
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash

from vllm_ascend.distributed.kvpool.backend.backend import Backend

# isort: off
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
LasyerMultiBlockReqMeta
)
from vllm_ascend.distributed.kvpool.config_data import (
ChunkedTokenDatabase,
LasyerMultiBlockReqMeta,
ReqMeta,
)
# isort: on


class KVTransferThread(threading.Thread):

def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event,
name: str):
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.m_store = m_store
self.ready_event = ready_event
self.block_size = block_size
self.tp_rank = tp_rank
self.dcp_size = dcp_size
self.token_database = token_database
@@ -35,22 +37,9 @@ class KVTransferThread(threading.Thread):

def add_request(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
is_last_chunk: Optional[bool] = None,
request: ReqMeta,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"token_len": token_len,
"block_ids": block_ids,
"block_hashes": block_hashes,
"mask_num": mask_num,
"is_last_chunk": is_last_chunk,
})
self.request_queue.put(req)
self.request_queue.put(request)

def get_and_clear_finished_requests(self) -> set[str]:
"""
@@ -82,50 +71,98 @@ class KVTransferThread(threading.Thread):
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")

def _handle_request(self, req_meta: dict[str, Any]):
def _handle_request(self, req_meta: Any):
pass

def lookup(
self,
keys: list[str],
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
try:
res = self.m_store.exists(keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return index
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return 0
return len(keys)


class KVCacheStoreSendingThread(KVTransferThread):

def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, put_step: int,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheSendingThread")
self.put_step = put_step

def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
block_hashes = req_meta["block_hashes"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
addr_list = []
size_list = []
key_list = []
def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.token_len_chunk
block_ids = req_meta.block_ids
req_id = req_meta.req_id
is_last_chunk = req_meta.is_last_chunk
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
token_len, req_meta.block_hashes):
starts.append(start)
ends.append(end)
keys.append(key.to_string())

if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]

if not keys:
if is_last_chunk:
self.set_finished_request(req_id)
return

skip_block_num = self.lookup(keys)

if skip_block_num == len(keys):
if is_last_chunk:
self.set_finished_request(req_id)
return

starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
keys = keys[skip_block_num:]

logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
token_len // self.block_size,
skip_block_num,
req_id,
)

addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
if self.dcp_size > 1:
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank %
self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if keys:
self.m_store.put(keys, addrs, sizes)

if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
@@ -134,27 +171,28 @@ class KVCacheStoreSendingThread(KVTransferThread):
class KVCacheStoreRecvingThread(KVTransferThread):

def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event):
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreRecvingThread")

def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
block_hashes = req_meta["block_hashes"]
def _handle_request(self, req_meta: ReqMeta):
req_id = req_meta.req_id
mask_num = (
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
req_meta.token_len_chunk, req_meta.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
start, end, req_meta.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
@@ -175,10 +213,11 @@ class KVCacheStoreRecvingThread(KVTransferThread):
class KVCacheStoreLayerSendingThread(KVTransferThread):

def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, put_step: int,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event, num_layers: int):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
@@ -187,43 +226,74 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.put_step = put_step

def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self, req_meta: ReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)

def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
starts = req_meta.starts
ends = req_meta.ends
keys = req_meta.keys
layer_id = req_meta.layer_id
total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]

if not keys:
if is_last_chunk:
self.set_finished_request(req_meta.req_id)
return

key_list = []
for key in keys:
key_list.append(key.to_string())

skip_block_num = self.lookup(key_list)

if skip_block_num == len(key_list):
if is_last_chunk and layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
return

starts = starts[skip_block_num:]
ends = ends[skip_block_num:]
key_list = key_list[skip_block_num:]

addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
for index, key in enumerate(key_list):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
starts[index], ends[index], req_meta.block_ids, layer_id)
addr_list.append(addr)
size_list.append(size)
if self.dcp_size > 1:
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank %
self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:

self.m_store.put(key_list, addr_list, size_list)

if layer_id == self.final_layer_id and is_last_chunk:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()

logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
len(keys),
total_block,
skip_block_num,
req_meta.req_id,
)


class KVCacheStoreLayerRecvingThread(KVTransferThread):

def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, dcp_size: int, ready_event: threading.Event,
get_event: threading.Event):
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, get_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,


+ 3
- 3
vllm_ascend/distributed/kvpool/pool_scheduler.py View File

@@ -310,8 +310,8 @@ class LookupKeyClient:
self.socket.close(linger=0)


def get_zmq_rpc_path_lookup(
vllm_config: Optional["VllmConfig"] = None, ) -> str:
def get_zmq_rpc_path_lookup(vllm_config: "VllmConfig") -> str:
dp_rank = vllm_config.parallel_config.data_parallel_rank
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
@@ -325,4 +325,4 @@ def get_zmq_rpc_path_lookup(
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}"
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}_dp_rank{dp_rank}"

+ 36
- 113
vllm_ascend/distributed/kvpool/pool_worker.py View File

@@ -18,7 +18,7 @@ from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
MooncakeBackend
from vllm_ascend.distributed.kvpool.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
LasyerMultiBlockReqMeta)
LasyerMultiBlockReqMeta, ReqMeta)
from vllm_ascend.distributed.kvpool.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
@@ -165,28 +165,29 @@ class KVPoolWorker:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.dcp_size, self.put_step, ready_event_sending,
self.num_layers)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
ready_event_sending, self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.tp_rank, self.dcp_size,
ready_event, self.get_event)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event, self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.tp_rank,
self.dcp_size, self.put_step, ready_event_sending)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.tp_rank,
self.dcp_size, ready_event)
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()

@@ -198,38 +199,27 @@ class KVPoolWorker:
if load_spec is None or not load_spec.can_load: #load =0
continue
token_len = request.token_len_chunk
req_id = request.req_id
if (load_spec.kvpool_cached_tokens % self.block_size
!= 0) and (load_spec.kvpool_cached_tokens
== token_len - 1):
token_len = request.load_spec.kvpool_cached_tokens + 1
else:
token_len = request.load_spec.kvpool_cached_tokens
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
request.token_len_chunk = token_len
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
layerwise_retriever = self.retrieve_layer(request)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
)
request, )
else:
addr_list = []
size_list = []
key_list = []
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
@@ -266,40 +256,7 @@ class KVPoolWorker:
if can_save is None or not can_save:
continue

token_len = request.token_len_chunk
req_id = request.req_id

# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = self.lookup(token_len,
request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request

mask_num = (skip_leading_tokens // self.block_size *
self.block_size)

logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)

layerwise_storer = self.store_layer(
req_id,
token_len,
block_hashes=request.block_hashes,
mask_num=mask_num,
block_ids=request.block_ids,
is_last_chunk=request.is_last_chunk,
)
layerwise_storer = self.store_layer(request)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
@@ -314,45 +271,12 @@ class KVPoolWorker:
if can_save is None or not can_save:
continue

token_len = request.token_len_chunk
req_id = request.req_id

skip_leading_tokens = self.lookup(token_len, request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request

mask_num = (skip_leading_tokens // self.block_size *
self.block_size)

logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)

self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id,
token_len,
request.block_ids,
request.block_hashes,
mask_num,
request.is_last_chunk,
)
request, )

def retrieve_layer(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
request: ReqMeta,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
@@ -370,6 +294,10 @@ class KVPoolWorker:
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
token_len = request.token_len_chunk
mask_num = (
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
num_required_tokens = token_len - mask_num

ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
@@ -379,7 +307,7 @@ class KVPoolWorker:
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
token_len, request.block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
@@ -395,8 +323,9 @@ class KVPoolWorker:
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
@@ -417,12 +346,7 @@ class KVPoolWorker:

def store_layer(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
is_last_chunk: bool,
mask_num: int = 0,
request: ReqMeta,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
@@ -444,13 +368,11 @@ class KVPoolWorker:
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
num_stored_tokens = token_len - mask_num

starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
request.token_len_chunk, request.block_hashes):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
@@ -459,17 +381,17 @@ class KVPoolWorker:
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id, is_last_chunk)
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id,
request.is_last_chunk)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
logger.debug(
f"Stored {num_stored_tokens} out of total {token_len} tokens")

def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
@@ -572,7 +494,8 @@ class KVPoolWorker:
num_block = len(keys) // self.num_layers
multi_tp_values = [
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
for i in range(min(self.tp_size, self.num_kv_head))
for i in range(
min(self.tp_size, self.num_kv_head) * self.pp_size)
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:


+ 4
- 3
vllm_ascend/distributed/parallel_state.py View File

@@ -182,9 +182,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
global _SHARED_WEIGHT
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
if enable_sp() and is_ds_v32:
if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")

# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable():
flashcomm2_otp_size = get_ascend_config(
@@ -240,7 +239,9 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
# Create shared weight group for flashcomm2 oproj
if flashcomm2_o_shared_enabled():
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
if _SHARED_WEIGHT is None:
_SHARED_WEIGHT = _create_shared_weight_group(
"flashcomm2_o_shared")

if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X


+ 1
- 1
vllm_ascend/eplb/utils.py View File

@@ -24,7 +24,7 @@ _MOE_LOAD_ASYNC_STREAM = None


def get_expert_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_map()
return self.model.layers[layer_id].mlp.experts.expert_map


def get_log2phy_map(self, layer_id):


+ 16
- 10
vllm_ascend/ops/fused_moe/fused_moe.py View File

@@ -153,7 +153,7 @@ class AscendFusedMoE(FusedMoE):
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter

self.expert_map = None
self._expert_map = None
self.log2phy = None

if self.quant_config is None:
@@ -184,7 +184,7 @@ class AscendFusedMoE(FusedMoE):
dtype=vllm_config.model_config.dtype)

# init moe.
self.local_num_experts, self.expert_map, _ = determine_expert_map(
self.local_num_experts, self._expert_map, _ = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
# TODO: Temporary flag to indicate if static EPLB is enabled. This is a
# workaround to bypass a quantization check that fails with float weights.
@@ -200,7 +200,7 @@ class AscendFusedMoE(FusedMoE):
self.expert_load_balancer.get_global_redundant_expert_num())
self.global_num_experts = num_experts + self.global_redundant_expert_num
try:
self.local_num_experts, self.expert_map = (
self.local_num_experts, self._expert_map = (
self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
@@ -216,16 +216,16 @@ class AscendFusedMoE(FusedMoE):
if self.dynamic_eplb:
self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank).npu()
if self.expert_map is not None and isinstance(self.expert_map,
torch.Tensor):
if self._expert_map is not None and isinstance(self._expert_map,
torch.Tensor):
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
" number of experts: %s/%s. Experts local to global index map:"
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self.expert_map))
get_compressed_expert_map(self._expert_map))
local_num_experts = (torch.sum(
self.expert_map != -1) if self.expert_map is not None else
self._expert_map != -1) if self._expert_map is not None else
self.global_num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts,
@@ -276,10 +276,16 @@ class AscendFusedMoE(FusedMoE):
return QuantType.NONE

def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
self._expert_map = new_expert_map

def get_map(self):
return self.expert_map
@property
def expert_map(self) -> torch.Tensor | None:
return self._expert_map

@expert_map.setter
def expert_map(self, new_expert_map):
# TODO(Potabk): Remove this once we drop vllm v0.12.0(This makes backward compatibility with vllm v0.12.0)
self._expert_map = new_expert_map

def get_log2phy_map(self):
return self.log2phy


+ 8
- 4
vllm_ascend/patch/__init__.py View File

@@ -237,7 +237,7 @@
# Replace with a new bind_kv_cache.
# Skip the raise.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# It need discuss.
# Future Plan:
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
#
@@ -245,11 +245,15 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
# Why:
# 'torch.argsort' func of npu does not support bool.
# 1. 'torch.argsort' func of npu does not support bool.
# 2. Without `stable=True`, the output will have a lot of redundant tokens.
# How:
# Replace with a new torch.argsort that will cast the input to torch.int32.
# Replace with a new torch.argsort that will cast the input to torch.int32
# and do stable sort.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/4770
# 1. It depends on torch_npu.
# 2. https://github.com/vllm-project/vllm/pull/30632
# Future Plan:
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
#

+ 6
- 1
vllm_ascend/patch/platform/__init__.py View File

@@ -17,10 +17,15 @@
import os

import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa
from vllm_ascend.utils import vllm_version_is

if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv(
"EXPERT_MAP_RECORD", "false") == "true":
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa

if vllm_version_is("0.12.0"):
import vllm_ascend.patch.platform.patch_ec_connector012 # noqa
else:
import vllm_ascend.patch.platform.patch_ec_connector # noqa

+ 6
- 7
vllm_ascend/patch/platform/patch_ec_connector.py View File

@@ -1,16 +1,15 @@
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector
import vllm.distributed.ec_transfer.ec_connector.example_connector
from safetensors.torch import load_file
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
from vllm.distributed.ec_transfer.ec_connector.example_connector import (
ECConnectorMetadata, ECExampleConnector)
from vllm.logger import logger


class AscendECSharedStorageConnector(ECSharedStorageConnector):
class AscendECExampleConnector(ECExampleConnector):

def start_load_caches(self, encoder_cache, **kwargs) -> None:
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert isinstance(metadata, ECConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning((
@@ -29,4 +28,4 @@ class AscendECSharedStorageConnector(ECSharedStorageConnector):
mm_data.mm_hash)


vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector
vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector

+ 33
- 0
vllm_ascend/patch/platform/patch_ec_connector012.py View File

@@ -0,0 +1,33 @@
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector # type: ignore[import-not-found] # noqa
from safetensors.torch import load_file
from vllm.distributed.ec_transfer.ec_connector.base import \
ECConnectorMetadata # type: ignore[import-not-found] # noqa
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( # type: ignore[import-not-found] # noqa
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
from vllm.logger import logger


class AscendECSharedStorageConnector(ECSharedStorageConnector):

def start_load_caches(self, encoder_cache, **kwargs) -> None:
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning((
"In connector.start_load_caches, ",
"but the connector metadata is None",
))
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = load_file(filename)["ec_cache"].npu()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s",
mm_data.mm_hash)


vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector

+ 2
- 0
vllm_ascend/patch/worker/patch_module.py View File

@@ -5,6 +5,8 @@ import torch
# TODO When the operator of argsort is ready, this patch must be removed.
def _argsort(tensor, *args, **kwargs):
if tensor.dtype == torch.bool:
# If it is not stable, it will have redundant outputs.
kwargs["stable"] = True
return torch.argsort(tensor.to(torch.int32), *args, **kwargs)
else:
return torch.argsort(tensor, *args, **kwargs)


+ 4
- 0
vllm_ascend/platform.py View File

@@ -365,6 +365,10 @@ class NPUPlatform(Platform):
use_mla,
has_sink=False,
use_sparse=False,
# NOTE: Please pay special attention to the order of these parameters.
# Although we are only using some of them so far
# vllm passes them in sequence when using this interface.
use_mm_prefix: bool = False,
attn_type: str | None = None,
):
# choose attention backend based on use_mla


+ 34
- 0
vllm_ascend/quantization/quant_config.py View File

@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs

@@ -103,6 +104,15 @@ class AscendQuantConfig(QuantizationConfig):
return ASCEND_QUANTIZATION_METHOD
return None

def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
# TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix=prefix_mapping)
return hf_to_vllm_mapper._map_name(prefix)
return prefix

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
vllm_config = get_current_vllm_config()
@@ -110,6 +120,7 @@ class AscendQuantConfig(QuantizationConfig):
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[
model_type]
prefix = self.quant_prefix_mapper(model_type, prefix)
from vllm.attention.layer import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
@@ -174,6 +185,16 @@ class AscendQuantConfig(QuantizationConfig):
return []


# key: model_type
# value: orig_to_new_prefix
QUANT_MODEL_PREFIX_MAPPINGS = {
"qwen3_vl_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
}

packed_modules_model_mapping = {
"qwen3_moe": {
"qkv_proj": [
@@ -242,6 +263,19 @@ packed_modules_model_mapping = {
"up_proj",
],
},
"qwen3_vl_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4_moe": {
"qkv_proj": [
"q_proj",


+ 0
- 1
vllm_ascend/spec_decode/eagle_proposer.py View File

@@ -144,7 +144,6 @@ class EagleProposer(Proposer):
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):

attn_metadata = self._get_eagle_atten_dict(scheduler_output)


+ 0
- 1
vllm_ascend/spec_decode/interface.py View File

@@ -48,7 +48,6 @@ class Proposer:
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
"""Called by execute_model in model_runner"""
raise NotImplementedError

+ 9
- 17
vllm_ascend/spec_decode/mtp_proposer.py View File

@@ -8,6 +8,7 @@ import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.distributed import get_pcp_group
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -50,10 +51,6 @@ _MTP_MODELS = {
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
}

_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'

_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}


def _load_model(architecture):
if architecture not in _MTP_MODELS:
@@ -205,9 +202,11 @@ class MtpProposer(Proposer):
if self.vllm_config.model_config.is_deepseek_mla:
# check if mtp model use main model's embedding and LMhead
main_model = model
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
if get_pp_group().world_size == 1:
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):
@@ -342,10 +341,8 @@ class MtpProposer(Proposer):
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
attn_metadata = self._get_attn_metadata(attn_metadata)

if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
@@ -484,14 +481,6 @@ class MtpProposer(Proposer):
model = _load_model(architecture)
self.model = model(vllm_config=self.vllm_config).to(target_device)

def _get_attn_metadata(self, attn_metadata):
if attn_metadata is not None and isinstance(attn_metadata, dict):
architecture = self.vllm_config.model_config.architecture
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
attn_metadata = attn_metadata[layer_name]

return attn_metadata

def _prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
@@ -734,6 +723,9 @@ class MtpProposer(Proposer):
num_input_tokens, self.runner.with_prefill)

moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens)
# TODO: remove this after moe_comm_type selection logic is finalized
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
== MoECommType.FUSED_ALLTOALL else moe_comm_type)

# Enable shared_expert_dp and MTP FULL graph may cause accuracy issues.
if scheduler_output and not self.enable_shared_expert_dp:


+ 0
- 1
vllm_ascend/spec_decode/ngram_proposer.py View File

@@ -38,7 +38,6 @@ class NgramProposer(VllmNgramProposer, Proposer):
positions=None,
num_scheduled_tokens=None,
hidden_states=None,
attn_metadata=None,
aux_hidden_states=None) -> list[list[int]]:
valid_ngram_requests = []
for i, sampled_ids in enumerate(valid_sampled_token_ids):


+ 0
- 1
vllm_ascend/spec_decode/suffix_proposer.py View File

@@ -38,7 +38,6 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
positions=None,
num_scheduled_tokens=None,
hidden_states=None,
attn_metadata=None,
aux_hidden_states=None) -> list[list[int]]:
draft_token_ids = self.propose(self.runner.input_batch,
valid_sampled_token_ids)


+ 16
- 6
vllm_ascend/utils.py View File

@@ -31,6 +31,7 @@ import torch_npu # noqa: F401
from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event
from vllm.logger import logger
from vllm.sequence import IntermediateTensors

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
@@ -475,9 +476,10 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:

# Calculate maximum supported batch sizes considering model architecture
resources_per_graph = num_hidden_layers + 1
if vllm_config.speculative_config is not None:
draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config
resources_per_graph += draft_model_hf_config.num_hidden_layers + 1
# For suffix decoding, use the suffix path when no draft_model_config is provided.
if (spec := vllm_config.speculative_config) and \
(draft := spec.draft_model_config):
resources_per_graph += draft.hf_config.num_hidden_layers + 1

# TODO: Find out whether we need to take into account the pp_size
num_comm_groups = sum(size > 1 for size in [
@@ -844,6 +846,13 @@ def weak_ref_tensors(
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
# For IntermediateTensors used in pipeline parallelism
if isinstance(tensors, IntermediateTensors):
ret = IntermediateTensors({
key: weak_ref_tensor(val)
for key, val in tensors.tensors.items()
})
return ret
raise ValueError("Invalid type for tensors")


@@ -920,16 +929,17 @@ def calculate_ep_buffer_size() -> int:
try:
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
tp_size = vllm_config.parallel_config.tensor_parallel_size
hf_config = vllm_config.model_config.hf_config

hidden_size = hf_config.hidden_size
topk = getattr(hf_config, "num_experts_per_token", 1)
batch_size = vllm_config.scheduler_config.max_num_batched_tokens
topk = getattr(hf_config, "num_experts_per_tok", 1)
batch_size = vllm_config.scheduler_config.max_num_batched_tokens // tp_size
int8_size = torch.iinfo(torch.int8).bits // 8
bf16_size = torch.finfo(torch.bfloat16).bits // 8
ep_buffer_size = math.ceil(
(batch_size * hidden_size * topk *
(int8_size * 2 + bf16_size)) / (1024 * 1024))
(int8_size + bf16_size) * 3) / (1024 * 1024))
except Exception:
pass
return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE)


+ 65
- 32
vllm_ascend/worker/model_runner_v1.py View File

@@ -121,8 +121,8 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_enable_nz,
is_moe_model, lmhead_tp_enable)
from vllm_ascend.worker.npu_input_batch import InputBatch
is_moe_model, lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
@@ -249,13 +249,24 @@ class NPUModelRunner(GPUModelRunner):
# Set up Attention
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
"index_topk")
self.attn_backend = get_attn_backend(0,
self.dtype,
None,
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)

if vllm_version_is('0.12.0'):
self.attn_backend = get_attn_backend(
0,
self.dtype,
None,
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)
else:
self.attn_backend = get_attn_backend(
0,
self.dtype,
None,
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse,
use_mm_prefix=self.model_config is not None
and self.model_config.is_mm_prefix_lm)
self.attn_mask_builder = AttentionMaskBuilder(self.device)

self._set_up_drafter()
@@ -353,7 +364,7 @@ class NPUModelRunner(GPUModelRunner):
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
self.input_batch = InputBatch(
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
@@ -594,6 +605,8 @@ class NPUModelRunner(GPUModelRunner):
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)

total_num_pcp_pads = 0
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.generate_kv_idx(scheduler_output)
@@ -601,18 +614,21 @@ class NPUModelRunner(GPUModelRunner):
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
total_num_pcp_pads = torch.sum(self.num_pcp_pads).item()
else:
position_pcp, pcp_unpad_mask = None, None
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]

total_num_pcp_pads = sum(self.num_pcp_pads)
max_num_scheduled_tokens = max(tokens)
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
for num_tokens, i in zip(tokens, req_ids)
],
dtype=np.int32)
if not scheduler_output.scheduled_spec_decode_tokens:
num_valid_tokens = np.array(tokens, dtype=np.int32)
else:
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
for num_tokens, i in zip(tokens, req_ids)
],
dtype=np.int32)

if (self.use_aclgraph and total_num_scheduled_tokens
<= self.cudagraph_batch_sizes[-1]):
@@ -1367,7 +1383,7 @@ class NPUModelRunner(GPUModelRunner):
draft_token_ids = self.drafter.generate_token_ids(
valid_sampled_token_ids, sampling_metadata, scheduler_output,
spec_decode_metadata, positions, num_scheduled_tokens,
hidden_states, attn_metadata, aux_hidden_states)
hidden_states, aux_hidden_states)
return draft_token_ids

def _select_moe_comm_method(self,
@@ -1762,7 +1778,7 @@ class NPUModelRunner(GPUModelRunner):
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:scheduler_output.total_num_scheduled_tokens],
scheduler_output,
scheduler_output.num_scheduled_tokens,
)

num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
@@ -1995,19 +2011,36 @@ class NPUModelRunner(GPUModelRunner):
self.speculative_config.method == "mtp":
attn_state = AscendAttentionState.SpecDecoding

common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs +
if vllm_version_is("0.12.0"):
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
seq_lens=self.seq_lens.cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=slot_mapping.gpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)
query_start_loc_cpu=self.query_start_loc.
cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
seq_lens=self.seq_lens.cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=slot_mapping.gpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)
else:
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[:num_reqs +
1],
query_start_loc_cpu=self.query_start_loc.
cpu[:num_reqs + 1],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
seq_lens=self.seq_lens.cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=slot_mapping.gpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)

for attn_group in self.attn_groups[kv_cache_group_id]:
builder = attn_group.get_metadata_builder()
@@ -2773,7 +2806,7 @@ class NPUModelRunner(GPUModelRunner):
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.")
self.input_batch = InputBatch(
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,


+ 28
- 752
vllm_ascend/worker/npu_input_batch.py View File

@@ -17,92 +17,29 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
#

from dataclasses import dataclass
from typing import Optional, cast

import numpy as np
import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitsProcessors,
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
LogitsProcessors)
from vllm.v1.worker.gpu_input_batch import InputBatch

from vllm_ascend.worker.block_table import MultiGroupBlockTable


@dataclass
class CachedRequestState:

req_id: str
prompt_token_ids: Optional[list[int]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]

block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]

mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None

mm_features: Optional[list[MultiModalFeatureSpec]] = None
# for back-compatibility, will be removed in next major release
mm_kwargs: Optional[list[MultiModalKwargsItem]] = None
mm_positions: Optional[list[PlaceholderRange]] = None
mm_hashes: Optional[list[PlaceholderRange]] = None

lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None

prev_num_draft_len: int = 0 # previous number of draft tokens
class PoolingStates:
# NOTE: This should be removed after we drop support of vLLM v0.12.0
def __init__(self):
# for chunked prefill with ALL pooling
self.hidden_states_cache: list[torch.Tensor] = []

def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
def clean(self):
self.hidden_states_cache.clear()

@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)

# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
assert self.mm_features is not None
return [
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
if f.data is not None
]

def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx]
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
else:
return -1


class InputBatch:
class NPUInputBatch(InputBatch):

def __init__(
self,
@@ -113,12 +50,12 @@ class InputBatch:
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
kernel_block_sizes: list[list[int]],
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
kernel_block_sizes: Optional[list[list[int]]] = None,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
@@ -130,12 +67,12 @@ class InputBatch:
self.pin_memory = pin_memory
self.vocab_size = vocab_size

self._req_ids: list[Optional[str]] = []
self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {}

# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the NPU, so it does not
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
@@ -162,8 +99,8 @@ class InputBatch:
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy(
)

# Block table.
self.block_table = MultiGroupBlockTable(
@@ -222,8 +159,8 @@ class InputBatch:
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy(
)
self.frequency_penalties_reqs: set[str] = set()

# Presence penalty related data structures
@@ -247,8 +184,8 @@ class InputBatch:
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy(
)
self.repetition_penalties_reqs: set[str] = set()

# Speculative decoding
@@ -256,12 +193,12 @@ class InputBatch:
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.num_accepted_tokens_cpu = \
self.num_accepted_tokens_cpu_tensor.numpy()
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy(
)

# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}

@@ -271,9 +208,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}

self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}

# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -287,8 +221,8 @@ class InputBatch:
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
self.allowed_token_ids_mask: torch.Tensor | None = None
self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None

# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
@@ -296,7 +230,7 @@ class InputBatch:
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool)

self.req_output_token_ids: list[Optional[list[int]]] = []
self.req_output_token_ids: list[list[int] | None] = []

# Store provided logitsprocs. If none are provided, initialize empty
# data structure
@@ -310,673 +244,15 @@ class InputBatch:
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()

# for pooling models
self.pooling_params: dict[str, PoolingParams] = {}
self.pooling_states: dict[str, PoolingStates] = {}

# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: dict[str, int] | None = None
# These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.Event | None = None

@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)

def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""

# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params

# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs

assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params, request.prompt_token_ids,
request.output_token_ids))
return new_req_index

def add_request(
self,
request: "CachedRequestState",
) -> int:
if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs

req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
self.spec_token_ids.append([])
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.spec_token_ids[req_index].clear()

self.req_id_to_index[req_id] = req_index

# Copy the prompt token ids and output token ids.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)

if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)

self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)

# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator

if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs

if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False

if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError(request)

# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1

# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()

self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0

return req_index

def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense().

Args:
req_id: request to remove

Returns:
Removed request index, or `None` if `req_id` not recognized
"""

req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear()

# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0

if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index

self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)

if self.prev_req_id_to_index is not None:
self.prev_req_id_to_index.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0

self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index

def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
self.spec_token_ids[i2],
self.spec_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]

# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp

self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)

swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)

self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]

if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)

def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.

Any consecutive empty indices at the very end of the list are not
filled.

Args:
empty_req_indices: empty indices which may be filled.

Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs

if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return

if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
self.spec_token_ids.clear()
return

# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1

# Find the smallest empty index.
empty_index = self.batch_update_builder.peek_removed()
assert empty_index is not None
if empty_index >= last_req_index:
break

# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index

if last_req_index != empty_index:
(
self.spec_token_ids[last_req_index],
self.spec_token_ids[empty_index],
) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()

num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[
empty_index] = self.req_prompt_embeds.pop(last_req_index)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.num_accepted_tokens_cpu[
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator

self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]

# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]

bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids

# Decrement last_req_index since it is now empty.
last_req_index -= 1

# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]

def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""

if self.is_pooling_model:
# Batch changes every step for pooling models.
self.sampling_metadata = self._make_sampling_metadata()
return

# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()

def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)

if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)

needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None

allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]

return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)

def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]

def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()

return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)

def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""

req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())

return prompt_lora_mapping, token_lora_mapping, active_lora_requests

def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
tensor and corresponding copy-ready event. Used to repair
output_token_ids prior to sampling, if needed by logits processors.
"""
if self.sampling_metadata.output_token_ids:
self.sampled_token_ids_cpu = sampled_token_ids_cpu
self.async_copy_ready_event = async_copy_ready_event
else:
self.sampled_token_ids_cpu = None
self.async_copy_ready_event = None

def update_async_output_token_ids(self) -> None:
"""
In async scheduling case, update output_token_ids in sampling metadata
from prior steps sampled token ids once they've finished copying to CPU.
This is called right before they are needed by the logits processors.
"""
output_token_ids = self.sampling_metadata.output_token_ids
if self.sampled_token_ids_cpu is None or not output_token_ids:
# Output token ids not needed or not async scheduling.
return

assert self.prev_req_id_to_index is not None
sampled_token_ids = None
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_output_token_ids = output_token_ids[index]
if not req_output_token_ids or req_output_token_ids[-1] != -1:
# Final output id is not a placeholder, some tokens must have
# been discarded after a kv-load failure.
continue
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(
-1).tolist()
# Replace placeholder token id with actual sampled id.
req_output_token_ids[-1] = sampled_token_ids[prev_index]

@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)

@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0

@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0

@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0

@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0

@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)

@property
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None

@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs

@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

Loading…
Cancel
Save
Baidu
map