2 Commits

Author SHA1 Message Date
  Matthew Douglas 4d19869189
CUDA/ROCm: Remove dead code (#1827) 1 week ago
  jiqing-feng 3c71007afc
Hf kernel (#1814) 1 week ago
12 changed files with 48 additions and 1109 deletions
Split View
  1. +47
    -32
      bitsandbytes/backends/cpu/ops.py
  2. +0
    -497
      csrc/kernels.cu
  3. +0
    -7
      csrc/kernels.cuh
  4. +0
    -475
      csrc/kernels.hip
  5. +0
    -7
      csrc/kernels_hip.cuh
  6. +0
    -27
      csrc/ops.cu
  7. +0
    -7
      csrc/ops.cuh
  8. +0
    -23
      csrc/ops.hip
  9. +0
    -7
      csrc/ops_hip.cuh
  10. +0
    -25
      csrc/pythonInterface.cpp
  11. +1
    -1
      pyproject.toml
  12. +0
    -1
      tests/test_ops.py

+ 47
- 32
bitsandbytes/backends/cpu/ops.py View File

@@ -219,6 +219,17 @@ if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
return out

if has_avx512bf16():
gemm_4bit_forward_kernel = None
try:
from kernels import get_kernel

gemm_4bit_forward_kernel = get_kernel("kernels-community/quantization_bitsandbytes").gemm_4bit_forward
except Exception as exc: # pragma: no cover - best effort fallback
gemm_4bit_forward_kernel = None
logger.warning(
"Failed to load CPU gemm_4bit_forward from kernels-community: %s. Please make sure you already `pip install kernels` and the kernels >= 0.11.1",
exc,
)

@register_kernel("bitsandbytes::gemv_4bit", "cpu")
def _(
@@ -239,38 +250,42 @@ if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
final_out_shape = (*A.shape[:-1], shapeB[0])
A = A.reshape(-1, A.shape[-1])
out_shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
M = A.shape[0]
N = shapeB[0]
K = A.shape[1]
x_strideM = A.stride(0)
out_strideM = out.stride(0)
if quant_type == "fp4":
lib.gemv_4bit_inference_cpu_fp4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
elif quant_type == "nf4":
lib.gemv_4bit_inference_cpu_nf4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
if gemm_4bit_forward_kernel is not None:
quant_type_num = 1 if quant_type == "fp4" else 0
out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num)
else:
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
M = A.shape[0]
N = shapeB[0]
K = A.shape[1]
x_strideM = A.stride(0)
out_strideM = out.stride(0)
if quant_type == "fp4":
lib.gemv_4bit_inference_cpu_fp4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
elif quant_type == "nf4":
lib.gemv_4bit_inference_cpu_nf4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)

if dtype != torch.bfloat16:
out = out.to(dtype)


+ 0
- 497
csrc/kernels.cu View File

@@ -2025,429 +2025,6 @@ __global__ void kspmm_coo_very_sparse_naive(
}
}

#define WARPS 3

template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc) {

#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x * 32;
const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS - 1) * 2;
const int val_per_iter = blockDim.x - 32;

T local_A[4];
T local_B[128];

const int a_tile_offset = 16;
const int b_tile_offset = (16 * 32 + 16);

__shared__ T smem_A[8 * 16 + (2 * 16 * (batch_size_warps - 1))];
__shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))];
//__shared__ T smem_C[8*32];

wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);

int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if (idx < K && warp_id < (WARPS - 1)) {
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx + (1 * val_per_iter)];
local_A[2] = A[idx + (2 * val_per_iter)];
local_A[3] = A[idx + (3 * val_per_iter)];

#pragma unroll 32
for (int col = 0; col < 32; col++) {
local_B[col] = B[(col_offset + col) * ldb + idx];
local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)];
local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)];
local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)];
}
loaded_values = 3;
} else {

if (loaded_values == 3) {
local_A[0] = local_A[1];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (32)];
} else if (loaded_values == 2) {
local_A[0] = local_A[2];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (64)];
} else {
local_A[0] = local_A[3];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (96)];
}
loaded_values--;
}

smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;

// for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) {
idx = base_idx + threadIdx.x;

__syncthreads();
if (idx < K && warp_id < (WARPS - 1)) {
// local_A[0] = A[idx];

// #pragma unroll 32
// for(int col = 0; col < 32; col++)
// local_B[col] = B[(col_offset+col)*ldb+idx];
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx + (1 * val_per_iter)];
local_A[2] = A[idx + (2 * val_per_iter)];
local_A[3] = A[idx + (3 * val_per_iter)];

#pragma unroll 32
for (int col = 0; col < 32; col++) {
local_B[col] = B[(col_offset + col) * ldb + idx];
local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)];
local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)];
local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)];
}
loaded_values = 3;

} else {

if (loaded_values == 3) {
local_A[0] = local_A[1];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (32)];
} else if (loaded_values == 2) {
local_A[0] = local_A[2];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (64)];
} else {
local_A[0] = local_A[3];
#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = local_B[col + (96)];
}
loaded_values--;
}

smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;

if (warp_id == (WARPS - 1))
for (int k = 0; k < batch_size_warps; k++) {
wmma::load_matrix_sync(
a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16
); // 111 mu
wmma::load_matrix_sync(
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16
); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}

__syncthreads();
if (warp_id != (WARPS - 1)) {
return;
}
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;

ticktock = ticktock == 0 ? 1 : 0;
for (int k = 0; k < batch_size_warps; k++) {
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

// 129 mu
if (warp_id == (WARPS - 1))
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);

if (col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
#endif
}

template <typename T> __device__ void printnonzero(T* A, int num_values, const char* strval) {
for (int i = 0; i < num_values; i++)
if ((float)A[i] != 0.0)
printf("%s %i %f\n", strval, i, (float)A[i]);
}

template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
) {

//// element-wise kernel
//// 1. Load batch x k into registers
//// 2. Load k x k into registers
//// 3. dequantize and store in second pair of k x k
//// 4. matmul
//// 5. sum with cub
//// 6. store outputs
//// TC kernel
//// use k warps per thread block
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//// 3. each warp reads a segment of values 16x32 from B
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 7. aggregate files of C into shared memory block C
//// 8. sum (7)
//// 9. write outputs to matmul output matrix
#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x * 32;
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS - 1) * 2;

T quant_map[16];

#pragma unroll 16
for (int i = 0; i < 16; i++)
quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160];

T local_A[2];
T local_B[64];
unsigned char local_B_4bit[32];

const int a_tile_offset = 16;
const int b_tile_offset = (16 * 32 + 16);

__shared__ T smem_A[8 * 16 + (16 * (batch_size_warps - 1))];
__shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))];
__shared__ T smem_C[8 * 32];

wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);

for (int i = threadIdx.x; i < (8 * 32); i += blockDim.x)
smem_C[i] = 0.0f;

__syncthreads();

int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if (idx < K && warp_id < (WARPS - 1)) {
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx + blockDim.x - 32];

#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B_4bit[col] = B[(col_offset + col) * ldb + idx];

loaded_values = 1;
} else {
local_A[0] = local_A[1];
loaded_values--;

#pragma unroll 64
for (int col = 0; col < 64; col += 2) {
// local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
// local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
// local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
// local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
// local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
// local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);

// local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
// local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B[col] = quant_map[160 * (local_B_4bit[col / 2] >> 4) + warp_idx] * T(17.0);
local_B[col + 1] = quant_map[160 * (local_B_4bit[col / 2] & 0x0F) + warp_idx] * T(17.0);
}
}

smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
// if(threadIdx.x == 0)
// printf("aa %i %i\n", idx, loaded_values);

// for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) {
idx = base_idx + threadIdx.x;
// if(threadIdx.x == 0)
// printf("%i %i\n", idx, loaded_values);

//__syncthreads();
if (idx < K && warp_id < (WARPS - 1)) {
if (loaded_values == 0) {
local_A[0] = A[idx];
local_A[1] = A[idx + blockDim.x - 32];

#pragma unroll 32
for (int col = 0; col < 32; col++) {
local_B_4bit[col] = B[(col_offset + col) * ldb + idx];
local_B_4bit[col + 16] = B[(col_offset + col) * ldb + idx];
}

loaded_values = 1;
} else {
local_A[0] = local_A[1];
loaded_values--;

int absidx = (idx + col_offset) / blocksize;
half local_absmax = __ldg(&(absmax[absidx]));

#pragma unroll 64
for (int col = 0; col < 64; col += 2) {
// local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
// local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
// local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
// local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);

// local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
// local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B[col] = quant_map[(local_B_4bit[col / 2] >> 4)] * T(absidx);
local_B[col + 1] = quant_map[(local_B_4bit[col / 2] & 0x0F)] * T(absidx);
}
// printnonzero<T>(local_B, 128, "");
}

smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0];

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
local_B[col];
} else if (warp_id < (WARPS - 1)) {
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for (int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] =
0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;

if (warp_id == (WARPS - 1))
for (int k = 0; k < batch_size_warps; k++) {
wmma::load_matrix_sync(
a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16
); // 111 mu
wmma::load_matrix_sync(
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16
); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}

__syncthreads();
// if(threadIdx.x == 0)
//{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
// }
if (warp_id != (WARPS - 1)) {
return;
}
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;

ticktock = ticktock == 0 ? 1 : 0;
for (int k = 0; k < batch_size_warps; k++) {
// if(warp_lane == 0)
// printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

// 129 mu
if (warp_id == (WARPS - 1))
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major);

// printnonzero<T>(smem_C, 32, "");

if (col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_C[warp_lane];
#endif
}

#define num_values_4bit 32

template <typename T, int THREADS, int BITS>
@@ -2592,77 +2169,6 @@ template __global__ void kfunc<unsigned char, FILL>(unsigned char* A, unsigned c
template __global__ void kfunc<float, ARANGE>(float* A, float* B, float value, long n);
template __global__ void kfunc<float, _MUL>(float* A, float* B, float value, long n);

// these are not used and make no sense, but the compiler needs them
// template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 256>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 192>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 160>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 128>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 64>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 32, 96>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// these are not used and make no sense, but the compiler needs them

// template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 256>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 192>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 160>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 128>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
// template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B,
// float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 64>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);
template __global__ void gemm_device<half, 16, 96>(
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc
);

template __global__ void kgemm_4bit_inference<half, 96>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 128>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 160>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);
template __global__ void kgemm_4bit_inference<half, 256>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb,
int ldc, int blocksize
);

template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out,
int lda, int ldb, int ldc, int blocksize
@@ -2996,6 +2502,3 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)

template __device__ void printnonzero<float>(float* A, int num_values, const char* strval);
template __device__ void printnonzero<half>(half* A, int num_values, const char* strval);

+ 0
- 7
csrc/kernels.cuh View File

@@ -112,13 +112,6 @@ __global__ void kdequant_mm_int32_fp16(
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);

template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
);
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,


+ 0
- 475
csrc/kernels.hip View File

@@ -2162,451 +2162,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}

#define WARPS 3
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{

#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
const int val_per_iter = blockDim.x-32;

T local_A[4];
T local_B[128];

const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);

__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
//__shared__ T smem_C[8*32];

rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);

int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];

#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
}
loaded_values = 3;
}
else
{

if(loaded_values == 3)
{
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
}
loaded_values--;
}

smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;

//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
idx = base_idx + threadIdx.x;

__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
//local_A[0] = A[idx];

//#pragma unroll 32
//for(int col = 0; col < 32; col++)
// local_B[col] = B[(col_offset+col)*ldb+idx];
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];

#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
}
loaded_values = 3;

}
else
{

if(loaded_values == 3)
{
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
}
loaded_values--;
}

smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;

if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}

__syncthreads();
if(warp_id != (WARPS-1)){ return; }
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;

ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

// 129 mu
if(warp_id == (WARPS-1))
rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major);

if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
#endif
}


template <typename T> __device__ void printnonzero(T *A, int num_values, const char * strval)
{
for(int i = 0; i < num_values; i++)
if((float)A[i] != 0.0)
printf("%s %i %f\n", strval, i, (float)A[i]);
}

template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{

//// element-wise kernel
//// 1. Load batch x k into registers + //// 2. Load k x k into registers
//// 3. dequantize and store in second pair of k x k + //// 4. matmul
//// 5. sum with cub
//// 6. store outputs
//// TC kernel
//// use k warps per thread block
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//// 3. each warp reads a segment of values 16x32 from B
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 7. aggregate files of C into shared memory block C
//// 8. sum (7)
//// 9. write outputs to matmul output matrix
#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;

T quant_map[16];

#pragma unroll 16
for(int i = 0; i < 16; i++)
quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160];

T local_A[2];
T local_B[64];
unsigned char local_B_4bit[32];


const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);

__shared__ T smem_A[8*16 + (16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[8*32];

rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);

for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
smem_C[i] = 0.0f;

__syncthreads();

int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];

#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];

loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;

#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);

//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0);
local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0);
}
}

smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
//if(threadIdx.x == 0)
//printf("aa %i %i\n", idx, loaded_values);

//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
idx = base_idx + threadIdx.x;
//if(threadIdx.x == 0)
//printf("%i %i\n", idx, loaded_values);

//__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];

#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx];
}

loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;

int absidx = (idx + col_offset)/blocksize;
half local_absmax = __ldg(&(absmax[absidx]));

#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);

//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
}
//printnonzero<T>(local_B, 128, "");
}

smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;

#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;

if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}

__syncthreads();
//if(threadIdx.x == 0)
//{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
//}
if(warp_id != (WARPS-1)){ return; }
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;

ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
//if(warp_lane == 0)
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}

// 129 mu
if(warp_id == (WARPS-1))
rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major);

//printnonzero<T>(smem_C, 32, "");

if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_C[warp_lane];
#endif
}

// No of 4bit values processed by each thread
#define num_values_4bit 32
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
@@ -2764,33 +2319,6 @@ template __global__ void kfunc<unsigned char, FILL>(unsigned char *A, unsigned c
template __global__ void kfunc<float, ARANGE>(float *A, float *B, float value, long n);
template __global__ void kfunc<float, _MUL>(float *A, float *B, float value, long n);

// these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
// these are not used and make no sense, but the compiler needs them

//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);

template __global__ void kgemm_4bit_inference<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);

template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<hip_bfloat16, 128, 16>(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
@@ -3086,6 +2614,3 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, hip_bfloat16, 256, 1)

template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);

+ 0
- 7
csrc/kernels_hip.cuh View File

@@ -114,13 +114,6 @@ __global__ void kdequant_mm_int32_fp16(
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);

template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
);
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,


+ 0
- 27
csrc/ops.cu View File

@@ -451,26 +451,6 @@ void spmm_coo_very_sparse_naive(
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits) {

int num_blocks = (m + 31) / 32;

if (bits == 32)
gemm_device<T, 32, 32><<<num_blocks, 32, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
if (bits == 16)
gemm_device<T, 16, 160><<<num_blocks, 160, 0, 0>>>(m, n, k, A, B, out, lda, ldb, ldc);
}

template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
) {

int num_blocks = (m + 31) / 32;

kgemm_4bit_inference<T, 96><<<num_blocks, 96, 0, 0>>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}

template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
@@ -501,9 +481,6 @@ template void func<unsigned char, FILL>(unsigned char* A, unsigned char* B, unsi
template void func<float, ARANGE>(float* A, float* B, float value, long n);
template void func<float, _MUL>(float* A, float* B, float value, long n);

template void gemm_4bit_inference<half>(
int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
);
template void gemm_4bit_inference_naive<half, 16>(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
@@ -517,10 +494,6 @@ template void gemm_4bit_inference_naive<float, 32>(
int ldc, int blocksize, cudaStream_t stream
);

// template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc,
// int bits);
template void gemm_host<half>(int m, int n, int k, half* A, half* B, half* out, int lda, int ldb, int ldc, int bits);

template void spmm_coo_very_sparse_naive<half, 16>(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, half* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB


+ 0
- 7
csrc/ops.cuh View File

@@ -179,13 +179,6 @@ void spmm_coo_very_sparse_naive(
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);

void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB);

template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits);
template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
);
template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,


+ 0
- 23
csrc/ops.hip View File

@@ -589,25 +589,6 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(hipPeekAtLastError());
}

template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{

int num_blocks = (m+31)/32;

if(bits == 32)
hipLaunchKernelGGL(( gemm_device<T, 32, 32>), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16)
hipLaunchKernelGGL(( gemm_device<T, 16, 160>), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc);
}

template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{

int num_blocks = (m+31)/32;

hipLaunchKernelGGL(( kgemm_4bit_inference<T, 96>), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}

template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream)
{

@@ -641,14 +622,10 @@ template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsi
template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);

template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream);
template void gemm_4bit_inference_naive<hip_bfloat16, 16>(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream);
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream);

//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);

template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);



+ 0
- 7
csrc/ops_hip.cuh View File

@@ -181,13 +181,6 @@ void spmm_coo_very_sparse_naive(
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);

void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB);

template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits);
template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
);
template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,


+ 0
- 25
csrc/pythonInterface.cpp View File

@@ -42,18 +42,6 @@

#if BUILD_CUDA || BUILD_HIP

// void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
void gemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16);
}

void gemm_4bit_inference(
int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}

void gemm_4bit_inference_naive_fp16(
int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb,
int ldc, int blocksize, cudaStream_t stream
@@ -677,19 +665,6 @@ void cspmm_coo_very_sparse_naive_int8(
);
}

// void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }

void cgemm_host_fp16(int M, int N, int K, half* A, half* B, half* out, int lda, int ldb, int ldc) {
gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc);
}

void cgemm_4bit_inference(
int m, int n, int k, half* A, unsigned char* B, float* absmax, half* out, int lda, int ldb, int ldc, int blocksize
) {
gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}

void* cget_managed_ptr(size_t bytes) {
void* ptr;
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));


+ 1
- 1
pyproject.toml View File

@@ -45,7 +45,7 @@ classifiers = [
dependencies = [
"torch>=2.3,<3",
"numpy>=1.17",
"packaging>=20.9"
"packaging>=20.9",
]

[project.urls]


+ 0
- 1
tests/test_ops.py View File

@@ -237,7 +237,6 @@ class Test4bitBlockwiseQuantOps:
quant_type=quant_type,
)
B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state)
B_q = B_q.t()
absmax = state.absmax
out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)



Loading…
Cancel
Save
Baidu
map