vikhyatk's picture
Upload HfMoondream
80427a0 verified
import torch
def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
_step = W_q.shape[0]
W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
W_r[:_step] = (W_q & 0b11110000) >> 4
W_r[_step:] = W_q & 0b00001111
W_r.sub_(zero).mul_(scale)
return W_r.reshape(orig_shape)