Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
from typing import Tuple, Union | |
import numpy as np | |
INT4_MIN = -8 | |
INT4_MAX = 7 | |
UINT4_MIN = 0 | |
UINT4_MAX = 15 | |
def float32_to_4bit_unpacked( | |
x: Union[np.ndarray, np.dtype, float], signed: bool | |
) -> np.ndarray: | |
"""Cast to 4bit via rounding and clipping (without packing). | |
Args: | |
x: element to be converted | |
signed: boolean, whether to convert to signed int4. | |
Returns: | |
An ndarray with a single int4 element (sign-extended to int8/uint8) | |
""" | |
dtype = np.int8 if signed else np.uint8 | |
clip_low = INT4_MIN if signed else UINT4_MIN | |
clip_high = INT4_MAX if signed else UINT4_MAX | |
if not isinstance(x, np.ndarray): | |
x = np.asarray(x) | |
return np.rint(np.clip(x, clip_low, clip_high)).astype(dtype) # type: ignore[no-any-return] | |
def float32x2_to_4bitx2( | |
val_low: np.dtype, val_high: np.dtype, signed: bool | |
) -> np.ndarray: | |
"""Cast two elements to 4bit (via rounding and clipping) and pack | |
to a single byte | |
Args: | |
val_low: element to be packed in the 4 LSB | |
val_high: element to be packed in the 4 MSB | |
signed: boolean, whether to convert to signed int4. | |
Returns: | |
An ndarray with a single int8/uint8 element, containing both int4 elements | |
""" | |
i8_high = float32_to_4bit_unpacked(val_high, signed) | |
i8_low = float32_to_4bit_unpacked(val_low, signed) | |
return i8_high << 4 | i8_low & 0x0F # type: ignore[operator] | |
def unpack_single_4bitx2( | |
x: Union[np.ndarray, np.dtype, float], signed: bool | |
) -> Tuple[np.ndarray, np.ndarray]: | |
unpack_signed = lambda x: np.where((x >> 3) == 0, x, x | 0xF0) # noqa: E731 | |
"""Unpack a single byte 4bitx2 to two 4 bit elements | |
Args: | |
x: Input data | |
signed: boolean, whether to interpret as signed int4. | |
Returns: | |
A tuple of ndarrays containing int4 elements (sign-extended to int8/uint8) | |
""" | |
if not isinstance(x, np.ndarray): | |
x = np.asarray(x) | |
x_low = x & 0x0F | |
x_high = x >> 4 | |
x_low = unpack_signed(x_low) if signed else x_low | |
x_high = unpack_signed(x_high) if signed else x_high | |
dtype = np.int8 if signed else np.uint8 | |
return (x_low.astype(dtype), x_high.astype(dtype)) | |