File size: 5,218 Bytes
5a29263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#endif


uint a_offset, b_offset, d_offset, y_offset;

void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
        const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
        const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
        const uint iybs = col - col%QUANT_K; // y block start index

#if K_PER_ITER == 8
#if QUANT_R == 2
        const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
        const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
        const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
        const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
        const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
        const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
        // Check if the second of the pair of elements is OOB, and don't fetch B or
        // accumulate it. We still fetch a pair of elements for A, which is fine for
        // quantized formats since they'll be within the same block. We should
        // probably skip fetching the second element for F16/F32, but as of now we
        // still do.
        const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);

        FLOAT_TYPE b0 = 0, b1 = 0;
        b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
        if (!OOB) {
            b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
        }
#endif
        uint ibi = first_row*p.ncols;
        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
            const uint ib = (ibi + col)/QUANT_K; // block index
            ibi += p.ncols;

#if K_PER_ITER == 8
            vec4 v = dequantize4(ib, iqs, a_offset);
            vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);

            const vec2 dm = get_dm(ib, a_offset);
            if (dm.y != 0) { // quant has min component
                v = v * dm.x + dm.y;
                v2 = v2 * dm.x + dm.y;
            }

            // matrix multiplication
            FLOAT_TYPE rowtmp = dot(bv0, v);
            rowtmp += dot(bv1, v2);

            if (dm.y == 0)
                rowtmp *= dm.x;

            temp[j][n] += rowtmp;
#else
            const vec2 v = dequantize(ib, iqs, a_offset);

            // matrix multiplication
            temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
            if (!OOB) {
                temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
            }
#endif
        }
    }
}

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
    const uint tid = gl_LocalInvocationID.x;

    get_offsets(a_offset, b_offset, d_offset);
    a_offset /= QUANT_K;

    y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;

    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
        [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
            temp[j][i] = FLOAT_TYPE(0);
        }
    }

    uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
    if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
        num_iters++;
    }
    int unroll_count = 4;
    uint unrolled_iters = num_iters & ~(unroll_count - 1);

    uint i = 0;
    while (i < unrolled_iters) {
        // Manually partially unroll the loop
        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
            i++;
        }
    }
    unroll_count = 2;
    unrolled_iters = num_iters & ~(unroll_count - 1);
    while (i < unrolled_iters) {
        // Manually partially unroll the loop
        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
            iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
            i++;
        }
    }
    while (i < num_iters) {
        iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
        i++;
    }

    reduce_result(temp, d_offset, first_row, num_rows, tid);
}

void main() {
    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
    init_iq_shmem(gl_WorkGroupSize);
#endif

    // do NUM_ROWS at a time, unless there aren't enough remaining rows
    if (first_row + NUM_ROWS <= p.stride_d) {
        compute_outputs(first_row, NUM_ROWS);
    } else {
        if (first_row >= p.stride_d) {
            return;
        }
        compute_outputs(first_row, p.stride_d - first_row);
    }
}