|
from typing import List |
|
import torch |
|
|
|
from ._ops import ops |
|
|
|
|
|
def w8_a16_gemm( |
|
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor |
|
) -> torch.Tensor: |
|
return ops.w8_a16_gemm(input, weight, scale) |
|
|
|
|
|
def w8_a16_gemm_( |
|
input: torch.Tensor, |
|
weight: torch.Tensor, |
|
scale: torch.Tensor, |
|
output: torch.Tensor, |
|
m: int, |
|
n: int, |
|
k: int, |
|
) -> torch.Tensor: |
|
return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k) |
|
|
|
|
|
def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor: |
|
return ops.preprocess_weights(origin_weight, is_int4) |
|
|
|
|
|
def quant_weights( |
|
origin_weight: torch.Tensor, |
|
quant_type: torch.dtype, |
|
return_unprocessed_quantized_tensor: bool, |
|
) -> List[torch.Tensor]: |
|
return ops.quant_weights( |
|
origin_weight, quant_type, return_unprocessed_quantized_tensor |
|
) |
|
|