7 Commits

Author SHA1 Message Date
  Jia Ningyu fdac87060f
DeepEP normal dispatch stream bugfix (#76901) 2 days ago
  Lucas de56830107
[XPU] update xhpc to 20251213 (#76902) 2 days ago
  HU Shenwei 3e32d54ca0
support batched_gemm fp32 (#76897) 2 days ago
  Lucas b0394b13c0
[XPU] support bflaot16,float16 type for op scatter_nd (#76893) 2 days ago
  ZhouMinhao98 58a6a3aa47
[XPU] binding ceil ond gaussian_inplace op on xpu3 (#76874) 2 days ago
  ZhouMinhao98 d3d79e6619
[XPU] binding moe permute unpermute (#76869) 2 days ago
  gouzil 4a2d0f100c
[SOT][3.14] Enable all SOT unittests in Python 3.14 (#76804) 2 days ago
19 changed files with 800 additions and 189 deletions
Split View
  1. +26
    -26
      .github/workflows/_SOT.yml
  2. +27
    -3
      ci/run_sot_test.sh
  3. +1
    -1
      cmake/external/xpu.cmake
  4. +40
    -36
      paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp
  5. +2
    -10
      paddle/fluid/pybind/sot/cpython_internals.c
  6. +1
    -1
      paddle/fluid/pybind/sot/eval_frame.c
  7. +16
    -3
      paddle/phi/backends/xpu/xpu3_op_list.cc
  8. +6
    -2
      paddle/phi/infermeta/binary.cc
  9. +58
    -26
      paddle/phi/kernels/legacy/gpu/batched_gemm.cu
  10. +22
    -0
      paddle/phi/kernels/xpu/activation_kernel.cc
  11. +33
    -0
      paddle/phi/kernels/xpu/gaussian_kernel.cc
  12. +274
    -0
      paddle/phi/kernels/xpu/moe_permute_kernel.cc
  13. +118
    -0
      paddle/phi/kernels/xpu/moe_unpermute_kernel.cc
  14. +2
    -1
      paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc
  15. +32
    -27
      paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc
  16. +68
    -46
      test/legacy_test/test_batched_gemm.py
  17. +64
    -0
      test/legacy_test/test_ceil.py
  18. +0
    -2
      test/sot/skip_files_py314
  19. +10
    -5
      test/xpu/test_scatter_nd_add_op_xpu.py

+ 26
- 26
.github/workflows/_SOT.yml View File

@@ -142,22 +142,22 @@ jobs:
echo "determine_excode=$determine_excode" >> ${{ github.env }}
'

- name: Build with python3.14
- name: Build with python3.9
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
if: ${{ env.determine_excode == 0 }}
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
export PY_VERSION=3.14
export PY_VERSION=3.9
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/cmake-predownload.sh
bash ${ci_scripts}/run_setup.sh bdist_wheel
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE
'

- name: Test with python3.14
- name: Test with python3.9
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
@@ -165,28 +165,27 @@ jobs:
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_sot_test.sh 3.14
bash ${ci_scripts}/run_sot_test.sh 3.9
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE
'

- name: Build with python3.9
- name: Build with python3.10
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
if: ${{ env.determine_excode == 0 }}
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
export PY_VERSION=3.9
export PY_VERSION=3.10
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/cmake-predownload.sh
bash ${ci_scripts}/run_setup.sh bdist_wheel
EXCODE=$?
exit $EXCODE
'

- name: Test with python3.9
- name: Test with python3.10
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
@@ -194,27 +193,27 @@ jobs:
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_sot_test.sh 3.9
bash ${ci_scripts}/run_sot_test.sh 3.10
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE
'

- name: Build with python3.10
- name: Build with python3.11
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
if: ${{ env.determine_excode == 0 }}
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
export PY_VERSION=3.10
export PY_VERSION=3.11
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_setup.sh bdist_wheel
EXCODE=$?
exit $EXCODE
'

- name: Test with python3.10
- name: Test with python3.11
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
@@ -222,27 +221,27 @@ jobs:
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_sot_test.sh 3.10
bash ${ci_scripts}/run_sot_test.sh 3.11
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE
'

- name: Build with python3.11
- name: Build with python3.12
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
if: ${{ env.determine_excode == 0 }}
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
export PY_VERSION=3.11
export PY_VERSION=3.12
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_setup.sh bdist_wheel
EXCODE=$?
exit $EXCODE
'

- name: Test with python3.11
- name: Test with python3.12
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
@@ -250,27 +249,27 @@ jobs:
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_sot_test.sh 3.11
bash ${ci_scripts}/run_sot_test.sh 3.12
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE
'

- name: Build with python3.12
- name: Build with python3.13
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
if: ${{ env.determine_excode == 0 }}
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
export PY_VERSION=3.12
export PY_VERSION=3.13
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_setup.sh bdist_wheel
EXCODE=$?
exit $EXCODE
'

- name: Test with python3.12
- name: Test with python3.13
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
@@ -278,27 +277,28 @@ jobs:
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_sot_test.sh 3.12
bash ${ci_scripts}/run_sot_test.sh 3.13
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE
'

- name: Build with python3.13
- name: Build with python3.14
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
if: ${{ env.determine_excode == 0 }}
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
export PY_VERSION=3.13
export PY_VERSION=3.14
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_setup.sh bdist_wheel
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE
'

- name: Test with python3.13
- name: Test with python3.14
env:
work_dir: ${{ github.workspace }}
PADDLE_ROOT: ${{ github.workspace }}
@@ -306,7 +306,7 @@ jobs:
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
source ${{ github.workspace }}/../../../proxy
bash ${ci_scripts}/run_sot_test.sh 3.13
bash ${ci_scripts}/run_sot_test.sh 3.14
EXCODE=$?
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
exit $EXCODE


+ 27
- 3
ci/run_sot_test.sh View File

@@ -12,6 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

run_and_check() {
local desc=$1
shift
echo "::group::${desc}"
local output
output=$("$@" 2>&1)
local code=$?
echo "${output}"
echo "::endgroup::"
if [ "$code" -ne 0 ]; then
echo "$desc with exit code $code"
exit "$code"
fi
}

function run_sot_test() {
PY_VERSION=$1
PYTHON_WITH_SPECIFY_VERSION=python$PY_VERSION
@@ -24,9 +39,18 @@ function run_sot_test() {
export SOT_ENABLE_STRICT_GUARD_CHECK=True

# Install PaddlePaddle
echo "::group::Installing paddle wheel..."
$PYTHON_WITH_SPECIFY_VERSION -m pip install ${PADDLE_ROOT}/dist/paddlepaddle-0.0.0-cp${PY_VERSION_NO_DOT}-cp${PY_VERSION_NO_DOT}-linux_x86_64.whl
echo "::endgroup::"
run_and_check "Installing paddle wheel..." \
$PYTHON_WITH_SPECIFY_VERSION -m pip install ${PADDLE_ROOT}/dist/paddlepaddle-0.0.0-cp${PY_VERSION_NO_DOT}-cp${PY_VERSION_NO_DOT}-linux_x86_64.whl

# Only python3.14 needs to install numpy>=2.3.5, because opencv-python will downgrade numpy to 2.2.6
# see: https://github.com/opencv/opencv-python/issues/1155
if [ "$PY_VERSION" == "3.14" ]; then
run_and_check "Uninstalling numpy for Python 3.14..." \
$PYTHON_WITH_SPECIFY_VERSION -m pip uninstall -y "numpy"
run_and_check "Installing numpy>=2.3.5 for Python 3.14..." \
$PYTHON_WITH_SPECIFY_VERSION -m pip install "numpy>=2.3.5"
fi

# cd to sot test dir
cd $PADDLE_ROOT/test/sot/



+ 1
- 1
cmake/external/xpu.cmake View File

@@ -36,7 +36,7 @@ set(XPU_FFT_LIB_NAME "libcufft.so")
add_compile_definitions(XPUAPI_NOT_INCLUDE_DEPRECATED)

if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "dev/20251210")
set(XPU_XHPC_BASE_DATE "dev/20251213")
endif()
if(WITH_ARM)
set(XPU_XCCL_BASE_VERSION "20251104") # For XRE5


+ 40
- 36
paddle/fluid/distributed/collective/deep_ep_xpu/deep_ep.cpp View File

@@ -378,24 +378,26 @@ Buffer::intranode_dispatch(
<< " last_topk_idx dim " << last_topk_idx->dim()
<< " last_topk_weights dim " << last_topk_weights->dim();

ret = bkcl_normal_dispatch_standard(comm_ctx->GetBKCLComm(),
x.data_ptr(), // sendbuf
x_scales_ptr,
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(),
recv_x.data_ptr(),
recv_x_scales_ptr,
recv_topk_idx->data_ptr<int>(),
recv_topk_weights->data_ptr<float>(),
num_scales,
-1, // UNUSED
hidden_size,
num_tokens,
num_topk,
last_num_experts,
ToBKCLDataType(x.dtype()),
use_int8,
reinterpret_cast<XPUStream>(comm_stream));
ret = bkcl_normal_dispatch_standard(
comm_ctx->GetBKCLComm(),
x.data_ptr(), // sendbuf
x_scales_ptr,
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(),
recv_x.data_ptr(),
recv_x_scales_ptr,
recv_topk_idx->data_ptr<int>(),
recv_topk_weights->data_ptr<float>(),
num_scales,
-1, // UNUSED
hidden_size,
num_tokens,
num_topk,
last_num_experts,
ToBKCLDataType(x.dtype()),
use_int8,
async ? reinterpret_cast<XPUStream>(comm_stream)
: reinterpret_cast<XPUStream>(compute_stream));
EP_HOST_ASSERT(ret == 0 && "bkcl_normal_dispatch_standard failed");

// Wait streams
@@ -729,24 +731,26 @@ Buffer::internode_dispatch(
<< " num_tokens " << num_tokens << " last_num_experts "
<< last_num_experts << " num_recv_tokens " << num_recv_tokens;

ret = bkcl_normal_dispatch_standard(comm_ctx->GetBKCLComm(),
x.data_ptr(), // sendbuf
x_scales_ptr,
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(),
recv_x.data_ptr(),
recv_x_scales_ptr,
recv_topk_idx->data_ptr<int>(),
recv_topk_weights->data_ptr<float>(),
num_scales,
-1, // UNUSED
hidden_size,
num_tokens,
num_topk,
last_num_experts,
ToBKCLDataType(x.dtype()),
use_int8,
reinterpret_cast<XPUStream>(comm_stream));
ret = bkcl_normal_dispatch_standard(
comm_ctx->GetBKCLComm(),
x.data_ptr(), // sendbuf
x_scales_ptr,
last_topk_idx->data_ptr<int>(),
last_topk_weights->data_ptr<float>(),
recv_x.data_ptr(),
recv_x_scales_ptr,
recv_topk_idx->data_ptr<int>(),
recv_topk_weights->data_ptr<float>(),
num_scales,
-1, // UNUSED
hidden_size,
num_tokens,
num_topk,
last_num_experts,
ToBKCLDataType(x.dtype()),
use_int8,
async ? reinterpret_cast<XPUStream>(comm_stream)
: reinterpret_cast<XPUStream>(compute_stream));
EP_HOST_ASSERT(ret == 0 && "bkcl_normal_dispatch_standard failed");

// Wait streams


+ 2
- 10
paddle/fluid/pybind/sot/cpython_internals.c View File

@@ -817,15 +817,6 @@ void Internal_PyFrame_ClearLocals(_PyInterpreterFrame *frame) {
}
#endif

#if PY_3_14_PLUS
static inline PyGenObject *_PyGen_GetGeneratorFromFrame(
_PyInterpreterFrame *frame) {
assert(frame->owner == FRAME_OWNED_BY_GENERATOR);
size_t offset_in_gen = offsetof(PyGenObject, gi_iframe);
return (PyGenObject *)(((char *)frame) - offset_in_gen);
}
#endif

// Call on 3.11 _PyFrame_Clear is called on 3.12+ _PyFrame_ClearExceptCode
#if PY_3_12_PLUS
void Internal_PyFrame_ClearExceptCode(_PyInterpreterFrame *frame) {
@@ -836,7 +827,8 @@ void Internal_PyFrame_Clear(_PyInterpreterFrame *frame) {
* to have cleared the enclosing generator, if any. */
#if PY_3_14_PLUS
assert(frame->owner != FRAME_OWNED_BY_GENERATOR ||
_PyGen_GetGeneratorFromFrame(frame)->gi_frame_state == FRAME_CLEARED);
Internal_PyGen_GetGeneratorFromFrame(frame)->gi_frame_state ==
FRAME_CLEARED);
#else
assert(frame->owner != FRAME_OWNED_BY_GENERATOR ||
_PyFrame_GetGenerator(frame)->gi_frame_state == FRAME_CLEARED);


+ 1
- 1
paddle/fluid/pybind/sot/eval_frame.c View File

@@ -152,7 +152,7 @@ inline static PyObject *eval_custom_code_py311_plus(PyThreadState *tstate,
}
#if PY_3_14_PLUS
if (PyStackRef_IsNull(fastlocals_old[i])) {
fastlocals_new[PyLong_AsSize_t(index)] = PyStackRef_NULL;
fastlocals_new[PyLong_AsSize_t(index)] = fastlocals_old[i];
} else {
fastlocals_new[PyLong_AsSize_t(index)] =
PyStackRef_DUP(fastlocals_old[i]);


+ 16
- 3
paddle/phi/backends/xpu/xpu3_op_list.cc View File

@@ -745,6 +745,10 @@ XPUOpMap& get_kl3_ops() {
{"unfold_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"floor", XPUKernelSet({phi::DataType::FLOAT32})},
{"ceil",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"gather_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
@@ -779,6 +783,10 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"gaussian_inplace",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"gelu_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
@@ -1338,12 +1346,15 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"scatter_nd_add",
XPUKernelSet({phi::DataType::FLOAT32,
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"scatter_nd_add_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"set_value",
@@ -1962,6 +1973,8 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"moe_permute", XPUKernelSet({phi::DataType::BFLOAT16})},
{"moe_unpermute", XPUKernelSet({phi::DataType::BFLOAT16})},
{"fused_rms_norm_ext",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,


+ 6
- 2
paddle/phi/infermeta/binary.cc View File

@@ -4803,10 +4803,14 @@ void BatchedGemmInferMeta(const MetaTensor& lhs,
common::errors::InvalidArgument(
"We don't support both lhs and rhs are transposed at the same time"));
PADDLE_ENFORCE_EQ(
lhs.dtype() == DataType::BFLOAT16 && rhs.dtype() == DataType::BFLOAT16,
(lhs.dtype() == DataType::BFLOAT16 || lhs.dtype() == DataType::FLOAT32) &&
(rhs.dtype() == DataType::BFLOAT16 ||
rhs.dtype() == DataType::FLOAT32) &&
lhs.dtype() == rhs.dtype(),
true,
common::errors::InvalidArgument(
"The dtype of lhs and rhs must be BFLOAT16, but got [%s] and [%s]",
"The dtype of lhs and rhs must both be BFLOAT16 or both be FLOAT32, "
"but got [%s] and [%s]",
lhs.dtype(),
rhs.dtype()));
PADDLE_ENFORCE_EQ(


+ 58
- 26
paddle/phi/kernels/legacy/gpu/batched_gemm.cu View File

@@ -75,16 +75,17 @@ inline cublasComputeType_t GetCublasComputeType(paddle::DataType dtype) {
}
} // namespace

template <typename T>
void CublasGemm(cublasHandle_t cublas_handle,
phi::bfloat16 *a,
T *a,
int64_t a_rows,
int64_t a_cols,
bool trans_a,
phi::bfloat16 *b,
T *b,
int64_t b_rows,
int64_t b_cols,
bool trans_b,
phi::bfloat16 *c,
T *c,
int64_t c_rows,
int64_t c_cols) {
// NOTE(Pan Zhaowu): We use int32_t because cuBLAS requires int32_t for
@@ -100,25 +101,45 @@ void CublasGemm(cublasHandle_t cublas_handle,
cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;

constexpr float alpha = 1.0f, beta = 0.0f;
CUBLAS_CALL(phi::dynload::cublasGemmEx(cublas_handle,
transpose_b,
transpose_a,
m,
n,
k,
&alpha,
b,
CUDA_R_16BF,
ldb,
a,
CUDA_R_16BF,
lda,
&beta,
c,
CUDA_R_16BF,
c_cols,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT));

if constexpr (std::is_same<T, phi::bfloat16>::value) {
CUBLAS_CALL(phi::dynload::cublasGemmEx(cublas_handle,
transpose_b,
transpose_a,
m,
n,
k,
&alpha,
b,
CUDA_R_16BF,
ldb,
a,
CUDA_R_16BF,
lda,
&beta,
c,
CUDA_R_16BF,
c_cols,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT));
} else if constexpr (std::is_same<T, float>::value) {
CUBLAS_CALL(phi::dynload::cublasSgemm(cublas_handle,
transpose_b,
transpose_a,
m,
n,
k,
&alpha,
b,
ldb,
a,
lda,
&beta,
c,
c_cols));
} else {
PD_CHECK(false, "Unsupported data type in CublasGemm");
}
}

// Grouped GEMM forward kernel
@@ -140,7 +161,8 @@ void m_grouped_gemm_cuda_forward(const Context &dev_ctx,
const int64_t input_hidden_size = a_shape[1];
const int64_t output_hidden_size = trans_rhs ? b_shape[1] : b_shape[2];

if constexpr (std::is_same<T, paddle::bfloat16>::value) {
if constexpr (std::is_same<T, paddle::bfloat16>::value ||
std::is_same<T, float>::value) {
T *a_data = const_cast<T *>(a.data<T>()); // alias for a.data
T *b_data = const_cast<T *>(b.data<T>()); // alias for b.data
T *output_data = output->data<T>();
@@ -165,7 +187,7 @@ void m_grouped_gemm_cuda_forward(const Context &dev_ctx,
output_data += expert_bs * output_hidden_size;
}
} else {
PD_CHECK(false, "Unsupported data type");
PD_CHECK(false, "Unsupported data type, only support bfloat16 and float32");
}
}

@@ -186,7 +208,8 @@ void k_grouped_gemm_cuda_forward(const Context &dev_ctx,
const int64_t input_hidden_size = a_shape[1];
const int64_t output_hidden_size = b_shape[1];

if constexpr (std::is_same<T, paddle::bfloat16>::value) {
if constexpr (std::is_same<T, paddle::bfloat16>::value ||
std::is_same<T, float>::value) {
T *a_data = const_cast<T *>(a.data<T>()); // alias for a.data
T *b_data = const_cast<T *>(b.data<T>()); // alias for b.data
T *output_data = output->data<T>();
@@ -211,7 +234,7 @@ void k_grouped_gemm_cuda_forward(const Context &dev_ctx,
output_data += input_hidden_size * output_hidden_size;
}
} else {
PD_CHECK(false, "Unsupported data type");
PD_CHECK(false, "Unsupported data type, only support bfloat16 and float32");
}
}

@@ -275,6 +298,15 @@ void BatchedGEMM(const Context &dev_ctx,
dev_ctx, lhs, rhs, batch_sizes, output);
}
break;
case paddle::DataType::FLOAT32:
if (!trans_lhs) {
m_grouped_gemm_cuda_forward<float>(
dev_ctx, lhs, rhs, batch_sizes, trans_rhs, output);
} else {
k_grouped_gemm_cuda_forward<float>(
dev_ctx, lhs, rhs, batch_sizes, output);
}
break;
default:
PD_CHECK(false, "Unsupported data type");
}


+ 22
- 0
paddle/phi/kernels/xpu/activation_kernel.cc View File

@@ -523,6 +523,19 @@ struct XPUFloorFunctor : public funcs::BaseActivationFunctor<T> {
}
};

template <typename T>
struct XPUCeilFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::ceil<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "ceil");
}
};

template <typename T>
struct XPUSinFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
@@ -590,6 +603,7 @@ struct XPUAcosFunctor : public funcs::BaseActivationFunctor<T> {

DEFINE_XPU_ACTIVATION_KERNEL(Exp, XPUExpFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Floor, XPUFloorFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Ceil, XPUCeilFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Log, XPULogFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Reciprocal, XPUReciprocalFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Relu, XPUReluFunctor)
@@ -775,3 +789,11 @@ PD_REGISTER_KERNEL(floor,
int64_t,
phi::float16,
phi::bfloat16) {}

PD_REGISTER_KERNEL(ceil,
XPU,
ALL_LAYOUT,
phi::CeilKernel,
float,
phi::float16,
phi::bfloat16) {}

+ 33
- 0
paddle/phi/kernels/xpu/gaussian_kernel.cc View File

@@ -49,6 +49,31 @@ void GaussianKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "normal");
}

template <typename T, typename Context>
void GaussianInplaceKernel(const Context& dev_ctx,
const DenseTensor& x,
float mean,
float std,
int seed,
DenseTensor* out) {
T* data = dev_ctx.template Alloc<T>(out);

if (out->numel() == 0) {
return;
}

using XPUType = typename XPUTypeTrait<T>::Type;
int64_t real_seed = seed != 0 ? seed : dev_ctx.GetGenerator()->Random64();

int r = xpu::normal_<XPUType>(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(data),
mean,
std,
out->numel(),
real_seed);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "normal");
}

} // namespace phi

PD_REGISTER_KERNEL(gaussian,
@@ -58,3 +83,11 @@ PD_REGISTER_KERNEL(gaussian,
float,
phi::float16,
phi::bfloat16) {}

PD_REGISTER_KERNEL(gaussian_inplace,
XPU,
ALL_LAYOUT,
phi::GaussianInplaceKernel,
float,
phi::float16,
phi::bfloat16) {}

+ 274
- 0
paddle/phi/kernels/xpu/moe_permute_kernel.cc View File

@@ -0,0 +1,274 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/utils/optional.h"

namespace phi {

#ifndef MAX_NUM_EXPERTS
#define MAX_NUM_EXPERTS 80
#endif

template <typename T, typename Context>
void dispatch_tokens_unzip_stable(const Context &dev_ctx,
const DenseTensor &X,
const DenseTensor &expert_routemap_topk,
const DenseTensor &expert_prob_topk,
const paddle::optional<DenseTensor> &XScale,
const DenseTensor &expert_offsets,
DenseTensor *X_unzipped,
DenseTensor *zipped_expertwise_rowmap,
DenseTensor *token_prob_unzipped,
DenseTensor *XScale_unzipped,
const int total_zipped_tokens_num,
const int token_length,
const int total_tokens_after_broadcast,
const int topk,
const int num_experts,
const int scale_length,
const bool do_gather) {
#define DTYPE_CASE(dtype, type) dtype == phi::DataType::type
#define GET_DATA(tensor, type) tensor.data<type>()
#define GET_XPU_DATA(tensor, type, xpu_type) \
reinterpret_cast<const xpu_type *>(tensor.data<type>())
#define GET_PTR_XPU_DATA(tensor, type, xpu_type) \
reinterpret_cast<xpu_type *>(tensor->data<type>())

#define DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, DO_GATHER) \
using XPU_TOKEN_T = typename XPUTypeTrait<TOKEN_T>::Type; \
using XPU_PROB_T = typename XPUTypeTrait<PROB_T>::Type; \
using XPU_INT_T = typename XPUTypeTrait<INT_T>::Type; \
\
int r = xpu::moe_permute<XPU_TOKEN_T, XPU_INT_T, XPU_PROB_T>( \
dev_ctx.x_context(), \
reinterpret_cast<const XPU_TOKEN_T *>( \
X.data<TOKEN_T>()), /* hidden_states */ \
(XScale ? XScale.get_ptr()->data<float>() : nullptr), /* scale */ \
reinterpret_cast<const XPU_INT_T *>( \
expert_routemap_topk.data<INT_T>()), /* expert_routemap_topk */ \
reinterpret_cast<const XPU_PROB_T *>( \
expert_prob_topk.data<PROB_T>()), /* expert_prob_topk */ \
reinterpret_cast<const XPU_INT_T *>( \
expert_offsets.data<int>()), /* expert_base_offset */ \
reinterpret_cast<XPU_TOKEN_T *>( \
X_unzipped->data<TOKEN_T>()), /* hidden_states_unzipped */ \
reinterpret_cast<XPU_INT_T *>( \
zipped_expertwise_rowmap \
->data<INT_T>()), /* zipped_expertwise_rowmap */ \
reinterpret_cast<XPU_PROB_T *>( \
token_prob_unzipped->data<PROB_T>()), /* token_prob_unzipped */ \
XScale_unzipped->data<float>(), /* scale_unzipped */ \
static_cast<int64_t>(total_zipped_tokens_num), /* sequence_length */ \
static_cast<int64_t>(token_length), /* hidden_size */ \
static_cast<int64_t>( \
total_tokens_after_broadcast), /* total_tokens_after_broadcast */ \
static_cast<int64_t>(topk), /* topk */ \
static_cast<int64_t>(num_experts), /* num_experts */ \
128, /* num_scale */ \
DO_GATHER /* do_gather */ \
); \
\
PADDLE_ENFORCE_XDNN_SUCCESS(r, "moe_permute");

#define HANDLE_GATHER_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE) \
if (do_gather) { \
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, true) \
} else { \
DISPATCH_CASE(TOKEN_T, PROB_T, INT_T, HAS_SCALE, false) \
}

// HANDLE_GATHER_CASE(phi::float8_e4m3fn, PROB_T, INT_T, true)
#define HANDLE_TOKEN_TYPE(PROB_T, INT_T) \
if (DTYPE_CASE(X.dtype(), BFLOAT16)) { \
HANDLE_GATHER_CASE(phi::bfloat16, PROB_T, INT_T, false) \
} else if (DTYPE_CASE(X.dtype(), FLOAT8_E4M3FN)) { \
PADDLE_THROW(common::errors::Unimplemented( \
"moe_permute input only support bfloat16")); \
}

#define HANDLE_PROB_TYPE(INT_T) \
if (DTYPE_CASE(expert_prob_topk.dtype(), BFLOAT16)) { \
PADDLE_THROW(common::errors::Unimplemented( \
"moe_permute expert_prob_topk only support float32")); \
} else if (DTYPE_CASE(expert_prob_topk.dtype(), FLOAT32)) { \
HANDLE_TOKEN_TYPE(float, INT_T) \
}

if (DTYPE_CASE(zipped_expertwise_rowmap->dtype(), INT32)) {
HANDLE_PROB_TYPE(int)
}

#undef DTYPE_CASE
#undef GET_DATA
#undef GET_XPU_DATA
#undef GET_PTR_XPU_DATA
#undef DISPATCH_CASE
#undef HANDLE_EXPERT_CASE
#undef HANDLE_TOKEN_TYPE
#undef HANDLE_PROB_TYPE
}

template <typename T, typename Context>
void MoePermuteKernel(const Context &dev_ctx,
const DenseTensor &X, // hidden_states
const paddle::optional<DenseTensor> &XScale,
const DenseTensor &expert_routemap_topk,
const DenseTensor &expert_prob_topk,
const int num_experts,
const std::vector<int> &tokens_per_expert,
const int padding_multiplex,
const bool do_gather,
DenseTensor *X_unzipped,
DenseTensor *zipped_expertwise_rowmap,
DenseTensor *token_prob_unzipped,
DenseTensor *XScale_unzipped) {
const int64_t rows = X.dims()[0];
const int64_t cols = X.dims()[1];
PADDLE_ENFORCE_LE(
rows,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument("X.dims()[0] should be less than "
"INT_MAX, received X.dims()[0]: (%ld)",
rows));
PADDLE_ENFORCE_LE(
cols,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument("X.dims()[1] should be less than "
"INT_MAX, received X.dims()[1]: (%ld)",
cols));
PADDLE_ENFORCE_LE(
num_experts,
MAX_NUM_EXPERTS,
common::errors::InvalidArgument(
"Currently we support no more than (%ld), received num_expert: "
"(%ld). Please check input "
"value.",
MAX_NUM_EXPERTS,
num_experts));
const int64_t quanted_cols = (XScale) ? XScale.get_ptr()->dims()[1] : 0;
PADDLE_ENFORCE_LE(
quanted_cols,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument("quanted_cols should be less than "
"INT_MAX, received quanted_cols: (%ld)",
quanted_cols));

// Expert base offset initialization, tensor numeric range [0, max_token_num]
int expert_offset[MAX_NUM_EXPERTS];
int tokens_cumulated = 0;
for (int i = 0; i < MAX_NUM_EXPERTS; i++) {
if (i < num_experts) {
expert_offset[i] = tokens_cumulated;
tokens_cumulated +=
((tokens_per_expert[i] + padding_multiplex - 1) / padding_multiplex) *
padding_multiplex;
} else {
expert_offset[i] = 0;
}
}
DenseTensor expert_offset_tensor;
expert_offset_tensor.Resize({MAX_NUM_EXPERTS});
dev_ctx.template Alloc<int>(&expert_offset_tensor);
PADDLE_ENFORCE_XPU_SUCCESS(
cudaMemcpyAsync(expert_offset_tensor.data<int>(),
expert_offset,
sizeof(int) * MAX_NUM_EXPERTS,
cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(dev_ctx.stream())));
// ------------------- resource allocate -------------------------
const int output_rows = tokens_cumulated;
const int64_t topk = expert_routemap_topk.dims()[1];
PADDLE_ENFORCE_LE(
topk,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument(
"topk should be less than INT_MAX, received topk: (%ld)", topk));
token_prob_unzipped->Resize({output_rows});
if (do_gather) { // no gather, no resize.
X_unzipped->Resize({output_rows, cols});
if (XScale) {
const int quanted_cols = XScale.get_ptr()->dims()[1];
XScale_unzipped->Resize({output_rows, quanted_cols});
}
}
dev_ctx.template Alloc<T>(X_unzipped);
dev_ctx.template Alloc<float>(XScale_unzipped);
dev_ctx.template Alloc<int>(zipped_expertwise_rowmap);
dev_ctx.template Alloc<float>(token_prob_unzipped);
auto X_unzipped_ptr = reinterpret_cast<void *>(X_unzipped->data<T>());
auto token_prob_unzipped_ptr =
reinterpret_cast<void *>(token_prob_unzipped->data<float>());
auto XScale_unzipped_ptr =
reinterpret_cast<void *>(XScale_unzipped->data<float>());

// -------- Memset all padding area to zero, with regard to do_gather
auto memset_invalid_rows =
[&](void *ptr, int64_t element_size, int64_t stride) {
for (int i = 0; i < num_experts; i++) {
int64_t next_expert_offset =
i < num_experts - 1 ? expert_offset[i + 1] : output_rows;
int64_t invalid_rows =
next_expert_offset - expert_offset[i] - tokens_per_expert[i];
int64_t cur_expert_end = expert_offset[i] + tokens_per_expert[i];

PADDLE_ENFORCE_XPU_SUCCESS(cudaMemsetAsync(
ptr + cur_expert_end * stride * element_size,
0,
element_size * invalid_rows * stride,
reinterpret_cast<cudaStream_t>(dev_ctx.stream())));
}
};
if (do_gather) { // no gather, no memset
memset_invalid_rows(X_unzipped_ptr, sizeof(T), cols);
if (XScale) {
memset_invalid_rows(XScale_unzipped_ptr, sizeof(float), quanted_cols);
}
}
// Probs will be memset to zero whatsoever
memset_invalid_rows(token_prob_unzipped_ptr, sizeof(float), 1);

// Handle 0-size input
if (X.numel() == 0) return;

// -------- Initialize semaphore for cumsum ---------------
dispatch_tokens_unzip_stable<T, Context>(dev_ctx,
X,
expert_routemap_topk,
expert_prob_topk,
XScale,
expert_offset_tensor,
X_unzipped,
zipped_expertwise_rowmap,
token_prob_unzipped,
XScale_unzipped,
static_cast<int>(rows),
static_cast<int>(cols),
static_cast<int>(output_rows),
static_cast<int>(topk),
num_experts,
static_cast<int>(quanted_cols),
do_gather);
}
#undef MAX_NUM_EXPERTS
} // namespace phi

PD_REGISTER_KERNEL(moe_permute,
XPU,
ALL_LAYOUT,
phi::MoePermuteKernel,
// phi::float8_e4m3fn,
phi::bfloat16) {}

+ 118
- 0
paddle/phi/kernels/xpu/moe_unpermute_kernel.cc View File

@@ -0,0 +1,118 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/utils/optional.h"

namespace phi {
#ifndef MAX_NUM_EXPERTS
#define MAX_NUM_EXPERTS 64
#endif

template <typename T, typename Context>
void dispatch_tokens_zip(const Context &dev_ctx,
const DenseTensor &unzipped_tokens,
const DenseTensor &zipped_expertwise_rowmap,
const DenseTensor &expert_routemap_topk,
const DenseTensor &unzipped_token_probs,
DenseTensor *zipped_tokens,
DenseTensor *zipped_probs_topk,
const int total_zipped_tokens_num,
const int num_experts,
const int token_length,
const int topk,
const bool MP) {
using XPU_BF16 = typename XPUTypeTrait<phi::bfloat16>::Type;
// Map data types to C++ types
if (unzipped_token_probs.dtype() == paddle::DataType::FLOAT32) {
int r = xpu::moe_unpermute(
dev_ctx.x_context(),
reinterpret_cast<const XPU_BF16 *>(
unzipped_tokens.data<phi::bfloat16>()),
reinterpret_cast<const int *>(zipped_expertwise_rowmap.data<int>()),
reinterpret_cast<const int *>(expert_routemap_topk.data<int>()),
reinterpret_cast<const float *>(unzipped_token_probs.data<float>()),
reinterpret_cast<XPU_BF16 *>(zipped_tokens->data<phi::bfloat16>()),
zipped_probs_topk->data<float>(),
total_zipped_tokens_num,
num_experts,
token_length,
topk,
MP,
unzipped_tokens.dims()[0]);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "moe_unpermute");
}
}

template <typename T, typename Context>
void MoeUnpermuteKernel(const Context &dev_ctx,
const DenseTensor &unzipped_tokens,
const DenseTensor &zipped_expertwise_rowmap,
const DenseTensor &expert_routemap_topk,
const DenseTensor &unzipped_token_probs,
const int total_zipped_tokens_num,
const int num_experts,
const bool MP,
DenseTensor *zipped_tokens,
DenseTensor *zipped_probs_topk) {
const int64_t cols = unzipped_tokens.dims()[1];
PADDLE_ENFORCE_LE(cols,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument(
"unzipped_tokens.dims()[1] should be less than "
"INT_MAX, received unzipped_tokens.dims()[1]: (%ld)",
cols));
PADDLE_ENFORCE_LE(
num_experts,
MAX_NUM_EXPERTS,
common::errors::InvalidArgument(
"Currently we support no more than (%ld), received num_expert: "
"(%ld). Please check input "
"value.",
MAX_NUM_EXPERTS,
num_experts));
const int64_t topk = expert_routemap_topk.dims()[1];
PADDLE_ENFORCE_LE(
topk,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument(
"topk should be less than INT_MAX, received topk: (%ld)", topk));
dev_ctx.template Alloc<T>(zipped_tokens);
dev_ctx.template Alloc<float>(zipped_probs_topk);
if (unzipped_tokens.numel() == 0) return; // 0-size tensor
void *zipped_probs_topk_ptr =
reinterpret_cast<void *>(zipped_probs_topk->data<float>());
PADDLE_ENFORCE_XPU_SUCCESS(
cudaMemsetAsync(zipped_probs_topk_ptr,
0,
sizeof(float) * int64_t(total_zipped_tokens_num) * topk,
reinterpret_cast<cudaStream_t>(dev_ctx.stream())));
dispatch_tokens_zip<T, Context>(dev_ctx,
unzipped_tokens,
zipped_expertwise_rowmap,
expert_routemap_topk,
unzipped_token_probs,
zipped_tokens,
zipped_probs_topk,
total_zipped_tokens_num,
num_experts,
static_cast<int>(cols),
static_cast<int>(topk),
MP);
}
} // namespace phi

PD_REGISTER_KERNEL(
moe_unpermute, XPU, ALL_LAYOUT, phi::MoeUnpermuteKernel, phi::bfloat16) {}

+ 2
- 1
paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc View File

@@ -124,7 +124,8 @@ PD_REGISTER_KERNEL(scatter_nd_add_grad,
XPU,
ALL_LAYOUT,
phi::ScatterNdAddGradKernel,
float,
phi::float16,
phi::bfloat16,
float,
int,
int64_t) {}

+ 32
- 27
paddle/phi/kernels/xpu/scatter_nd_add_kernel.cc View File

@@ -24,14 +24,17 @@ void ScatterNdAddKernel(const Context &dev_ctx,
const DenseTensor &index,
const DenseTensor &updates,
DenseTensor *out) {
using XPUType = typename XPUTypeTrait<T>::Type;
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
const T *x_ptr = x.data<T>();
const T *updates_ptr = updates.data<T>();
const XPUType *x_ptr = reinterpret_cast<const XPUType *>(x.data<T>());
const XPUType *updates_ptr =
reinterpret_cast<const XPUType *>(updates.data<T>());

T *out_ptr = dev_ctx.template Alloc<T>(out);
dev_ctx.template Alloc<T>(out);
XPUType *out_ptr = reinterpret_cast<XPUType *>(out->data<T>());
int r = xpu::copy(dev_ctx.x_context(), x_ptr, out_ptr, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");

@@ -45,12 +48,12 @@ void ScatterNdAddKernel(const Context &dev_ctx,
index.dims(), 0, index_dims_size - 1));

for (int64_t i = 0; i < loop_time; i++) {
r = xpu::broadcast_add<T>(dev_ctx.x_context(),
updates_ptr + out->numel() * i,
out_ptr,
out_ptr,
{out->numel()},
{out->numel()});
r = xpu::broadcast_add<XPUType>(dev_ctx.x_context(),
updates_ptr + out->numel() * i,
out_ptr,
out_ptr,
{out->numel()},
{out->numel()});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
return;
@@ -81,25 +84,25 @@ void ScatterNdAddKernel(const Context &dev_ctx,
if (index_type == phi::DataType::INT32) {
auto index_data = const_cast<int *>(index.data<int>());
xpu::VectorParam<int> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int>(dev_ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
r = xpu::scatter_nd<XPUType, int>(dev_ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
} else {
auto index_data = const_cast<int64_t *>(index.data<int64_t>());
xpu::VectorParam<int64_t> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int64_t>(dev_ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
r = xpu::scatter_nd<XPUType, int64_t>(dev_ctx.x_context(),
nullptr,
updates_ptr,
out_ptr,
index_vec,
x_vec,
index_shape,
false);
}

PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter_nd_add");
@@ -110,6 +113,8 @@ PD_REGISTER_KERNEL(scatter_nd_add,
XPU,
ALL_LAYOUT,
phi::ScatterNdAddKernel,
phi::float16,
phi::bfloat16,
float,
int64_t,
int) {}
int,
int64_t) {}

+ 68
- 46
test/legacy_test/test_batched_gemm.py View File

@@ -25,8 +25,12 @@ os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["FLAGS_embedding_deterministic"] = "1"


def allclose(x, y):
mask = np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-5)
def allclose(x, y, dtype):
if dtype == paddle.bfloat16:
rtol = 1e-5
else:
rtol = 1e-5
np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=rtol)


_TEST_PROBLEMS = (
@@ -41,9 +45,9 @@ _TEST_PROBLEMS = (
m_group_layout_cases = [(False, True), (False, False)]


def randn(bs, x, y):
def randn(bs, x, y, dtype=paddle.bfloat16):
out = (paddle.rand([bs, x, y]) - 0.5 * 2) / (y * x)
return out.astype(paddle.bfloat16)
return out.astype(dtype)


def pyref_gmm(a, b, batch_sizes, trans_b=False):
@@ -74,58 +78,76 @@ class TestGroupedGemm(unittest.TestCase):

def test_m_grouped_gemm_fixed_sizes(self):
"""Test grouped GEMM with fixed sizes"""
for z, m, k, n in _TEST_PROBLEMS:
for trans_lhs, trans_rhs in m_group_layout_cases:
# Test both bfloat16 and float32 dtypes
dtypes = [paddle.bfloat16, paddle.float32]

for dtype in dtypes:
for z, m, k, n in _TEST_PROBLEMS:
for trans_lhs, trans_rhs in m_group_layout_cases:
with self.subTest(
dtype=dtype,
z=z,
m=m,
k=k,
n=n,
trans_a=trans_lhs,
trans_b=trans_rhs,
) and paddle.amp.auto_cast(False):
a = randn(z, m, k, dtype).reshape([-1, k]).astype(dtype)
b = randn(z, k, n, dtype).astype(dtype)
if trans_rhs:
b = b.mT
batch_sizes = [m] * z
a.stop_gradient = False
b.stop_gradient = False
a_ref = a.clone().detach()
b_ref = b.clone().detach()
a_ref.stop_gradient = False
b_ref.stop_gradient = False
print(
f"Testing dtype={dtype}, shape={a.shape}, {b.shape}"
)
out = grouped_gemm(a, b, batch_sizes, False, trans_rhs)
expected_out = pyref_gmm(
a_ref, b_ref, batch_sizes, trans_rhs
)
allclose(out, expected_out.reshape(out.shape), dtype)

def test_k_grouped_gemm_variable_sizes(self):
"""Test grouped GEMM with variable sizes"""
# Test both bfloat16 and float32 dtypes
dtypes = [paddle.bfloat16, paddle.float32]

for dtype in dtypes:
for z, m, k, n in _TEST_PROBLEMS:
with self.subTest(
z=z, m=m, k=k, n=n, trans_a=trans_lhs, trans_b=trans_rhs
dtype=dtype, z=z, m=m, k=k, n=n, trans_a=True, trans_b=False
) and paddle.amp.auto_cast(False):
a = randn(z, m, k).reshape([-1, k]).astype(paddle.bfloat16)
b = randn(z, k, n).astype(paddle.bfloat16)
if trans_rhs:
b = b.mT
a = randn(z, m, k, dtype).astype(dtype)
b = randn(z, m, n, dtype).astype(dtype)

batch_sizes = [m] * z

a.stop_gradient = False
b.stop_gradient = False
a_ref = a.clone().detach()
b_ref = b.clone().detach()
a_ref.stop_gradient = False
b_ref.stop_gradient = False
print(f"{a.shape}, {b.shape}")
out = grouped_gemm(a, b, batch_sizes, False, trans_rhs)
expected_out = pyref_gmm(
a_ref, b_ref, batch_sizes, trans_rhs
)
allclose(out, expected_out.reshape(out.shape))

def test_k_grouped_gemm_variable_sizes(self):
"""Test grouped GEMM with variable sizes"""
for z, m, k, n in _TEST_PROBLEMS:
with self.subTest(
z=z, m=m, k=k, n=n, trans_a=True, trans_b=False
) and paddle.amp.auto_cast(False):
a = randn(z, m, k).astype(paddle.bfloat16)
b = randn(z, m, n).astype(paddle.bfloat16)

batch_sizes = [m] * z

a.stop_gradient = False
b.stop_gradient = False
a_ref = a.clone().detach()
b_ref = b.clone().detach()
a_ref.stop_gradient = False
b_ref.stop_gradient = False

out = grouped_gemm(
a.reshape([-1, k]),
b.reshape([-1, n]),
batch_sizes,
True,
False,
)
expected_out = pyref_k_gmm(
a_ref.reshape([-1, k]), b_ref.reshape([-1, n]), batch_sizes
)
allclose(out, expected_out.reshape(out.shape))
out = grouped_gemm(
a.reshape([-1, k]),
b.reshape([-1, n]),
batch_sizes,
True,
False,
)
expected_out = pyref_k_gmm(
a_ref.reshape([-1, k]),
b_ref.reshape([-1, n]),
batch_sizes,
)
allclose(out, expected_out.reshape(out.shape), dtype)


if __name__ == '__main__':


+ 64
- 0
test/legacy_test/test_ceil.py View File

@@ -0,0 +1,64 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle


class TestCeilOutAndParamDecorator(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x_np = np.random.uniform(-10, 10, [3, 4]).astype(np.float32)
self.test_types = ["decorator", "out", "out_decorator"]

def do_test(self, test_type):
x = paddle.to_tensor(self.x_np, stop_gradient=False)
if test_type == 'raw':
result = paddle.ceil(x)
result.mean().backward()
return result, x.grad
elif test_type == 'decorator':
result = paddle.ceil(input=x)
result.mean().backward()
return result, x.grad
elif test_type == 'out':
out = paddle.empty_like(x)
out.stop_gradient = False
paddle.ceil(x, out=out)
out.mean().backward()
return out, x.grad
elif test_type == 'out_decorator':
out = paddle.empty_like(x)
out.stop_gradient = False
paddle.ceil(input=x, out=out)
out.mean().backward()
return out, x.grad
else:
raise ValueError(f"Unknown test type: {test_type}")

def test_all(self):
out_std, grad_std = self.do_test('raw')
for test_type in self.test_types:
out, grad = self.do_test(test_type)
np.testing.assert_allclose(out.numpy(), out_std.numpy(), rtol=1e-20)
np.testing.assert_allclose(
grad.numpy(), grad_std.numpy(), rtol=1e-20
)


if __name__ == "__main__":
unittest.main()

+ 0
- 2
test/sot/skip_files_py314 View File

@@ -1,2 +0,0 @@
./test_force_dynamic.py
./test_numpy.py

+ 10
- 5
test/xpu/test_scatter_nd_add_op_xpu.py View File

@@ -20,6 +20,7 @@ from get_test_cover_info import (
create_test_class,
get_xpu_op_support_types,
)
from op_test import convert_float_to_uint16
from op_test_xpu import XPUOpTest

import paddle
@@ -78,11 +79,15 @@ class XPUTestScatterNdAdd(XPUOpTestWrapper):

self.init_data() # only test float32 because of its register type

self.inputs = {
'X': self.x_np,
'Index': self.index_np,
'Updates': self.updates_np,
}
self.inputs = {"Index": self.index_np}
if self.dtype == np.uint16:
self.inputs["X"] = convert_float_to_uint16(self.x_np)
self.inputs["Updates"] = convert_float_to_uint16(
self.updates_np
)
else:
self.inputs["X"] = self.x_np
self.inputs["Updates"] = self.updates_np
output = numpy_scatter_nd_add(
self.x_np.copy(), self.index_np, self.updates_np
)


Loading…
Cancel
Save
Baidu
map