File size: 4,426 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch


# Pack pairs of int4 values into int8, in row major order; first int4
# value goes into lower order bits, and second int4 value into higher
# order bits of resulting int8 value.
def pack_int4_to_int8(weight):
    assert weight.dim() == 2
    assert weight.shape[1] % 2 == 0
    assert weight.dtype == torch.int8
    return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF)


# Unpack quandruples of bits in int8 values into int4 values, in row
# major order; lower 4 bits go into first int4 value goes, and upper 4
# bits go into second int4 value.
def unpack_int8_to_int4(weight):
    assert weight.dim() == 2
    assert weight.dtype == torch.int8
    return torch.stack((weight & 0xF, (weight >> 4) & 0xF), dim=2).view(
        weight.shape[0], 2 * weight.shape[1]
    )


# Transpose the weight matrix, and then reorder its elements according
# to underlying requirements of CUTLASS library, so that it could be
# used for CUTLASS-based mixed datatypes linear operation.
def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(

    weight, dtypeq, transpose=False

):
    assert weight.dim() == 2
    assert weight.dtype == torch.int8
    assert dtypeq == torch.int8 or dtypeq == torch.quint4x2
    assert weight.device.type == "cuda"

    device = weight.device

    # subbyte_transpose
    if not transpose:
        if dtypeq == torch.int8:
            outp = weight.T
        elif dtypeq == torch.quint4x2:
            outp = pack_int4_to_int8(unpack_int8_to_int4(weight.view(torch.int8)).T)
    else:
        outp = weight

    ncols, nrows = outp.shape  # type: ignore[possibly-undefined]
    assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0
    assert ncols % 64 == 0

    # permute_B_rows_for_mixed_gemm
    # (permute cols actually, as transpose is applied first here)
    if dtypeq == torch.quint4x2:
        cols_permuted = (
            torch.tensor(
                [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
                device=device,
            )
            + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand(
                nrows // 16, 16
            )
        ).view(-1)
    else:
        cols_permuted = (
            torch.tensor(
                [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15],
                device=device,
            )
            + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand(
                nrows // 16, 16
            )
        ).view(-1)
    outp = outp.index_copy(1, cols_permuted, outp)

    # interleave_column_major_tensor
    magic0 = 4 if dtypeq == torch.quint4x2 else 2
    magic1 = 32 // magic0

    tmp0 = (
        (torch.arange(0, ncols // magic0, device=device) * (nrows // 4 * magic0))
        .view(-1, 1)
        .repeat(1, nrows // 4 * magic0)
        .view(-1)
    )
    tmp1 = (
        (torch.arange(0, nrows // 4 // magic1, device=device) * (magic0 * magic1))
        .view(-1, 1)
        .repeat(1, magic1)
        .view(-1)
        .repeat(ncols)
    )
    tmp2 = (
        (torch.arange(0, magic0, device=device) * magic1)
        .view(-1, 1)
        .repeat(1, nrows // 4)
        .view(-1)
        .repeat(ncols // magic0)
    )
    tmp3 = torch.arange(0, magic1, device=device).repeat(nrows // 4 * ncols // magic1)

    outp_offsets = tmp0 + tmp1 + tmp2 + tmp3

    tmp = outp.view(-1).view(torch.int32)
    outp = torch.zeros_like(tmp)
    outp.scatter_(0, outp_offsets, tmp)
    outp = outp.view(weight.dtype)

    # add_bias_and_interleave_quantized_tensor_inplace
    tmp = outp.view(-1)

    outp = torch.empty_like(tmp)
    if dtypeq == torch.int8:
        tmp = (tmp.to(torch.int) + 128).to(tmp.dtype)
        outp[0::4] = tmp[0::4]
        outp[1::4] = tmp[2::4]
        outp[2::4] = tmp[1::4]
        outp[3::4] = tmp[3::4]
    elif dtypeq == torch.quint4x2:
        tmp0 = ((tmp & 0xF) + 8) & 0xF
        tmp0 = (tmp0[1::2] << 4) | tmp0[0::2]
        tmp1 = (((tmp >> 4) & 0xF) + 8) & 0xF
        tmp1 = (tmp1[1::2] << 4) | tmp1[0::2]
        outp[0::4] = tmp0[0::2]
        outp[1::4] = tmp0[1::2]
        outp[2::4] = tmp1[0::2]
        outp[3::4] = tmp1[1::2]

    if dtypeq == torch.quint4x2:
        nrows *= 2
        ncols //= 2

    return outp.view(nrows, ncols).view(torch.uint8)