File size: 2,653 Bytes
bda3cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import tensorrt as trt
from itertools import tee

from polygraphy.backend.trt import (
    network_from_onnx_path,
    engine_from_network,
    save_engine,
    Profile,
)

from polygraphy.backend.trt import CreateConfig
from tensorrt import PreviewFeature, MemoryPoolType

batch_size = 1
max_length  = 2048
opt_length = max_length // 2


profiles = [Profile().add(
    "input_ids",
    min=(batch_size, 1),
    opt=(batch_size, opt_length), # Optimized based on the inputs.
    max=(batch_size, max_length),
).add(
       "position_ids",
    min=(batch_size, 2,1),
    opt=(batch_size, 2, opt_length), # Optimized based on the inputs.
    max=(batch_size, 2,max_length),
).add(
         "attention_mask",
    min=(batch_size, 1,1,1),
    opt=(batch_size, 1,opt_length,opt_length), # Optimized based on the inputs.
    max=(batch_size, 1,max_length,max_length),
)]





def get_network_definition(network_definition):
    def pairwise(iterable):
            a, b = tee(iterable)
            next(b, None)
            return zip(a, b)

    indices = list(range(0, network_definition[1].num_layers))
    for i, i_next in pairwise(indices):
            l = network_definition[1].get_layer(i)
            l_next = network_definition[1].get_layer(i_next)

            if not all([l.get_output(i).is_execution_tensor for i in range(l.num_outputs)]):
                continue

            if l.get_output_type(0) != trt.float32:
                continue

            if l.type == trt.LayerType.ELEMENTWISE and l_next.type == trt.LayerType.REDUCE:
                l.__class__ = getattr(trt, "IElementWiseLayer")
                if l.op == trt.ElementWiseOperation.POW:
                    l.precision = trt.float32
                    l.set_output_type(0, trt.float32)

                l_next.precision = trt.float32
                l_next.set_output_type(0, trt.float32)

    return network_definition


input_fpath = "./model6b_onnx_pkv/model.onnx"


preview_features = [PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]



trt_inference_config = CreateConfig(
                fp16=True,
                memory_pool_limits = {MemoryPoolType.WORKSPACE: 2048 * 1024 * 1024},
                profiles=profiles,
                precision_constraints=("obey"),
                preview_features=preview_features
            )


onnx_network = network_from_onnx_path(input_fpath)


network_definition = get_network_definition(onnx_network)
print(network_definition)
print(trt_inference_config)

trt_engine = engine_from_network(network_definition, trt_inference_config)
print(trt_engine)

output_fpath = "./model6b_trt_pkv/out.engine"
save_engine(trt_engine, output_fpath)