from typing import List
import torch

from ._ops import ops


def w8_a16_gemm(
    input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
    return ops.w8_a16_gemm(input, weight, scale)


def w8_a16_gemm_(
    input: torch.Tensor,
    weight: torch.Tensor,
    scale: torch.Tensor,
    output: torch.Tensor,
    m: int,
    n: int,
    k: int,
) -> torch.Tensor:
    return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k)


def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor:
    return ops.preprocess_weights(origin_weight, is_int4)


def quant_weights(
    origin_weight: torch.Tensor,
    quant_type: torch.dtype,
    return_unprocessed_quantized_tensor: bool,
) -> List[torch.Tensor]:
    return ops.quant_weights(
        origin_weight, quant_type, return_unprocessed_quantized_tensor
    )