2 Commits

Author SHA1 Message Date
  Matthew Douglas 63f538a4e4
Remove deprecated code (#1798) 1 month ago
  Matthew Douglas a2cb49bc84
Tests: Run CPU tests against PyTorch 2.9 (#1797) 1 month ago
14 changed files with 1 additions and 468 deletions
Split View
  1. +1
    -1
      .github/workflows/tests.yml
  2. +0
    -1
      bitsandbytes/autograd/__init__.py
  3. +0
    -59
      bitsandbytes/autograd/_functions.py
  4. +0
    -96
      bitsandbytes/functional.py
  5. +0
    -45
      csrc/kernels.cu
  6. +0
    -7
      csrc/kernels.cuh
  7. +0
    -43
      csrc/kernels.hip
  8. +0
    -7
      csrc/kernels_hip.cuh
  9. +0
    -67
      csrc/ops.cu
  10. +0
    -9
      csrc/ops.cuh
  11. +0
    -81
      csrc/ops.hip
  12. +0
    -9
      csrc/ops_hip.cuh
  13. +0
    -4
      csrc/pythonInterface.cpp
  14. +0
    -39
      tests/test_functional.py

+ 1
- 1
.github/workflows/tests.yml View File

@@ -103,7 +103,7 @@ jobs:
matrix:
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15]
# Test with the oldest supported torch version, the newest two stable/RC.
torch_version: ["2.3.1", "2.7.1", "2.8.0"]
torch_version: ["2.3.1", "2.8.0", "2.9.0"]
include:
- os: ubuntu-22.04
arch: x86_64


+ 0
- 1
bitsandbytes/autograd/__init__.py View File

@@ -1 +0,0 @@
from ._functions import get_inverse_transform_indices, undo_layout

+ 0
- 59
bitsandbytes/autograd/_functions.py View File

@@ -1,4 +1,3 @@
from collections.abc import Callable
from dataclasses import dataclass
from math import prod
from typing import Optional
@@ -6,7 +5,6 @@ import warnings
from warnings import warn

import torch
from typing_extensions import deprecated

import bitsandbytes.functional as F

@@ -50,66 +48,9 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64)


@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: tuple[int, int],
):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation

:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
:returns: indices
"""
d1, d2 = tile_size
assert 0 < d1 * d2 < 2**64
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
# encode each position in tile as a tuple of <= 8 unique bytes
permuted_tile_indices = torch.zeros_like(tile_indices)
for i in range(8):
# select i-th byte, apply transformation and trace where each index ended up
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
permuted_tile_i = transform_tile(sample_tile_i)
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
permuted_tile_indices += ith_permuted_indices * (256**i)
if d1 * d2 < 256**i:
break # if all indices fit in i bytes, stop early
return permuted_tile_indices


_is_compiling = torch.compiler.is_compiling


@deprecated(
"This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout

:param permuted_tensor: torch tensor in a permuted layout
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
outputs[tile_indices.flatten()] = tensor
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
return outputs.reshape(rows, cols).contiguous()


@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None # TODO: remove


+ 0
- 96
bitsandbytes/functional.py View File

@@ -1795,102 +1795,6 @@ def int8_mm_dequant(
return result


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_colrow_absmax(
A: torch.Tensor,
row_stats: Optional[torch.Tensor] = None,
col_stats: Optional[torch.Tensor] = None,
nnz_block_ptr: Optional[torch.Tensor] = None,
threshold=0.0,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.

The row-wise and column-wise absmax values are determined.

For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).

<Tip>
This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
The column-wise quantization scales are not typically needed in inference scenarios.
</Tip>

Args:
A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.

Returns:
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
- `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
"""
assert A.is_floating_point()

outlier_mask = None

if row_stats is None or col_stats is None:
absA = A.abs().view(-1, A.shape[-1])

if threshold > 0.0:
# Filter outliers from stats when enabled
outlier_mask = absA >= threshold
absA.masked_fill_(outlier_mask, 0.0)

if row_stats is None:
# shape [rows]; unsqueeze(-1) gives [rows,1]
# We have a CUDA kernel for row max, but not yet for cols.
row_stats = get_row_absmax(A, threshold)

if col_stats is None:
# shape [cols]; unsqueeze(0) gives [1,cols]
col_stats = absA.amax(dim=0, keepdim=False).float()

return row_stats, col_stats, outlier_mask


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_row_absmax(A: torch.Tensor, threshold=0.0):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.

For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).

Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.

Returns:
`torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
"""

assert A.dtype == torch.float16

rows = prod(A.shape[:-1])
cols = A.shape[-1]

row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device)

is_on_gpu([A])

with _cuda_device_of(A):
lib.cget_row_stats(
get_ptr(A),
get_ptr(row_stats),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
_get_tensor_stream(A),
)

return row_stats


class COOSparseTensor:
def __init__(
self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor


+ 0
- 45
csrc/kernels.cu View File

@@ -1825,51 +1825,6 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
}
}

template <typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__
void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols) {
using BlockReduceT = cub::BlockReduce<float, THREADS>;

// One block per row.
// Threads load column values in a striped arrangement.
// e.g. t0 reads row[0], row[0+nthreads], ..
// and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.

__shared__ typename BlockReduceT::TempStorage temp_storage;

const int row_id = blockIdx.x;
const T* __restrict__ row_data = A + (row_id * cols);

// Threads will read the row values in a striped access pattern and find a local absmax.
float row_local_absmax = -FLT_MIN;
for (int i = threadIdx.x; i < cols; i += THREADS) {
const float absval = fabsf(row_data[i]);

// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
}
}

// Reduce thread-local absmax across the block.
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols);
if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
rowStats[row_id] = row_absmax;
}
}

template __global__ void
kgetRowStats<half, 1024, 0>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template __global__ void
kgetRowStats<half, 1024, 1>(half* __restrict__ A, float* rowStats, float threshold, int rows, int cols);

template __global__ void kInt8VectorQuant<half, 1024, 0>(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols
);


+ 0
- 7
csrc/kernels.cuh View File

@@ -109,16 +109,9 @@ __global__ void kdequant_mm_int32_fp16(
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);

template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);

template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>


+ 0
- 43
csrc/kernels.hip View File

@@ -1946,49 +1946,6 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
}
}

template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
using BlockReduceT = hipcub::BlockReduce<float, THREADS>;

// One block per row.
// Threads load column values in a striped arrangement.
// e.g. t0 reads row[0], row[0+nthreads], ..
// and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.

__shared__ typename BlockReduceT::TempStorage temp_storage;

const int row_id = blockIdx.x;
const T* __restrict__ row_data = A + (row_id * cols);

// Threads will read the row values in a striped access pattern and find a local absmax.
float row_local_absmax = -FLT_MIN;
for (int i = threadIdx.x; i < cols; i += THREADS) {
const float absval = fabsf(row_data[i]);

// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
}
}

// Reduce thread-local absmax across the block.
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols);
if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
rowStats[row_id] = row_absmax;
}
}

template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);

template __global__ void kInt8VectorQuant<half, 1024, 0>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 1>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);



+ 0
- 7
csrc/kernels_hip.cuh View File

@@ -111,16 +111,9 @@ __global__ void kdequant_mm_int32_fp16(
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);

template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);

template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);

template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>


+ 0
- 67
csrc/ops.cu View File

@@ -292,61 +292,6 @@ void strided_gemmex(

int roundoff(int v, int d) { return (v + d - 1) / d * d; }

template <int ORDER> cublasLtOrder_t get_order() {
switch (ORDER) {
case ROW:
return CUBLASLT_ORDER_ROW;
break;
case COL:
return CUBLASLT_ORDER_COL;
break;
case COL32:
return CUBLASLT_ORDER_COL32;
break;
case COL_TURING:
return CUBLASLT_ORDER_COL4_4R2_8C;
break;
case COL_AMPERE:
return CUBLASLT_ORDER_COL32_2R_4R4;
break;
default:
break;
}

return CUBLASLT_ORDER_ROW;
}

template cublasLtOrder_t get_order<ROW>();
template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();

template <int ORDER> int get_leading_dim(int dim1, int dim2) {
switch (ORDER) {
case ROW:
return dim2;
break;
case COL:
return dim1;
break;
case COL32:
// 32*row tiles
return dim1 * 32;
break;
case COL_TURING:
return 32 * roundoff(dim1, 8);
break;
case COL_AMPERE:
// 32*32 tiles
return 32 * roundoff(dim1, 32);
break;
default:
return 0;
break;
}
}

template <int DTYPE_OUT, int SCALE_ROWS>
int igemmlt(
cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
@@ -449,14 +394,6 @@ void int8VectorQuant(
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
else
kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

void spmm_coo(
cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
@@ -730,7 +667,3 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);

template void percentileClipping(float* g, float* gnorm_vec, int step, const int n);
template void percentileClipping(half* g, float* gnorm_vec, int step, const int n);

template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2);

+ 0
- 9
csrc/ops.cuh View File

@@ -69,14 +69,6 @@ typedef enum Optimizer_t {
ADEMAMIX = 6
} Optimizer_t;

typedef enum Transform_t {
ROW = 0,
COL = 1,
COL32 = 2,
COL_TURING = 3,
COL_AMPERE = 4,
} Transform_t;

typedef enum DataType_t {
General8bit = 0,
FP4 = 1,
@@ -177,7 +169,6 @@ void cutlass_igemm(
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
);
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
);


+ 0
- 81
csrc/ops.hip View File

@@ -326,75 +326,6 @@ int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}

#ifdef NO_HIPBLASLT
#else
template<int ORDER> hipblasLtOrder_t get_order()
{
switch(ORDER)
{
case ROW:
return HIPBLASLT_ORDER_ROW;
break;
case COL:
return HIPBLASLT_ORDER_COL;
break;
case COL32:
//return HIPBLASLT_ORDER_COL32;
return HIPBLASLT_ORDER_COL;
break;
case COL_TURING:
//return HIPBLASLT_ORDER_COL4_4R2_8C;
return HIPBLASLT_ORDER_COL;
break;
case COL_AMPERE:
//return HIPBLASLT_ORDER_COL32_2R_4R4;
return HIPBLASLT_ORDER_COL;
break;
default:
break;
}

return HIPBLASLT_ORDER_ROW;
}

template hipblasLtOrder_t get_order<ROW>();
template hipblasLtOrder_t get_order<COL>();
template hipblasLtOrder_t get_order<COL32>();
//template hipblasLtOrder_t get_order<COL_TURING>();
//template hipblasLtOrder_t get_order<COL_AMPERE>();
#endif

template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
switch(ORDER)
{
case ROW:
return dim2;
break;
case COL:
return dim1;
break;
default:
return dim1;
break;
/*case COL32:
// 32*row tiles
return dim1*32;
break;
case COL_TURING:
return 32*roundoff(dim1, 8);
break;
case COL_AMPERE:
// 32*32 tiles
return 32*roundoff(dim1, 32);
break;
default:
return 0;
break;
*/
}
}

static std::string hipError_to_string(const hipError_t ret)
{
switch(ret)
@@ -603,14 +534,6 @@ void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float
CUDA_CHECK_RETURN(hipPeekAtLastError());
}

void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) {
if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
else
kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}

void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{

@@ -835,7 +758,3 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);

template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);

template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2);

+ 0
- 9
csrc/ops_hip.cuh View File

@@ -71,14 +71,6 @@ typedef enum Optimizer_t {
ADEMAMIX = 6,
} Optimizer_t;

typedef enum Transform_t {
ROW = 0,
COL = 1,
COL32 = 2,
COL_TURING = 3,
COL_AMPERE = 4,
} Transform_t;

typedef enum DataType_t {
General8bit = 0,
FP4 = 1,
@@ -179,7 +171,6 @@ void cutlass_igemm(
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream
);
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, hipStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream
);


+ 0
- 4
csrc/pythonInterface.cpp View File

@@ -641,10 +641,6 @@ void cdequant_mm_int32_fp16(
dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream);
}

void cget_row_stats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
getRowStats(A, rowStats, threshold, rows, cols, stream);
}

void cint8_vector_quant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
) {


+ 0
- 39
tests/test_functional.py View File

@@ -704,45 +704,6 @@ class TestLLMInt8Functional:
n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))

@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
@pytest.mark.deprecated
def test_colrow_absmax(self, dim1, dim2, dims, threshold):
for i in range(k):
A = torch.randn(dim1, dim2, device="cuda").half()

assert dims == 2

row_stats1, _ = torch.abs(A.float()).max(1)
col_stats1, _ = torch.abs(A.float()).max(0)

if threshold > 0.0:
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= threshold] = 0.0
row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)

row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)

nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten()
nnz_block_ptr1 = torch.zeros(
nnz_rows1_counts.shape[0] + 1,
dtype=nnz_rows1_counts.dtype,
device=nnz_rows1_counts.device,
)
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)

torch.testing.assert_close(col_stats1_trunc, col_stats2)
torch.testing.assert_close(row_stats1_trunc, row_stats2)
# torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2)
else:
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
assert nnz_block_ptr2 is None
torch.testing.assert_close(col_stats1, col_stats2)
torch.testing.assert_close(row_stats1, row_stats2)

@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
@pytest.mark.deprecated


Loading…
Cancel
Save
Baidu
map