import torch | |
def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16): | |
_step = W_q.shape[0] | |
W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device) | |
W_r[:_step] = (W_q & 0b11110000) >> 4 | |
W_r[_step:] = W_q & 0b00001111 | |
W_r.sub_(zero).mul_(scale) | |
return W_r.reshape(orig_shape) | |