File size: 1,267 Bytes
452b173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#ifndef _q4_attn_cuh
#define _q4_attn_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>

#include "../tuning.h"
#include "q4_matrix.cuh"

void q4_attn_cuda
(
    ExLlamaTuning* tuningParams,
    cudaStream_t stream,
    cublasHandle_t handle,
    half* x,
    const half* rms_norm_weight,    // shape == (x.shape[1],) == (dim,)
    float epsilon,
    half* query_states,
    half* key_states,
    half* value_states,
    Q4Matrix* q_proj,
    Q4Matrix* k_proj,
    Q4Matrix* v_proj,
    half* sin,
    half* cos,
    const int bsz,
    const int q_len,
    const int dim,
    const int head_dim,
    const int num_heads,
    const int num_kv_heads,
    const int past_len,
    half* key_cache,
    half* value_cache,
    const half* q_a,
    const half* q_b,
    const int q_rank,
    const half* k_a,
    const half* k_b,
    const int k_rank,
    const half* v_a,
    const half* v_b,
    const int v_rank,
    half* lora_temp,
    const int max_seq_len,
    const int device_index
);

void q4_attn_2_cuda
(
    ExLlamaTuning* tuningParams,
    cublasHandle_t handle,
    half* x,
    half* attn_output,
    Q4Matrix* o_proj,
    const int height,
    const half* o_a,
    const half* o_b,
    const int o_rank,
    half* lora_temp
);

#endif