Spaces:
Runtime error
Runtime error
OpenHathi-SI-v0.1
/
env
/Lib
/site-packages
/bitsandbytes
/triton
/quantize_columnwise_and_transpose.py
import math | |
import torch | |
import time | |
from bitsandbytes.triton.triton_utils import is_triton_available | |
if not is_triton_available(): | |
def quantize_columnwise_and_transpose(x: torch.Tensor): return None | |
else: | |
import triton | |
import triton.language as tl | |
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time | |
# This kernel does fused columnwise quantization and transpose. | |
# TODO: autotune this better. | |
def _quantize_columnwise_and_transpose( | |
x_ptr, | |
output_ptr, | |
output_maxs, | |
n_elements, | |
M : tl.constexpr, N : tl.constexpr, | |
BLOCK_SIZE: tl.constexpr, | |
P2: tl.constexpr, | |
): | |
pid = tl.program_id(axis=0) | |
block_start = pid | |
p2_arange = tl.arange(0, P2) | |
p2_arange_mask = p2_arange < M | |
arange = p2_arange * N | |
offsets = block_start + arange | |
x = tl.load(x_ptr + offsets, mask=p2_arange_mask) | |
abs_x = tl.abs(x) | |
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) | |
output = tl.libdevice.llrint(127. * (x / max_val)) | |
new_start = pid * M | |
new_offsets = new_start + p2_arange | |
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) | |
tl.store(output_maxs + pid, max_val) | |
def quantize_columnwise_and_transpose(x: torch.Tensor): | |
M, N = x.shape | |
output = torch.empty(N, M, device=x.device, dtype=torch.int8) | |
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) | |
P2 = int(2 ** (math.ceil(math.log2(M)))) | |
assert x.is_cuda and output.is_cuda | |
n_elements = output.numel() | |
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) | |
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) | |
return output, output_maxs | |