11 Commits

Author SHA1 Message Date
  Tofu 24c29dc7eb
[CodeStyle][Xdoctest][17,21,27,28] Fix example code(`paddle.Tensor.matmul`,`paddle.Tensor.new_empty`,`paddle.Tensor.sgn`,`paddle.Tensor.shape`,) (#76691) 14 hours ago
  Yuqiang Ge 107306f398
enable avx (#76835) 15 hours ago
  xuanyuanminzheng 24e18c9948
Update forwards.h to metax (#76875) 16 hours ago
  Yuqiang Ge 52f76d38a4
fix custom device all to all (#76880) 16 hours ago
  Zhou Xin 704e3eeb87
[API Compatibilty][CustomDevice] Remove force set grad to None (#76883) 16 hours ago
  SUN Dong 3e93582eff
Fix 0 size for some api (#76468) 20 hours ago
  xinruiM e4cb191c0b
[XPU]Suppoort concat complex64 (#76904) 20 hours ago
  Lucas 0a50223cc5
[XPU] support bool type for multiply op (#76903) 20 hours ago
  Leo Guo a75775142a
[XPU] Binding fused_linear_param_grad_add (#76907) 21 hours ago
  bigwhite37 4af899ffd3
[MaskedFill] Update XPU grad kernel and add XPU masked_fill_grad tests (#76861) 21 hours ago
  baoqiwen 9529b19fe1
Reduce: align precision with PyTorch 2.9.1 (#76590) 21 hours ago
29 changed files with 2303 additions and 278 deletions
Split View
  1. +2
    -1
      .github/workflows/_Auto-Parallel.yml
  2. +1
    -1
      ci/run_setup.sh
  3. +1
    -3
      paddle/fluid/pybind/eager_method.cc
  4. +2
    -2
      paddle/fluid/pybind/eager_properties.cc
  5. +3
    -1
      paddle/phi/backends/custom/custom_device.cc
  6. +0
    -5
      paddle/phi/backends/gpu/forwards.h
  7. +18
    -5
      paddle/phi/backends/xpu/xpu3_op_list.cc
  8. +6
    -0
      paddle/phi/infermeta/fusion.cc
  9. +4
    -0
      paddle/phi/infermeta/ternary.cc
  10. +118
    -2
      paddle/phi/kernels/funcs/dense_tensor_iterator.cc
  11. +46
    -4
      paddle/phi/kernels/funcs/dense_tensor_iterator.h
  12. +1427
    -0
      paddle/phi/kernels/funcs/reduce_gpu_kernel.h
  13. +1
    -0
      paddle/phi/kernels/fusion/xpu/fused_linear_param_grad_add_kernel.cc
  14. +93
    -27
      paddle/phi/kernels/gpu/reduce.h
  15. +6
    -2
      paddle/phi/kernels/gpu/reduce_amin_amax_common.h
  16. +21
    -23
      paddle/phi/kernels/kps/reduce_kernel.cu
  17. +2
    -1
      paddle/phi/kernels/stride/reduce_grad_stride_kernel.cu
  18. +12
    -180
      paddle/phi/kernels/stride/reduce_stride_kernel.cu
  19. +9
    -2
      paddle/phi/kernels/xpu/concat_kernel.cc
  20. +1
    -0
      paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc
  21. +1
    -0
      paddle/phi/kernels/xpu/expand_grad_kernel.cc
  22. +159
    -0
      paddle/phi/kernels/xpu/masked_fill_grad_kernel.cc
  23. +0
    -1
      python/paddle/_paddle_docs.py
  24. +1
    -1
      python/paddle/tensor/attribute.py
  25. +5
    -9
      python/paddle/tensor/math.py
  26. +28
    -0
      test/legacy_test/test_flash_attention.py
  27. +52
    -0
      test/legacy_test/test_reduce_op.py
  28. +5
    -8
      test/legacy_test/test_tensor.py
  29. +279
    -0
      test/xpu/test_masked_fill_grad_op_xpu.py

+ 2
- 1
.github/workflows/_Auto-Parallel.yml View File

@@ -100,7 +100,8 @@ jobs:
git pull
git checkout 6ac04028757dfbcc089916997493611f62de81b2
git switch -c 6ac04028757dfbcc089916997493611f62de81b2
git cherry-pick bc08aeec91d2c992c3d8d39755bea7c6213b0e82
git cherry-pick bc08aeec91d2c992c3d8d39755bea7c6213b0e82
git cherry-pick 7ab35ce94eca977bcf3b44bfb42deb0e0b5ef158
git submodule update --init --recursive --force
'



+ 1
- 1
ci/run_setup.sh View File

@@ -234,7 +234,7 @@ EOF
export WITH_CINN=${WITH_CINN:-OFF}
export WITH_DISTRIBUTE=${distributed_flag}
export WITH_MKL=${WITH_MKL:-ON}
export WITH_AVX=${WITH_AVX:-OFF}
export WITH_AVX=${WITH_AVX:-ON}
export CUDA_ARCH_NAME=${CUDA_ARCH_NAME:-All}
export NEW_RELEASE_PYPI=${NEW_RELEASE_PYPI:-OFF}
export NEW_RELEASE_ALL=${NEW_RELEASE_ALL:-OFF}


+ 1
- 3
paddle/fluid/pybind/eager_method.cc View File

@@ -977,9 +977,7 @@ static PyObject* tensor_clear_gradient(TensorObject* self,
static_cast<phi::distributed::DistTensor*>(grad->impl().get())
->unsafe_mutable_value();
}
bool is_mismatched = self->tensor.place() != grad_t->place() ||
self->tensor.dtype() != grad_t->dtype();
if (set_to_zero && !is_mismatched) {
if (set_to_zero) {
EagerSetDeviceId();
auto* dev_ctx =
phi::DeviceContextPool::Instance().Get(grad_t->place());


+ 2
- 2
paddle/fluid/pybind/eager_properties.cc View File

@@ -585,13 +585,13 @@ Returns:
List: shape.

Examples:
.. code-block:: python
.. code-block:: pycon

>>> import paddle

>>> x = paddle.to_tensor(1.0, stop_gradient=False)
>>> print(x.shape)
[]
paddle.Size([])
)DOC");

PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) {


+ 3
- 1
paddle/phi/backends/custom/custom_device.cc View File

@@ -972,7 +972,9 @@ class CustomDevice : public DeviceInterface {
}
const phi::stream::Stream stream_wrapper(
Place(AllocationType::CUSTOM, Type()), stream);
MemoryCopyD2D(rank,

int current_device_id = GetDevice();
MemoryCopyD2D(current_device_id,
recv_buf[rank],
send_buf[rank],
send_count[rank] * phi::SizeOf(send_dtype[rank]),


+ 0
- 5
paddle/phi/backends/gpu/forwards.h View File

@@ -67,11 +67,6 @@ using cusolverDnHandle_t = struct cusolverDnContext *;
// Forward declaration of cuSparse types.
using cusparseHandle_t = struct cusparseContext *;

#ifdef PADDLE_WITH_CUDA
// Forward declaration of cuFFT types.
using cufftHandle = int;
#endif

// Forward declaration of NCCL types.
using ncclComm_t = struct ncclComm *;



+ 18
- 5
paddle/phi/backends/xpu/xpu3_op_list.cc View File

@@ -315,7 +315,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT8,
phi::DataType::INT16,
phi::DataType::INT32,
phi::DataType::INT64})},
phi::DataType::INT64,
phi::DataType::COMPLEX64})},
{"conv2d_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
@@ -463,9 +464,11 @@ XPUOpMap& get_kl3_ops() {
#ifdef PADDLE_WITH_XPU_FFT
phi::DataType::COMPLEX64,
#endif
phi::DataType::BFLOAT16})},
phi::DataType::BFLOAT16,
phi::DataType::INT64})},
{"elementwise_mul",
XPUKernelSet({phi::DataType::FLOAT32,
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
#ifdef PADDLE_WITH_XPU_FFT
@@ -972,6 +975,11 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"masked_fill_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"masked_select",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
@@ -1029,8 +1037,10 @@ XPUOpMap& get_kl3_ops() {
{"mul_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"multiply",
XPUKernelSet({phi::DataType::FLOAT32,
XPUKernelSet({phi::DataType::BOOL,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"multi_encoder_xpu",
@@ -1862,6 +1872,7 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT64})},
{"expand_v2_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT64,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT16})},
{"eye",
@@ -1908,7 +1919,9 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"fused_linear_param_grad_add",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"fused_attention",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fused_attention_grad",


+ 6
- 0
paddle/phi/infermeta/fusion.cc View File

@@ -4603,6 +4603,12 @@ void VariableLengthMemoryEfficientAttentionInferMeta(
true,
common::errors::InvalidArgument(
"The batch size of Query, Key, Value and Mask should be equal."));
PADDLE_ENFORCE_EQ(
mask.dims()[1],
1,
common::errors::InvalidArgument(
"The second dim of mask should be 1, but received mask dim is [%s]",
mask.dims()));
}

std::vector<int64_t> out_dims(


+ 4
- 0
paddle/phi/infermeta/ternary.cc View File

@@ -711,6 +711,10 @@ void FlashAttnInferMeta(const MetaTensor& q,
if (out_dims.size() == 4) {
out_dims[3] = v.dims()[3];
}
// for 0-size
if (q.dims()[0] == 0 || k.dims()[0] == 0 || v.dims()[0] == 0) {
out_dims[0] = 0;
}
out->set_dims(out_dims);
out->set_dtype(q.dtype());
out->set_layout(q.layout());


+ 118
- 2
paddle/phi/kernels/funcs/dense_tensor_iterator.cc View File

@@ -232,6 +232,7 @@ void DenseTensorIteratorBase::coalesce_dimensions() {
stride[dim0] = stride[dim1];
}
};

int prev_dim = 0;
for (auto dim = 1; dim < ndim(); dim++) {
if (can_coalesce(prev_dim, dim)) {
@@ -262,8 +263,8 @@ int64_t DenseTensorIteratorBase::numel() const {
return numel;
}

const void* DenseTensorIteratorBase::data_ptr(int64_t arg) const {
return static_cast<void*>(operands_[arg].tensor().data());
void* DenseTensorIteratorBase::data_ptr(int64_t arg) const {
return static_cast<void*>(operands_[arg].data);
}

static inline std::vector<int64_t> infer_size_dimvector(
@@ -449,6 +450,12 @@ void DenseTensorIteratorBase::build(DenseTensorIteratorConfig& config) {
allocate_or_resize_outputs();
coalesce_dimensions();
}

for (auto& op : operands_) {
op.data = const_cast<void*>(op.tensor().data());
}
int64_t ndim_offsets = (ndim() ? ndim() : 1);
view_offsets_ = std::vector<int64_t>(ndim_offsets, 0);
}

DimIter::DimIter(std::vector<int64_t> shape, int64_t start, int64_t end)
@@ -507,4 +514,113 @@ std::array<int64_t, 2> DimIter::iter_for_step() const {
return {step0, step1};
}

void DenseTensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
shape_[dim] = size;
view_offsets_[dim] += start;
for (auto& op : operands_) {
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[dim] * start;
}
if (size == 1 && !is_reduction_) {
coalesce_dimensions();
}
}

bool DenseTensorIteratorBase::is_dim_reduced(int dim) const {
for (auto& op : operands_) {
if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) {
return true;
}
}
return false;
}

std::unique_ptr<DenseTensorIterator> DenseTensorIteratorBase::split(int dim) {
auto split_iter = std::make_unique<DenseTensorIterator>(*this);
bool has_overlap = is_dim_reduced(dim);
int64_t split_size = shape_[dim] / 2;
int64_t remaining_size = shape_[dim] - split_size;

split_iter->narrow(dim, 0, split_size);
split_iter->final_output_ = !has_overlap;

narrow(dim, split_size, remaining_size);
accumulate_ |= has_overlap;

return split_iter;
}

int DenseTensorIteratorBase::get_dim_to_split() const {
int64_t max_extent = -1;
int dim_to_split = -1;

for (int dim = ndim() - 1; dim >= 0; --dim) {
const int64_t size = shape_[dim];
if (size == 0) {
continue;
}
for (auto& op : operands_) {
const int64_t extent = (size - 1) * std::abs(op.stride_bytes[dim]);
if (extent > max_extent) {
max_extent = extent;
dim_to_split = dim;
}
}
}
return dim_to_split;
}

bool DenseTensorIteratorBase::can_use_32bit_indexing() const {
constexpr int64_t max_32bit_value = std::numeric_limits<int32_t>::max();

if (numel() > max_32bit_value) {
return false;
}

for (auto& op : operands_) {
int64_t max_offset = 1;
for (int dim = 0; dim < ndim(); ++dim) {
max_offset += (shape_[dim] - 1) * op.stride_bytes[dim];
}

if (max_offset > max_32bit_value) {
return false;
}
}
return true;
}

Tensor32BitSplitter DenseTensorIteratorBase::with_32bit_indexing() const {
return Tensor32BitSplitter(*this);
}

Tensor32BitSplitter::iterator::iterator(const DenseTensorIteratorBase& iter) {
iterator_stack_.emplace_back(std::make_unique<DenseTensorIterator>(iter));
iterator_stack_.emplace_back(nullptr);
++(*this);
}

Tensor32BitSplitter::iterator& Tensor32BitSplitter::iterator::operator++() {
iterator_stack_.pop_back();

while (!iterator_stack_.empty() &&
!iterator_stack_.back()->can_use_32bit_indexing()) {
auto& current_iter = *iterator_stack_.back();
int split_dim = current_iter.get_dim_to_split();
iterator_stack_.emplace_back(current_iter.split(split_dim));
}

return *this;
}

DenseTensorIterator& Tensor32BitSplitter::iterator::operator*() const {
return *iterator_stack_.back();
}

Tensor32BitSplitter::iterator Tensor32BitSplitter::begin() const {
return Tensor32BitSplitter::iterator(source_iterator_);
}

Tensor32BitSplitter::iterator Tensor32BitSplitter::end() const {
return Tensor32BitSplitter::iterator();
}
} // namespace phi

+ 46
- 4
paddle/phi/kernels/funcs/dense_tensor_iterator.h View File

@@ -21,8 +21,10 @@
#include "paddle/utils/small_vector.h"

namespace phi {

struct DenseTensorIteratorConfig;
struct DenseTensorIterator;
struct Tensor32BitSplitter;

enum struct FastSetupType : uint8_t { NONE, CONTIGUOUS };

@@ -79,9 +81,14 @@ struct DenseTensorIteratorBase {
const std::vector<int64_t>& strides(int64_t arg) const {
return operands_[arg].stride_bytes;
}
const void* data_ptr(int64_t arg) const;
DataType dtype(int64_t arg = 0) const { return operands_[arg].current_dtype; }
std::vector<int64_t> view_offsets() const { return view_offsets_; }
void* data_ptr(int64_t arg) const;
bool should_accumulate() const { return accumulate_; }
bool is_final_output() const { return final_output_; }
int get_dim_to_split() const;
bool is_dim_reduced(int dim) const;
std::unique_ptr<DenseTensorIterator> split(int dim);

protected:
void populate_operands(DenseTensorIteratorConfig&);
@@ -93,10 +100,12 @@ struct DenseTensorIteratorBase {
bool fast_set_up(const DenseTensorIteratorConfig&);
FastSetupType compute_fast_setup_type(const DenseTensorIteratorConfig&);
void coalesce_dimensions();
void narrow(int dim, int64_t start, int64_t size);

protected:
std::vector<int64_t> shape_;
std::vector<int64_t> perm_;
std::vector<int64_t> view_offsets_;
bool has_coalesced_dimensions_ = false;
size_t num_outputs_ = 0;
bool all_ops_same_shape_ = false;
@@ -106,6 +115,8 @@ struct DenseTensorIteratorBase {
std::vector<DenseOperandInfo> operands_;
std::vector<int64_t> compatible_stride(int64_t element_size) const;
std::vector<int64_t> invert_perm(std::vector<int64_t> input) const;
bool can_use_32bit_indexing() const;
Tensor32BitSplitter with_32bit_indexing() const;
virtual void set_output_raw_strided(int64_t output_idx,
std::vector<int64_t> sizes,
std::vector<int64_t> strides);
@@ -116,9 +127,9 @@ struct DenseTensorIteratorBase {
};

/**
* DenseTensorIterator: Used for preprocessing metadata of tensors participating
* in computation. Can be directly used as OffsetCalculator input parameter to
* assist with index calculations.
* DenseTensorIterator: Used for preprocessing metadata of tensors
* participating in computation. Can be directly used as OffsetCalculator
* input parameter to assist with index calculations.
*/
struct DenseTensorIterator final : public DenseTensorIteratorBase {
DenseTensorIterator() : DenseTensorIteratorBase() {}
@@ -223,4 +234,35 @@ struct DimIter {
int64_t offset;
};

struct Tensor32BitSplitter {
struct iterator {
iterator() = default;
explicit iterator(const DenseTensorIteratorBase& iter);
iterator(iterator&&) = default;
iterator& operator=(iterator&&) = default;
~iterator() = default;

DenseTensorIterator& operator*() const;
iterator& operator++();

bool operator==(const iterator& other) const {
return this == &other ||
(iterator_stack_.empty() && other.iterator_stack_.empty());
}

bool operator!=(const iterator& other) const { return !(*this == other); }

std::vector<std::unique_ptr<DenseTensorIterator>> iterator_stack_;
};

explicit Tensor32BitSplitter(const DenseTensorIteratorBase& iter)
: source_iterator_(iter) {}

iterator begin() const;
iterator end() const;

private:
const DenseTensorIteratorBase& source_iterator_;
};

} // namespace phi

+ 1427
- 0
paddle/phi/kernels/funcs/reduce_gpu_kernel.h View File

@@ -0,0 +1,1427 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <bitset>
#include <limits>
#include <set>

#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/platform/device/gpu/gpu_info.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/dense_tensor_iterator.h"
#include "paddle/phi/kernels/funcs/index_elementwise.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/legacy/reduce_max_kernel.h"
#include "paddle/phi/kernels/prod_kernel.h"
#include "paddle/phi/kernels/reduce_all_kernel.h"
#include "paddle/phi/kernels/reduce_amin_kernel.h"
#include "paddle/phi/kernels/reduce_any_kernel.h"
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/kernels/reduce_mean_kernel.h"
#include "paddle/phi/kernels/reduce_min_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

#define WARP_SIZE 32

// The GPUReduceScheduler splits tensors with indices exceeding 32-bit range to
// ensure that all incoming tensors can be addressed within 32-bit index space.
using IndexType = uint32_t;

template <typename T>
struct LoadImpl {
HOSTDEVICE static T Apply(const void* src) {
return *reinterpret_cast<const T*>(src);
}
};

template <>
struct LoadImpl<bool> {
HOSTDEVICE static bool Apply(const void* src) {
static_assert(sizeof(bool) == sizeof(char));
return *reinterpret_cast<const unsigned char*>(src);
}
};

template <typename T>
HOSTDEVICE constexpr T LoadData(const void* src) {
return LoadImpl<T>::Apply(src);
}

template <typename ScalarT>
HOSTDEVICE constexpr ScalarT LoadData(const ScalarT* src) {
return LoadImpl<ScalarT>::Apply(src);
}

namespace phi {
inline std::bitset<64> DimListToBitset(std::vector<int> opt_dims,
size_t ndims) {
std::bitset<64> dim_mask;

if (opt_dims.size() > 0) {
for (int dim : opt_dims) {
dim_mask.set(dim, true);
}
} else {
for (size_t dim = 0; dim < ndims; dim++) {
dim_mask.set(dim, true);
}
}
return dim_mask;
}

inline std::vector<int> ConvertToPositiveDims(
const std::vector<int>& origin_reduce_dims, int64_t ndim) {
std::vector<int> positive_reduce_dims = origin_reduce_dims;
for (size_t i = 0; i < origin_reduce_dims.size(); ++i) {
PADDLE_ENFORCE_GE(
origin_reduce_dims[i],
-ndim,
common::errors::InvalidArgument(
"ReduceOp: invalid axis, when x_dims is %d, "
"axis[i] should be in the range of [-%d, %d), but got %d.",
ndim,
ndim,
ndim,
origin_reduce_dims[i]));
PADDLE_ENFORCE_LT(
origin_reduce_dims[i],
ndim,
common::errors::InvalidArgument(
"ReduceOp: invalid axis, when x_dims is %d, "
"axis[i] should be in the range of [-%d, %d), but got %d.",
ndim,
ndim,
ndim,
origin_reduce_dims[i]));

if (origin_reduce_dims[i] < 0) {
positive_reduce_dims[i] = ndim + origin_reduce_dims[i];
}
}
return positive_reduce_dims;
}

inline std::bitset<64> MakeDimMask(std::vector<int> opt_dims,
int64_t ndim,
bool allow_empty_dims = false) {
// flip() sets all bits to 1 (masking all dimensions for reduction).
if (opt_dims.empty() && !allow_empty_dims) {
return std::bitset<64>().flip();
}

// Otherwise, use the dimensions specified in opt_dims.
return DimListToBitset(opt_dims, ndim);
}

inline DenseTensor ReviewReduceResult(const DenseTensor& src,
const DenseTensor& result,
int ndim,
std::bitset<64> mask) {
std::vector<int64_t> shape;
std::vector<int64_t> stride;

int64_t cal_stride = 1;
const auto& src_dims = src.dims();

for (int dim = ndim - 1; dim >= 0; dim--) {
if (!mask[dim]) {
shape.insert(shape.begin(), src_dims[dim]);
stride.insert(stride.begin(), cal_stride);
cal_stride *= src_dims[dim];
} else {
shape.insert(shape.begin(), 1);
stride.insert(stride.begin(), cal_stride);
}
}

return funcs::as_strided(result, shape, stride);
}

template <typename T, int Size>
DEVICE AlignedVector<T, Size> LoadVector(const T* base_ptr, uint32_t offset) {
using vec_t = AlignedVector<T, Size>;
auto* from = reinterpret_cast<const vec_t*>(base_ptr);
return from[offset];
}

template <int Size>
DEVICE AlignedVector<bool, Size> LoadVector(const bool* base_ptr,
uint32_t offset) {
auto tmp = LoadVector<uint8_t, Size>(
reinterpret_cast<const uint8_t*>(base_ptr), offset);
AlignedVector<bool, Size> ret;
for (int i = 0; i < Size; ++i) {
ret.val[i] = static_cast<bool>(tmp.val[i]);
}
return ret;
}

// Chose max num threads.
template <typename T>
struct MaxThreadsConfig {
static constexpr int MAX_NUM_THREADS = 512;
};

template <>
struct MaxThreadsConfig<phi::dtype::complex<double>> {
static constexpr int MAX_NUM_THREADS = 256;
};

template <int kNumThreads, int kOutputVecSize, typename Reducer>
__launch_bounds__(kNumThreads, 4) __global__
void VecReduceKernel(Reducer reduction) {
reduction.template Run<kOutputVecSize>();
}

template <typename IndexType>
static funcs::OffsetCalculator<2, IndexType> MakeOutputOffsetCalculator(
const DenseTensorIterator& iter) {
int num_reduce_dims = iter.num_reduce_dims();
int num_output_dims = iter.ndim() - num_reduce_dims;
int input_index = iter.ntensors() - 1;
int output_index = 0;

std::array<const int64_t*, 2> stride_ptrs = {
iter.strides(output_index).data() + num_reduce_dims,
iter.strides(input_index).data() + num_reduce_dims,
};

auto output_shape_ptr = iter.shape().data() + num_reduce_dims;

return funcs::OffsetCalculator<2, IndexType>(
num_output_dims, output_shape_ptr, stride_ptrs.data());
}

template <typename IndexType>
static funcs::OffsetCalculator<1, IndexType> MakeInputOffsetCalculator(
const DenseTensorIterator& iter) {
int num_reduce_dims = iter.num_reduce_dims();
int input_index = iter.ntensors() - 1;

std::array<const int64_t*, 1> strides = {
iter.strides(input_index).data(),
};

auto input_shape_ptr = iter.shape().data();

return funcs::OffsetCalculator<1, IndexType>(
num_reduce_dims, input_shape_ptr, strides.data());
}

template <typename T>
int GetOutputVecSize(const DenseTensorIterator& iter) {
int vec_size = 4;

auto UpdateVectorSize = [&vec_size](uint64_t n) {
while (n % vec_size != 0) {
vec_size /= 2;
}
};

// Check base address alignment.
uint64_t base_address =
reinterpret_cast<uint64_t>(iter.data_ptr(iter.noutputs())) / sizeof(T);
UpdateVectorSize(base_address);

// Check output dimension size.
const int output_index = iter.num_reduce_dims();
UpdateVectorSize(iter.shape()[output_index]);

// Check strides alignment for all dimensions except output dimension.
auto input_tensor_index = iter.noutputs();
auto input_strides = iter.strides(input_tensor_index);

for (int dim = 0; dim < input_strides.size(); ++dim) {
if (dim != output_index) {
UpdateVectorSize(input_strides[dim] / sizeof(T));
}
}

return vec_size;
}

// Simplify fraction by dividing both numerator and denominator by their GCD
// (Greatest Common Divisor).
HOSTDEVICE static void ReduceFraction(size_t* numerator, size_t* denominator) {
size_t a = *denominator;
size_t b = *numerator;
while (b != 0) {
a %= b;
size_t tmp = a;
a = b;
b = tmp;
}

*numerator /= a;
*denominator /= a;
}

struct ReduceConfig {
static constexpr int BLOCK_X = 0;
static constexpr int BLOCK_Y = 1;
static constexpr int CTA = 2;

ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
: element_size_bytes(element_size_bytes),
num_inputs(num_inputs),
num_outputs(num_outputs) {}

// Basic configuration.
int element_size_bytes;
int num_inputs;
int num_outputs;

// Parallelism control.
int step_input = 1;
int step_output = 1;
int ctas_per_output = 1;

// Multiplier arrays for index calculation.
int input_multiplier[3] = {0, 0, 0};
int output_multiplier[2] = {0, 0};

// Dimensions.
int block_width;
int block_height;
int num_threads;

// Vectorization control.
bool vectorize_input = false;
int output_vec_size = 1;

template <typename T>
void SetBlockDimensions(int64_t dim0, int64_t dim1) {
const int max_num_threads =
MaxThreadsConfig<T>::MAX_NUM_THREADS / output_vec_size;

int dim0_pow2 =
(dim0 < max_num_threads)
? static_cast<int>(phi::backends::gpu::GetLastPow2(dim0))
: max_num_threads;
int dim1_pow2 =
(dim1 < max_num_threads)
? static_cast<int>(phi::backends::gpu::GetLastPow2(dim1))
: max_num_threads;
block_width = std::min(dim0_pow2, WARP_SIZE);
block_height =
std::min(dim1_pow2, static_cast<int>(max_num_threads / block_width));
block_width =
std::min(dim0_pow2, static_cast<int>(max_num_threads / block_height));
num_threads = block_width * block_height;
}

int SplitInput(int parallelism) {
const int current_step = step_input;
step_input *= parallelism;
return current_step;
}

int SplitOutput(int parallelism) {
const int current_step = step_output;
step_output *= parallelism;
return current_step;
}

dim3 GetBlockDim() const { return dim3(block_width, block_height); }

dim3 GetGridDim() const {
return dim3(phi::backends::gpu::DivUp<int64_t>(
num_outputs / output_vec_size, step_output),
ctas_per_output);
}

HOSTDEVICE bool ShouldReduceBlockX() const {
return input_multiplier[BLOCK_X] != 0;
}

HOSTDEVICE bool ShouldReduceBlockY() const {
return input_multiplier[BLOCK_Y] != 0;
}

HOSTDEVICE bool ShouldReduceGlobal() const {
return input_multiplier[CTA] != 0;
}

DEVICE bool ShouldStore(int output_idx) const {
// 1. Boundary Check: Ensure the output index is within the valid range.
// If out of bounds, no storage is necessary.
if (output_idx >= num_outputs) {
return false;
}

// 2. X-Reduction Check: If block-wide X-reduction is active, only the
// thread with index 0 in the X-dimension (the "leader") is allowed to
// store.
if (ShouldReduceBlockX() && threadIdx.x != 0) {
return false;
}

// 3. Y-Reduction Check: If block-wide Y-reduction is active, only the
// thread with index 0 in the Y-dimension (the "leader") is allowed to
// store.
if (ShouldReduceBlockY() && threadIdx.y != 0) {
return false;
}

// If the thread passes all checks, it is the designated thread to store the
// result.
return true;
}

DEVICE bool ShouldReduceTail() const {
return (!ShouldReduceBlockY() || threadIdx.y == 0) &&
(!ShouldReduceGlobal() || blockIdx.y == 0);
}

HOSTDEVICE int GetInIdx() const {
int thread_x = threadIdx.x;
int thread_y = threadIdx.y;
int block_y = blockIdx.y;
return (thread_x * input_multiplier[BLOCK_X] +
thread_y * input_multiplier[BLOCK_Y] +
block_y * input_multiplier[CTA]);
}

template <int kOutputVecSize>
HOSTDEVICE int GetOutIdx() const {
int thread_x = threadIdx.x;
int thread_y = threadIdx.y;
int block_x = blockIdx.x;
return (thread_x * output_multiplier[BLOCK_X] +
thread_y * output_multiplier[BLOCK_Y] + block_x * step_output) *
kOutputVecSize;
}

DEVICE int SharedMemoryOffset(int offset) const {
return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
}

DEVICE int StagingMemoryOffset(int block_y) const {
IndexType offset = block_y + static_cast<IndexType>(blockIdx.x) *
static_cast<IndexType>(gridDim.y);
if (!ShouldReduceBlockX()) {
offset = threadIdx.x + offset * blockDim.x;
}

return offset;
}

int SharedMemorySize() const {
if (!ShouldReduceBlockY() &&
(!ShouldReduceBlockX() || block_width <= WARP_SIZE)) {
return 0;
}

return element_size_bytes * num_threads * output_vec_size;
}

int64_t GlobalMemorySize() const {
if (!ShouldReduceGlobal()) {
return 0;
}

auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
if (!ShouldReduceBlockX()) {
size *= GetBlockDim().x * output_vec_size;
}

return size;
}

int SemaphoreSize() const {
if (!ShouldReduceGlobal()) {
return 0;
}
return sizeof(int) * GetGridDim().x;
}

int ValuesPerThread() const {
return phi::backends::gpu::DivUp<int64_t>(num_inputs, step_input);
}
};

template <typename MPType,
typename ScalarT,
int kVecSize,
int kInputVecSize = kVecSize>
ReduceConfig SetReduceConfig(const DenseTensorIterator& iter) {
int device_id = paddle::platform::GetCurrentDeviceId();

int64_t num_outputs = iter.num_output_elements();
int64_t inputs_per_output = iter.numel() / num_outputs;
int input_index = iter.ntensors() - 1;

auto config = ReduceConfig(sizeof(MPType), num_outputs, inputs_per_output);

int64_t dim0;
int64_t dim1;
int64_t fastest_moving_stride;
bool reduce_fastest_dim;

if (iter.ndim() > 0) {
// Check if we're reducing along the fastest-changing dimension
// This affects memory access patterns for better performance.
reduce_fastest_dim = (iter.num_reduce_dims() == iter.ndim()) ||
(iter.strides(input_index)[0] <
iter.strides(input_index)[iter.num_reduce_dims()]);

// Set block dimensions based on reduction pattern.
if (reduce_fastest_dim) {
// Reducing along fastest dimension: use block.x for reduction.
// block.x handles reduction elements.
// block.y handles output elements.
dim0 = inputs_per_output;
dim1 = num_outputs;
fastest_moving_stride = iter.strides(input_index)[0];
} else {
// Not reducing along fastest dimension: use block.x for outputs.
// block.x handles output elements.
// block.y handles reduction elements.
dim0 = num_outputs;
dim1 = inputs_per_output;
fastest_moving_stride = iter.strides(input_index)[iter.num_reduce_dims()];
}
} else {
// Handle 0-dimensional case.
reduce_fastest_dim = true;
fastest_moving_stride = sizeof(ScalarT);
dim0 = 1;
dim1 = 1;
}

// Use vectorization for better memory access. Two cases:
// Case 1: "Vectorize along input" - when reducing on fastest dimension,
// data in same vector corresponds to the same output.
// Case 2: "Vectorize along output" - when fastest dimension is not reduced,
// data in same vector corresponds to different outputs.
if (fastest_moving_stride == sizeof(ScalarT)) {
if (reduce_fastest_dim && dim0 > 128 && iter.num_reduce_dims() == 1 &&
kVecSize >= kInputVecSize) {
// Case 1: Vectorize along input (load data for same output together).
config.vectorize_input = true;
dim0 /= kInputVecSize;
} else if (!reduce_fastest_dim) {
// Case 2: Vectorize along output (load data for multiple outputs
// together).
config.output_vec_size = GetOutputVecSize<ScalarT>(iter);
dim0 /= config.output_vec_size;
}
}

// Adjust block_width and block_height.
config.SetBlockDimensions<ScalarT>(dim0, dim1);

int block_width = config.block_width;
int block_height = config.block_height;

// Level 1 parallelization: split work at thread level.
if (iter.ndim() == 0 || reduce_fastest_dim) {
// Case 1: Split input across threads (requires thread synchronization).
config.input_multiplier[0] = config.SplitInput(block_width);
} else {
// Case 2: Split output across threads (each thread handles different
// output).
config.output_multiplier[0] = config.SplitOutput(block_width);
}

// Min elements per thread.
constexpr int min_values_per_thread = 16;
// Max elements per thread.
constexpr int max_values_per_thread = 256;

// Decide if we need to split work across warps.
const int warp_split_threshold =
std::min<int>(block_height * 16, max_values_per_thread);
bool split_across_warps = config.ValuesPerThread() >= warp_split_threshold;

const int num_mp = paddle::platform::GetGPUMultiProcessors(device_id);

// Level 2 parallelization: split work at warp level.
if (split_across_warps) {
// Case 1: Split input across warps (requires warp synchronization).
config.input_multiplier[1] = config.SplitInput(block_height);
} else {
// Case 2: Each warp handles independent outputs.
config.output_multiplier[1] = config.SplitOutput(block_height);
}

int max_threads_per_mp =
paddle::platform::GetGPUMaxThreadsPerMultiProcessor(device_id);

const int blocks_per_sm = max_threads_per_mp / config.num_threads;
const int target_grid_size = num_mp * blocks_per_sm;
int grid = config.GetGridDim().x;

// Level 3 parallelization: split work at block level (for large datasets).
if (config.input_multiplier[1] != 0 &&
config.ValuesPerThread() >= max_values_per_thread &&
grid <= target_grid_size) {
// Calculate optimal block splitting strategy.
// Based on SM utilization.
int ctas_per_output1 =
phi::backends::gpu::DivUp<int64_t>(target_grid_size, grid);
// Based on min workload.
int ctas_per_output2 = phi::backends::gpu::DivUp<int64_t>(
config.ValuesPerThread(), min_values_per_thread);
// Based on max workload.
int ctas_per_output3 = phi::backends::gpu::DivUp<int64_t>(
config.ValuesPerThread(), max_values_per_thread);

// Choose best splitting strategy to balance parallelism and per-thread
// workload.
config.ctas_per_output = std::max(
std::min<int>(ctas_per_output1, ctas_per_output2), ctas_per_output3);

if (config.ctas_per_output > 1) {
// Case 3: Split input across blocks (requires global memory
// synchronization).
config.input_multiplier[2] = config.SplitInput(config.ctas_per_output);
}
}
return config;
}

template <typename ScalarT,
typename ReduceOp,
typename OutScalarT = ScalarT,
int kVecSize = 4,
int kInputVecSize = kVecSize>
struct ReduceExecutor {
using MPType = typename phi::dtype::MPTypeTrait<OutScalarT>::Type;

using InputCalculator = funcs::OffsetCalculator<1, IndexType>;
using OutputCalculator = funcs::OffsetCalculator<2, IndexType>;

static constexpr bool can_accumulate_in_output =
std::is_convertible_v<MPType, OutScalarT> &&
std::is_convertible_v<OutScalarT, MPType>;

// Core reduction algorithm configuration.
ReduceOp reducer;
ReduceConfig config;
MPType ident;
MPType factor;

// Data access calculators for input and output indexing.
InputCalculator input_calc;
OutputCalculator output_calc;

// Data pointers for source, destination, and buffers.
const void* src;
char* dst[2];
void* acc_buf;
void* cta_buf;

// Parallel synchronization primitives.
int* semaphores;

// Runtime state and control flags.
int64_t base_idx;
bool accumulate;
bool final_output;
int noutputs;

ReduceExecutor(ReduceOp reducer,
ReduceConfig config,
MPType ident,
MPType factor,
InputCalculator input_calc,
OutputCalculator output_calc,
const void* src,
char* dst0,
std::optional<char*> dst1,
void* acc_buf,
void* cta_buf,
int* semaphores,
int base_idx,
bool accumulate,
bool final_output,
int64_t noutputs)
: reducer(reducer),
config(config),
ident(ident),
factor(factor),
input_calc(input_calc),
output_calc(output_calc),
src(src),
acc_buf(acc_buf),
cta_buf(cta_buf),
semaphores(semaphores),
base_idx(base_idx),
accumulate(accumulate),
final_output(final_output),
noutputs(noutputs) {
dst[0] = dst0;
if (dst1.has_value()) {
dst[1] = dst1.value();
}
}

template <int kOutputVecSize>
DEVICE void Run() const {
extern __shared__ char shared_memory[];

IndexType output_idx = config.GetOutIdx<kOutputVecSize>();
IndexType input_idx = config.GetInIdx();
auto base_offsets1 = output_calc.get(output_idx)[1];

using MPTypeVec = std::array<MPType, kOutputVecSize>;
MPTypeVec value;

if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
const ScalarT* input_slice =
(const ScalarT*)((const char*)src + base_offsets1);
value = ThreadReduce<kOutputVecSize>(input_slice);
}

if (config.ShouldReduceBlockY()) {
value = BlockYReduce<kOutputVecSize>(value, shared_memory);
}

if (config.ShouldReduceBlockX()) {
value = BlockXReduce<kOutputVecSize>(value, shared_memory);
}

using OutPtrVec = std::array<OutScalarT*, kOutputVecSize>;
using OffsetVec = std::array<IndexType, kOutputVecSize>;

OffsetVec base_offsets;
OutPtrVec out;

#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
base_offsets[i] = output_calc.get(output_idx + i)[0];
out[i] = reinterpret_cast<OutScalarT*>(dst[0] + base_offsets[i]);
}

MPTypeVec* acc = nullptr;
if (acc_buf != nullptr) {
size_t numerator = sizeof(MPType);
size_t denominator = sizeof(OutScalarT);
ReduceFraction(&numerator, &denominator);
acc = reinterpret_cast<MPTypeVec*>(
reinterpret_cast<char*>(acc_buf) +
(base_offsets[0] * numerator / denominator));
}

if (config.ShouldReduceGlobal()) {
value = GlobalReduce<kOutputVecSize>(value, acc, shared_memory);
} else if (config.ShouldStore(output_idx)) {
if (acc == nullptr) {
if (accumulate) {
value = AccumulateInOutput<kOutputVecSize, can_accumulate_in_output>(
out, value);
}
if (final_output) {
SetResultsToOutput<kOutputVecSize>(value, base_offsets);
} else {
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
*(out[i]) = GetAccumulatedOutput<can_accumulate_in_output>(
out[i], value[i]);
}
}
} else {
if (accumulate) {
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
value[i] = reducer((*acc)[i], value[i]);
}
}
if (final_output) {
SetResultsToOutput<kOutputVecSize>(value, base_offsets);
} else {
*acc = value;
}
}
}
}

template <int kOutputVecSize>
DEVICE std::array<MPType, kOutputVecSize> ThreadReduce(
const ScalarT* data) const {
if (config.vectorize_input) {
return {InVectorizedThreadReduceImpl(data)};
} else {
IndexType element_stride = input_calc.strides_[0][0] / sizeof(ScalarT);
bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
if (is_contiguous) {
return ThreadReduceImpl<kOutputVecSize>(
data, [](IndexType idx) { return idx; });
} else if (input_calc.dims == 1) {
return ThreadReduceImpl<kOutputVecSize>(
data, [&](IndexType idx) { return idx * element_stride; });
} else {
return ThreadReduceImpl<kOutputVecSize>(data, [&](IndexType idx) {
return input_calc.get(idx)[0] / sizeof(ScalarT);
});
}
}
}

DEVICE MPType InVectorizedThreadReduceImpl(const ScalarT* data) const {
IndexType end = config.num_inputs;
MPType value = ident;
constexpr int align_bytes =
alignof(phi::AlignedVector<ScalarT, kInputVecSize>);

constexpr int align_elements = align_bytes / sizeof(ScalarT);
int shift = ((uint64_t)data) % align_bytes / sizeof(ScalarT);

if (shift > 0) {
data -= shift;
end += shift;
if (threadIdx.x >= shift && threadIdx.x < align_elements &&
config.ShouldReduceTail()) {
value = reducer(value, LoadData(data + threadIdx.x));
}
end -= align_elements;
data += align_elements;
shift = align_elements - shift;
}

IndexType idx = config.GetInIdx();
const IndexType stride = config.step_input;

std::array<MPType, kInputVecSize> value_list;
value_list[0] = value;

#pragma unroll
for (int i = 1; i < kInputVecSize; i++) {
value_list[i] = ident;
}

using load_t = phi::AlignedVector<ScalarT, kInputVecSize>;

while (idx * kInputVecSize + kInputVecSize - 1 < end) {
const auto values_vec = LoadVector<ScalarT, kInputVecSize>(data, idx);

#pragma unroll
for (IndexType i = 0; i < kInputVecSize; i++) {
value_list[i] = reducer(value_list[i], values_vec.val[i]);
}
idx += stride;
}

// Tile processing.
IndexType tail_start = end - end % kInputVecSize;

if (config.ShouldReduceTail()) {
int idx = tail_start + threadIdx.x;
if (idx < end) {
const auto value = LoadData(data + idx);
value_list[0] = reducer(value_list[0], value);
}
}

#pragma unroll
for (int i = 1; i < kInputVecSize; i++) {
value_list[0] = reducer(value_list[0], value_list[i]);
}

return value_list[0];
}

template <int kOutputVecSize, typename offset_calc_t>
DEVICE std::array<MPType, kOutputVecSize> ThreadReduceImpl(
const ScalarT* data_, offset_calc_t calc) const {
IndexType idx = config.GetInIdx();
const IndexType end = config.num_inputs;
const IndexType stride = config.step_input;

using MPTypeVec = std::array<MPType, kOutputVecSize>;
using load_t = phi::AlignedVector<ScalarT, kOutputVecSize>;

std::array<MPTypeVec, kVecSize> value_list;

#pragma unroll
for (int i = 0; i < kVecSize; i++) {
#pragma unroll
for (int j = 0; j < kOutputVecSize; j++) {
value_list[i][j] = ident;
}
}

std::array<load_t, kVecSize> values;

while (idx + (kVecSize - 1) * stride < end) {
#pragma unroll
for (IndexType i = 0; i < kVecSize; i++) {
const auto offset = calc(idx + i * stride) / kOutputVecSize;
values[i] = LoadVector<ScalarT, kOutputVecSize>(data_, offset);
}
#pragma unroll
for (IndexType i = 0; i < kVecSize; i++) {
#pragma unroll
for (IndexType j = 0; j < kOutputVecSize; j++) {
value_list[i][j] = reducer(value_list[i][j], values[i].val[j]);
}
}
idx += stride * kVecSize;
}

// tail
int idx_ = idx;
#pragma unroll
for (IndexType i = 0; i < kVecSize; i++) {
if (idx >= end) {
break;
}
const auto offset = calc(idx) / kOutputVecSize;
values[i] = LoadVector<ScalarT, kOutputVecSize>(data_, offset);
idx += stride;
}
idx = idx_;
#pragma unroll
for (IndexType i = 0; i < kVecSize; i++) {
if (idx >= end) {
break;
}
#pragma unroll
for (IndexType j = 0; j < kOutputVecSize; j++) {
value_list[i][j] = reducer(value_list[i][j], values[i].val[j]);
}
idx += stride;
}

#pragma unroll
for (int i = 1; i < kVecSize; i++) {
#pragma unroll
for (IndexType j = 0; j < kOutputVecSize; j++) {
value_list[0][j] = reducer(value_list[0][j], value_list[i][j]);
}
}
return value_list[0];
}

template <int kOutputVecSize>
DEVICE std::array<MPType, kOutputVecSize> BlockXReduce(
std::array<MPType, kOutputVecSize> value, char* shared_memory) const {
using MPTypeVec = std::array<MPType, kOutputVecSize>;
int dim_x = blockDim.x;
MPTypeVec* shared = reinterpret_cast<MPTypeVec*>(shared_memory);
if (dim_x > WARP_SIZE) {
IndexType address_base = static_cast<IndexType>(threadIdx.x) +
static_cast<IndexType>(threadIdx.y) *
static_cast<IndexType>(blockDim.x);

shared[address_base] = value;
for (int offset = dim_x / 2; offset >= WARP_SIZE; offset >>= 1) {
__syncthreads();

if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
MPTypeVec other = shared[address_base + offset];
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
value[i] = reducer(value[i], other[i]);
}
shared[address_base] = value;
}
}
dim_x = WARP_SIZE;
}

__syncthreads();

unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int offset = 1; offset < dim_x; offset <<= 1) {
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
MPType other =
phi::backends::gpu::CudaShuffleDownSync(mask, value[i], offset);
value[i] = reducer(value[i], other);
}
}
return value;
}

template <int kOutputVecSize>
DEVICE std::array<MPType, kOutputVecSize> BlockYReduce(
std::array<MPType, kOutputVecSize> value, char* shared_memory) const {
using MPTypeVec = std::array<MPType, kOutputVecSize>;
MPTypeVec* shared = reinterpret_cast<MPTypeVec*>(shared_memory);
shared[config.SharedMemoryOffset(0)] = value;

for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
__syncthreads();
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
MPTypeVec other = shared[config.SharedMemoryOffset(offset)];
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
value[i] = reducer(value[i], other[i]);
}
shared[config.SharedMemoryOffset(0)] = value;
}
}
return value;
}

DEVICE bool MarkBlockFinished() const {
__shared__ bool is_last_block_done_shared;

__syncthreads();
if (threadIdx.x == 0 && threadIdx.y == 0) {
int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
}

__syncthreads();

return is_last_block_done_shared;
}

template <int kOutputVecSize, bool can_acc>
DEVICE std::array<MPType, kOutputVecSize> AccumulateInOutput(
std::array<OutScalarT*, kOutputVecSize> out,
std::array<MPType, kOutputVecSize> value) const {
if constexpr (can_acc) {
std::array<MPType, kOutputVecSize> ret;
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
ret[i] = reducer(*(out[i]), value[i]);
}
return ret;
} else {
return {MPType{}};
}
}

template <bool can_acc>
DEVICE OutScalarT GetAccumulatedOutput(OutScalarT* out, MPType value) const {
if constexpr (can_acc) {
return (OutScalarT)value;
} else {
return *out;
}
}

template <class T>
DEVICE void SetResults(const T x, const IndexType base_offset) const {
auto res = reinterpret_cast<OutScalarT*>(dst[0] + base_offset);
*res = x;
}

template <class T1, class T2>
DEVICE void SetResults(const thrust::pair<T1, T2> x,
const IndexType base_offset) const {
if (noutputs >= 1) {
auto res0 = reinterpret_cast<T1*>(dst[0] + base_offset);
*res0 = x.first;
}
if (noutputs >= 2) {
auto res1 =
reinterpret_cast<T2*>(dst[1] + base_offset / sizeof(T1) * sizeof(T2));
*res1 = x.second;
}
}

template <int kOutputVecSize>
DEVICE void SetResultsToOutput(
std::array<MPType, kOutputVecSize> value,
std::array<IndexType, kOutputVecSize> base_offset) const {
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
SetResults(static_cast<OutScalarT>(value[i] * factor), base_offset[i]);
}
}

template <int kOutputVecSize>
DEVICE std::array<MPType, kOutputVecSize> GlobalReduce(
std::array<MPType, kOutputVecSize> value,
std::array<MPType, kOutputVecSize>* acc,
char* shared_memory) const {
using MPTypeVec = std::array<MPType, kOutputVecSize>;
using OutPtrVec = std::array<OutScalarT*, kOutputVecSize>;
using OffsetVec = std::array<IndexType, kOutputVecSize>;

MPTypeVec* reduce_buffer = reinterpret_cast<MPTypeVec*>(cta_buf);
IndexType output_idx = config.GetOutIdx<kOutputVecSize>();
OffsetVec base_offsets;
OutPtrVec out;

#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
base_offsets[i] = output_calc.get(output_idx + i)[0];
out[i] = reinterpret_cast<OutScalarT*>(dst[0] + base_offsets[i]);
}

bool should_store = config.ShouldStore(output_idx);
if (should_store) {
IndexType offset = config.StagingMemoryOffset(blockIdx.y);
reduce_buffer[offset] = value;
}

__threadfence();

__syncthreads();

bool is_last_block_done = MarkBlockFinished();

if (is_last_block_done) {
__threadfence();

for (auto& v : value) {
v = ident;
}

if (config.ShouldReduceBlockX()) {
IndexType input_offset = static_cast<IndexType>(threadIdx.x) +
static_cast<IndexType>(threadIdx.y) *
static_cast<IndexType>(blockDim.x);
IndexType step = static_cast<IndexType>(blockDim.x) *
static_cast<IndexType>(blockDim.y);

for (; input_offset < config.ctas_per_output; input_offset += step) {
IndexType idx = config.StagingMemoryOffset(input_offset);
MPTypeVec next = reduce_buffer[idx];
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
value[i] = reducer(value[i], next[i]);
}
}
} else {
IndexType input_offset = threadIdx.y;
IndexType step = blockDim.y;

for (; input_offset < config.ctas_per_output; input_offset += step) {
IndexType idx = config.StagingMemoryOffset(input_offset);
MPTypeVec next = reduce_buffer[idx];
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
value[i] = reducer(value[i], next[i]);
}
}
}
value = BlockYReduce<kOutputVecSize>(value, shared_memory);
if (config.ShouldReduceBlockX()) {
value = BlockXReduce<kOutputVecSize>(value, shared_memory);
}
if (should_store) {
if (acc == nullptr) {
if (accumulate) {
value =
AccumulateInOutput<kOutputVecSize, can_accumulate_in_output>(
out, value);
}
if (final_output) {
SetResultsToOutput<kOutputVecSize>(value, base_offsets);
} else {
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
*(out[i]) = GetAccumulatedOutput<can_accumulate_in_output>(
out[i], value[i]);
}
}
} else {
if (accumulate) {
#pragma unroll
for (int i = 0; i < kOutputVecSize; i++) {
value[i] = reducer((*acc)[i], value[i]);
}
}
if (final_output) {
SetResultsToOutput<kOutputVecSize>(value, base_offsets);
} else {
*acc = value;
}
}
}
}

return value;
}
};

class AccumulationBuffer {
public:
AccumulationBuffer() {}

AccumulationBuffer(const KPDevice& dev_ctx,
size_t acc_t_size,
size_t out_t_size,
char* out_ptr,
int64_t size) {
out_ptr_ = reinterpret_cast<char*>(out_ptr);
if (out_t_size >= acc_t_size) {
acc_ptr_ = reinterpret_cast<char*>(out_ptr);
numerator_ = 1;
denominator_ = 1;
} else {
phi::Allocator* allocator =
const_cast<phi::Allocator*>(&(dev_ctx.GetAllocator())); // NOLINT
buffer_ = allocator->Allocate(size);
acc_ptr_ = reinterpret_cast<char*>(buffer_->ptr());
numerator_ = acc_t_size;
denominator_ = out_t_size;
ReduceFraction(&numerator_, &denominator_);
}
}

char* GetAccSlice(char* out_ptr) {
if (acc_ptr_ == nullptr) {
return nullptr;
}
return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_);
}

private:
char* acc_ptr_ = nullptr;
char* out_ptr_ = nullptr;
size_t numerator_;
size_t denominator_;
Allocator::AllocationPtr buffer_;
};

template <int max_threads, typename R>
static void LaunchReduceKernel(const KPDevice& dev_ctx,
const ReduceConfig& config,
const R& reduction) {
dim3 block = config.GetBlockDim();
dim3 grid = config.GetGridDim();
int shared_memory = config.SharedMemorySize();

auto stream = dev_ctx.stream();

switch (config.output_vec_size) {
case 4:
VecReduceKernel<max_threads / 4, 4, R>
<<<grid, block, shared_memory, stream>>>(reduction);
break;
case 2:
VecReduceKernel<max_threads / 2, 2, R>
<<<grid, block, shared_memory, stream>>>(reduction);
break;
default:
VecReduceKernel<max_threads / 1, 1, R>
<<<grid, block, shared_memory, stream>>>(reduction);
break;
}
}

template <typename Tx,
typename Ty,
int kVecSize = 4,
int kInputVecSize = kVecSize,
typename ReduceOp,
typename ident_t = double>
inline void GPUReduceScheduler(
const KPDevice& dev_ctx,
const DenseTensorIterator& iter,
const ReduceOp& reducer,
ident_t ident = 0,
typename phi::dtype::MPTypeTrait<Ty>::Type factor = 1,
AccumulationBuffer* acc_buf_ptr = nullptr,
int64_t base_idx = 0) {
auto stream = dev_ctx.stream();

using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;

static constexpr bool is_inp_out_type_half_or_chalf =
(std::is_same_v<phi::float16, Tx> && std::is_same_v<phi::float16, Ty>) ||
(std::is_same_v<phi::dtype::complex<float16>, Tx> &&
std::is_same_v<phi::dtype::complex<float16>, Ty>);
static constexpr bool is_inp_out_type_bfloat16 =
(std::is_same_v<phi::bfloat16, Tx> && std::is_same_v<phi::bfloat16, Ty>);
static constexpr bool can_accumulate_in_output =
std::is_convertible_v<MPType, Ty> &&
!(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);

bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
if (acc_buf_ptr == NULL) {
if (!can_accumulate_in_output && !can_use_32bit_indexing) {
int64_t output_memory_size = sizeof(iter.dtype(0));
for (int dim = 0; dim < iter.ndim(); dim++) {
output_memory_size = std::max(output_memory_size,
iter.shape()[dim] * iter.strides(0)[dim]);
}
output_memory_size /= sizeof(iter.dtype(0));
owned_buf_ptr.reset(
new AccumulationBuffer(dev_ctx,
sizeof(MPType),
sizeof(Ty),
reinterpret_cast<char*>(iter.data_ptr(0)),
output_memory_size * sizeof(MPType)));
} else {
owned_buf_ptr.reset(new AccumulationBuffer());
}
acc_buf_ptr = owned_buf_ptr.get();
}

// Split iter if index exceeds 32-bit range.
if (!can_use_32bit_indexing) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
GPUReduceScheduler<Tx, Ty, kVecSize, kInputVecSize, ReduceOp>(
dev_ctx,
sub_iter,
reducer,
ident,
factor,
acc_buf_ptr,
sub_iter_base_idx);
}
return;
}

const char* in_data =
reinterpret_cast<const char*>(iter.data_ptr(iter.ntensors() - 1));
char* out_data = reinterpret_cast<char*>(iter.data_ptr(0));
const auto noutputs = iter.noutputs();

std::optional<char*> out_data_extra;
if (noutputs > 1) {
out_data_extra = reinterpret_cast<char*>(iter.data_ptr(1));
} else {
out_data_extra = std::nullopt;
}

char* acc_data = acc_buf_ptr->GetAccSlice(out_data);

ReduceConfig config =
SetReduceConfig<MPType, Tx, kVecSize, kInputVecSize>(iter);

Allocator::AllocationPtr buffer;
Allocator::AllocationPtr semaphores;
void* buffer_ptr;
void* semaphores_ptr;

if (config.ShouldReduceGlobal()) {
phi::Allocator* allocator =
const_cast<phi::Allocator*>(&(dev_ctx.GetAllocator())); // NOLINT
buffer = allocator->Allocate(config.GlobalMemorySize());
semaphores = allocator->Allocate(config.SemaphoreSize());
buffer_ptr = buffer->ptr();
semaphores_ptr = semaphores->ptr();

phi::backends::gpu::GpuMemsetAsync(
semaphores_ptr, 0, config.SemaphoreSize(), stream);
}

auto output_calc = MakeOutputOffsetCalculator<uint32_t>(iter);
auto input_calc = MakeInputOffsetCalculator<uint32_t>(iter);
auto should_accumulate = iter.should_accumulate();
auto is_final_output = iter.is_final_output();

auto reduce = ReduceExecutor<Tx, ReduceOp, Ty, kVecSize, kInputVecSize>(
reducer,
config,
ident,
factor,
input_calc,
output_calc,
in_data,
out_data,
out_data_extra,
acc_data,
buffer_ptr,
reinterpret_cast<int*>(semaphores_ptr),
base_idx,
should_accumulate,
is_final_output,
noutputs);

LaunchReduceKernel<MaxThreadsConfig<Tx>::MAX_NUM_THREADS>(
dev_ctx, config, reduce);

return;
}

namespace funcs {
template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp,
bool IsMean = false>
void ReduceGpuKernel(const KPDevice& dev_ctx,
const phi::DenseTensor& x,
phi::DenseTensor* y,
const TransformOp& transform,
const std::vector<int>& origin_reduce_dims) {
if (x.numel() == 0) {
dev_ctx.Alloc<Ty>(y);
return;
}

dev_ctx.Alloc<Ty>(y);

int64_t ndim = x.dims().size();
auto positive_reduce_dims = ConvertToPositiveDims(origin_reduce_dims, ndim);
auto mask = MakeDimMask(positive_reduce_dims, ndim);
auto viewed_result = ReviewReduceResult(x, *(y), ndim, mask);

auto x_dim = common::vectorize<int64_t>(x.dims());

if (x_dim.size() == 0) {
std::vector<const DenseTensor*> inputs = {&x};
std::vector<DenseTensor*> outputs = {&viewed_result};
funcs::ElementwiseKernel<Ty>(dev_ctx, inputs, &outputs, transform);
return;
}

DenseTensorIteratorConfig dense_iter_config;
dense_iter_config.is_reduction(true);
dense_iter_config.add_output(viewed_result);
dense_iter_config.add_const_input(x);
DenseTensorIterator iter = dense_iter_config.build();

// TODO(baoqiwen): When ReduceOp is WelfordOps, kVecSize is 2.
constexpr int kVecSize = 4;
constexpr int kInputVecSize = kVecSize;
using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<MPType>();

MPType factor = 1.0f;
if (IsMean) {
factor = static_cast<MPType>(iter.num_output_elements()) /
static_cast<MPType>(iter.numel());
}

// Initialize ident value.
Tx ident = []() {
if constexpr (std::is_same_v<ReduceOp<MPType>, kps::MaxFunctor<MPType>>) {
return std::numeric_limits<Tx>::lowest();
}

if constexpr (std::is_same_v<ReduceOp<MPType>, kps::MinFunctor<MPType>>) {
return std::numeric_limits<Tx>::max();
}

if constexpr (std::is_same_v<ReduceOp<MPType>,
kps::LogicalAndFunctor<MPType>>) {
return Tx{1};
}

if constexpr (std::is_same_v<ReduceOp<MPType>, kps::MulFunctor<MPType>>) {
return Tx{1};
}

// AddFunctor, LogicalOrFunctor and others
return Tx{0};
}();

GPUReduceScheduler<Tx, Ty, kVecSize, kInputVecSize, ReduceOp<MPType>>(
dev_ctx, iter, reducer, ident, factor);

return;
}
} // namespace funcs
} // namespace phi

+ 1
- 0
paddle/phi/kernels/fusion/xpu/fused_linear_param_grad_add_kernel.cc View File

@@ -271,4 +271,5 @@ PD_REGISTER_KERNEL(fused_linear_param_grad_add,
ALL_LAYOUT,
phi::fusion::FusedLinearParamGradAdd,
float,
phi::bfloat16,
phi::float16) {}

+ 93
- 27
paddle/phi/kernels/gpu/reduce.h View File

@@ -14,12 +14,13 @@

#pragma once

// CUDA and HIP use same api
// CUDA and HIP use ReduceGpuKernel API, XPU use ReduceKernel API.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU_KP)

#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/reduce_gpu_kernel.h"

namespace phi {

@@ -33,7 +34,7 @@ void Reduce(const KPDevice& dev_ctx,
const DenseTensor& x,
bool reduce_all,
const std::vector<int64_t>& dims,
bool keep_dim,
bool keep_dim, // unused
DataType out_dtype,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
@@ -44,33 +45,98 @@ void Reduce(const KPDevice& dev_ctx,
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
}
#ifndef PADDLE_WITH_XPU_KP
if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
out_dtype,
"ReduceKernel",
([&] {
using MPType = typename phi::dtype::MPTypeTrait<data_t>::Type;
funcs::ReduceKernel<data_t,
data_t,
ReduceOp,
TransformOp<data_t, MPType>,
IsMean>(dev_ctx,
tmp_tensor,
out,
TransformOp<data_t, MPType>(reduce_num),
reduce_dims);
}));

// CUDA and HIP use ReduceGpuKernel API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
constexpr bool is_identity_v =
std::is_same_v<TransformOp<T, T>, kps::IdentityFunctor<T, T>>;

if constexpr (is_identity_v) {
if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) {
if (x.dtype() == phi::DataType::BFLOAT16 &&
out_dtype == phi::DataType::FLOAT32) {
phi::funcs::ReduceGpuKernel<phi::bfloat16,
float,
ReduceOp,
TransformOp<phi::bfloat16, float>,
IsMean>(
dev_ctx,
x,
out,
TransformOp<phi::bfloat16, float>(reduce_num),
reduce_dims);
} else if (x.dtype() == phi::DataType::FLOAT16 &&
out_dtype == phi::DataType::FLOAT32) {
phi::funcs::ReduceGpuKernel<phi::float16,
float,
ReduceOp,
TransformOp<phi::float16, float>,
IsMean>(
dev_ctx,
x,
out,
TransformOp<phi::float16, float>(reduce_num),
reduce_dims);
} else {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype);
tmp_tensor.set_strides(x.strides());

PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
out_dtype,
"ReduceGpuKernel",
([&] {
using MPType = typename phi::dtype::MPTypeTrait<data_t>::Type;
phi::funcs::ReduceGpuKernel<data_t,
data_t,
ReduceOp,
TransformOp<data_t, MPType>,
IsMean>(
dev_ctx,
tmp_tensor,
out,
TransformOp<data_t, MPType>(reduce_num),
reduce_dims);
}));
}
} else {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
phi::funcs::
ReduceGpuKernel<T, T, ReduceOp, TransformOp<T, MPType>, IsMean>(
dev_ctx, x, out, TransformOp<T, MPType>(reduce_num), reduce_dims);
}
} else {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>, IsMean>(
dev_ctx, x, out, TransformOp<T, MPType>(reduce_num), reduce_dims);
if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x.dtype()) {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
out_dtype,
"ReduceKernel",
([&] {
using MPType = typename phi::dtype::MPTypeTrait<data_t>::Type;
funcs::ReduceKernel<data_t,
data_t,
ReduceOp,
TransformOp<data_t, MPType>,
IsMean>(dev_ctx,
tmp_tensor,
out,
TransformOp<data_t, MPType>(reduce_num),
reduce_dims);
}));
} else {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>, IsMean>(
dev_ctx, x, out, TransformOp<T, MPType>(reduce_num), reduce_dims);
}
}
// XPU use ReduceKernel API
#else
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>, IsMean>(


+ 6
- 2
paddle/phi/kernels/gpu/reduce_amin_amax_common.h View File

@@ -85,8 +85,12 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
funcs::BroadcastKernel<T>(
dev_ctx, equal_inputs, &equal_outputs, funcs::EqualFunctor<T>(), 0);
// 2. equal_count = reduceSum(equal_out)
phi::SumKernel<T, Context>(
dev_ctx, equal_out, reduce_dims, equal_out.dtype(), false, &equal_count);
phi::SumKernel<T, Context>(dev_ctx,
equal_out,
reduce_dims,
equal_out.dtype(),
keep_dim,
&equal_count);
// 3. dx = dout * 1
phi::MultiplyKernel<T, Context>(dev_ctx, new_dout, equal_out, &equal_out);



+ 21
- 23
paddle/phi/kernels/kps/reduce_kernel.cu View File

@@ -44,15 +44,28 @@ void ProdKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
auto out_dtype = x.dtype();

if (x.numel() == 0) {
// fill with 1.
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 1, out);
dev_ctx.template Alloc<T>(out);
if (out_dtype == DataType::INT64) {
FullKernel<int64_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(out->dims())),
1,
out_dtype, // not used
out);
} else {
FullKernel<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(out->dims())),
1,
out_dtype, // not used
out);
}
return;
}

reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::MulFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
}
@@ -212,10 +225,10 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}

if (x.numel() == 0) {
dev_ctx.template Alloc<T>(out);
if (out_dtype == DataType::INT64) {
@@ -235,24 +248,9 @@ void SumRawKernel(const Context& dev_ctx,
return;
}

if (x.dtype() == phi::DataType::BFLOAT16 &&
out_dtype == phi::DataType::FLOAT32) {
std::vector<int> reduce_dims = phi::funcs::details::GetReduceDim(
dims.GetData(), x.dims().size(), reduce_all);

phi::funcs::ReduceKernel<phi::bfloat16,
float,
kps::AddFunctor,
kps::IdentityFunctor<phi::bfloat16, float>>(
dev_ctx,
x,
out,
kps::IdentityFunctor<phi::bfloat16, float>(),
reduce_dims);
} else {
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
}
reduce_all = recompute_reduce_all(x, dims, reduce_all);
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
}
} // namespace phi



+ 2
- 1
paddle/phi/kernels/stride/reduce_grad_stride_kernel.cu View File

@@ -129,7 +129,8 @@ void ReduceSumGradStrideKernel(const Context& dev_ctx,
std::vector<int64_t> out_dims;
std::vector<int64_t> out_strides;

if (!FLAGS_use_stride_compute_kernel || !out_grad.dims().size() > 0) {
if ((!FLAGS_use_stride_compute_kernel) || !(out_grad.dims().size() > 0) ||
(out_grad.dtype() != x.dtype())) {
invalid = true;
}



+ 12
- 180
paddle/phi/kernels/stride/reduce_stride_kernel.cu View File

@@ -248,47 +248,12 @@ void ProdStrideKernel(const Context& dev_ctx,
"be called, something wrong has happened!"));
}

DenseTensor x_;
if (!FLAGS_use_stride_compute_kernel || (out->dims().size() > 0)) {
if (!x.meta().is_contiguous()) {
x_ = Tensor2Contiguous<Context>(dev_ctx, x);
} else {
x_ = x;
}
} else {
x_ = x;
}
if (x_.meta().is_contiguous() || (out->dims().size() > 0)) {
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
phi::ProdKernel<T, Context>(dev_ctx, x_, dims, keep_dim, reduce_all, out);
return;
}
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);

if (!FLAGS_use_stride_compute_kernel) {
PADDLE_THROW(
common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. "
"Kernel using DenseTensorIterator "
"be called, something wrong has happened!"));
}
phi::ProdKernel<T, Context>(dev_ctx, x, dims, keep_dim, reduce_all, out);

if (FLAGS_force_stride_compute_contig_out) {
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
}

if (x_.numel() == 0) {
// fill with 1.
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 1, out);
return;
}

T ident = T(1);
ReduceStrideImpl<T, Context, kps::MulFunctor>(
dev_ctx, x_, dims.GetData(), keep_dim, ident, out);
return;
}

@@ -442,92 +407,17 @@ void SumStrideKernel(const Context& dev_ctx,
DataType out_dtype,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = recompute_reduce_all(x, dims);
if (!FLAGS_use_stride_kernel) {
PADDLE_THROW(common::errors::Fatal(
"FLAGS_use_stride_kernel is closed. Strided kernel "
"be called, something wrong has happened!"));
}

DenseTensor x_;
if (!FLAGS_use_stride_compute_kernel || out->dims().size() > 0) {
if (!x.meta().is_contiguous()) {
x_ = Tensor2Contiguous<Context>(dev_ctx, x);
} else {
x_ = x;
}
} else {
x_ = x;
}

if (x_.meta().is_contiguous() || (out->dims().size() > 0)) {
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
phi::SumKernel<T, Context>(dev_ctx, x_, dims, out_dtype, keep_dim, out);
return;
}

if (!FLAGS_use_stride_compute_kernel) {
PADDLE_THROW(
common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. "
"Kernel using DenseTensorIterator "
"be called, something wrong has happened!"));
}

if (FLAGS_force_stride_compute_contig_out) {
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
}

if (out_dtype == DataType::UNDEFINED && out->dtype() != x_.dtype()) {
out_dtype = out->dtype();
}
if (x_.numel() == 0) {
dev_ctx.template Alloc<T>(out);
if (out_dtype == DataType::INT64) {
FullKernel<int64_t, Context>(
dev_ctx,
phi::IntArray(common::vectorize(out->dims())),
0,
out_dtype, // not used
out);
} else {
FullKernel<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(out->dims())),
0,
out_dtype, // not used
out);
}
return;
}
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);

if (x.dtype() == phi::DataType::BFLOAT16 &&
out_dtype == phi::DataType::FLOAT32) {
phi::dtype::bfloat16 ident = static_cast<phi::dtype::bfloat16>(0);
ReduceStrideImpl<phi::dtype::bfloat16, Context, kps::AddFunctor>(
dev_ctx, x_, dims.GetData(), keep_dim, ident, out);
*out = phi::Cast<phi::dtype::bfloat16>(dev_ctx, x_, out_dtype);
} else if (out_dtype != phi::DataType::UNDEFINED && out_dtype != x_.dtype()) {
auto tmp_tensor = phi::Cast<T>(dev_ctx, x_, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_4_TYPES(
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
out_dtype,
"ReduceStrideImpl",
([&] {
data_t ident = static_cast<data_t>(0);
ReduceStrideImpl<data_t, Context, kps::AddFunctor>(
dev_ctx, tmp_tensor, dims.GetData(), keep_dim, ident, out);
}));
} else {
T ident = static_cast<T>(0);
ReduceStrideImpl<T, Context, kps::AddFunctor>(
dev_ctx, x_, dims.GetData(), keep_dim, ident, out);
}
phi::SumKernel<T, Context>(dev_ctx, x, dims, out_dtype, keep_dim, out);
return;
}

@@ -537,75 +427,17 @@ void MeanStrideKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = recompute_reduce_all(x, dims);
if (!FLAGS_use_stride_kernel) {
PADDLE_THROW(common::errors::Fatal(
"FLAGS_use_stride_kernel is closed. Strided kernel "
"be called, something wrong has happened!"));
}

DenseTensor x_;
if (!FLAGS_use_stride_compute_kernel || (out->dims().size() > 0)) {
if (!x.meta().is_contiguous()) {
x_ = Tensor2Contiguous<Context>(dev_ctx, x);
} else {
x_ = x;
}
} else {
x_ = x;
}
if (x_.meta().is_contiguous() || (out->dims().size() > 0)) {
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
phi::MeanKernel<T, Context>(dev_ctx, x_, dims, keep_dim, out);
return;
}
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);

if (!FLAGS_use_stride_compute_kernel) {
PADDLE_THROW(
common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. "
"Kernel using DenseTensorIterator "
"be called, something wrong has happened!"));
}

if (FLAGS_force_stride_compute_contig_out) {
auto meta = out->meta();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
}

if (x_.numel() == 0) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), NAN, out);
return;
}

if (std::is_same<T, int>::value || std::is_same<T, int64_t>::value ||
std::is_same<T, bool>::value) {
using Type =
typename std::conditional<std::is_same<T, int>::value ||
std::is_same<T, int64_t>::value ||
std::is_same<T, bool>::value,
float,
T>::type;
DenseTensor x_float =
phi::Cast<T, Context>(dev_ctx, x_, phi::DataType::FLOAT32);
DenseTensor* out_float = new DenseTensor();
out_float->Resize(out->dims());
MeanRawKernel<Type>(
dev_ctx, x_float, dims, keep_dim, reduce_all, out_float);

Type ident = static_cast<Type>(0);
ReduceStrideImpl<Type, Context, kps::AddFunctor, true>(
dev_ctx, x_float, dims.GetData(), keep_dim, ident, out_float);

phi::CastKernel<Type, Context>(dev_ctx, *out_float, x_.dtype(), out);
} else {
T ident = static_cast<T>(0);
ReduceStrideImpl<T, Context, kps::AddFunctor, true>(
dev_ctx, x_, dims.GetData(), keep_dim, ident, out);
}
phi::MeanKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
return;
}



+ 9
- 2
paddle/phi/kernels/xpu/concat_kernel.cc View File

@@ -27,7 +27,13 @@ void ConcatKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis_scalar,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
// handle complex64 by treating data as raw 64-bit units
using Complex64 = phi::complex64;
using DefaultXPUType = typename XPUTypeTrait<T>::Type;
using XPUType = typename std::conditional<std::is_same<T, Complex64>::value,
int64_t,
DefaultXPUType>::type;

int64_t axis = axis_scalar.to<int64_t>();
PADDLE_ENFORCE_NE(
x[0],
@@ -134,4 +140,5 @@ PD_REGISTER_KERNEL(concat,
int8_t,
int16_t,
int32_t,
int64_t) {}
int64_t,
phi::complex64) {}

+ 1
- 0
paddle/phi/kernels/xpu/elementwise_multiply_kernel.cc View File

@@ -85,6 +85,7 @@ PD_REGISTER_KERNEL(multiply,
XPU,
ALL_LAYOUT,
phi::MultiplyKernel,
bool,
phi::float16,
phi::bfloat16,
#ifdef PADDLE_WITH_XPU_FFT


+ 1
- 0
paddle/phi/kernels/xpu/expand_grad_kernel.cc View File

@@ -57,5 +57,6 @@ PD_REGISTER_KERNEL(expand_grad,
ALL_LAYOUT,
phi::ExpandGradKernel,
float,
int64_t,
phi::bfloat16,
phi::float16) {}

+ 159
- 0
paddle/phi/kernels/xpu/masked_fill_grad_kernel.cc View File

@@ -0,0 +1,159 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/masked_fill_grad_kernel.h"

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_grad_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/common_infer_shape_functions.h"

namespace phi {

template <typename T, typename Context>
void MaskedFillGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& value,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* v_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;

if (out_grad.numel() == 0 || mask.numel() == 0) {
if (x_grad) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad);
}
if (v_grad) {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(v_grad->dims())), 0, v_grad);
}
return;
}

auto x_dims = x.dims();
auto mask_dims = mask.dims();

auto expanded_size =
common::vectorize(funcs::BroadcastTwoDims(x_dims, mask_dims, -1));
auto expanded_dims = common::make_ddim(expanded_size);

DenseTensor mask_expand;
DenseTensor x_grad_expand;
DenseTensor value_grad_expand;

bool expand_x = false;
bool expand_value = false;

if (mask.dims() != expanded_dims) {
ExpandKernel<bool, Context>(
dev_ctx, mask, IntArray(expanded_size), &mask_expand);
} else {
mask_expand = mask;
}

DenseTensor* x_grad_tmp = nullptr;
if (x_grad) {
if (x_grad->dims() != expanded_dims) {
x_grad_expand = Empty<T, Context>(dev_ctx, IntArray(expanded_size));
x_grad_tmp = &x_grad_expand;
expand_x = true;
} else {
x_grad_tmp = x_grad;
}
}

DenseTensor* value_grad_tmp = nullptr;
if (v_grad) {
if (v_grad->dims() != expanded_dims) {
value_grad_expand = Empty<T, Context>(dev_ctx, IntArray(expanded_size));
value_grad_tmp = &value_grad_expand;
expand_value = true;
} else {
value_grad_tmp = v_grad;
}
}

auto* cond_data = mask_expand.data<bool>();
auto* dout_data = out_grad.data<T>();
const int64_t len = mask_expand.numel();
if (len <= 0) {
return;
}

if (x_grad_tmp) {
dev_ctx.template Alloc<T>(x_grad_tmp);
}
if (value_grad_tmp) {
dev_ctx.template Alloc<T>(value_grad_tmp);
}

DenseTensor dx_dummy;
DenseTensor dy_dummy;

T* dx_ptr = nullptr;
T* dy_ptr = nullptr;

if (x_grad_tmp) {
dx_ptr = x_grad_tmp->data<T>();
} else {
dx_dummy = Empty<T, Context>(dev_ctx, IntArray(expanded_size));
dx_ptr = dx_dummy.data<T>();
}

if (value_grad_tmp) {
dy_ptr = value_grad_tmp->data<T>();
} else {
dy_dummy = Empty<T, Context>(dev_ctx, IntArray(expanded_size));
dy_ptr = dy_dummy.data<T>();
}

int r = xpu::masked_fill_grad<XPUType>(
dev_ctx.x_context(),
cond_data,
reinterpret_cast<const XPUType*>(dout_data),
reinterpret_cast<XPUType*>(dx_ptr),
reinterpret_cast<XPUType*>(dy_ptr),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "masked_fill_grad");

if (x_grad && expand_x) {
ExpandGradKernel<T, Context>(
dev_ctx, x, x_grad_expand, IntArray(expanded_size), x_grad);
}

if (v_grad) {
if (expand_value) {
ExpandGradKernel<T, Context>(
dev_ctx, value, value_grad_expand, IntArray(expanded_size), v_grad);
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(masked_fill_grad,
XPU,
ALL_LAYOUT,
phi::MaskedFillGradKernel,
float,
int64_t,
phi::float16,
phi::bfloat16) {
kernel->InputAt(1).SetDataType(phi::DataType::BOOL);
}

+ 0
- 1
python/paddle/_paddle_docs.py View File

@@ -901,7 +901,6 @@ add_doc_and_signature(
>>> z = paddle.matmul(x, y)
>>> print(z.shape)
paddle.Size([10, 3, 5, 5])

""",
""" def matmul(
x: Tensor,


+ 1
- 1
python/paddle/tensor/attribute.py View File

@@ -91,7 +91,7 @@ def shape(input: Tensor) -> Tensor:
Tensor: The shape of the input variable.

Examples:
.. code-block:: python
.. code-block:: pycon

>>> import numpy as np
>>> import paddle


+ 5
- 9
python/paddle/tensor/math.py View File

@@ -6082,21 +6082,17 @@ def sgn(x: Tensor, name: str | None = None) -> Tensor:
Tensor: A sign Tensor for real input, or normalized Tensor for complex input, shape and data type are same as input.

Examples:
.. code-block:: python
.. code-block:: pycon

>>> import paddle

>>> x = paddle.to_tensor([[3 + 4j, 7 - 24j, 0, 1 + 2j], [6 + 8j, 3, 0, -2]])
>>> paddle.sgn(x)
Tensor(shape=[2, 4], dtype=complex64, place=Place(cpu), stop_gradient=True,
[[ (0.6000000238418579+0.800000011920929j),
(0.2800000011920929-0.9599999785423279j),
0j ,
(0.4472135901451111+0.8944271802902222j)],
[ (0.6000000238418579+0.800000011920929j),
(1+0j) ,
0j ,
(-1+0j) ]])
[[ (0.60000002+0.80000001j), (0.28000000-0.95999998j),
(0.00000000+0.00000000j), (0.44721359+0.89442718j)],
[ (0.60000002+0.80000001j), (1.00000000+0.00000000j),
(0.00000000+0.00000000j), (-1.00000000+0.00000000j)]])

"""
if x.dtype not in [


+ 28
- 0
test/legacy_test/test_flash_attention.py View File

@@ -492,6 +492,20 @@ class TestFlashAttentionAPITest5(TestFlashAttentionAPI):
self.use_sdp_kernel = False


class TestFlashAttentionAPITest6(TestFlashAttentionAPI):
def setUp(self):
self.place = get_device_place()
self.shape = (0, 256, 8, 16)
self.dtype = 'float16'
self.dropout = 0.0
self.causal = True
self.return_softmax = False
self.use_sdp_kernel = False

def test_unpadded(self):
pass


class TestMathAttentionAPITest(TestFlashAttentionAPI):
def setUp(self):
self.place = get_device_place()
@@ -566,6 +580,20 @@ class TestSDPAttentionWithMaskAPITest3(TestFlashAttentionWithMaskAPI):
self.causal = False


@unittest.skipIf(
is_sm_supported,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
"and device's compute capability must be 7.5 or 8.x",
)
class TestSDPAttentionWithMaskAPITest4(TestFlashAttentionWithMaskAPI):
def setUp(self):
self.place = get_device_place()
self.shape = (0, 1024, 16, 128)
self.dtype = 'float32'
self.dropout = 0.0
self.causal = True


@unittest.skipIf(
not is_flashattn_supported(),
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"


+ 52
- 0
test/legacy_test/test_reduce_op.py View File

@@ -2770,6 +2770,58 @@ class TestAnyCompatibility(unittest.TestCase):
)


# Dimension exceeds int32 range.
class TestSumOpIndexInt32OverflowCase0(unittest.TestCase):
def setUp(self):
self.shape = [2147483678]
self.axis = 0
self.input_dtype = 'float32'
self.test_dtypes = [np.float32]

def test_dygraph(self):
with dygraph_guard():
x_paddle = paddle.ones(shape=self.shape, dtype=self.input_dtype)
for dtype_input in self.test_dtypes:
numpy_result = np.sum(
x_paddle.numpy(),
axis=self.axis,
dtype=np.dtype(dtype_input),
keepdims=False,
)

# paddle test case
paddle_result0 = paddle.sum(x_paddle, self.axis, dtype_input)
np.testing.assert_allclose(
paddle_result0, numpy_result, rtol=1e-05
)


# Index exceeds int32 range.
class TestSumOpIndexInt32OverflowCase1(unittest.TestCase):
def setUp(self):
self.shape = [1073741830]
self.axis = 0
self.input_dtype = 'float32'
self.test_dtypes = [np.float32]

def test_dygraph(self):
with dygraph_guard():
x_paddle = paddle.ones(shape=self.shape, dtype=self.input_dtype)
for dtype_input in self.test_dtypes:
numpy_result = np.sum(
x_paddle.numpy(),
axis=self.axis,
dtype=np.dtype(dtype_input),
keepdims=False,
)

# paddle test case
paddle_result0 = paddle.sum(x_paddle, self.axis, dtype_input)
np.testing.assert_allclose(
paddle_result0, numpy_result, rtol=1e-05
)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

+ 5
- 8
test/legacy_test/test_tensor.py View File

@@ -549,7 +549,7 @@ class TestTensorDataSetter(unittest.TestCase):

def test_new_shape_same_dtype_same_place(self):
x: paddle.Tensor = paddle.tensor([[1, 2], [3, 4]], dtype="float32")
y = paddle.tensor([2], dtype="float32")
y = paddle.rand([3, 4, 5], dtype="float32")
x.requires_grad = True

loss = x.sum()
@@ -569,8 +569,7 @@ class TestTensorDataSetter(unittest.TestCase):
True,
"x's requires_grad should be True after data setting.",
)

with self.assertRaises(RuntimeError):
with self.assertRaises((ValueError, RuntimeError)):
loss = x.sum()
loss.backward()

@@ -595,8 +594,7 @@ class TestTensorDataSetter(unittest.TestCase):
np.testing.assert_equal(x.grad.numpy(), x_grad_expected.numpy())
assert x.grad.dtype == x_grad_expected.dtype

x.clear_gradient(set_to_zero=True)
# dtype changed, clear_gradient must set grad to None
x.clear_gradient(False)
assert x.grad is None
z = x.sum()
z.backward()
@@ -627,8 +625,7 @@ class TestTensorDataSetter(unittest.TestCase):
assert x.grad.dtype == x_grad_expected.dtype
assert x.grad.place == x_grad_expected.place

x.clear_gradient(set_to_zero=True)
# place changed, clear_gradient must set grad to None
x.clear_gradient(False)
assert x.grad is None
loss = x.sum()
loss.backward()
@@ -638,4 +635,4 @@ class TestTensorDataSetter(unittest.TestCase):


if __name__ == '__main__':
unittest.main(argv=["", "TestTensorDataSetter"])
unittest.main()

+ 279
- 0
test/xpu/test_masked_fill_grad_op_xpu.py View File

@@ -0,0 +1,279 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test import convert_float_to_uint16
from op_test_xpu import XPUOpTest

import paddle

paddle.enable_static()


def np_masked_fill(x, mask, value):
x_b, mask_b = np.broadcast_arrays(x, mask)
v_b = np.broadcast_to(value, x_b.shape)
out = np.where(mask_b, v_b, x_b)
return out


def _np_reduce_to_shape(x, target_shape, out_dtype):
if x.shape == tuple(target_shape):
return x.astype(out_dtype, copy=False)

target_shape = tuple(target_shape)
target_shape_padded = (1,) * (x.ndim - len(target_shape)) + target_shape
reduce_axes = [
axis
for axis, (t_dim, x_dim) in enumerate(zip(target_shape_padded, x.shape))
if t_dim == 1 and x_dim != 1
]
if reduce_axes:
x = x.sum(axis=tuple(reduce_axes), keepdims=True)
x = x.reshape(target_shape)
return x.astype(out_dtype, copy=False)


def np_masked_fill_grad(x, mask, value, out_grad):
out_shape = np.broadcast(np.empty(x.shape), np.empty(mask.shape)).shape
x_b, mask_b = np.broadcast_arrays(x, mask)
out_grad_b = np.broadcast_to(out_grad, out_shape)
dx_full = np.where(mask_b, np.zeros_like(out_grad_b), out_grad_b)
dv_full = np.where(mask_b, out_grad_b, np.zeros_like(out_grad_b))
dx = _np_reduce_to_shape(dx_full, x.shape, x.dtype)
dv = _np_reduce_to_shape(dv_full, value.shape, value.dtype)
return dx, dv


class XPUTestMaskedFillGradOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'masked_fill'

class TestMaskedFillGradBase(XPUOpTest):
def setUp(self):
self.init()
self.init_config()
self.init_data()

self.inputs = {
'x': self.x,
'mask': self.mask,
'value': self.value,
}
self.attrs = {}
self.outputs = {'out': self.out}

def init_config(self):
self.op_type = "masked_fill"
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.__class__.no_need_check_grad = False
self.is_scalar_value = getattr(self, 'is_scalar_value', False)
self.mask_shape = getattr(self, 'mask_shape', self.x_shape)
self.value_shape = getattr(self, 'value_shape', self.x_shape)

def init_data(self):
self.mask = np.random.randint(0, 2, size=self.mask_shape).astype(
'bool'
)

if self.dtype == np.uint16:
x_fp32 = np.random.randn(*self.x_shape).astype('float32')
if x_fp32.size == 0:
self.x = x_fp32.astype(np.uint16)
else:
self.x = convert_float_to_uint16(x_fp32)

if self.is_scalar_value:
scalar_fp32 = float(np.random.randn())
scalar_arr = np.array([scalar_fp32], dtype='float32')
self.value = convert_float_to_uint16(scalar_arr)
v_np = scalar_fp32
else:
v_fp32 = np.random.randn(*self.value_shape).astype(
'float32'
)
if v_fp32.size == 0:
self.value = v_fp32.astype(np.uint16)
else:
self.value = convert_float_to_uint16(v_fp32)
v_np = v_fp32

self.out = np_masked_fill(
x_fp32,
self.mask,
v_np,
).astype(x_fp32.dtype)
else:
self.x = np.random.randn(*self.x_shape).astype(self.dtype)

if self.is_scalar_value:
scalar = float(np.random.randn())
self.value = np.array([scalar]).astype(self.dtype)
v_np = scalar
else:
v_np = np.random.randn(*self.value_shape).astype(self.dtype)
self.value = v_np

self.out = np_masked_fill(self.x, self.mask, v_np).astype(
self.dtype
)

def test_check_output(self):
self.check_output_with_place(self.place)

def test_check_grad(self):
# Only floating types support numeric grad check in OpTest.
if self.dtype not in [np.float32, np.float16, np.uint16]:
self.__class__.no_need_check_grad = True
return
self.check_grad_with_place(self.place, ['x'], 'out')

def init(self):
self.x_shape = (16, 32)
self.mask_shape = self.x_shape
self.value_shape = self.x_shape

# ------------------ Tensor Value ------------------
class TestMaskedFillGradTensorValue1D(TestMaskedFillGradBase):
def init(self):
self.x_shape = (64,)
self.mask_shape = self.x_shape
self.value_shape = self.x_shape

class TestMaskedFillGradTensorValue4D(TestMaskedFillGradBase):
def init(self):
self.x_shape = (2, 3, 4, 5)
self.mask_shape = self.x_shape
self.value_shape = self.x_shape

class TestMaskedFillGradTensorValueBroadcastMask(TestMaskedFillGradBase):
def init(self):
self.x_shape = (10, 4, 5)
self.mask_shape = (1, 4, 1)
self.value_shape = (10, 4, 5)

class TestMaskedFillGradTensorValueBroadcastValue(TestMaskedFillGradBase):
def init(self):
self.x_shape = (2, 4, 5)
self.mask_shape = (2, 4, 5)
self.value_shape = (1, 5)

# ------------------ Scalar Value ------------------
class TestMaskedFillGradScalarValue4D(TestMaskedFillGradBase):
def init(self):
self.x_shape = (2, 3, 4, 5)
self.mask_shape = self.x_shape

def init_config(self):
super().init_config()
self.is_scalar_value = True

class TestMaskedFillGradScalarValueBroadcastMask(TestMaskedFillGradBase):
def init(self):
self.x_shape = (10, 4, 5)
self.mask_shape = (1, 4, 1)

def init_config(self):
super().init_config()
self.is_scalar_value = True

class TestMaskedFillGradScalarValueBroadcastMask2(TestMaskedFillGradBase):
def init(self):
self.x_shape = (10, 1)
self.mask_shape = (1, 5)

def init_config(self):
super().init_config()
self.is_scalar_value = True

# ------------------ Empty Tensor ------------------
class TestMaskedFillGradEmptyX(TestMaskedFillGradBase):
"""Test empty x tensor (numel == 0)"""

def init(self):
self.x_shape = (0, 5)
self.mask_shape = (0, 5)
self.value_shape = (0, 5)

def test_check_grad(self):
# Skip grad check for empty tensor
pass

class TestMaskedFillGradEmptyMask(TestMaskedFillGradBase):
"""Test empty mask tensor (numel == 0)"""

def init(self):
self.x_shape = (5, 0)
self.mask_shape = (5, 0)
self.value_shape = (5, 0)

def test_check_grad(self):
# Skip grad check for empty tensor
pass

# ------------------ Both Mask and Value Broadcast ------------------
class TestMaskedFillGradBothBroadcast(TestMaskedFillGradBase):
"""Test both mask and value need broadcast"""

def init(self):
self.x_shape = (4, 5, 6)
self.mask_shape = (1, 5, 1)
self.value_shape = (4, 1, 6)

# ------------------ Large Tensor ------------------
class TestMaskedFillGradLargeTensor(TestMaskedFillGradBase):
"""Test with larger tensor to ensure kernel handles large data"""

def init(self):
self.x_shape = (128, 256)
self.mask_shape = self.x_shape
self.value_shape = self.x_shape

# ------------------ Broadcast X ------------------
class TestMaskedFillGradBroadcastX(TestMaskedFillGradBase):
def init(self):
self.x_shape = (1, 5)
self.mask_shape = (10, 5)
self.value_shape = self.x_shape

# ------------------ Broadcast X and Value ------------------
class TestMaskedFillGradBroadcastXAndValue(TestMaskedFillGradBase):
def init(self):
self.x_shape = (1, 5)
self.mask_shape = (10, 5)
self.value_shape = (1, 1)

# ------------------ Complex Broadcast ------------------
class TestMaskedFillGradComplexBroadcast(TestMaskedFillGradBase):
def init(self):
self.x_shape = (10, 1, 5)
self.mask_shape = (1, 4, 5)
self.value_shape = (10, 1, 1)


support_types = get_xpu_op_support_types('masked_fill')
for stype in support_types:
create_test_class(globals(), XPUTestMaskedFillGradOp, stype)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save
Baidu
map