import tensorrt as trt from itertools import tee from polygraphy.backend.trt import ( network_from_onnx_path, engine_from_network, save_engine, Profile, ) from polygraphy.backend.trt import CreateConfig from tensorrt import PreviewFeature, MemoryPoolType batch_size = 1 max_length = 2048 opt_length = max_length // 2 profiles = [Profile().add( "input_ids", min=(batch_size, 1), opt=(batch_size, opt_length), # Optimized based on the inputs. max=(batch_size, max_length), ).add( "position_ids", min=(batch_size, 2,1), opt=(batch_size, 2, opt_length), # Optimized based on the inputs. max=(batch_size, 2,max_length), ).add( "attention_mask", min=(batch_size, 1,1,1), opt=(batch_size, 1,opt_length,opt_length), # Optimized based on the inputs. max=(batch_size, 1,max_length,max_length), )] def get_network_definition(network_definition): def pairwise(iterable): a, b = tee(iterable) next(b, None) return zip(a, b) indices = list(range(0, network_definition[1].num_layers)) for i, i_next in pairwise(indices): l = network_definition[1].get_layer(i) l_next = network_definition[1].get_layer(i_next) if not all([l.get_output(i).is_execution_tensor for i in range(l.num_outputs)]): continue if l.get_output_type(0) != trt.float32: continue if l.type == trt.LayerType.ELEMENTWISE and l_next.type == trt.LayerType.REDUCE: l.__class__ = getattr(trt, "IElementWiseLayer") if l.op == trt.ElementWiseOperation.POW: l.precision = trt.float32 l.set_output_type(0, trt.float32) l_next.precision = trt.float32 l_next.set_output_type(0, trt.float32) return network_definition input_fpath = "./model6b_onnx_pkv/model.onnx" preview_features = [PreviewFeature.FASTER_DYNAMIC_SHAPES_0805] trt_inference_config = CreateConfig( fp16=True, memory_pool_limits = {MemoryPoolType.WORKSPACE: 2048 * 1024 * 1024}, profiles=profiles, precision_constraints=("obey"), preview_features=preview_features ) onnx_network = network_from_onnx_path(input_fpath) network_definition = get_network_definition(onnx_network) print(network_definition) print(trt_inference_config) trt_engine = engine_from_network(network_definition, trt_inference_config) print(trt_engine) output_fpath = "./model6b_trt_pkv/out.engine" save_engine(trt_engine, output_fpath)