| from typing import List | |
| import torch | |
| from ._ops import ops | |
| def w8_a16_gemm( | |
| input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor | |
| ) -> torch.Tensor: | |
| return ops.w8_a16_gemm(input, weight, scale) | |
| def w8_a16_gemm_( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| scale: torch.Tensor, | |
| output: torch.Tensor, | |
| m: int, | |
| n: int, | |
| k: int, | |
| ) -> torch.Tensor: | |
| return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k) | |
| def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor: | |
| return ops.preprocess_weights(origin_weight, is_int4) | |
| def quant_weights( | |
| origin_weight: torch.Tensor, | |
| quant_type: torch.dtype, | |
| return_unprocessed_quantized_tensor: bool, | |
| ) -> List[torch.Tensor]: | |
| return ops.quant_weights( | |
| origin_weight, quant_type, return_unprocessed_quantized_tensor | |
| ) | |