chatglm_6b_onnx / onnx2engine.py
tonycloud's picture
Add onnx2engine.py
bda3cf5
import tensorrt as trt
from itertools import tee
from polygraphy.backend.trt import (
network_from_onnx_path,
engine_from_network,
save_engine,
Profile,
)
from polygraphy.backend.trt import CreateConfig
from tensorrt import PreviewFeature, MemoryPoolType
batch_size = 1
max_length = 2048
opt_length = max_length // 2
profiles = [Profile().add(
"input_ids",
min=(batch_size, 1),
opt=(batch_size, opt_length), # Optimized based on the inputs.
max=(batch_size, max_length),
).add(
"position_ids",
min=(batch_size, 2,1),
opt=(batch_size, 2, opt_length), # Optimized based on the inputs.
max=(batch_size, 2,max_length),
).add(
"attention_mask",
min=(batch_size, 1,1,1),
opt=(batch_size, 1,opt_length,opt_length), # Optimized based on the inputs.
max=(batch_size, 1,max_length,max_length),
)]
def get_network_definition(network_definition):
def pairwise(iterable):
a, b = tee(iterable)
next(b, None)
return zip(a, b)
indices = list(range(0, network_definition[1].num_layers))
for i, i_next in pairwise(indices):
l = network_definition[1].get_layer(i)
l_next = network_definition[1].get_layer(i_next)
if not all([l.get_output(i).is_execution_tensor for i in range(l.num_outputs)]):
continue
if l.get_output_type(0) != trt.float32:
continue
if l.type == trt.LayerType.ELEMENTWISE and l_next.type == trt.LayerType.REDUCE:
l.__class__ = getattr(trt, "IElementWiseLayer")
if l.op == trt.ElementWiseOperation.POW:
l.precision = trt.float32
l.set_output_type(0, trt.float32)
l_next.precision = trt.float32
l_next.set_output_type(0, trt.float32)
return network_definition
input_fpath = "./model6b_onnx_pkv/model.onnx"
preview_features = [PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
trt_inference_config = CreateConfig(
fp16=True,
memory_pool_limits = {MemoryPoolType.WORKSPACE: 2048 * 1024 * 1024},
profiles=profiles,
precision_constraints=("obey"),
preview_features=preview_features
)
onnx_network = network_from_onnx_path(input_fpath)
network_definition = get_network_definition(onnx_network)
print(network_definition)
print(trt_inference_config)
trt_engine = engine_from_network(network_definition, trt_inference_config)
print(trt_engine)
output_fpath = "./model6b_trt_pkv/out.engine"
save_engine(trt_engine, output_fpath)