Spaces:
Build error
Build error
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a | |
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of | |
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. | |
void check_cuda(cudaError_t ret) | |
{ | |
switch (ret) | |
{ | |
case cudaSuccess: | |
break; | |
case cudaUnspecified: | |
printf(" **** Unspecified error\n"); | |
TORCH_CHECK(false, "CUDA error"); | |
break; | |
default: | |
printf(" **** CUDA error\n"); \ | |
printf(" **** %s\n", cudaGetErrorString(ret)); \ | |
TORCH_CHECK(false, "CUDA error"); \ | |
break; | |
} | |
} | |
// Some decluttering macros | |
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) | |
{ | |
int groupsize = w.size(0) * 8 / w_zeros.size(0); | |
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") | |
return groupsize; | |
} | |
// Tuning parameters | |
ExLlamaTuning tuningParams; | |
void set_tuning_params | |
( | |
int matmul_recons_thd, | |
int fused_mlp_thd, | |
int sdp_thd, | |
bool matmul_fused_remap, | |
bool rmsnorm_no_half2, | |
bool rope_no_half2, | |
bool matmul_no_half2, | |
bool silu_no_half2, | |
bool concurrent_streams | |
) | |
{ | |
tuningParams.matmul_recons_thd = matmul_recons_thd; | |
tuningParams.fused_mlp_thd = fused_mlp_thd; | |
tuningParams.sdp_thd = sdp_thd; | |
tuningParams.matmul_fused_remap = matmul_fused_remap; | |
tuningParams.rmsnorm_no_half2 = rmsnorm_no_half2; | |
tuningParams.rope_no_half2 = rope_no_half2; | |
tuningParams.matmul_no_half2 = matmul_no_half2; | |
tuningParams.silu_no_half2 = silu_no_half2; | |
tuningParams.concurrent_streams = concurrent_streams; | |
} | |
// Release all unmanaged objects allocated by the extension | |
void cleanup() | |
{ | |
cleanup_buffers_cuda(); | |
g_q4_free_matrices(); | |
} | |
// Prepare buffers for forward pass | |
void prepare_buffers | |
( | |
torch::Device device, | |
torch::Tensor temp_state, | |
torch::Tensor temp_mlp, | |
torch::Tensor temp_zeros_float, | |
torch::Tensor temp_dq | |
) | |
{ | |
int device_index = device.index(); | |
TORCH_CHECK_DEVICE_INDEX(device_index); | |
const at::cuda::OptionalCUDAGuard device_guard(device); | |
int max_zeros_float = temp_zeros_float.size(-1); | |
prepare_buffers_cuda | |
( | |
device_index, | |
(half*) temp_state.data_ptr(), | |
// buffer size used for sanity checks | |
temp_state.numel(), | |
(half*) temp_mlp.data_ptr(), | |
(float*) temp_zeros_float.data_ptr(), | |
(half*) temp_dq.data_ptr(), | |
max_zeros_float | |
); | |
} | |
// Create Q4Matrix, return handle | |
uintptr_t make_q4 | |
( | |
torch::Tensor qweight, | |
torch::Tensor qzeros, | |
torch::Tensor scales, | |
torch::Tensor g_idx, | |
int device | |
) | |
{ | |
TORCH_CHECK_DTYPE(qweight, kInt); | |
TORCH_CHECK_DTYPE(qzeros, kInt); | |
TORCH_CHECK_DTYPE(scales, kHalf); | |
TORCH_CHECK_DTYPE_OPT(g_idx, kInt); | |
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); | |
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); | |
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); | |
int width = qweight.size(1); | |
int height = qweight.size(0) * 8; | |
int groups = qzeros.size(0); | |
Q4Matrix* m = new Q4Matrix | |
( | |
height, | |
width, | |
groups, | |
(uint32_t*) qweight.data_ptr(), | |
(uint32_t*) qzeros.data_ptr(), | |
(half*) scales.data_ptr(), | |
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), | |
device | |
); | |
g_q4_keep_matrix(m); | |
return reinterpret_cast<uintptr_t> (m); | |
} | |
// Matmul half @ quant -> half | |
void q4_matmul | |
( | |
torch::Tensor x, | |
uintptr_t w, | |
torch::Tensor out | |
) | |
{ | |
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w); | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(out, kHalf); | |
TORCH_CHECK_SHAPES(x, 0, out, 0, 1); | |
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
int x_height = x.size(0); | |
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) | |
{ | |
q4_matmul_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
x_height, | |
wm, | |
(half*) out.data_ptr() | |
); | |
} | |
else | |
{ | |
q4_matmul_recons_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
x_height, | |
wm, | |
(half*) out.data_ptr(), | |
at::cuda::getCurrentCUDABlasHandle() | |
); | |
} | |
} | |
// Matmul half @ quant + half @ half @ half -> half | |
// Same as q4_matmul, but adds (x @ lora_A) @ lora_B to the result | |
void q4_matmul_lora | |
( | |
torch::Tensor x, | |
uintptr_t w, | |
torch::Tensor out, | |
torch::Tensor lora_A, | |
torch::Tensor lora_B, | |
torch::Tensor lora_temp // empty tensor, shape of (x @ lora_A) | |
) | |
{ | |
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w); | |
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(out, kHalf); | |
TORCH_CHECK_SHAPES(x, 0, out, 0, 1); | |
TORCH_CHECK_SHAPES(x, 0, lora_temp, 0, 1); | |
TORCH_CHECK_SHAPES(x, 1, lora_A, 0, 1); | |
TORCH_CHECK_SHAPES(lora_A, 1, lora_B, 0, 1); | |
TORCH_CHECK_SHAPES(lora_B, 1, out, 1, 1); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | |
// lora_temp = x @ lora_A | |
half_matmul_cublas_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
(half*) lora_A.data_ptr(), | |
(half*) lora_temp.data_ptr(), | |
x.size(0), | |
x.size(1), | |
lora_A.size(1), | |
handle | |
); | |
// out = lora_temp @ lora_B | |
half_matmul_cublas_cuda | |
( | |
&tuningParams, | |
(half*) lora_temp.data_ptr(), | |
(half*) lora_B.data_ptr(), | |
(half*) out.data_ptr(), | |
lora_temp.size(0), | |
lora_temp.size(1), | |
lora_B.size(1), | |
handle | |
); | |
int x_height = x.size(0); | |
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) | |
{ | |
q4_matmul_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
x_height, | |
wm, | |
(half*) out.data_ptr(), | |
true | |
); | |
} | |
else | |
{ | |
q4_matmul_recons_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
x_height, | |
wm, | |
(half*) out.data_ptr(), | |
handle, | |
true | |
); | |
} | |
} | |
// Remap columns in half tensor | |
void column_remap | |
( | |
torch::Tensor x, | |
torch::Tensor x_new, | |
torch::Tensor x_map | |
) | |
{ | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(x_new, kHalf); | |
TORCH_CHECK_DTYPE(x_map, kInt); | |
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); | |
int height = x.size(0); | |
int width = x.size(1); | |
TORCH_CHECK_BUFFER_SIZE(x_new, height * width); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
column_remap_cuda | |
( | |
(half*) x.data_ptr(), | |
(half*) x_new.data_ptr(), | |
height, | |
width, | |
(uint32_t*) x_map.data_ptr() | |
); | |
} | |
// Matmul half @ half -> half, custom kernel | |
void half_matmul | |
( | |
torch::Tensor x, | |
torch::Tensor w, | |
torch::Tensor out | |
) | |
{ | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(w, kHalf); | |
TORCH_CHECK_DTYPE(out, kHalf); | |
TORCH_CHECK_SHAPES(x, 1, w, 0, 1); | |
int height = x.size(0); | |
int dim = x.size(1); | |
int width = w.size(1); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
half_matmul_cuda | |
( | |
(half*) x.data_ptr(), | |
(half*) w.data_ptr(), | |
(half*) out.data_ptr(), | |
height, | |
dim, | |
width | |
); | |
} | |
// Matmul half @ half -> half using cuBLAS | |
void half_matmul_cublas | |
( | |
torch::Tensor x, | |
torch::Tensor w, | |
torch::Tensor out | |
) | |
{ | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(w, kHalf); | |
TORCH_CHECK_DTYPE(out, kHalf); | |
TORCH_CHECK_SHAPES(x, 1, w, 0, 1); | |
int height = x.size(0); | |
int dim = x.size(1); | |
int width = w.size(1); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | |
half_matmul_cublas_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
(half*) w.data_ptr(), | |
(half*) out.data_ptr(), | |
height, | |
dim, | |
width, | |
handle | |
); | |
} | |
// Llama self attention (WIP) | |
void q4_attn | |
( | |
torch::Tensor x, // shape == (bsz, q_len, dim) | |
torch::Tensor rms_norm_weight, // shape == (x.shape[1],) == (dim,) | |
float epsilon, | |
torch::Tensor query_states, // shape == (bsz, q_len, dim) | |
torch::Tensor key_states, // shape == (bsz, q_len, dim) | |
torch::Tensor value_states, // shape == (bsz, q_len, dim) | |
uintptr_t q_proj, | |
uintptr_t k_proj, | |
uintptr_t v_proj, | |
torch::Tensor sin, | |
torch::Tensor cos, | |
int q_len, | |
int past_len, | |
int num_heads, | |
int num_kv_heads, | |
int head_dim, | |
torch::Tensor key_cache, | |
torch::Tensor value_cache, | |
int max_seq_len, | |
torch::Tensor q_a, | |
torch::Tensor q_b, | |
torch::Tensor k_a, | |
torch::Tensor k_b, | |
torch::Tensor v_a, | |
torch::Tensor v_b, | |
torch::Tensor lora_temp | |
) | |
{ | |
TORCH_CHECK_DTYPE(query_states, kHalf); | |
TORCH_CHECK_DTYPE(key_states, kHalf); | |
int bsz = query_states.size(0); | |
int dim = query_states.size(2); | |
torch::Device device = x.device(); | |
int device_index = device.index(); | |
TORCH_CHECK_DEVICE_INDEX(device_index); | |
const at::cuda::OptionalCUDAGuard device_guard(device); | |
cudaStream_t current_stream = at::cuda::getCurrentCUDAStream().stream(); | |
int q_rank = q_a.device().is_meta() ? 0 : q_a.size(1); | |
int k_rank = k_a.device().is_meta() ? 0 : k_a.size(1); | |
int v_rank = v_a.device().is_meta() ? 0 : v_a.size(1); | |
q4_attn_cuda | |
( | |
&tuningParams, | |
current_stream, | |
at::cuda::getCurrentCUDABlasHandle(), | |
(half*) x.data_ptr(), | |
(half*) rms_norm_weight.data_ptr(), | |
epsilon, | |
(half*) query_states.data_ptr(), | |
(half*) key_states.data_ptr(), | |
(half*) value_states.data_ptr(), | |
reinterpret_cast<Q4Matrix*>(q_proj), | |
reinterpret_cast<Q4Matrix*>(k_proj), | |
reinterpret_cast<Q4Matrix*>(v_proj), | |
(half*) sin.data_ptr(), | |
(half*) cos.data_ptr(), | |
bsz, | |
q_len, | |
dim, | |
head_dim, | |
num_heads, | |
num_kv_heads, | |
past_len, | |
(half*) key_cache.data_ptr(), | |
(half*) value_cache.data_ptr(), | |
q_rank ? (half*) q_a.data_ptr() : NULL, | |
q_rank ? (half*) q_b.data_ptr() : NULL, | |
q_rank, | |
k_rank ? (half*) k_a.data_ptr() : NULL, | |
k_rank ? (half*) k_b.data_ptr() : NULL, | |
k_rank, | |
v_rank ? (half*) v_a.data_ptr() : NULL, | |
v_rank ? (half*) v_b.data_ptr() : NULL, | |
v_rank, | |
lora_temp.device().is_meta() ? NULL : (half*) lora_temp.data_ptr(), | |
max_seq_len, | |
device_index | |
); | |
} | |
void q4_attn_2 | |
( | |
torch::Tensor x, | |
torch::Tensor attn_output, | |
uintptr_t o_proj, | |
torch::Tensor o_a, | |
torch::Tensor o_b, | |
torch::Tensor lora_temp | |
) | |
{ | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(attn_output, kHalf); | |
const at::cuda::OptionalCUDAGuard device_guard(x.device()); | |
int height = x.size(0); | |
int o_rank = o_a.device().is_meta() ? 0 : o_a.size(1); | |
q4_attn_2_cuda | |
( | |
&tuningParams, | |
at::cuda::getCurrentCUDABlasHandle(), | |
(half*) x.data_ptr(), | |
(half*) attn_output.data_ptr(), | |
reinterpret_cast<Q4Matrix*>(o_proj), | |
height, | |
o_rank ? (half*) o_a.data_ptr() : NULL, | |
o_rank ? (half*) o_b.data_ptr() : NULL, | |
o_rank, | |
lora_temp.device().is_meta() ? NULL : (half*) lora_temp.data_ptr() | |
); | |
} | |
// Llama MLP | |
void q4_mlp | |
( | |
torch::Tensor x, // shape == (height, dim) | |
torch::Tensor rms_norm_weight, // shape == (x.shape[1],) == (dim,) | |
float epsilon, | |
uintptr_t gate, | |
uintptr_t up, | |
uintptr_t down, | |
torch::Tensor gate_a, | |
torch::Tensor gate_b, | |
torch::Tensor up_a, | |
torch::Tensor up_b, | |
torch::Tensor down_a, | |
torch::Tensor down_b, | |
torch::Tensor lora_temp | |
) | |
{ | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(rms_norm_weight, kHalf); | |
int height = x.size(0); | |
int dim = x.size(1); | |
torch::Device device = x.device(); | |
int device_index = device.index(); | |
TORCH_CHECK_DEVICE_INDEX(device_index); | |
const at::cuda::OptionalCUDAGuard device_guard(device); | |
int gate_rank = gate_a.device().is_meta() ? 0 : gate_a.size(1); | |
int up_rank = gate_a.device().is_meta() ? 0 : up_a.size(1); | |
int down_rank = gate_a.device().is_meta() ? 0 : down_a.size(1); | |
q4_mlp_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
(half*) rms_norm_weight.data_ptr(), | |
epsilon, | |
reinterpret_cast<Q4Matrix*>(gate), | |
reinterpret_cast<Q4Matrix*>(up), | |
reinterpret_cast<Q4Matrix*>(down), | |
height, | |
dim, | |
gate_rank ? (half*) gate_a.data_ptr() : NULL, | |
gate_rank ? (half*) gate_b.data_ptr() : NULL, | |
gate_rank, | |
up_rank ? (half*) up_a.data_ptr() : NULL, | |
up_rank ? (half*) up_b.data_ptr() : NULL, | |
up_rank, | |
down_rank ? (half*) down_a.data_ptr() : NULL, | |
down_rank ? (half*) down_b.data_ptr() : NULL, | |
down_rank, | |
lora_temp.device().is_meta() ? NULL : (half*) lora_temp.data_ptr(), | |
at::cuda::getCurrentCUDABlasHandle(), | |
device_index | |
); | |
} | |
// RMS layernorm | |
void rms_norm | |
( | |
torch::Tensor x, | |
torch::Tensor w, | |
torch::Tensor out, | |
float epsilon | |
) | |
{ | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(w, kHalf); | |
TORCH_CHECK_DTYPE(out, kHalf); | |
TORCH_CHECK_SHAPES(x, 1, w, 0, 1); | |
TORCH_CHECK_SHAPES(x, 1, w, 0, 1); | |
TORCH_CHECK_SHAPES(x, 0, out, 0, 1); | |
TORCH_CHECK_SHAPES(x, 1, out, 1, 1); | |
int rows = x.size(0); | |
int dim = x.size(1); | |
torch::Device device = x.device(); | |
int device_index = device.index(); | |
TORCH_CHECK_DEVICE_INDEX(device_index); | |
const at::cuda::OptionalCUDAGuard device_guard(device); | |
rms_norm_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
(half*) w.data_ptr(), | |
(half*) out.data_ptr(), | |
epsilon, | |
rows, | |
dim, | |
device_index | |
); | |
} | |
// RoPE rotary positional embeddings | |
void rope_ | |
( | |
torch::Tensor x, | |
torch::Tensor sin, | |
torch::Tensor cos, | |
int past_len, | |
int num_heads, | |
int head_dim | |
) | |
{ | |
TORCH_CHECK_DTYPE(x, kHalf); | |
TORCH_CHECK_DTYPE(sin, kHalf); | |
TORCH_CHECK_DTYPE(cos, kHalf); | |
TORCH_CHECK(head_dim == cos.size(-1), "cos table does not match head_dim"); | |
TORCH_CHECK(head_dim == sin.size(-1), "sin table does not match head_dim"); | |
int bsz = x.size(0); | |
int rows_per_batch = x.numel() / head_dim / bsz; | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
rope_cuda | |
( | |
&tuningParams, | |
(half*) x.data_ptr(), | |
(half*) sin.data_ptr(), | |
(half*) cos.data_ptr(), | |
bsz, | |
rows_per_batch, | |
head_dim, | |
num_heads, | |
past_len | |
); | |
} | |
// Repetition penalty (CPU) | |
void rep_penalty | |
( | |
torch::Tensor sequence, | |
torch::Tensor rep_mask, | |
float penalty_max, | |
int sustain, | |
int decay | |
) | |
{ | |
TORCH_CHECK_DTYPE(sequence, kLong); | |
TORCH_CHECK_DTYPE(rep_mask, kFloat); | |
int vocab_size = rep_mask.size(0); | |
int seq_len = sequence.size(-1); | |
// TODO: Support batch size | |
rep_penalty_cpu | |
( | |
vocab_size, | |
(uint64_t*) sequence.data_ptr(), | |
(float*) rep_mask.data_ptr(), | |
penalty_max, | |
sustain, | |
decay, | |
seq_len | |
); | |
} | |
void apply_rep_penalty | |
( | |
torch::Tensor sequence, | |
float penalty_max, | |
int sustain, | |
int decay, | |
torch::Tensor logits | |
) | |
{ | |
TORCH_CHECK_DTYPE(sequence, kLong); | |
TORCH_CHECK_DTYPE(logits, kFloat); | |
TORCH_CHECK_SHAPES(sequence, 0, logits, 0, 1); | |
int vocab_size = logits.size(-1); | |
int bsz = sequence.size(0); | |
int seq_len = sequence.size(-1); | |
for (int i = 0; i < bsz; i++) | |
{ | |
apply_rep_penalty_cpu | |
( | |
vocab_size, | |
((uint64_t*) sequence.data_ptr()) + i * seq_len, | |
penalty_max, | |
sustain, | |
decay, | |
seq_len, | |
((float*) logits.data_ptr()) + i * vocab_size | |
); | |
} | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | |
{ | |
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); | |
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); | |
m.def("cleanup", &cleanup, "cleanup"); | |
m.def("make_q4", &make_q4, "make_q4"); | |
m.def("q4_matmul", &q4_matmul, "q4_matmul"); | |
m.def("q4_matmul_lora", &q4_matmul_lora, "q4_matmul_lora"); | |
m.def("q4_attn", &q4_attn, "q4_attn"); | |
m.def("q4_attn_2", &q4_attn_2, "q4_attn_2"); | |
m.def("q4_mlp", &q4_mlp, "q4_mlp"); | |
m.def("column_remap", &column_remap, "column_remap"); | |
m.def("rms_norm", &rms_norm, "rms_norm"); | |
m.def("rope_", &rope_, "rope_"); | |
m.def("half_matmul", &half_matmul, "half_matmul"); | |
m.def("half_matmul_cublas", &half_matmul_cublas, "half_matmul_cublas"); | |
m.def("rep_penalty", &rep_penalty, "rep_penalty"); | |
m.def("apply_rep_penalty", &apply_rep_penalty, "apply_rep_penalty"); | |
} | |