Skip to content

[MLAS][KleidiAI] Add MLAS Arm64 half GEMM and convolution support#28786

Open
Laan33 wants to merge 8 commits into
microsoft:mainfrom
Laan33:fp16-split/02-mlas-half-api-squashed
Open

[MLAS][KleidiAI] Add MLAS Arm64 half GEMM and convolution support#28786
Laan33 wants to merge 8 commits into
microsoft:mainfrom
Laan33:fp16-split/02-mlas-half-api-squashed

Conversation

@Laan33
Copy link
Copy Markdown
Contributor

@Laan33 Laan33 commented Jun 4, 2026

Description

Adds MLAS half-precision GEMM and convolution support for Arm64 KleidiAI paths.

This change:

  • Adds public MLAS HalfGemm backend-native packed-B APIs.
  • Adds public MLAS HalfConv prepare, execute, and packed weights+bias APIs.
  • Wires KleidiAI HalfGemm and HalfConv overrides into the MLAS platform dispatch.
  • Adds SME/SME2 FP16 HGEMM and FP16 IMATMUL ukernel selection.
  • Adds MLAS unit coverage for packed-B behavior, selector opt-out, zero-K handling, HalfConv prepare behaviour, and varied HalfGemm shapes/bias cases.

Motivation and Context

This is the second MR in the FP16 split and introduces the MLAS API and Arm64 backend plumbing needed for accelerated FP16 CPU kernels.

The later CPU operator changes can use these MLAS HalfGemm and HalfConv entry points without carrying KleidiAI-specific details at the operator layer. The backend selector support also preserves an explicit fallback path when KleidiAI should not be used for a given call.

Note: This PR is the second offshoot from 28487

Laan33 added 8 commits May 27, 2026 10:24
Avoid holding a reference into lhs_ptrs_cache_by_pad across the lookup/update sequence so the cache remains keyed by the current pad buffer identity.

Source-commit: 923422f

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Use BatchN for the batched dynamic QGEMM parameter to match the surrounding MLAS batched GEMM naming.

Source-commit: 923422f

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
The fp16 AttentionBatch1 test uses mask_index, which WebGPU Attention does not support.

Source-commit: 2483990

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Native fp16 accumulation paths need a slightly wider tolerance, and WebGPU needs a separate bound for this coverage.

Source-commit: 2483990

Source-commit: 08f4c6a8c51dc018b4440c63d279bf5995d3386a

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
The Resize(13) MLFloat16 test should not run against EPs that do not provide this kernel.

Source-commit: 2483990

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Allow the fp16 FusedConvWithSum transformer test to tolerate native fp16 numerical drift.

Source-commit: 2483990

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
These internal-domain NHWC fp16 pool tests target ORT CPU/MLAS coverage and should not be offered to unrelated registered EPs.

Source-commit: a504a0c

Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
Signed-off-by: Cathal Lawlor <cathal.lawlor@arm.com>
@Laan33 Laan33 changed the title Add MLAS Arm64 half GEMM and convolution support [MLAS][KleidiAI] Add MLAS Arm64 half GEMM and convolution support Jun 4, 2026
@xadupre xadupre requested a review from Copilot June 4, 2026 16:37
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR expands ONNX Runtime’s MLAS Arm64 KleidiAI integration to support FP16 (half) GEMM and convolution via new public MLAS APIs, KleidiAI override wiring, and additional kernel selection logic, with accompanying unit tests and some test-harness exclusions/tolerance tweaks.

Changes:

  • Adds public MLAS FP16 HalfGemm packed-B APIs (generic + backend-native) and FP16 HalfConv prepare/execute/pack APIs.
  • Wires KleidiAI FP16 HalfGemm/HalfConv overrides into MLAS platform dispatch and adds SME/SME2 FP16 ukernel selection.
  • Extends MLAS/unit and ORT operator tests to cover new behaviors (packed-B contracts, selector opt-out, zero-K handling, HalfConv prepare behavior), plus minor EP exclusions/tolerance adjustments.

Reviewed changes

Copilot reviewed 26 out of 26 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/tensor/resize_op_test.cc Excludes additional EPs for FP16 Resize(13) coverage gaps.
onnxruntime/test/optimizer/nhwc_transformer_test.cc Adjusts transformer tester invocation to use explicit tolerances/opset.
onnxruntime/test/mlas/unittest/test_util.h Ensures guard buffer state is fully reset (guard pointer).
onnxruntime/test/mlas/unittest/test_halfgemm.h Adds FP16-native comparison tolerance + new execution modes for HalfGemm tests.
onnxruntime/test/mlas/unittest/test_halfgemm.cpp Adds HalfGemm tests for selector routing, packed-B behavior, overflow handling, and more shape/bias cases.
onnxruntime/test/mlas/unittest/test_conv2d.cpp Adds HalfConv prepare tests for selector config handling and working-buffer sizing.
onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc Excludes non-owning EPs from internal NHWC fp16 pool tests.
onnxruntime/test/contrib_ops/matmul_4bits_test.cc Adjusts abs error tolerance for FP16, conditional on WebGPU.
onnxruntime/test/contrib_ops/attention_op_test.cc Disables WebGPU for a FP16 Attention test that uses unsupported mask_index.
onnxruntime/core/mlas/lib/platform.cpp Enables KleidiAI overrides under SME or SME2 and wires HalfGemm/HalfConv overrides.
onnxruntime/core/mlas/lib/mlasi.h Adds size_t overflow helpers used by KleidiAI/packing paths.
onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp Switches overflow checks to shared helpers.
onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp Switches overflow checks to shared helpers.
onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp Renames BatchSize→BatchN for clarity/consistency and updates uses.
onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h Refactors includes and adds scratch-buffer shrinking helper declarations.
onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp Adds KleidiAI FP16 HalfGemm implementation and native RHS pack.
onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp Adds KleidiAI FP16 convolution implementation (IMATMUL-based) and packing.
onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp Minor cache-lookup refactor for indirection table cache.
onnxruntime/core/mlas/lib/kai_ukernel_interface.h Adds FP16 IMATMUL/HGEMM kernel wrapper types and selection APIs.
onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp Implements FP16 kernel selection (SME vs SME2) and new ukernel wrappers.
onnxruntime/core/mlas/lib/halfgemm.h Adds shared packed-B sizing helper + implements generic CopyPackB with padding.
onnxruntime/core/mlas/lib/halfgemm.cpp Adds zero-K behavior, selector-config routing, native-pack APIs, and enables CopyPackB dispatch.
onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp Enables CopyPackB dispatch for NEON halfgemm.
onnxruntime/core/mlas/lib/halfconv.cpp Adds public dispatch wrappers for HalfConv prepare/execute/pack APIs.
onnxruntime/core/mlas/inc/mlas.h Exposes new HalfGemm/HalfConv APIs and extends parameter structs for new flags/config.
cmake/onnxruntime_mlas.cmake Adds new MLAS sources (halfconv + KleidiAI halfgemm/halfconv).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 9 to +11
#include "../mlasi.h"
#include <iostream>
#include <limits>
#include <vector>
Comment on lines +105 to +112
if (BatchN == 0 || M == 0 || N == 0) {
return;
}

if (K == 0) {
MlasHalfGemmZeroKBatch(M, N, BatchN, DataParams);
return;
}
Comment on lines +64 to +66
if (PackedK == 0) {
return false;
}
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 2 comments.

Comment on lines +154 to +158
if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchN) {

g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize);
g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchN);
}
g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize);
g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchN);
Comment on lines +109 to +112
if (K == 0) {
MlasHalfGemmZeroKBatch(M, N, BatchN, DataParams);
return;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants