File size: 864 Bytes
a80ce84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import List
import torch

from ._ops import ops


def w8_a16_gemm(
    input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
    return ops.w8_a16_gemm(input, weight, scale)


def w8_a16_gemm_(
    input: torch.Tensor,
    weight: torch.Tensor,
    scale: torch.Tensor,
    output: torch.Tensor,
    m: int,
    n: int,
    k: int,
) -> torch.Tensor:
    return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k)


def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor:
    return ops.preprocess_weights(origin_weight, is_int4)


def quant_weights(
    origin_weight: torch.Tensor,
    quant_type: torch.dtype,
    return_unprocessed_quantized_tensor: bool,
) -> List[torch.Tensor]:
    return ops.quant_weights(
        origin_weight, quant_type, return_unprocessed_quantized_tensor
    )