AlexchIA's picture
Upload folder using huggingface_hub
452b173
#include "rope.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
const int THREADS_X = 32;
const int THREADS_Y = 4;
const int MAX_POS_EMBEDDINGS = 32768; // Actual number doesn't matter
typedef void (*fp_rope_cuda_kernel)
(
half*,
const half*,
const half*,
int,
int,
int,
int
);
template<bool use_half2>
__global__ void rope_cuda_kernel
(
half* __restrict__ x,
const half* __restrict__ sin,
const half* __restrict__ cos,
int rows_per_batch,
int head_dim,
int num_heads,
int past_len
)
{
// These heights aren't used so it's okay if they're wrong.
MatrixView_half_rw x_(x, rows_per_batch, head_dim);
MatrixView_half sin_(sin, MAX_POS_EMBEDDINGS, head_dim);
MatrixView_half cos_(cos, MAX_POS_EMBEDDINGS, head_dim);
int column = (blockIdx.x * THREADS_X + threadIdx.x); if constexpr (use_half2) column *= 2;
int half_dim = head_dim / 2;
if (column >= half_dim) return;
int row = blockIdx.y * THREADS_Y + threadIdx.y;
if (row >= rows_per_batch) return;
int batch_offset = blockIdx.z * rows_per_batch;
int row_offset = batch_offset + row;
// Get sin and cos
int sincos_row = past_len + row / num_heads;
if constexpr (use_half2)
{
half2 cos2_l = cos_.item_half2(sincos_row, column);
half2 cos2_r = cos_.item_half2(sincos_row, column + half_dim);
half2 sin2_l = sin_.item_half2(sincos_row, column);
half2 sin2_r = sin_.item_half2(sincos_row, column + half_dim);
sin2_l = __hneg2(sin2_l);
// Apply embedding to row
half2 item2_l = x_.item_half2(row_offset, column);
half2 item2_r = x_.item_half2(row_offset, column + half_dim);
half2 item2_ls = __hmul2(item2_r, sin2_l);
half2 item2_rs = __hmul2(item2_l, sin2_r);
item2_l = __hfma2(item2_l, cos2_l, item2_ls);
item2_r = __hfma2(item2_r, cos2_r, item2_rs);
x_.set_half2(row_offset, column, item2_l);
x_.set_half2(row_offset, column + half_dim, item2_r);
}
else
{
half cos_l = cos_.item(sincos_row, column);
half cos_r = cos_.item(sincos_row, column + half_dim);
half sin_l = sin_.item(sincos_row, column);
half sin_r = sin_.item(sincos_row, column + half_dim);
sin_l = __hneg(sin_l);
// Apply embedding to row
half item_l = x_.item(row_offset, column);
half item_r = x_.item(row_offset, column + half_dim);
half item_ls = __hmul(item_r, sin_l);
half item_rs = __hmul(item_l, sin_r);
item_l = __hfma(item_l, cos_l, item_ls);
item_r = __hfma(item_r, cos_r, item_rs);
x_.set(row_offset, column, item_l);
x_.set(row_offset, column + half_dim, item_r);
}
}
fp_rope_cuda_kernel rope_cuda_kernel_pick(ExLlamaTuning* tuningParams)
{
// <bool use_half2>
if (tuningParams->matmul_no_half2) {
return rope_cuda_kernel<false>;
} else {
return rope_cuda_kernel<true>;
}
};
void rope_cuda
(
ExLlamaTuning* tuningParams,
half* x,
const half* sin,
const half* cos,
const int bsz,
const int rows_per_batch,
const int head_dim,
const int num_heads,
const int past_len,
cudaStream_t alt_stream
)
{
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(head_dim + THREADS_X - 1) / THREADS_X / 2 / (tuningParams->rope_no_half2 ? 1 : 2),
(rows_per_batch + THREADS_Y - 1) / THREADS_Y,
int(bsz)
);
fp_rope_cuda_kernel kernel = rope_cuda_kernel_pick(tuningParams);
kernel<<<blocks, threads, 0, alt_stream>>>(x, sin, cos, rows_per_batch, head_dim, num_heads, past_len);
}