|
|
|
@@ -2025,429 +2025,6 @@ __global__ void kspmm_coo_very_sparse_naive( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#define WARPS 3 |
|
|
|
|
|
|
|
template <typename T, int BITS, int THREADS> |
|
|
|
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc) { |
|
|
|
|
|
|
|
#if __CUDA_ARCH__ >= 750 |
|
|
|
using namespace nvcuda; |
|
|
|
int col_offset = blockIdx.x * 32; |
|
|
|
const int warp_id = threadIdx.x / 32; |
|
|
|
const int half_warp_id = threadIdx.x / 16; |
|
|
|
const int half_warp_lane = threadIdx.x % 16; |
|
|
|
const int batch_size_warps = (WARPS - 1) * 2; |
|
|
|
const int val_per_iter = blockDim.x - 32; |
|
|
|
|
|
|
|
T local_A[4]; |
|
|
|
T local_B[128]; |
|
|
|
|
|
|
|
const int a_tile_offset = 16; |
|
|
|
const int b_tile_offset = (16 * 32 + 16); |
|
|
|
|
|
|
|
__shared__ T smem_A[8 * 16 + (2 * 16 * (batch_size_warps - 1))]; |
|
|
|
__shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; |
|
|
|
//__shared__ T smem_C[8*32]; |
|
|
|
|
|
|
|
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag; |
|
|
|
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag; |
|
|
|
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag; |
|
|
|
wmma::fill_fragment(c_frag, 0.0f); |
|
|
|
|
|
|
|
int ticktock = 0; |
|
|
|
int idx = 0 + threadIdx.x; |
|
|
|
int loaded_values = 0; |
|
|
|
// prefetch |
|
|
|
if (idx < K && warp_id < (WARPS - 1)) { |
|
|
|
if (loaded_values == 0) { |
|
|
|
local_A[0] = A[idx]; |
|
|
|
local_A[1] = A[idx + (1 * val_per_iter)]; |
|
|
|
local_A[2] = A[idx + (2 * val_per_iter)]; |
|
|
|
local_A[3] = A[idx + (3 * val_per_iter)]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) { |
|
|
|
local_B[col] = B[(col_offset + col) * ldb + idx]; |
|
|
|
local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; |
|
|
|
local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; |
|
|
|
local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; |
|
|
|
} |
|
|
|
loaded_values = 3; |
|
|
|
} else { |
|
|
|
|
|
|
|
if (loaded_values == 3) { |
|
|
|
local_A[0] = local_A[1]; |
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = local_B[col + (32)]; |
|
|
|
} else if (loaded_values == 2) { |
|
|
|
local_A[0] = local_A[2]; |
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = local_B[col + (64)]; |
|
|
|
} else { |
|
|
|
local_A[0] = local_A[3]; |
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = local_B[col + (96)]; |
|
|
|
} |
|
|
|
loaded_values--; |
|
|
|
} |
|
|
|
|
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
local_B[col]; |
|
|
|
} else if (warp_id < (WARPS - 1)) { |
|
|
|
local_A[0] = T(0.0); |
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
0.0f; |
|
|
|
} |
|
|
|
ticktock = ticktock == 0 ? 1 : 0; |
|
|
|
|
|
|
|
// for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) |
|
|
|
for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { |
|
|
|
idx = base_idx + threadIdx.x; |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
if (idx < K && warp_id < (WARPS - 1)) { |
|
|
|
// local_A[0] = A[idx]; |
|
|
|
|
|
|
|
// #pragma unroll 32 |
|
|
|
// for(int col = 0; col < 32; col++) |
|
|
|
// local_B[col] = B[(col_offset+col)*ldb+idx]; |
|
|
|
if (loaded_values == 0) { |
|
|
|
local_A[0] = A[idx]; |
|
|
|
local_A[1] = A[idx + (1 * val_per_iter)]; |
|
|
|
local_A[2] = A[idx + (2 * val_per_iter)]; |
|
|
|
local_A[3] = A[idx + (3 * val_per_iter)]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) { |
|
|
|
local_B[col] = B[(col_offset + col) * ldb + idx]; |
|
|
|
local_B[col + 32] = B[(col_offset + col) * ldb + idx + (1 * val_per_iter)]; |
|
|
|
local_B[col + 64] = B[(col_offset + col) * ldb + idx + (2 * val_per_iter)]; |
|
|
|
local_B[col + 96] = B[(col_offset + col) * ldb + idx + (3 * val_per_iter)]; |
|
|
|
} |
|
|
|
loaded_values = 3; |
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
if (loaded_values == 3) { |
|
|
|
local_A[0] = local_A[1]; |
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = local_B[col + (32)]; |
|
|
|
} else if (loaded_values == 2) { |
|
|
|
local_A[0] = local_A[2]; |
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = local_B[col + (64)]; |
|
|
|
} else { |
|
|
|
local_A[0] = local_A[3]; |
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = local_B[col + (96)]; |
|
|
|
} |
|
|
|
loaded_values--; |
|
|
|
} |
|
|
|
|
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
local_B[col]; |
|
|
|
} else if (warp_id < (WARPS - 1)) { |
|
|
|
local_A[0] = T(0.0); |
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
0.0f; |
|
|
|
} |
|
|
|
ticktock = ticktock == 0 ? 1 : 0; |
|
|
|
|
|
|
|
if (warp_id == (WARPS - 1)) |
|
|
|
for (int k = 0; k < batch_size_warps; k++) { |
|
|
|
wmma::load_matrix_sync( |
|
|
|
a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 |
|
|
|
); // 111 mu |
|
|
|
wmma::load_matrix_sync( |
|
|
|
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 |
|
|
|
); // 35 mu |
|
|
|
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
if (warp_id != (WARPS - 1)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// only warp_id == (WARPS-1) from here |
|
|
|
int warp_lane = threadIdx.x % 32; |
|
|
|
|
|
|
|
ticktock = ticktock == 0 ? 1 : 0; |
|
|
|
for (int k = 0; k < batch_size_warps; k++) { |
|
|
|
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu |
|
|
|
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu |
|
|
|
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); |
|
|
|
} |
|
|
|
|
|
|
|
// 129 mu |
|
|
|
if (warp_id == (WARPS - 1)) |
|
|
|
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); |
|
|
|
|
|
|
|
if (col_offset + warp_lane < M) |
|
|
|
out[col_offset + warp_lane] = smem_A[warp_lane]; |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> __device__ void printnonzero(T* A, int num_values, const char* strval) { |
|
|
|
for (int i = 0; i < num_values; i++) |
|
|
|
if ((float)A[i] != 0.0) |
|
|
|
printf("%s %i %f\n", strval, i, (float)A[i]); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T, int THREADS> |
|
|
|
__global__ void kgemm_4bit_inference( |
|
|
|
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, |
|
|
|
int blocksize |
|
|
|
) { |
|
|
|
|
|
|
|
//// element-wise kernel |
|
|
|
//// 1. Load batch x k into registers |
|
|
|
//// 2. Load k x k into registers |
|
|
|
//// 3. dequantize and store in second pair of k x k |
|
|
|
//// 4. matmul |
|
|
|
//// 5. sum with cub |
|
|
|
//// 6. store outputs |
|
|
|
//// TC kernel |
|
|
|
//// use k warps per thread block |
|
|
|
//// 1. threadblock use read-only cache to read in register tile for A into shared memory |
|
|
|
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments |
|
|
|
//// 3. each warp reads a segment of values 16x32 from B |
|
|
|
//// 4. do dequantization from register of B into second pair of registers |
|
|
|
//// 5. store (4) into fragment |
|
|
|
//// 6. matmul aggregate into fragment C |
|
|
|
//// 7. aggregate files of C into shared memory block C |
|
|
|
//// 8. sum (7) |
|
|
|
//// 9. write outputs to matmul output matrix |
|
|
|
#if __CUDA_ARCH__ >= 750 |
|
|
|
using namespace nvcuda; |
|
|
|
int col_offset = blockIdx.x * 32; |
|
|
|
const int warp_id = threadIdx.x / 32; |
|
|
|
const int warp_idx = threadIdx.x % 32; |
|
|
|
const int half_warp_id = threadIdx.x / 16; |
|
|
|
const int half_warp_lane = threadIdx.x % 16; |
|
|
|
const int batch_size_warps = (WARPS - 1) * 2; |
|
|
|
|
|
|
|
T quant_map[16]; |
|
|
|
|
|
|
|
#pragma unroll 16 |
|
|
|
for (int i = 0; i < 16; i++) |
|
|
|
quant_map[i] = nf4_dequantization_lut[i]; |
|
|
|
//__shared__ T quant_map[16*160]; |
|
|
|
|
|
|
|
T local_A[2]; |
|
|
|
T local_B[64]; |
|
|
|
unsigned char local_B_4bit[32]; |
|
|
|
|
|
|
|
const int a_tile_offset = 16; |
|
|
|
const int b_tile_offset = (16 * 32 + 16); |
|
|
|
|
|
|
|
__shared__ T smem_A[8 * 16 + (16 * (batch_size_warps - 1))]; |
|
|
|
__shared__ T smem_B[2 * batch_size_warps * 16 * 32 + (2 * 16 * (batch_size_warps - 1))]; |
|
|
|
__shared__ T smem_C[8 * 32]; |
|
|
|
|
|
|
|
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag; |
|
|
|
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag; |
|
|
|
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag; |
|
|
|
wmma::fill_fragment(c_frag, 0.0f); |
|
|
|
|
|
|
|
for (int i = threadIdx.x; i < (8 * 32); i += blockDim.x) |
|
|
|
smem_C[i] = 0.0f; |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
int ticktock = 0; |
|
|
|
int idx = 0 + threadIdx.x; |
|
|
|
int loaded_values = 0; |
|
|
|
// prefetch |
|
|
|
if (idx < K && warp_id < (WARPS - 1)) { |
|
|
|
if (loaded_values == 0) { |
|
|
|
local_A[0] = A[idx]; |
|
|
|
local_A[1] = A[idx + blockDim.x - 32]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; |
|
|
|
|
|
|
|
loaded_values = 1; |
|
|
|
} else { |
|
|
|
local_A[0] = local_A[1]; |
|
|
|
loaded_values--; |
|
|
|
|
|
|
|
#pragma unroll 64 |
|
|
|
for (int col = 0; col < 64; col += 2) { |
|
|
|
// local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); |
|
|
|
// local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); |
|
|
|
// local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); |
|
|
|
// local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); |
|
|
|
// local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); |
|
|
|
// local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); |
|
|
|
|
|
|
|
// local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); |
|
|
|
// local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); |
|
|
|
local_B[col] = quant_map[160 * (local_B_4bit[col / 2] >> 4) + warp_idx] * T(17.0); |
|
|
|
local_B[col + 1] = quant_map[160 * (local_B_4bit[col / 2] & 0x0F) + warp_idx] * T(17.0); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
local_B[col]; |
|
|
|
} else if (warp_id < (WARPS - 1)) { |
|
|
|
local_A[0] = T(0.0); |
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
0.0f; |
|
|
|
} |
|
|
|
ticktock = ticktock == 0 ? 1 : 0; |
|
|
|
// if(threadIdx.x == 0) |
|
|
|
// printf("aa %i %i\n", idx, loaded_values); |
|
|
|
|
|
|
|
// for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) |
|
|
|
for (int base_idx = blockDim.x - 32; base_idx < K; base_idx += blockDim.x - 32) { |
|
|
|
idx = base_idx + threadIdx.x; |
|
|
|
// if(threadIdx.x == 0) |
|
|
|
// printf("%i %i\n", idx, loaded_values); |
|
|
|
|
|
|
|
//__syncthreads(); |
|
|
|
if (idx < K && warp_id < (WARPS - 1)) { |
|
|
|
if (loaded_values == 0) { |
|
|
|
local_A[0] = A[idx]; |
|
|
|
local_A[1] = A[idx + blockDim.x - 32]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) { |
|
|
|
local_B_4bit[col] = B[(col_offset + col) * ldb + idx]; |
|
|
|
local_B_4bit[col + 16] = B[(col_offset + col) * ldb + idx]; |
|
|
|
} |
|
|
|
|
|
|
|
loaded_values = 1; |
|
|
|
} else { |
|
|
|
local_A[0] = local_A[1]; |
|
|
|
loaded_values--; |
|
|
|
|
|
|
|
int absidx = (idx + col_offset) / blocksize; |
|
|
|
half local_absmax = __ldg(&(absmax[absidx])); |
|
|
|
|
|
|
|
#pragma unroll 64 |
|
|
|
for (int col = 0; col < 64; col += 2) { |
|
|
|
// local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); |
|
|
|
// local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); |
|
|
|
// local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); |
|
|
|
// local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); |
|
|
|
|
|
|
|
// local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); |
|
|
|
// local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); |
|
|
|
local_B[col] = quant_map[(local_B_4bit[col / 2] >> 4)] * T(absidx); |
|
|
|
local_B[col + 1] = quant_map[(local_B_4bit[col / 2] & 0x0F)] * T(absidx); |
|
|
|
} |
|
|
|
// printnonzero<T>(local_B, 128, ""); |
|
|
|
} |
|
|
|
|
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = local_A[0]; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
local_B[col]; |
|
|
|
} else if (warp_id < (WARPS - 1)) { |
|
|
|
local_A[0] = T(0.0); |
|
|
|
smem_A[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * a_tile_offset)] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
local_B[col] = 0.0f; |
|
|
|
|
|
|
|
#pragma unroll 32 |
|
|
|
for (int col = 0; col < 32; col++) |
|
|
|
smem_B[half_warp_lane + (((batch_size_warps * ticktock) + half_warp_id) * b_tile_offset) + (col * 16)] = |
|
|
|
0.0f; |
|
|
|
} |
|
|
|
ticktock = ticktock == 0 ? 1 : 0; |
|
|
|
|
|
|
|
if (warp_id == (WARPS - 1)) |
|
|
|
for (int k = 0; k < batch_size_warps; k++) { |
|
|
|
wmma::load_matrix_sync( |
|
|
|
a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16 |
|
|
|
); // 111 mu |
|
|
|
wmma::load_matrix_sync( |
|
|
|
b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16 |
|
|
|
); // 35 mu |
|
|
|
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
// if(threadIdx.x == 0) |
|
|
|
//{ |
|
|
|
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); |
|
|
|
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); |
|
|
|
// } |
|
|
|
if (warp_id != (WARPS - 1)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// only warp_id == (WARPS-1) from here |
|
|
|
int warp_lane = threadIdx.x % 32; |
|
|
|
|
|
|
|
ticktock = ticktock == 0 ? 1 : 0; |
|
|
|
for (int k = 0; k < batch_size_warps; k++) { |
|
|
|
// if(warp_lane == 0) |
|
|
|
// printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); |
|
|
|
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock * batch_size_warps + k) * a_tile_offset]), 16); // 111 mu |
|
|
|
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock * batch_size_warps + k) * b_tile_offset]), 16); // 35 mu |
|
|
|
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); |
|
|
|
} |
|
|
|
|
|
|
|
// 129 mu |
|
|
|
if (warp_id == (WARPS - 1)) |
|
|
|
wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); |
|
|
|
|
|
|
|
// printnonzero<T>(smem_C, 32, ""); |
|
|
|
|
|
|
|
if (col_offset + warp_lane < M) |
|
|
|
out[col_offset + warp_lane] = smem_C[warp_lane]; |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
#define num_values_4bit 32 |
|
|
|
|
|
|
|
template <typename T, int THREADS, int BITS> |
|
|
|
@@ -2592,77 +2169,6 @@ template __global__ void kfunc<unsigned char, FILL>(unsigned char* A, unsigned c |
|
|
|
template __global__ void kfunc<float, ARANGE>(float* A, float* B, float value, long n); |
|
|
|
template __global__ void kfunc<float, _MUL>(float* A, float* B, float value, long n); |
|
|
|
|
|
|
|
// these are not used and make no sense, but the compiler needs them |
|
|
|
// template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, |
|
|
|
// float * out, int lda, int ldb, int ldc); |
|
|
|
template __global__ void gemm_device<half, 32, 256>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 32, 192>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 32, 160>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 32, 128>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
// template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, |
|
|
|
// float * out, int lda, int ldb, int ldc); |
|
|
|
template __global__ void gemm_device<half, 32, 32>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 32, 64>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 32, 96>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
// these are not used and make no sense, but the compiler needs them |
|
|
|
|
|
|
|
// template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, |
|
|
|
// float * out, int lda, int ldb, int ldc); |
|
|
|
template __global__ void gemm_device<half, 16, 256>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 16, 192>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 16, 160>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 16, 128>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
// template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, |
|
|
|
// float * out, int lda, int ldb, int ldc); |
|
|
|
template __global__ void gemm_device<half, 16, 32>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 16, 64>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
template __global__ void gemm_device<half, 16, 96>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, half* B, half* out, int lda, int ldb, int ldc |
|
|
|
); |
|
|
|
|
|
|
|
template __global__ void kgemm_4bit_inference<half, 96>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, |
|
|
|
int ldc, int blocksize |
|
|
|
); |
|
|
|
template __global__ void kgemm_4bit_inference<half, 128>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, |
|
|
|
int ldc, int blocksize |
|
|
|
); |
|
|
|
template __global__ void kgemm_4bit_inference<half, 160>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, |
|
|
|
int ldc, int blocksize |
|
|
|
); |
|
|
|
template __global__ void kgemm_4bit_inference<half, 256>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, half* out, int lda, int ldb, |
|
|
|
int ldc, int blocksize |
|
|
|
); |
|
|
|
|
|
|
|
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>( |
|
|
|
int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out, |
|
|
|
int lda, int ldb, int ldc, int blocksize |
|
|
|
@@ -2996,6 +2502,3 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1) |
|
|
|
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) |
|
|
|
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) |
|
|
|
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) |
|
|
|
|
|
|
|
template __device__ void printnonzero<float>(float* A, int num_values, const char* strval); |
|
|
|
template __device__ void printnonzero<half>(half* A, int num_values, const char* strval); |