duzx16
commited on
Commit
•
8b97bf2
1
Parent(s):
3cbefb1
Add cpu kernel
Browse files- quantization.py +185 -9
quantization.py
CHANGED
@@ -1,20 +1,22 @@
|
|
1 |
-
from torch.nn import Linear
|
2 |
-
from torch.nn.parameter import Parameter
|
3 |
-
|
4 |
import bz2
|
5 |
import torch
|
6 |
import base64
|
7 |
import ctypes
|
|
|
|
|
|
|
|
|
|
|
8 |
from transformers.utils import logging
|
9 |
|
10 |
from typing import List
|
11 |
-
from functools import partial
|
12 |
|
13 |
logger = logging.get_logger(__name__)
|
14 |
|
15 |
try:
|
16 |
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
17 |
|
|
|
18 |
class Kernel:
|
19 |
def __init__(self, code: bytes, function_names: List[str]):
|
20 |
self.code = code
|
@@ -24,6 +26,7 @@ try:
|
|
24 |
for name in self._function_names:
|
25 |
setattr(self, name, KernelFunction(self._cmodule, name))
|
26 |
|
|
|
27 |
quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
|
28 |
|
29 |
kernels = Kernel(
|
@@ -64,6 +67,175 @@ class W8A16Linear(torch.autograd.Function):
|
|
64 |
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
65 |
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
68 |
with torch.cuda.device(weight.device):
|
69 |
n, m = weight.size(0), weight.size(1)
|
@@ -133,6 +305,7 @@ class QuantizedLinear(torch.nn.Module):
|
|
133 |
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
|
134 |
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
|
135 |
else:
|
|
|
136 |
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
|
137 |
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
138 |
if weight_bit_width == 4:
|
@@ -143,7 +316,10 @@ class QuantizedLinear(torch.nn.Module):
|
|
143 |
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
|
144 |
|
145 |
def forward(self, input):
|
146 |
-
|
|
|
|
|
|
|
147 |
if self.bias is not None:
|
148 |
output = output + self.bias
|
149 |
return output
|
@@ -154,7 +330,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
154 |
for layer in model.layers:
|
155 |
layer.self_attention.query_key_value = QuantizedLinear(
|
156 |
weight_bit_width=weight_bit_width,
|
157 |
-
weight=layer.self_attention.query_key_value.weight
|
158 |
bias=layer.self_attention.query_key_value.bias,
|
159 |
dtype=layer.self_attention.query_key_value.weight.dtype,
|
160 |
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
@@ -162,7 +338,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
162 |
)
|
163 |
layer.self_attention.dense = QuantizedLinear(
|
164 |
weight_bit_width=weight_bit_width,
|
165 |
-
weight=layer.self_attention.dense.weight
|
166 |
bias=layer.self_attention.dense.bias,
|
167 |
dtype=layer.self_attention.dense.weight.dtype,
|
168 |
device=layer.self_attention.dense.weight.device if device is None else device,
|
@@ -170,7 +346,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
170 |
)
|
171 |
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
172 |
weight_bit_width=weight_bit_width,
|
173 |
-
weight=layer.mlp.dense_h_to_4h.weight
|
174 |
bias=layer.mlp.dense_h_to_4h.bias,
|
175 |
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
176 |
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
@@ -178,7 +354,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
|
|
178 |
)
|
179 |
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
180 |
weight_bit_width=weight_bit_width,
|
181 |
-
weight=layer.mlp.dense_4h_to_h.weight
|
182 |
bias=layer.mlp.dense_4h_to_h.bias,
|
183 |
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
184 |
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|
|
|
|
|
|
|
|
|
1 |
import bz2
|
2 |
import torch
|
3 |
import base64
|
4 |
import ctypes
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
from torch.nn.parameter import Parameter
|
10 |
from transformers.utils import logging
|
11 |
|
12 |
from typing import List
|
|
|
13 |
|
14 |
logger = logging.get_logger(__name__)
|
15 |
|
16 |
try:
|
17 |
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
18 |
|
19 |
+
|
20 |
class Kernel:
|
21 |
def __init__(self, code: bytes, function_names: List[str]):
|
22 |
self.code = code
|
|
|
26 |
for name in self._function_names:
|
27 |
setattr(self, name, KernelFunction(self._cmodule, name))
|
28 |
|
29 |
+
|
30 |
quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
|
31 |
|
32 |
kernels = Kernel(
|
|
|
67 |
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
68 |
|
69 |
|
70 |
+
default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
|
71 |
+
default_cpu_kernel_code = "QlpoOTFBWSZTWXLbSoQAAgzbgERwQXxmTwAAr/ff3kABt0Q2oRVT0hpo9RtEAAAAyBEiSQ9EGjQGQAAAwANGhowjJoNGmgMEUplMTNSMJ5TQaDJpsoMyRMj8P4mZzFSVVwqSXG8GG7MlVwiToYEQwVD7noBxMhNfkeZYtYFtbgOBUSIGtIQjhNHCEnPJsadhb3yBmRIOD3TeAtNLSaU5GgvKUBWSNuuOIHmVt0YhW6rsmDMDUjeUJGJ64R1Jm5lrh0Aa0tKjhFwPdWcGogxLDSXPWQUWTM8Sd3Qz1HMYNxx3HMeiNqNo4jeRDEfZ3gUSHIcU/heomq0vEzL1Msz5KKGxH8FrNOYw3KaxdqaEmNHYMxJFgQbR0DyRknL2L4kwUSxKRdhjRpEtUqilVfggFL1klaMS3PPRDfNqbBOPWO7m4JTVGhS9QTBDDJaEbLbrUQNB+IpJSKQbG5SZZ5gkwJEhJ3aYKJipZ/i7kinChIOW2lQg"
|
72 |
+
default_cpu_parallel_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
73 |
+
"quantization_kernels_parallel.c")
|
74 |
+
default_cpu_parallel_kernel_code = "QlpoOTFBWSZTWUzax5EAALXbgERwSX1mTwAAr/ff3kACNyXUbZYwBpoaNGIyAaADQwRSaVP9QoMg0A2oAPU0AEUkU9GaaKMaQB6gA09T1ARRKnpk0niaJkaaNDJ6g0DTIKVKfZ/g6v1Kem5LJLa0WmkukkuCIHUqWbtJGJMsCSQFiPEIYHgBIZDzR8R6REbYxIqD2Cu7lMkFoPu6LmHeOAy0GF83Tc40jgmTs4HnCe60QfJa2bDBZ0Y1lhgbiZjW8SNsAKCk42UOEdjWN3KoiCIYeQUCCKWIyHewhtSoInLKSG22l4jKM2ZDCVKtBm3OTYBl3jsVqMImtj7PQw7xKxLXQzwgJaPPgW1fRhrvPJICl4YFDYfNbkbBh5JDgrazFml50xEQQwQUjxNwE0IDSofLzSg7UNVKn+Rr1KErzBHUxBqdHRlXzqYsIa5K9Y0UuE2ugw3g5KYofm7AaGNTzJSMhcchhxdaU4JZ0F1UNgQ8XcGDguypqYza8yFaEoGgNRcLej+g2t0feGKFE5OY2PFluQ3q4HgycxlfvzHqo0KcM0JI8OKXtzayJFgsqC1NdUQVu8rChnA6FO3MFyGOoC9KO8ITPpYM5pRqTlczFkLES/4u5IpwoSCZtY8i"
|
75 |
+
|
76 |
+
|
77 |
+
class CPUKernel:
|
78 |
+
def __init__(self, kernel_file="", source_code=default_cpu_kernel_code_path, compile_parallel_kernel=None,
|
79 |
+
parallel_num=None):
|
80 |
+
self.load = False
|
81 |
+
self.int8WeightExtractionFloat = None
|
82 |
+
self.int4WeightExtractionFloat = None
|
83 |
+
self.int4WeightCompression = None
|
84 |
+
self.SetNumThreads = lambda x: x
|
85 |
+
|
86 |
+
try:
|
87 |
+
if not os.path.exists(default_cpu_kernel_code_path):
|
88 |
+
with open(default_cpu_kernel_code_path, "w", encoding="utf-8") as file:
|
89 |
+
code = default_cpu_kernel_code
|
90 |
+
cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode()
|
91 |
+
file.write(cpu_quantization_code)
|
92 |
+
|
93 |
+
if not os.path.exists(default_cpu_parallel_kernel_code_path):
|
94 |
+
with open(default_cpu_parallel_kernel_code_path, "w", encoding="utf-8") as file:
|
95 |
+
code = default_cpu_parallel_kernel_code
|
96 |
+
cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode()
|
97 |
+
file.write(cpu_quantization_code)
|
98 |
+
|
99 |
+
except Exception:
|
100 |
+
logger.warning("Error when generating default cpu kernel code.")
|
101 |
+
|
102 |
+
if compile_parallel_kernel is None:
|
103 |
+
compile_parallel_kernel = bool(int(os.cpu_count()) >= 4)
|
104 |
+
|
105 |
+
if compile_parallel_kernel and source_code == default_cpu_kernel_code_path:
|
106 |
+
source_code = default_cpu_parallel_kernel_code_path
|
107 |
+
|
108 |
+
kernels = None
|
109 |
+
|
110 |
+
if (not kernel_file) or (not os.path.exists(kernel_file)):
|
111 |
+
try:
|
112 |
+
if os.path.exists(source_code):
|
113 |
+
kernel_file = source_code[:-2] + ".so"
|
114 |
+
|
115 |
+
if compile_parallel_kernel:
|
116 |
+
if sys.platform != 'darwin':
|
117 |
+
compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format(
|
118 |
+
source_code, kernel_file)
|
119 |
+
else:
|
120 |
+
compile_command = "clang -O3 -fPIC -pthread -Xclang -fopenmp -lomp -std=c99 {} -shared -o {}".format(
|
121 |
+
source_code, kernel_file)
|
122 |
+
exit_state = os.system(compile_command)
|
123 |
+
if not exit_state:
|
124 |
+
try:
|
125 |
+
kernels = ctypes.cdll.LoadLibrary(kernel_file)
|
126 |
+
except:
|
127 |
+
logger.warning(
|
128 |
+
f"Load parallel cpu kernel failed {kernel_file}: {traceback.format_exc()}")
|
129 |
+
else:
|
130 |
+
logger.warning(f"Compile parallel cpu kernel {compile_command} failed.")
|
131 |
+
|
132 |
+
if kernels is None: # adjust config, use default cpu kernel
|
133 |
+
compile_parallel_kernel = False
|
134 |
+
source_code = default_cpu_kernel_code_path
|
135 |
+
kernel_file = source_code[:-2] + ".so"
|
136 |
+
|
137 |
+
if kernels is None:
|
138 |
+
compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file)
|
139 |
+
exit_state = os.system(compile_command)
|
140 |
+
if not exit_state:
|
141 |
+
try:
|
142 |
+
kernels = ctypes.cdll.LoadLibrary(kernel_file)
|
143 |
+
except:
|
144 |
+
logger.warning(f"Load cpu kernel {kernel_file} failed: {traceback.format_exc()}")
|
145 |
+
else:
|
146 |
+
logger.warning(f"Compile cpu kernel {compile_command} failed.")
|
147 |
+
else:
|
148 |
+
logger.warning("Kernel source code not found.")
|
149 |
+
return
|
150 |
+
except:
|
151 |
+
logger.warning(f"Failed to build cpu kernel: {traceback.format_exc()}")
|
152 |
+
return
|
153 |
+
else:
|
154 |
+
try:
|
155 |
+
kernels = ctypes.cdll.LoadLibrary(kernel_file)
|
156 |
+
except:
|
157 |
+
logger.warning(f"Load custom cpu kernel {kernel_file} failed: {traceback.format_exc()}")
|
158 |
+
|
159 |
+
if kernels is not None:
|
160 |
+
self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float
|
161 |
+
self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float
|
162 |
+
self.int4WeightCompression = kernels.compress_int4_weight
|
163 |
+
if compile_parallel_kernel:
|
164 |
+
try:
|
165 |
+
self.SetNumThreads = kernels.set_num_threads
|
166 |
+
except:
|
167 |
+
logger.warning("No set_num_threads() found in kernel.")
|
168 |
+
self.load = True
|
169 |
+
|
170 |
+
if compile_parallel_kernel:
|
171 |
+
if parallel_num is None:
|
172 |
+
parallel_num = max(os.cpu_count(), 1)
|
173 |
+
self.SetNumThreads(parallel_num)
|
174 |
+
|
175 |
+
self.parallel_num = parallel_num
|
176 |
+
|
177 |
+
|
178 |
+
cpu_kernels = CPUKernel()
|
179 |
+
|
180 |
+
|
181 |
+
def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int,
|
182 |
+
quantization_cache=None):
|
183 |
+
"""extract weight on cpu to float32"""
|
184 |
+
if source_bit_width == 8:
|
185 |
+
func = cpu_kernels.int8WeightExtractionFloat
|
186 |
+
elif source_bit_width == 4:
|
187 |
+
func = cpu_kernels.int4WeightExtractionFloat
|
188 |
+
else:
|
189 |
+
assert False, "Unsupported bit-width"
|
190 |
+
|
191 |
+
n, m = weight.size(0), weight.size(1)
|
192 |
+
|
193 |
+
if quantization_cache is not None:
|
194 |
+
out = quantization_cache
|
195 |
+
func(
|
196 |
+
ctypes.c_void_p(weight.data_ptr()),
|
197 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
198 |
+
ctypes.c_void_p(out.data_ptr()),
|
199 |
+
ctypes.c_int32(n),
|
200 |
+
ctypes.c_int32(m)
|
201 |
+
)
|
202 |
+
return out.tensor
|
203 |
+
else:
|
204 |
+
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.float, device="cpu")
|
205 |
+
func(
|
206 |
+
ctypes.c_void_p(weight.data_ptr()),
|
207 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
208 |
+
ctypes.c_void_p(out.data_ptr()),
|
209 |
+
ctypes.c_int32(n),
|
210 |
+
ctypes.c_int32(m)
|
211 |
+
)
|
212 |
+
return out
|
213 |
+
|
214 |
+
|
215 |
+
class W8A16LinearCPU(torch.autograd.Function):
|
216 |
+
@staticmethod
|
217 |
+
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width,
|
218 |
+
quantization_cache=None):
|
219 |
+
ctx.inp_shape = inp.size()
|
220 |
+
ctx.weight_bit_width = weight_bit_width
|
221 |
+
out_features = quant_w.size(0)
|
222 |
+
inp = inp.contiguous().view(-1, inp.size(-1))
|
223 |
+
weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
|
224 |
+
ctx.weight_shape = weight.size()
|
225 |
+
output = inp.mm(weight.t())
|
226 |
+
ctx.save_for_backward(inp, quant_w, scale_w)
|
227 |
+
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
228 |
+
|
229 |
+
@staticmethod
|
230 |
+
def backward(ctx, grad_output: torch.Tensor):
|
231 |
+
inp, quant_w, scale_w = ctx.saved_tensors
|
232 |
+
weight = extract_weight_to_float(quant_w, scale_w, ctx.weight_bit_width)
|
233 |
+
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
234 |
+
grad_input = grad_output.mm(weight)
|
235 |
+
grad_weight = grad_output.t().mm(inp)
|
236 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
237 |
+
|
238 |
+
|
239 |
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
240 |
with torch.cuda.device(weight.device):
|
241 |
n, m = weight.size(0), weight.size(1)
|
|
|
305 |
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
|
306 |
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
|
307 |
else:
|
308 |
+
weight = weight.to(torch.cuda.current_device())
|
309 |
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
|
310 |
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
311 |
if weight_bit_width == 4:
|
|
|
316 |
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
|
317 |
|
318 |
def forward(self, input):
|
319 |
+
if self.weight.device == torch.device("cpu"):
|
320 |
+
output = W8A16LinearCPU.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
321 |
+
else:
|
322 |
+
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
323 |
if self.bias is not None:
|
324 |
output = output + self.bias
|
325 |
return output
|
|
|
330 |
for layer in model.layers:
|
331 |
layer.self_attention.query_key_value = QuantizedLinear(
|
332 |
weight_bit_width=weight_bit_width,
|
333 |
+
weight=layer.self_attention.query_key_value.weight,
|
334 |
bias=layer.self_attention.query_key_value.bias,
|
335 |
dtype=layer.self_attention.query_key_value.weight.dtype,
|
336 |
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
|
|
338 |
)
|
339 |
layer.self_attention.dense = QuantizedLinear(
|
340 |
weight_bit_width=weight_bit_width,
|
341 |
+
weight=layer.self_attention.dense.weight,
|
342 |
bias=layer.self_attention.dense.bias,
|
343 |
dtype=layer.self_attention.dense.weight.dtype,
|
344 |
device=layer.self_attention.dense.weight.device if device is None else device,
|
|
|
346 |
)
|
347 |
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
348 |
weight_bit_width=weight_bit_width,
|
349 |
+
weight=layer.mlp.dense_h_to_4h.weight,
|
350 |
bias=layer.mlp.dense_h_to_4h.bias,
|
351 |
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
352 |
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
|
|
354 |
)
|
355 |
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
356 |
weight_bit_width=weight_bit_width,
|
357 |
+
weight=layer.mlp.dense_4h_to_h.weight,
|
358 |
bias=layer.mlp.dense_4h_to_h.bias,
|
359 |
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
360 |
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|