danieldk's picture
danieldk HF staff
Import EETQ kernels
1dc29e9
raw
history blame
864 Bytes
from typing import List
import torch
from ._ops import ops
def w8_a16_gemm(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
return ops.w8_a16_gemm(input, weight, scale)
def w8_a16_gemm_(
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
output: torch.Tensor,
m: int,
n: int,
k: int,
) -> torch.Tensor:
return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k)
def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor:
return ops.preprocess_weights(origin_weight, is_int4)
def quant_weights(
origin_weight: torch.Tensor,
quant_type: torch.dtype,
return_unprocessed_quantized_tensor: bool,
) -> List[torch.Tensor]:
return ops.quant_weights(
origin_weight, quant_type, return_unprocessed_quantized_tensor
)