Spaces:
Build error
Build error
File size: 1,449 Bytes
452b173 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
#include "column_remap.cuh"
#include "../util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}
|