File size: 2,361 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

from typing import Tuple, Union

import numpy as np

INT4_MIN = -8
INT4_MAX = 7
UINT4_MIN = 0
UINT4_MAX = 15


def float32_to_4bit_unpacked(

    x: Union[np.ndarray, np.dtype, float], signed: bool

) -> np.ndarray:
    """Cast to 4bit via rounding and clipping (without packing).



    Args:

        x: element to be converted

        signed: boolean, whether to convert to signed int4.



    Returns:

        An ndarray with a single int4 element (sign-extended to int8/uint8)

    """
    dtype = np.int8 if signed else np.uint8
    clip_low = INT4_MIN if signed else UINT4_MIN
    clip_high = INT4_MAX if signed else UINT4_MAX
    if not isinstance(x, np.ndarray):
        x = np.asarray(x)

    return np.rint(np.clip(x, clip_low, clip_high)).astype(dtype)  # type: ignore[no-any-return]


def float32x2_to_4bitx2(

    val_low: np.dtype, val_high: np.dtype, signed: bool

) -> np.ndarray:
    """Cast two elements to 4bit (via rounding and clipping) and pack

    to a single byte

    Args:

        val_low: element to be packed in the 4 LSB

        val_high: element to be packed in the 4 MSB

        signed: boolean, whether to convert to signed int4.



    Returns:

        An ndarray with a single int8/uint8 element, containing both int4 elements

    """
    i8_high = float32_to_4bit_unpacked(val_high, signed)
    i8_low = float32_to_4bit_unpacked(val_low, signed)
    return i8_high << 4 | i8_low & 0x0F  # type: ignore[operator]


def unpack_single_4bitx2(

    x: Union[np.ndarray, np.dtype, float], signed: bool

) -> Tuple[np.ndarray, np.ndarray]:
    unpack_signed = lambda x: np.where((x >> 3) == 0, x, x | 0xF0)  # noqa: E731
    """Unpack a single byte 4bitx2 to two 4 bit elements

    Args:

        x: Input data

        signed: boolean, whether to interpret as signed int4.

    Returns:

        A tuple of ndarrays containing int4 elements (sign-extended to int8/uint8)

    """
    if not isinstance(x, np.ndarray):
        x = np.asarray(x)
    x_low = x & 0x0F
    x_high = x >> 4
    x_low = unpack_signed(x_low) if signed else x_low
    x_high = unpack_signed(x_high) if signed else x_high
    dtype = np.int8 if signed else np.uint8
    return (x_low.astype(dtype), x_high.astype(dtype))