Merge pull request #92 from yuekaizhang/triton
Browse files- runtime/triton_trtllm/Dockerfile.server +5 -0
- runtime/triton_trtllm/README.md +45 -0
- runtime/triton_trtllm/client_grpc.py +482 -0
- runtime/triton_trtllm/client_http.py +165 -0
- runtime/triton_trtllm/docker-compose.yml +20 -0
- runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +137 -0
- runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt +58 -0
- runtime/triton_trtllm/model_repo/spark_tts/1/model.py +311 -0
- runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt +65 -0
- runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep +0 -0
- runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt +857 -0
- runtime/triton_trtllm/model_repo/vocoder/1/model.py +106 -0
- runtime/triton_trtllm/model_repo/vocoder/config.pbtxt +53 -0
- runtime/triton_trtllm/run.sh +68 -0
- runtime/triton_trtllm/scripts/convert_checkpoint.py +335 -0
- runtime/triton_trtllm/scripts/fill_template.py +70 -0
runtime/triton_trtllm/Dockerfile.server
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
|
2 |
+
RUN apt-get update && apt-get install -y cmake
|
3 |
+
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
|
4 |
+
RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa
|
5 |
+
WORKDIR /workspace
|
runtime/triton_trtllm/README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Nvidia Triton Inference Serving Best Practice for Spark TTS
|
2 |
+
|
3 |
+
### Quick Start
|
4 |
+
Directly launch the service using docker compose.
|
5 |
+
```sh
|
6 |
+
docker compose up
|
7 |
+
```
|
8 |
+
|
9 |
+
### Build Image
|
10 |
+
Build the docker image from scratch.
|
11 |
+
```sh
|
12 |
+
docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
|
13 |
+
```
|
14 |
+
|
15 |
+
### Create Docker Container
|
16 |
+
```sh
|
17 |
+
your_mount_dir=/mnt:/mnt
|
18 |
+
docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
|
19 |
+
```
|
20 |
+
|
21 |
+
### Export Models to TensorRT-LLM and Launch Server
|
22 |
+
Inside docker container, we would follow the official guide of TensorRT-LLM to build TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen).
|
23 |
+
|
24 |
+
```sh
|
25 |
+
bash run.sh 0 3
|
26 |
+
```
|
27 |
+
### Simple HTTP client
|
28 |
+
```sh
|
29 |
+
python3 client_http.py
|
30 |
+
```
|
31 |
+
|
32 |
+
### Benchmark using Dataset
|
33 |
+
```sh
|
34 |
+
num_task=2
|
35 |
+
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
36 |
+
```
|
37 |
+
|
38 |
+
### Benchmark Results
|
39 |
+
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs, total audio duration 169 secs.
|
40 |
+
|
41 |
+
| Model | Note | Concurrency | Avg Latency | RTF |
|
42 |
+
|-------|-----------|-----------------------|---------|--|
|
43 |
+
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms | 0.1362|
|
44 |
+
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms | 0.0737|
|
45 |
+
| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms | 0.0704|
|
runtime/triton_trtllm/client_grpc.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
# 2023 Nvidia (authors: Yuekai Zhang)
|
4 |
+
# 2023 Recurrent.ai (authors: Songtao Shi)
|
5 |
+
# See LICENSE for clarification regarding multiple authors
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
"""
|
19 |
+
This script supports to load dataset from huggingface and sends it to the server
|
20 |
+
for decoding, in parallel.
|
21 |
+
|
22 |
+
Usage:
|
23 |
+
# For offline Spark-TTS-0.5B
|
24 |
+
# huggingface dataset
|
25 |
+
num_task=2
|
26 |
+
python3 client_grpc.py \
|
27 |
+
--server-addr localhost \
|
28 |
+
--model-name spark_tts \
|
29 |
+
--num-tasks $num_task \
|
30 |
+
--huggingface-dataset yuekai/seed_tts \
|
31 |
+
--split-name wenetspeech4tts \
|
32 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
33 |
+
"""
|
34 |
+
|
35 |
+
import argparse
|
36 |
+
import asyncio
|
37 |
+
import json
|
38 |
+
|
39 |
+
import os
|
40 |
+
import time
|
41 |
+
import types
|
42 |
+
from pathlib import Path
|
43 |
+
|
44 |
+
import numpy as np
|
45 |
+
import soundfile as sf
|
46 |
+
import tritonclient
|
47 |
+
import tritonclient.grpc.aio as grpcclient
|
48 |
+
from tritonclient.utils import np_to_triton_dtype
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def write_triton_stats(stats, summary_file):
|
53 |
+
with open(summary_file, "w") as summary_f:
|
54 |
+
model_stats = stats["model_stats"]
|
55 |
+
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
56 |
+
summary_f.write(
|
57 |
+
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
58 |
+
)
|
59 |
+
summary_f.write("To learn more about the log, please refer to: \n")
|
60 |
+
summary_f.write(
|
61 |
+
"1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n"
|
62 |
+
)
|
63 |
+
summary_f.write(
|
64 |
+
"2. https://github.com/triton-inference-server/server/issues/5374 \n\n"
|
65 |
+
)
|
66 |
+
summary_f.write(
|
67 |
+
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
68 |
+
)
|
69 |
+
summary_f.write(
|
70 |
+
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
71 |
+
)
|
72 |
+
summary_f.write(
|
73 |
+
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
74 |
+
)
|
75 |
+
summary_f.write(
|
76 |
+
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
77 |
+
)
|
78 |
+
for model_state in model_stats:
|
79 |
+
if "last_inference" not in model_state:
|
80 |
+
continue
|
81 |
+
summary_f.write(f"model name is {model_state['name']} \n")
|
82 |
+
model_inference_stats = model_state["inference_stats"]
|
83 |
+
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
84 |
+
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
85 |
+
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
86 |
+
total_output_time_s = (
|
87 |
+
int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
88 |
+
)
|
89 |
+
summary_f.write(
|
90 |
+
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
91 |
+
)
|
92 |
+
model_batch_stats = model_state["batch_stats"]
|
93 |
+
for batch in model_batch_stats:
|
94 |
+
batch_size = int(batch["batch_size"])
|
95 |
+
compute_input = batch["compute_input"]
|
96 |
+
compute_output = batch["compute_output"]
|
97 |
+
compute_infer = batch["compute_infer"]
|
98 |
+
batch_count = int(compute_infer["count"])
|
99 |
+
assert (
|
100 |
+
compute_infer["count"]
|
101 |
+
== compute_output["count"]
|
102 |
+
== compute_input["count"]
|
103 |
+
)
|
104 |
+
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
105 |
+
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
106 |
+
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
107 |
+
summary_f.write(
|
108 |
+
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms/batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms/batch_count/batch_size:.2f} ms \n" # noqa
|
109 |
+
)
|
110 |
+
# summary_f.write(
|
111 |
+
# f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa
|
112 |
+
# )
|
113 |
+
# summary_f.write(
|
114 |
+
# f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa
|
115 |
+
# )
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
def get_args():
|
120 |
+
parser = argparse.ArgumentParser(
|
121 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
122 |
+
)
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--server-addr",
|
126 |
+
type=str,
|
127 |
+
default="localhost",
|
128 |
+
help="Address of the server",
|
129 |
+
)
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--server-port",
|
133 |
+
type=int,
|
134 |
+
default=8001,
|
135 |
+
help="Grpc port of the triton server, default is 8001",
|
136 |
+
)
|
137 |
+
|
138 |
+
parser.add_argument(
|
139 |
+
"--reference-audio",
|
140 |
+
type=str,
|
141 |
+
default=None,
|
142 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
143 |
+
)
|
144 |
+
|
145 |
+
parser.add_argument(
|
146 |
+
"--reference-text",
|
147 |
+
type=str,
|
148 |
+
default="",
|
149 |
+
help="",
|
150 |
+
)
|
151 |
+
|
152 |
+
parser.add_argument(
|
153 |
+
"--target-text",
|
154 |
+
type=str,
|
155 |
+
default="",
|
156 |
+
help="",
|
157 |
+
)
|
158 |
+
|
159 |
+
parser.add_argument(
|
160 |
+
"--huggingface-dataset",
|
161 |
+
type=str,
|
162 |
+
default="yuekai/seed_tts",
|
163 |
+
help="dataset name in huggingface dataset hub",
|
164 |
+
)
|
165 |
+
|
166 |
+
parser.add_argument(
|
167 |
+
"--split-name",
|
168 |
+
type=str,
|
169 |
+
default="wenetspeech4tts",
|
170 |
+
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
171 |
+
help="dataset split name, default is 'test'",
|
172 |
+
)
|
173 |
+
|
174 |
+
parser.add_argument(
|
175 |
+
"--manifest-path",
|
176 |
+
type=str,
|
177 |
+
default=None,
|
178 |
+
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
179 |
+
)
|
180 |
+
|
181 |
+
parser.add_argument(
|
182 |
+
"--model-name",
|
183 |
+
type=str,
|
184 |
+
default="f5_tts",
|
185 |
+
choices=[
|
186 |
+
"f5_tts", "spark_tts"
|
187 |
+
],
|
188 |
+
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
189 |
+
)
|
190 |
+
|
191 |
+
parser.add_argument(
|
192 |
+
"--num-tasks",
|
193 |
+
type=int,
|
194 |
+
default=1,
|
195 |
+
help="Number of concurrent tasks for sending",
|
196 |
+
)
|
197 |
+
|
198 |
+
parser.add_argument(
|
199 |
+
"--log-interval",
|
200 |
+
type=int,
|
201 |
+
default=5,
|
202 |
+
help="Controls how frequently we print the log.",
|
203 |
+
)
|
204 |
+
|
205 |
+
parser.add_argument(
|
206 |
+
"--compute-wer",
|
207 |
+
action="store_true",
|
208 |
+
default=False,
|
209 |
+
help="""True to compute WER.
|
210 |
+
""",
|
211 |
+
)
|
212 |
+
|
213 |
+
parser.add_argument(
|
214 |
+
"--log-dir",
|
215 |
+
type=str,
|
216 |
+
required=False,
|
217 |
+
default="./tmp",
|
218 |
+
help="log directory",
|
219 |
+
)
|
220 |
+
|
221 |
+
parser.add_argument(
|
222 |
+
"--batch-size",
|
223 |
+
type=int,
|
224 |
+
default=1,
|
225 |
+
help="Inference batch_size per request for offline mode.",
|
226 |
+
)
|
227 |
+
|
228 |
+
return parser.parse_args()
|
229 |
+
|
230 |
+
|
231 |
+
def load_audio(wav_path, target_sample_rate=16000):
|
232 |
+
assert target_sample_rate == 16000, "hard coding in server"
|
233 |
+
if isinstance(wav_path, dict):
|
234 |
+
waveform = wav_path["array"]
|
235 |
+
sample_rate = wav_path["sampling_rate"]
|
236 |
+
else:
|
237 |
+
waveform, sample_rate = sf.read(wav_path)
|
238 |
+
if sample_rate != target_sample_rate:
|
239 |
+
from scipy.signal import resample
|
240 |
+
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
241 |
+
waveform = resample(waveform, num_samples)
|
242 |
+
return waveform, target_sample_rate
|
243 |
+
|
244 |
+
async def send(
|
245 |
+
manifest_item_list: list,
|
246 |
+
name: str,
|
247 |
+
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
248 |
+
protocol_client: types.ModuleType,
|
249 |
+
log_interval: int,
|
250 |
+
model_name: str,
|
251 |
+
padding_duration: int = None,
|
252 |
+
audio_save_dir: str = "./",
|
253 |
+
):
|
254 |
+
total_duration = 0.0
|
255 |
+
results = []
|
256 |
+
latency_data = []
|
257 |
+
task_id = int(name[5:])
|
258 |
+
|
259 |
+
print(f"manifest_item_list: {manifest_item_list}")
|
260 |
+
for i, item in enumerate(manifest_item_list):
|
261 |
+
if i % log_interval == 0:
|
262 |
+
print(f"{name}: {i}/{len(manifest_item_list)}")
|
263 |
+
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
264 |
+
duration = len(waveform) / sample_rate
|
265 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
266 |
+
|
267 |
+
reference_text, target_text = item["reference_text"], item["target_text"]
|
268 |
+
|
269 |
+
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
270 |
+
|
271 |
+
if padding_duration:
|
272 |
+
# padding to nearset 10 seconds
|
273 |
+
samples = np.zeros(
|
274 |
+
(
|
275 |
+
1,
|
276 |
+
padding_duration
|
277 |
+
* sample_rate
|
278 |
+
* ((int(duration) // padding_duration) + 1),
|
279 |
+
),
|
280 |
+
dtype=np.float32,
|
281 |
+
)
|
282 |
+
|
283 |
+
samples[0, : len(waveform)] = waveform
|
284 |
+
else:
|
285 |
+
samples = waveform
|
286 |
+
|
287 |
+
samples = samples.reshape(1, -1).astype(np.float32)
|
288 |
+
|
289 |
+
inputs = [
|
290 |
+
protocol_client.InferInput(
|
291 |
+
"reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)
|
292 |
+
),
|
293 |
+
protocol_client.InferInput(
|
294 |
+
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
295 |
+
),
|
296 |
+
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
297 |
+
protocol_client.InferInput("target_text", [1, 1], "BYTES")
|
298 |
+
]
|
299 |
+
inputs[0].set_data_from_numpy(samples)
|
300 |
+
inputs[1].set_data_from_numpy(lengths)
|
301 |
+
|
302 |
+
input_data_numpy = np.array([reference_text], dtype=object)
|
303 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
304 |
+
inputs[2].set_data_from_numpy(input_data_numpy)
|
305 |
+
|
306 |
+
input_data_numpy = np.array([target_text], dtype=object)
|
307 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
308 |
+
inputs[3].set_data_from_numpy(input_data_numpy)
|
309 |
+
|
310 |
+
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
311 |
+
|
312 |
+
sequence_id = 100000000 + i + task_id * 10
|
313 |
+
start = time.time()
|
314 |
+
response = await triton_client.infer(
|
315 |
+
model_name, inputs, request_id=str(sequence_id), outputs=outputs
|
316 |
+
)
|
317 |
+
|
318 |
+
audio = response.as_numpy("waveform").reshape(-1)
|
319 |
+
|
320 |
+
end = time.time() - start
|
321 |
+
|
322 |
+
audio_save_path = os.path.join(
|
323 |
+
audio_save_dir, f"{item['target_audio_path']}.wav"
|
324 |
+
)
|
325 |
+
sf.write(audio_save_path, audio, 16000, "PCM_16")
|
326 |
+
|
327 |
+
latency_data.append((end, estimated_target_duration))
|
328 |
+
total_duration += estimated_target_duration
|
329 |
+
|
330 |
+
return total_duration, latency_data
|
331 |
+
|
332 |
+
def load_manifests(manifest_path):
|
333 |
+
with open(manifest_path, "r") as f:
|
334 |
+
manifest_list = []
|
335 |
+
for line in f:
|
336 |
+
assert len(line.strip().split("|")) == 4
|
337 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
338 |
+
utt = Path(utt).stem
|
339 |
+
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
340 |
+
if not os.path.isabs(prompt_wav):
|
341 |
+
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
342 |
+
manifest_list.append(
|
343 |
+
{
|
344 |
+
"audio_filepath": prompt_wav,
|
345 |
+
"reference_text": prompt_text,
|
346 |
+
"target_text": gt_text,
|
347 |
+
"target_audio_path": utt
|
348 |
+
}
|
349 |
+
)
|
350 |
+
return manifest_list
|
351 |
+
|
352 |
+
|
353 |
+
def split_data(data, k):
|
354 |
+
n = len(data)
|
355 |
+
if n < k:
|
356 |
+
print(
|
357 |
+
f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}."
|
358 |
+
)
|
359 |
+
k = n
|
360 |
+
|
361 |
+
quotient = n // k
|
362 |
+
remainder = n % k
|
363 |
+
|
364 |
+
result = []
|
365 |
+
start = 0
|
366 |
+
for i in range(k):
|
367 |
+
if i < remainder:
|
368 |
+
end = start + quotient + 1
|
369 |
+
else:
|
370 |
+
end = start + quotient
|
371 |
+
|
372 |
+
result.append(data[start:end])
|
373 |
+
start = end
|
374 |
+
|
375 |
+
return result
|
376 |
+
|
377 |
+
|
378 |
+
async def main():
|
379 |
+
args = get_args()
|
380 |
+
url = f"{args.server_addr}:{args.server_port}"
|
381 |
+
|
382 |
+
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
|
383 |
+
protocol_client = grpcclient
|
384 |
+
|
385 |
+
if args.reference_audio:
|
386 |
+
args.num_tasks = 1
|
387 |
+
args.log_interval = 1
|
388 |
+
manifest_item_list = [
|
389 |
+
{
|
390 |
+
"reference_text": args.reference_text,
|
391 |
+
"target_text": args.target_text,
|
392 |
+
"audio_filepath": args.reference_audio,
|
393 |
+
"target_audio_path": "test",
|
394 |
+
}
|
395 |
+
]
|
396 |
+
elif args.huggingface_dataset:
|
397 |
+
import datasets
|
398 |
+
|
399 |
+
dataset = datasets.load_dataset(
|
400 |
+
args.huggingface_dataset,
|
401 |
+
split=args.split_name,
|
402 |
+
trust_remote_code=True,
|
403 |
+
)
|
404 |
+
manifest_item_list = []
|
405 |
+
for i in range(len(dataset)):
|
406 |
+
manifest_item_list.append(
|
407 |
+
{
|
408 |
+
"audio_filepath": dataset[i]["prompt_audio"],
|
409 |
+
"reference_text": dataset[i]["prompt_text"],
|
410 |
+
"target_audio_path": dataset[i]["id"],
|
411 |
+
"target_text": dataset[i]["target_text"],
|
412 |
+
}
|
413 |
+
)
|
414 |
+
else:
|
415 |
+
manifest_item_list = load_manifests(args.manifest_path)
|
416 |
+
|
417 |
+
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
|
418 |
+
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
|
419 |
+
|
420 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
421 |
+
tasks = []
|
422 |
+
start_time = time.time()
|
423 |
+
for i in range(args.num_tasks):
|
424 |
+
task = asyncio.create_task(
|
425 |
+
send(
|
426 |
+
manifest_item_list[i],
|
427 |
+
name=f"task-{i}",
|
428 |
+
triton_client=triton_client,
|
429 |
+
protocol_client=protocol_client,
|
430 |
+
log_interval=args.log_interval,
|
431 |
+
model_name=args.model_name,
|
432 |
+
audio_save_dir=args.log_dir,
|
433 |
+
padding_duration=None,
|
434 |
+
)
|
435 |
+
)
|
436 |
+
tasks.append(task)
|
437 |
+
|
438 |
+
ans_list = await asyncio.gather(*tasks)
|
439 |
+
|
440 |
+
end_time = time.time()
|
441 |
+
elapsed = end_time - start_time
|
442 |
+
|
443 |
+
|
444 |
+
total_duration = 0.0
|
445 |
+
latency_data = []
|
446 |
+
for ans in ans_list:
|
447 |
+
total_duration += ans[0]
|
448 |
+
latency_data += ans[1]
|
449 |
+
|
450 |
+
rtf = elapsed / total_duration
|
451 |
+
|
452 |
+
s = f"RTF: {rtf:.4f}\n"
|
453 |
+
s += f"total_duration: {total_duration:.3f} seconds\n"
|
454 |
+
s += f"({total_duration/3600:.2f} hours)\n"
|
455 |
+
s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
|
456 |
+
|
457 |
+
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
458 |
+
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
459 |
+
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
460 |
+
s += f"latency_variance: {latency_variance:.2f}\n"
|
461 |
+
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
462 |
+
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
463 |
+
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
464 |
+
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
465 |
+
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
466 |
+
|
467 |
+
print(s)
|
468 |
+
if args.manifest_path:
|
469 |
+
name = Path(args.manifest_path).stem
|
470 |
+
elif args.split_name:
|
471 |
+
name = args.split_name
|
472 |
+
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
473 |
+
f.write(s)
|
474 |
+
|
475 |
+
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
|
476 |
+
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
477 |
+
|
478 |
+
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
|
479 |
+
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
480 |
+
json.dump(metadata, f, indent=4)
|
481 |
+
if __name__ == "__main__":
|
482 |
+
asyncio.run(main())
|
runtime/triton_trtllm/client_http.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import requests
|
27 |
+
import soundfile as sf
|
28 |
+
import json
|
29 |
+
import numpy as np
|
30 |
+
import argparse
|
31 |
+
|
32 |
+
def get_args():
|
33 |
+
parser = argparse.ArgumentParser(
|
34 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
35 |
+
)
|
36 |
+
|
37 |
+
parser.add_argument(
|
38 |
+
"--server-url",
|
39 |
+
type=str,
|
40 |
+
default="localhost:8000",
|
41 |
+
help="Address of the server",
|
42 |
+
)
|
43 |
+
|
44 |
+
parser.add_argument(
|
45 |
+
"--reference-audio",
|
46 |
+
type=str,
|
47 |
+
default="../../example/prompt_audio.wav",
|
48 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
49 |
+
)
|
50 |
+
|
51 |
+
parser.add_argument(
|
52 |
+
"--reference-text",
|
53 |
+
type=str,
|
54 |
+
default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
|
55 |
+
help="",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--target-text",
|
60 |
+
type=str,
|
61 |
+
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
|
62 |
+
help="",
|
63 |
+
)
|
64 |
+
|
65 |
+
parser.add_argument(
|
66 |
+
"--model-name",
|
67 |
+
type=str,
|
68 |
+
default="spark_tts",
|
69 |
+
choices=[
|
70 |
+
"f5_tts", "spark_tts"
|
71 |
+
],
|
72 |
+
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
73 |
+
)
|
74 |
+
|
75 |
+
parser.add_argument(
|
76 |
+
"--output-audio",
|
77 |
+
type=str,
|
78 |
+
default="output.wav",
|
79 |
+
help="Path to save the output audio",
|
80 |
+
)
|
81 |
+
return parser.parse_args()
|
82 |
+
|
83 |
+
def prepare_request(
|
84 |
+
waveform,
|
85 |
+
reference_text,
|
86 |
+
target_text,
|
87 |
+
sample_rate=16000,
|
88 |
+
padding_duration: int = None,
|
89 |
+
audio_save_dir: str = "./",
|
90 |
+
):
|
91 |
+
assert len(waveform.shape) == 1, "waveform should be 1D"
|
92 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
93 |
+
if padding_duration:
|
94 |
+
# padding to nearset 10 seconds
|
95 |
+
samples = np.zeros(
|
96 |
+
(
|
97 |
+
1,
|
98 |
+
padding_duration
|
99 |
+
* sample_rate
|
100 |
+
* ((int(duration) // padding_duration) + 1),
|
101 |
+
),
|
102 |
+
dtype=np.float32,
|
103 |
+
)
|
104 |
+
|
105 |
+
samples[0, : len(waveform)] = waveform
|
106 |
+
else:
|
107 |
+
samples = waveform
|
108 |
+
|
109 |
+
samples = samples.reshape(1, -1).astype(np.float32)
|
110 |
+
|
111 |
+
data = {
|
112 |
+
"inputs":[
|
113 |
+
{
|
114 |
+
"name": "reference_wav",
|
115 |
+
"shape": samples.shape,
|
116 |
+
"datatype": "FP32",
|
117 |
+
"data": samples.tolist()
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"name": "reference_wav_len",
|
121 |
+
"shape": lengths.shape,
|
122 |
+
"datatype": "INT32",
|
123 |
+
"data": lengths.tolist(),
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"name": "reference_text",
|
127 |
+
"shape": [1, 1],
|
128 |
+
"datatype": "BYTES",
|
129 |
+
"data": [reference_text]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"name": "target_text",
|
133 |
+
"shape": [1, 1],
|
134 |
+
"datatype": "BYTES",
|
135 |
+
"data": [target_text]
|
136 |
+
}
|
137 |
+
]
|
138 |
+
}
|
139 |
+
|
140 |
+
return data
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
args = get_args()
|
144 |
+
server_url = args.server_url
|
145 |
+
if not server_url.startswith(("http://", "https://")):
|
146 |
+
server_url = f"http://{server_url}"
|
147 |
+
|
148 |
+
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
149 |
+
waveform, sr = sf.read(args.reference_audio)
|
150 |
+
assert sr == 16000, "sample rate hardcoded in server"
|
151 |
+
|
152 |
+
samples = np.array(waveform, dtype=np.float32)
|
153 |
+
data = prepare_request(samples, args.reference_text, args.target_text)
|
154 |
+
|
155 |
+
rsp = requests.post(
|
156 |
+
url,
|
157 |
+
headers={"Content-Type": "application/json"},
|
158 |
+
json=data,
|
159 |
+
verify=False,
|
160 |
+
params={"request_id": '0'}
|
161 |
+
)
|
162 |
+
result = rsp.json()
|
163 |
+
audio = result["outputs"][0]["data"]
|
164 |
+
audio = np.array(audio, dtype=np.float32)
|
165 |
+
sf.write(args.output_audio, audio, 16000, "PCM_16")
|
runtime/triton_trtllm/docker-compose.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
tts:
|
3 |
+
image: soar97/triton-spark-tts:25.02
|
4 |
+
shm_size: '1gb'
|
5 |
+
ports:
|
6 |
+
- "8000:8000"
|
7 |
+
- "8001:8001"
|
8 |
+
- "8002:8002"
|
9 |
+
environment:
|
10 |
+
- PYTHONIOENCODING=utf-8
|
11 |
+
- MODEL_ID=${MODEL_ID}
|
12 |
+
deploy:
|
13 |
+
resources:
|
14 |
+
reservations:
|
15 |
+
devices:
|
16 |
+
- driver: nvidia
|
17 |
+
device_ids: ['0']
|
18 |
+
capabilities: [gpu]
|
19 |
+
command: >
|
20 |
+
/bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3"
|
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import json
|
27 |
+
import torch
|
28 |
+
from torch.utils.dlpack import to_dlpack
|
29 |
+
|
30 |
+
import triton_python_backend_utils as pb_utils
|
31 |
+
|
32 |
+
import os
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
36 |
+
|
37 |
+
class TritonPythonModel:
|
38 |
+
"""Triton Python model for audio tokenization.
|
39 |
+
|
40 |
+
This model takes reference audio input and extracts semantic and global tokens
|
41 |
+
using BiCodec tokenizer.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def initialize(self, args):
|
45 |
+
"""Initialize the model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
args: Dictionary containing model configuration
|
49 |
+
"""
|
50 |
+
# Parse model parameters
|
51 |
+
parameters = json.loads(args['model_config'])['parameters']
|
52 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
53 |
+
|
54 |
+
# Initialize tokenizer
|
55 |
+
self.device = torch.device("cuda")
|
56 |
+
self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"],
|
57 |
+
device=self.device)
|
58 |
+
|
59 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
60 |
+
"""Extract reference audio clip for speaker embedding.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav: Input waveform array
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Reference clip of fixed duration
|
67 |
+
"""
|
68 |
+
SAMPLE_RATE = 16000
|
69 |
+
REF_SEGMENT_DURATION = 6 # seconds
|
70 |
+
LATENT_HOP_LENGTH = 320
|
71 |
+
|
72 |
+
ref_segment_length = (
|
73 |
+
int(SAMPLE_RATE * REF_SEGMENT_DURATION)
|
74 |
+
// LATENT_HOP_LENGTH
|
75 |
+
* LATENT_HOP_LENGTH
|
76 |
+
)
|
77 |
+
wav_length = len(wav)
|
78 |
+
|
79 |
+
if ref_segment_length > wav_length:
|
80 |
+
# Repeat and truncate if input is too short
|
81 |
+
repeat_times = ref_segment_length // wav_length + 1
|
82 |
+
wav = np.tile(wav, repeat_times)
|
83 |
+
|
84 |
+
return wav[:ref_segment_length]
|
85 |
+
|
86 |
+
def execute(self, requests):
|
87 |
+
"""Execute inference on the batched requests.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
requests: List of inference requests
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
List of inference responses containing tokenized outputs
|
94 |
+
"""
|
95 |
+
reference_wav_list = []
|
96 |
+
reference_wav_ref_clip_list = []
|
97 |
+
|
98 |
+
# Process each request in batch
|
99 |
+
for request in requests:
|
100 |
+
# Extract input tensors
|
101 |
+
wav_array = pb_utils.get_input_tensor_by_name(
|
102 |
+
request, "reference_wav").as_numpy()
|
103 |
+
wav_len = pb_utils.get_input_tensor_by_name(
|
104 |
+
request, "reference_wav_len").as_numpy().item()
|
105 |
+
|
106 |
+
# Prepare inputs
|
107 |
+
wav = wav_array[:, :wav_len].squeeze(0)
|
108 |
+
reference_wav_list.append(wav)
|
109 |
+
|
110 |
+
wav_ref_clip = self.get_ref_clip(wav)
|
111 |
+
reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
|
112 |
+
|
113 |
+
# Batch process through tokenizer
|
114 |
+
ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
|
115 |
+
wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
|
116 |
+
reference_wav_list)
|
117 |
+
|
118 |
+
audio_tokenizer_input = {
|
119 |
+
"ref_wav": ref_wav_clip_tensor.to(self.device),
|
120 |
+
"feat": wav2vec2_features.to(self.device),
|
121 |
+
}
|
122 |
+
semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
|
123 |
+
audio_tokenizer_input)
|
124 |
+
|
125 |
+
# Prepare responses
|
126 |
+
responses = []
|
127 |
+
for i in range(len(requests)):
|
128 |
+
global_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
129 |
+
"global_tokens", to_dlpack(global_tokens[i]))
|
130 |
+
semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
131 |
+
"semantic_tokens", to_dlpack(semantic_tokens[i]))
|
132 |
+
|
133 |
+
inference_response = pb_utils.InferenceResponse(
|
134 |
+
output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
|
135 |
+
responses.append(inference_response)
|
136 |
+
|
137 |
+
return responses
|
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "audio_tokenizer"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
+
}
|
26 |
+
]
|
27 |
+
|
28 |
+
input [
|
29 |
+
{
|
30 |
+
name: "reference_wav"
|
31 |
+
data_type: TYPE_FP32
|
32 |
+
dims: [-1]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
name: "reference_wav_len"
|
36 |
+
data_type: TYPE_INT32
|
37 |
+
dims: [1]
|
38 |
+
}
|
39 |
+
]
|
40 |
+
output [
|
41 |
+
{
|
42 |
+
name: "global_tokens"
|
43 |
+
data_type: TYPE_INT32
|
44 |
+
dims: [-1]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
name: "semantic_tokens"
|
48 |
+
data_type: TYPE_INT32
|
49 |
+
dims: [-1]
|
50 |
+
}
|
51 |
+
]
|
52 |
+
|
53 |
+
instance_group [
|
54 |
+
{
|
55 |
+
count: 1
|
56 |
+
kind: KIND_CPU
|
57 |
+
}
|
58 |
+
]
|
runtime/triton_trtllm/model_repo/spark_tts/1/model.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
import json
|
28 |
+
import os
|
29 |
+
import re
|
30 |
+
from typing import Dict, List, Tuple, Optional, Union
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
+
import torch
|
34 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
35 |
+
import triton_python_backend_utils as pb_utils
|
36 |
+
from transformers import AutoTokenizer
|
37 |
+
|
38 |
+
from sparktts.utils.token_parser import TASK_TOKEN_MAP
|
39 |
+
|
40 |
+
def process_prompt(
|
41 |
+
text: str,
|
42 |
+
prompt_text: Optional[str] = None,
|
43 |
+
global_token_ids: torch.Tensor = None,
|
44 |
+
semantic_token_ids: torch.Tensor = None,
|
45 |
+
) -> Tuple[str, torch.Tensor]:
|
46 |
+
"""
|
47 |
+
Process input for voice cloning.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
text: The text input to be converted to speech.
|
51 |
+
prompt_text: Transcript of the prompt audio.
|
52 |
+
global_token_ids: Global token IDs extracted from reference audio.
|
53 |
+
semantic_token_ids: Semantic token IDs extracted from reference audio.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Tuple containing the formatted input prompt and global token IDs.
|
57 |
+
"""
|
58 |
+
# Convert global tokens to string format
|
59 |
+
global_tokens = "".join(
|
60 |
+
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
# Prepare the input tokens for the model
|
65 |
+
if prompt_text is not None:
|
66 |
+
# Include semantic tokens when prompt text is provided
|
67 |
+
semantic_tokens = "".join(
|
68 |
+
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
69 |
+
)
|
70 |
+
|
71 |
+
inputs = [
|
72 |
+
TASK_TOKEN_MAP["tts"],
|
73 |
+
"<|start_content|>",
|
74 |
+
prompt_text,
|
75 |
+
text,
|
76 |
+
"<|end_content|>",
|
77 |
+
"<|start_global_token|>",
|
78 |
+
global_tokens,
|
79 |
+
"<|end_global_token|>",
|
80 |
+
"<|start_semantic_token|>",
|
81 |
+
semantic_tokens,
|
82 |
+
]
|
83 |
+
else:
|
84 |
+
# Without prompt text, exclude semantic tokens
|
85 |
+
inputs = [
|
86 |
+
TASK_TOKEN_MAP["tts"],
|
87 |
+
"<|start_content|>",
|
88 |
+
text,
|
89 |
+
"<|end_content|>",
|
90 |
+
"<|start_global_token|>",
|
91 |
+
global_tokens,
|
92 |
+
"<|end_global_token|>",
|
93 |
+
]
|
94 |
+
|
95 |
+
# Join all input components into a single string
|
96 |
+
inputs = "".join(inputs)
|
97 |
+
return inputs, global_token_ids
|
98 |
+
|
99 |
+
|
100 |
+
class TritonPythonModel:
|
101 |
+
"""Triton Python model for Spark TTS.
|
102 |
+
|
103 |
+
This model orchestrates the end-to-end TTS pipeline by coordinating
|
104 |
+
between audio tokenizer, LLM, and vocoder components.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def initialize(self, args):
|
108 |
+
"""Initialize the model.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
args: Dictionary containing model configuration
|
112 |
+
"""
|
113 |
+
# Parse model parameters
|
114 |
+
parameters = json.loads(args['model_config'])['parameters']
|
115 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
116 |
+
|
117 |
+
# Initialize tokenizer
|
118 |
+
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
119 |
+
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
120 |
+
self.device = torch.device("cuda")
|
121 |
+
self.decoupled = False
|
122 |
+
|
123 |
+
def forward_llm(self, input_ids):
|
124 |
+
"""
|
125 |
+
Prepares the response from the language model based on the provided
|
126 |
+
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
127 |
+
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
128 |
+
For each response from the language model:
|
129 |
+
- Checks for errors and raise an exception if any are found.
|
130 |
+
- Extracts the "output_ids" tensor from the response.
|
131 |
+
- Determines the finish reason based on the presence of the
|
132 |
+
end-of-sequence token or reaching the maximum length.
|
133 |
+
- Appends the generated token IDs to `output_ids`.
|
134 |
+
- If the finish reason is determined, decodes the output IDs to text
|
135 |
+
and prepares the final response.
|
136 |
+
|
137 |
+
The final response includes the generated text, finish reason,
|
138 |
+
completion tokens, prompt tokens, and total tokens.
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
143 |
+
|
144 |
+
Returns
|
145 |
+
-------
|
146 |
+
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
147 |
+
"""
|
148 |
+
# convert input_ids to numpy, with shape [1, sequence_length]
|
149 |
+
input_ids = input_ids.cpu().numpy()
|
150 |
+
max_tokens = 512
|
151 |
+
input_dict = {
|
152 |
+
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
153 |
+
"end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
|
154 |
+
"pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
|
155 |
+
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
|
156 |
+
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
157 |
+
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
158 |
+
"temperature": np.array([[0.8]], dtype=np.float32),
|
159 |
+
"input_ids": input_ids,
|
160 |
+
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
161 |
+
}
|
162 |
+
|
163 |
+
# Convert inputs to Triton tensors
|
164 |
+
input_tensor_list = [
|
165 |
+
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
166 |
+
]
|
167 |
+
|
168 |
+
# Create and execute inference request
|
169 |
+
llm_request = pb_utils.InferenceRequest(
|
170 |
+
model_name="tensorrt_llm",
|
171 |
+
requested_output_names=["output_ids", "sequence_length"],
|
172 |
+
inputs=input_tensor_list,
|
173 |
+
)
|
174 |
+
|
175 |
+
llm_response = llm_request.exec(decoupled=self.decoupled)
|
176 |
+
if llm_response.has_error():
|
177 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
178 |
+
|
179 |
+
# Extract and process output
|
180 |
+
output_ids = pb_utils.get_output_tensor_by_name(
|
181 |
+
llm_response, "output_ids").as_numpy()
|
182 |
+
seq_lens = pb_utils.get_output_tensor_by_name(
|
183 |
+
llm_response, "sequence_length").as_numpy()
|
184 |
+
|
185 |
+
# Get actual output IDs up to the sequence length
|
186 |
+
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
187 |
+
|
188 |
+
return actual_output_ids
|
189 |
+
|
190 |
+
def forward_audio_tokenizer(self, wav, wav_len):
|
191 |
+
"""Forward pass through the audio tokenizer component.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
wav: Input waveform tensor
|
195 |
+
wav_len: Waveform length tensor
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
Tuple of global and semantic tokens
|
199 |
+
"""
|
200 |
+
inference_request = pb_utils.InferenceRequest(
|
201 |
+
model_name='audio_tokenizer',
|
202 |
+
requested_output_names=['global_tokens', 'semantic_tokens'],
|
203 |
+
inputs=[wav, wav_len]
|
204 |
+
)
|
205 |
+
|
206 |
+
inference_response = inference_request.exec()
|
207 |
+
if inference_response.has_error():
|
208 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
209 |
+
|
210 |
+
# Extract and convert output tensors
|
211 |
+
global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
|
212 |
+
global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
|
213 |
+
|
214 |
+
semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
|
215 |
+
semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
|
216 |
+
|
217 |
+
return global_tokens, semantic_tokens
|
218 |
+
|
219 |
+
def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
|
220 |
+
"""Forward pass through the vocoder component.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
global_token_ids: Global token IDs tensor
|
224 |
+
pred_semantic_ids: Predicted semantic token IDs tensor
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
Generated waveform tensor
|
228 |
+
"""
|
229 |
+
# Convert tensors to Triton format
|
230 |
+
global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
|
231 |
+
pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
|
232 |
+
|
233 |
+
# Create and execute inference request
|
234 |
+
inference_request = pb_utils.InferenceRequest(
|
235 |
+
model_name='vocoder',
|
236 |
+
requested_output_names=['waveform'],
|
237 |
+
inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
|
238 |
+
)
|
239 |
+
|
240 |
+
inference_response = inference_request.exec()
|
241 |
+
if inference_response.has_error():
|
242 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
243 |
+
|
244 |
+
# Extract and convert output waveform
|
245 |
+
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
246 |
+
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
247 |
+
|
248 |
+
return waveform
|
249 |
+
|
250 |
+
def execute(self, requests):
|
251 |
+
"""Execute inference on the batched requests.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
requests: List of inference requests
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
List of inference responses containing generated audio
|
258 |
+
"""
|
259 |
+
responses = []
|
260 |
+
|
261 |
+
for request in requests:
|
262 |
+
# Extract input tensors
|
263 |
+
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
264 |
+
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
265 |
+
|
266 |
+
# Process reference audio through audio tokenizer
|
267 |
+
global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
268 |
+
|
269 |
+
# Extract text inputs
|
270 |
+
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
271 |
+
reference_text = reference_text[0][0].decode('utf-8')
|
272 |
+
|
273 |
+
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
274 |
+
target_text = target_text[0][0].decode('utf-8')
|
275 |
+
|
276 |
+
# Prepare prompt for LLM
|
277 |
+
prompt, global_token_ids = process_prompt(
|
278 |
+
text=target_text,
|
279 |
+
prompt_text=reference_text,
|
280 |
+
global_token_ids=global_tokens,
|
281 |
+
semantic_token_ids=semantic_tokens,
|
282 |
+
)
|
283 |
+
|
284 |
+
|
285 |
+
# Tokenize prompt for LLM
|
286 |
+
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
287 |
+
input_ids = model_inputs.input_ids.to(torch.int32)
|
288 |
+
|
289 |
+
# Generate semantic tokens with LLM
|
290 |
+
generated_ids = self.forward_llm(input_ids)
|
291 |
+
|
292 |
+
# Decode and extract semantic token IDs from generated text
|
293 |
+
predicted_text = self.tokenizer.batch_decode([generated_ids], skip_special_tokens=True)[0]
|
294 |
+
pred_semantic_ids = (
|
295 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)])
|
296 |
+
.unsqueeze(0).to(torch.int32)
|
297 |
+
)
|
298 |
+
|
299 |
+
|
300 |
+
# Generate audio with vocoder
|
301 |
+
audio = self.forward_vocoder(
|
302 |
+
global_token_ids.to(self.device),
|
303 |
+
pred_semantic_ids.to(self.device),
|
304 |
+
)
|
305 |
+
|
306 |
+
# Prepare response
|
307 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
308 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
309 |
+
responses.append(inference_response)
|
310 |
+
|
311 |
+
return responses
|
runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "spark_tts"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "llm_tokenizer_dir",
|
24 |
+
value: {string_value:"${llm_tokenizer_dir}"}
|
25 |
+
}
|
26 |
+
]
|
27 |
+
|
28 |
+
input [
|
29 |
+
{
|
30 |
+
name: "reference_wav"
|
31 |
+
data_type: TYPE_FP32
|
32 |
+
dims: [-1]
|
33 |
+
optional: True
|
34 |
+
},
|
35 |
+
{
|
36 |
+
name: "reference_wav_len"
|
37 |
+
data_type: TYPE_INT32
|
38 |
+
dims: [1]
|
39 |
+
optional: True
|
40 |
+
},
|
41 |
+
{
|
42 |
+
name: "reference_text"
|
43 |
+
data_type: TYPE_STRING
|
44 |
+
dims: [1]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
name: "target_text"
|
48 |
+
data_type: TYPE_STRING
|
49 |
+
dims: [1]
|
50 |
+
}
|
51 |
+
]
|
52 |
+
output [
|
53 |
+
{
|
54 |
+
name: "waveform"
|
55 |
+
data_type: TYPE_FP32
|
56 |
+
dims: [ -1 ]
|
57 |
+
}
|
58 |
+
]
|
59 |
+
|
60 |
+
instance_group [
|
61 |
+
{
|
62 |
+
count: ${bls_instance_num}
|
63 |
+
kind: KIND_CPU
|
64 |
+
}
|
65 |
+
]
|
runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep
ADDED
File without changes
|
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
ADDED
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
name: "tensorrt_llm"
|
28 |
+
backend: "${triton_backend}"
|
29 |
+
max_batch_size: ${triton_max_batch_size}
|
30 |
+
|
31 |
+
model_transaction_policy {
|
32 |
+
decoupled: ${decoupled_mode}
|
33 |
+
}
|
34 |
+
|
35 |
+
dynamic_batching {
|
36 |
+
preferred_batch_size: [ ${triton_max_batch_size} ]
|
37 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
38 |
+
default_queue_policy: { max_queue_size: ${max_queue_size} }
|
39 |
+
}
|
40 |
+
|
41 |
+
input [
|
42 |
+
{
|
43 |
+
name: "input_ids"
|
44 |
+
data_type: TYPE_INT32
|
45 |
+
dims: [ -1 ]
|
46 |
+
allow_ragged_batch: true
|
47 |
+
optional: true
|
48 |
+
},
|
49 |
+
{
|
50 |
+
name: "encoder_input_features"
|
51 |
+
data_type: ${encoder_input_features_data_type}
|
52 |
+
dims: [ -1, -1 ]
|
53 |
+
allow_ragged_batch: true
|
54 |
+
optional: true
|
55 |
+
},
|
56 |
+
{
|
57 |
+
name: "encoder_output_lengths"
|
58 |
+
data_type: TYPE_INT32
|
59 |
+
dims: [ 1 ]
|
60 |
+
reshape: { shape: [ ] }
|
61 |
+
optional: true
|
62 |
+
},
|
63 |
+
{
|
64 |
+
name: "input_lengths"
|
65 |
+
data_type: TYPE_INT32
|
66 |
+
dims: [ 1 ]
|
67 |
+
reshape: { shape: [ ] }
|
68 |
+
},
|
69 |
+
{
|
70 |
+
name: "request_output_len"
|
71 |
+
data_type: TYPE_INT32
|
72 |
+
dims: [ 1 ]
|
73 |
+
reshape: { shape: [ ] }
|
74 |
+
},
|
75 |
+
{
|
76 |
+
name: "num_return_sequences"
|
77 |
+
data_type: TYPE_INT32
|
78 |
+
dims: [ 1 ]
|
79 |
+
reshape: { shape: [ ] }
|
80 |
+
optional: true
|
81 |
+
},
|
82 |
+
{
|
83 |
+
name: "draft_input_ids"
|
84 |
+
data_type: TYPE_INT32
|
85 |
+
dims: [ -1 ]
|
86 |
+
optional: true
|
87 |
+
allow_ragged_batch: true
|
88 |
+
},
|
89 |
+
{
|
90 |
+
name: "decoder_input_ids"
|
91 |
+
data_type: TYPE_INT32
|
92 |
+
dims: [ -1 ]
|
93 |
+
optional: true
|
94 |
+
allow_ragged_batch: true
|
95 |
+
},
|
96 |
+
{
|
97 |
+
name: "decoder_input_lengths"
|
98 |
+
data_type: TYPE_INT32
|
99 |
+
dims: [ 1 ]
|
100 |
+
optional: true
|
101 |
+
reshape: { shape: [ ] }
|
102 |
+
},
|
103 |
+
{
|
104 |
+
name: "draft_logits"
|
105 |
+
data_type: ${logits_datatype}
|
106 |
+
dims: [ -1, -1 ]
|
107 |
+
optional: true
|
108 |
+
allow_ragged_batch: true
|
109 |
+
},
|
110 |
+
{
|
111 |
+
name: "draft_acceptance_threshold"
|
112 |
+
data_type: TYPE_FP32
|
113 |
+
dims: [ 1 ]
|
114 |
+
reshape: { shape: [ ] }
|
115 |
+
optional: true
|
116 |
+
},
|
117 |
+
{
|
118 |
+
name: "end_id"
|
119 |
+
data_type: TYPE_INT32
|
120 |
+
dims: [ 1 ]
|
121 |
+
reshape: { shape: [ ] }
|
122 |
+
optional: true
|
123 |
+
},
|
124 |
+
{
|
125 |
+
name: "pad_id"
|
126 |
+
data_type: TYPE_INT32
|
127 |
+
dims: [ 1 ]
|
128 |
+
reshape: { shape: [ ] }
|
129 |
+
optional: true
|
130 |
+
},
|
131 |
+
{
|
132 |
+
name: "stop_words_list"
|
133 |
+
data_type: TYPE_INT32
|
134 |
+
dims: [ 2, -1 ]
|
135 |
+
optional: true
|
136 |
+
allow_ragged_batch: true
|
137 |
+
},
|
138 |
+
{
|
139 |
+
name: "bad_words_list"
|
140 |
+
data_type: TYPE_INT32
|
141 |
+
dims: [ 2, -1 ]
|
142 |
+
optional: true
|
143 |
+
allow_ragged_batch: true
|
144 |
+
},
|
145 |
+
{
|
146 |
+
name: "embedding_bias"
|
147 |
+
data_type: TYPE_FP32
|
148 |
+
dims: [ -1 ]
|
149 |
+
optional: true
|
150 |
+
allow_ragged_batch: true
|
151 |
+
},
|
152 |
+
{
|
153 |
+
name: "beam_width"
|
154 |
+
data_type: TYPE_INT32
|
155 |
+
dims: [ 1 ]
|
156 |
+
reshape: { shape: [ ] }
|
157 |
+
optional: true
|
158 |
+
},
|
159 |
+
{
|
160 |
+
name: "temperature"
|
161 |
+
data_type: TYPE_FP32
|
162 |
+
dims: [ 1 ]
|
163 |
+
reshape: { shape: [ ] }
|
164 |
+
optional: true
|
165 |
+
},
|
166 |
+
{
|
167 |
+
name: "runtime_top_k"
|
168 |
+
data_type: TYPE_INT32
|
169 |
+
dims: [ 1 ]
|
170 |
+
reshape: { shape: [ ] }
|
171 |
+
optional: true
|
172 |
+
},
|
173 |
+
{
|
174 |
+
name: "runtime_top_p"
|
175 |
+
data_type: TYPE_FP32
|
176 |
+
dims: [ 1 ]
|
177 |
+
reshape: { shape: [ ] }
|
178 |
+
optional: true
|
179 |
+
},
|
180 |
+
{
|
181 |
+
name: "runtime_top_p_min"
|
182 |
+
data_type: TYPE_FP32
|
183 |
+
dims: [ 1 ]
|
184 |
+
reshape: { shape: [ ] }
|
185 |
+
optional: true
|
186 |
+
},
|
187 |
+
{
|
188 |
+
name: "runtime_top_p_decay"
|
189 |
+
data_type: TYPE_FP32
|
190 |
+
dims: [ 1 ]
|
191 |
+
reshape: { shape: [ ] }
|
192 |
+
optional: true
|
193 |
+
},
|
194 |
+
{
|
195 |
+
name: "runtime_top_p_reset_ids"
|
196 |
+
data_type: TYPE_INT32
|
197 |
+
dims: [ 1 ]
|
198 |
+
reshape: { shape: [ ] }
|
199 |
+
optional: true
|
200 |
+
},
|
201 |
+
{
|
202 |
+
name: "len_penalty"
|
203 |
+
data_type: TYPE_FP32
|
204 |
+
dims: [ 1 ]
|
205 |
+
reshape: { shape: [ ] }
|
206 |
+
optional: true
|
207 |
+
},
|
208 |
+
{
|
209 |
+
name: "early_stopping"
|
210 |
+
data_type: TYPE_BOOL
|
211 |
+
dims: [ 1 ]
|
212 |
+
reshape: { shape: [ ] }
|
213 |
+
optional: true
|
214 |
+
},
|
215 |
+
{
|
216 |
+
name: "repetition_penalty"
|
217 |
+
data_type: TYPE_FP32
|
218 |
+
dims: [ 1 ]
|
219 |
+
reshape: { shape: [ ] }
|
220 |
+
optional: true
|
221 |
+
},
|
222 |
+
{
|
223 |
+
name: "min_length"
|
224 |
+
data_type: TYPE_INT32
|
225 |
+
dims: [ 1 ]
|
226 |
+
reshape: { shape: [ ] }
|
227 |
+
optional: true
|
228 |
+
},
|
229 |
+
{
|
230 |
+
name: "beam_search_diversity_rate"
|
231 |
+
data_type: TYPE_FP32
|
232 |
+
dims: [ 1 ]
|
233 |
+
reshape: { shape: [ ] }
|
234 |
+
optional: true
|
235 |
+
},
|
236 |
+
{
|
237 |
+
name: "presence_penalty"
|
238 |
+
data_type: TYPE_FP32
|
239 |
+
dims: [ 1 ]
|
240 |
+
reshape: { shape: [ ] }
|
241 |
+
optional: true
|
242 |
+
},
|
243 |
+
{
|
244 |
+
name: "frequency_penalty"
|
245 |
+
data_type: TYPE_FP32
|
246 |
+
dims: [ 1 ]
|
247 |
+
reshape: { shape: [ ] }
|
248 |
+
optional: true
|
249 |
+
},
|
250 |
+
{
|
251 |
+
name: "random_seed"
|
252 |
+
data_type: TYPE_UINT64
|
253 |
+
dims: [ 1 ]
|
254 |
+
reshape: { shape: [ ] }
|
255 |
+
optional: true
|
256 |
+
},
|
257 |
+
{
|
258 |
+
name: "return_log_probs"
|
259 |
+
data_type: TYPE_BOOL
|
260 |
+
dims: [ 1 ]
|
261 |
+
reshape: { shape: [ ] }
|
262 |
+
optional: true
|
263 |
+
},
|
264 |
+
{
|
265 |
+
name: "return_context_logits"
|
266 |
+
data_type: TYPE_BOOL
|
267 |
+
dims: [ 1 ]
|
268 |
+
reshape: { shape: [ ] }
|
269 |
+
optional: true
|
270 |
+
},
|
271 |
+
{
|
272 |
+
name: "return_generation_logits"
|
273 |
+
data_type: TYPE_BOOL
|
274 |
+
dims: [ 1 ]
|
275 |
+
reshape: { shape: [ ] }
|
276 |
+
optional: true
|
277 |
+
},
|
278 |
+
{
|
279 |
+
name: "return_perf_metrics"
|
280 |
+
data_type: TYPE_BOOL
|
281 |
+
dims: [ 1 ]
|
282 |
+
reshape: { shape: [ ] }
|
283 |
+
optional: true
|
284 |
+
},
|
285 |
+
{
|
286 |
+
name: "exclude_input_in_output"
|
287 |
+
data_type: TYPE_BOOL
|
288 |
+
dims: [ 1 ]
|
289 |
+
reshape: { shape: [ ] }
|
290 |
+
optional: true
|
291 |
+
},
|
292 |
+
{
|
293 |
+
name: "stop"
|
294 |
+
data_type: TYPE_BOOL
|
295 |
+
dims: [ 1 ]
|
296 |
+
reshape: { shape: [ ] }
|
297 |
+
optional: true
|
298 |
+
},
|
299 |
+
{
|
300 |
+
name: "streaming"
|
301 |
+
data_type: TYPE_BOOL
|
302 |
+
dims: [ 1 ]
|
303 |
+
reshape: { shape: [ ] }
|
304 |
+
optional: true
|
305 |
+
},
|
306 |
+
{
|
307 |
+
name: "prompt_embedding_table"
|
308 |
+
data_type: TYPE_FP16
|
309 |
+
dims: [ -1, -1 ]
|
310 |
+
optional: true
|
311 |
+
allow_ragged_batch: true
|
312 |
+
},
|
313 |
+
{
|
314 |
+
name: "prompt_table_extra_ids"
|
315 |
+
data_type: TYPE_UINT64
|
316 |
+
dims: [ -1 ]
|
317 |
+
optional: true
|
318 |
+
allow_ragged_batch: true
|
319 |
+
},
|
320 |
+
{
|
321 |
+
name: "prompt_vocab_size"
|
322 |
+
data_type: TYPE_INT32
|
323 |
+
dims: [ 1 ]
|
324 |
+
reshape: { shape: [ ] }
|
325 |
+
optional: true
|
326 |
+
},
|
327 |
+
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
|
328 |
+
{
|
329 |
+
name: "cross_attention_mask"
|
330 |
+
data_type: TYPE_BOOL
|
331 |
+
dims: [ -1, -1 ]
|
332 |
+
optional: true
|
333 |
+
allow_ragged_batch: true
|
334 |
+
},
|
335 |
+
# Mrope param when mrope is used
|
336 |
+
{
|
337 |
+
name: "mrope_rotary_cos_sin"
|
338 |
+
data_type: TYPE_FP32
|
339 |
+
dims: [ -1 ]
|
340 |
+
optional: true
|
341 |
+
},
|
342 |
+
{
|
343 |
+
name: "mrope_position_deltas"
|
344 |
+
data_type: TYPE_INT64
|
345 |
+
dims: [ 1 ]
|
346 |
+
optional: true
|
347 |
+
},
|
348 |
+
# the unique task ID for the given LoRA.
|
349 |
+
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
|
350 |
+
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
|
351 |
+
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
|
352 |
+
{
|
353 |
+
name: "lora_task_id"
|
354 |
+
data_type: TYPE_UINT64
|
355 |
+
dims: [ 1 ]
|
356 |
+
reshape: { shape: [ ] }
|
357 |
+
optional: true
|
358 |
+
},
|
359 |
+
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
360 |
+
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
361 |
+
# each of the in / out tensors are first flattened and then concatenated together in the format above.
|
362 |
+
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
|
363 |
+
{
|
364 |
+
name: "lora_weights"
|
365 |
+
data_type: TYPE_FP16
|
366 |
+
dims: [ -1, -1 ]
|
367 |
+
optional: true
|
368 |
+
allow_ragged_batch: true
|
369 |
+
},
|
370 |
+
# module identifier (same size a first dimension of lora_weights)
|
371 |
+
# See LoraModule::ModuleType for model id mapping
|
372 |
+
#
|
373 |
+
# "attn_qkv": 0 # compbined qkv adapter
|
374 |
+
# "attn_q": 1 # q adapter
|
375 |
+
# "attn_k": 2 # k adapter
|
376 |
+
# "attn_v": 3 # v adapter
|
377 |
+
# "attn_dense": 4 # adapter for the dense layer in attention
|
378 |
+
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
379 |
+
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
380 |
+
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
381 |
+
#
|
382 |
+
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
|
383 |
+
{
|
384 |
+
name: "lora_config"
|
385 |
+
data_type: TYPE_INT32
|
386 |
+
dims: [ -1, 3 ]
|
387 |
+
optional: true
|
388 |
+
allow_ragged_batch: true
|
389 |
+
},
|
390 |
+
{
|
391 |
+
name: "context_phase_params"
|
392 |
+
data_type: TYPE_UINT8
|
393 |
+
dims: [ -1 ]
|
394 |
+
optional: true
|
395 |
+
allow_ragged_batch: true
|
396 |
+
},
|
397 |
+
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
|
398 |
+
{
|
399 |
+
name: "skip_cross_attn_blocks"
|
400 |
+
data_type: TYPE_BOOL
|
401 |
+
dims: [ 1 ]
|
402 |
+
optional: true
|
403 |
+
allow_ragged_batch: true
|
404 |
+
},
|
405 |
+
{
|
406 |
+
name: "retention_token_range_starts"
|
407 |
+
data_type: TYPE_INT32
|
408 |
+
dims: [ -1 ]
|
409 |
+
optional: true
|
410 |
+
allow_ragged_batch: true
|
411 |
+
},
|
412 |
+
{
|
413 |
+
name: "retention_token_range_ends"
|
414 |
+
data_type: TYPE_INT32
|
415 |
+
dims: [ -1 ]
|
416 |
+
optional: true
|
417 |
+
allow_ragged_batch: true
|
418 |
+
},
|
419 |
+
{
|
420 |
+
name: "retention_token_range_priorities"
|
421 |
+
data_type: TYPE_INT32
|
422 |
+
dims: [ -1 ]
|
423 |
+
optional: true
|
424 |
+
allow_ragged_batch: true
|
425 |
+
},
|
426 |
+
{
|
427 |
+
name: "retention_token_range_durations_ms"
|
428 |
+
data_type: TYPE_INT32
|
429 |
+
dims: [ -1 ]
|
430 |
+
optional: true
|
431 |
+
allow_ragged_batch: true
|
432 |
+
},
|
433 |
+
{
|
434 |
+
name: "retention_decode_priority"
|
435 |
+
data_type: TYPE_INT32
|
436 |
+
dims: [ 1 ]
|
437 |
+
optional: true
|
438 |
+
allow_ragged_batch: true
|
439 |
+
},
|
440 |
+
{
|
441 |
+
name: "retention_decode_duration_ms"
|
442 |
+
data_type: TYPE_INT32
|
443 |
+
dims: [ 1 ]
|
444 |
+
optional: true
|
445 |
+
allow_ragged_batch: true
|
446 |
+
},
|
447 |
+
{
|
448 |
+
name: "guided_decoding_guide_type"
|
449 |
+
data_type: TYPE_STRING
|
450 |
+
dims: [ 1 ]
|
451 |
+
optional: true
|
452 |
+
allow_ragged_batch: true
|
453 |
+
},
|
454 |
+
{
|
455 |
+
name: "guided_decoding_guide"
|
456 |
+
data_type: TYPE_STRING
|
457 |
+
dims: [ 1 ]
|
458 |
+
optional: true
|
459 |
+
allow_ragged_batch: true
|
460 |
+
},
|
461 |
+
{
|
462 |
+
name: "lookahead_window_size"
|
463 |
+
data_type: TYPE_INT32
|
464 |
+
dims: [ 1 ]
|
465 |
+
optional: true
|
466 |
+
allow_ragged_batch: true
|
467 |
+
},
|
468 |
+
{
|
469 |
+
name: "lookahead_ngram_size"
|
470 |
+
data_type: TYPE_INT32
|
471 |
+
dims: [ 1 ]
|
472 |
+
optional: true
|
473 |
+
allow_ragged_batch: true
|
474 |
+
},
|
475 |
+
{
|
476 |
+
name: "lookahead_verification_set_size"
|
477 |
+
data_type: TYPE_INT32
|
478 |
+
dims: [ 1 ]
|
479 |
+
optional: true
|
480 |
+
allow_ragged_batch: true
|
481 |
+
}
|
482 |
+
]
|
483 |
+
output [
|
484 |
+
{
|
485 |
+
name: "output_ids"
|
486 |
+
data_type: TYPE_INT32
|
487 |
+
dims: [ -1, -1 ]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
name: "sequence_length"
|
491 |
+
data_type: TYPE_INT32
|
492 |
+
dims: [ -1 ]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
name: "cum_log_probs"
|
496 |
+
data_type: TYPE_FP32
|
497 |
+
dims: [ -1 ]
|
498 |
+
},
|
499 |
+
{
|
500 |
+
name: "output_log_probs"
|
501 |
+
data_type: TYPE_FP32
|
502 |
+
dims: [ -1, -1 ]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
name: "context_logits"
|
506 |
+
data_type: ${logits_datatype}
|
507 |
+
dims: [ -1, -1 ]
|
508 |
+
},
|
509 |
+
{
|
510 |
+
name: "generation_logits"
|
511 |
+
data_type: ${logits_datatype}
|
512 |
+
dims: [ -1, -1, -1 ]
|
513 |
+
},
|
514 |
+
{
|
515 |
+
name: "batch_index"
|
516 |
+
data_type: TYPE_INT32
|
517 |
+
dims: [ 1 ]
|
518 |
+
},
|
519 |
+
{
|
520 |
+
name: "sequence_index"
|
521 |
+
data_type: TYPE_INT32
|
522 |
+
dims: [ 1 ]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
name: "context_phase_params"
|
526 |
+
data_type: TYPE_UINT8
|
527 |
+
dims: [ -1 ]
|
528 |
+
},
|
529 |
+
{
|
530 |
+
name: "kv_cache_alloc_new_blocks"
|
531 |
+
data_type: TYPE_INT32
|
532 |
+
dims: [ 1 ]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
name: "kv_cache_reused_blocks"
|
536 |
+
data_type: TYPE_INT32
|
537 |
+
dims: [ 1 ]
|
538 |
+
},
|
539 |
+
{
|
540 |
+
name: "kv_cache_alloc_total_blocks"
|
541 |
+
data_type: TYPE_INT32
|
542 |
+
dims: [ 1 ]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
name: "arrival_time_ns"
|
546 |
+
data_type: TYPE_INT64
|
547 |
+
dims: [ 1 ]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
name: "first_scheduled_time_ns"
|
551 |
+
data_type: TYPE_INT64
|
552 |
+
dims: [ 1 ]
|
553 |
+
},
|
554 |
+
{
|
555 |
+
name: "first_token_time_ns"
|
556 |
+
data_type: TYPE_INT64
|
557 |
+
dims: [ 1 ]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
name: "last_token_time_ns"
|
561 |
+
data_type: TYPE_INT64
|
562 |
+
dims: [ 1 ]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
name: "acceptance_rate"
|
566 |
+
data_type: TYPE_FP32
|
567 |
+
dims: [ 1 ]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
name: "total_accepted_draft_tokens"
|
571 |
+
data_type: TYPE_INT32
|
572 |
+
dims: [ 1 ]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
name: "total_draft_tokens"
|
576 |
+
data_type: TYPE_INT32
|
577 |
+
dims: [ 1 ]
|
578 |
+
}
|
579 |
+
]
|
580 |
+
instance_group [
|
581 |
+
{
|
582 |
+
count: 1
|
583 |
+
kind : KIND_CPU
|
584 |
+
}
|
585 |
+
]
|
586 |
+
parameters: {
|
587 |
+
key: "max_beam_width"
|
588 |
+
value: {
|
589 |
+
string_value: "${max_beam_width}"
|
590 |
+
}
|
591 |
+
}
|
592 |
+
parameters: {
|
593 |
+
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
594 |
+
value: {
|
595 |
+
string_value: "no"
|
596 |
+
}
|
597 |
+
}
|
598 |
+
parameters: {
|
599 |
+
key: "gpt_model_type"
|
600 |
+
value: {
|
601 |
+
string_value: "${batching_strategy}"
|
602 |
+
}
|
603 |
+
}
|
604 |
+
parameters: {
|
605 |
+
key: "gpt_model_path"
|
606 |
+
value: {
|
607 |
+
string_value: "${engine_dir}"
|
608 |
+
}
|
609 |
+
}
|
610 |
+
parameters: {
|
611 |
+
key: "encoder_model_path"
|
612 |
+
value: {
|
613 |
+
string_value: "${encoder_engine_dir}"
|
614 |
+
}
|
615 |
+
}
|
616 |
+
parameters: {
|
617 |
+
key: "max_tokens_in_paged_kv_cache"
|
618 |
+
value: {
|
619 |
+
string_value: "${max_tokens_in_paged_kv_cache}"
|
620 |
+
}
|
621 |
+
}
|
622 |
+
parameters: {
|
623 |
+
key: "max_attention_window_size"
|
624 |
+
value: {
|
625 |
+
string_value: "${max_attention_window_size}"
|
626 |
+
}
|
627 |
+
}
|
628 |
+
parameters: {
|
629 |
+
key: "sink_token_length"
|
630 |
+
value: {
|
631 |
+
string_value: "${sink_token_length}"
|
632 |
+
}
|
633 |
+
}
|
634 |
+
parameters: {
|
635 |
+
key: "batch_scheduler_policy"
|
636 |
+
value: {
|
637 |
+
string_value: "${batch_scheduler_policy}"
|
638 |
+
}
|
639 |
+
}
|
640 |
+
parameters: {
|
641 |
+
key: "kv_cache_free_gpu_mem_fraction"
|
642 |
+
value: {
|
643 |
+
string_value: "${kv_cache_free_gpu_mem_fraction}"
|
644 |
+
}
|
645 |
+
}
|
646 |
+
parameters: {
|
647 |
+
key: "cross_kv_cache_fraction"
|
648 |
+
value: {
|
649 |
+
string_value: "${cross_kv_cache_fraction}"
|
650 |
+
}
|
651 |
+
}
|
652 |
+
parameters: {
|
653 |
+
key: "kv_cache_host_memory_bytes"
|
654 |
+
value: {
|
655 |
+
string_value: "${kv_cache_host_memory_bytes}"
|
656 |
+
}
|
657 |
+
}
|
658 |
+
# kv_cache_onboard_blocks is for internal implementation.
|
659 |
+
parameters: {
|
660 |
+
key: "kv_cache_onboard_blocks"
|
661 |
+
value: {
|
662 |
+
string_value: "${kv_cache_onboard_blocks}"
|
663 |
+
}
|
664 |
+
}
|
665 |
+
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
|
666 |
+
# parameters: {
|
667 |
+
# key: "enable_trt_overlap"
|
668 |
+
# value: {
|
669 |
+
# string_value: "${enable_trt_overlap}"
|
670 |
+
# }
|
671 |
+
# }
|
672 |
+
parameters: {
|
673 |
+
key: "exclude_input_in_output"
|
674 |
+
value: {
|
675 |
+
string_value: "${exclude_input_in_output}"
|
676 |
+
}
|
677 |
+
}
|
678 |
+
parameters: {
|
679 |
+
key: "cancellation_check_period_ms"
|
680 |
+
value: {
|
681 |
+
string_value: "${cancellation_check_period_ms}"
|
682 |
+
}
|
683 |
+
}
|
684 |
+
parameters: {
|
685 |
+
key: "stats_check_period_ms"
|
686 |
+
value: {
|
687 |
+
string_value: "${stats_check_period_ms}"
|
688 |
+
}
|
689 |
+
}
|
690 |
+
parameters: {
|
691 |
+
key: "iter_stats_max_iterations"
|
692 |
+
value: {
|
693 |
+
string_value: "${iter_stats_max_iterations}"
|
694 |
+
}
|
695 |
+
}
|
696 |
+
parameters: {
|
697 |
+
key: "request_stats_max_iterations"
|
698 |
+
value: {
|
699 |
+
string_value: "${request_stats_max_iterations}"
|
700 |
+
}
|
701 |
+
}
|
702 |
+
parameters: {
|
703 |
+
key: "enable_kv_cache_reuse"
|
704 |
+
value: {
|
705 |
+
string_value: "${enable_kv_cache_reuse}"
|
706 |
+
}
|
707 |
+
}
|
708 |
+
parameters: {
|
709 |
+
key: "normalize_log_probs"
|
710 |
+
value: {
|
711 |
+
string_value: "${normalize_log_probs}"
|
712 |
+
}
|
713 |
+
}
|
714 |
+
parameters: {
|
715 |
+
key: "enable_chunked_context"
|
716 |
+
value: {
|
717 |
+
string_value: "${enable_chunked_context}"
|
718 |
+
}
|
719 |
+
}
|
720 |
+
parameters: {
|
721 |
+
key: "gpu_device_ids"
|
722 |
+
value: {
|
723 |
+
string_value: "${gpu_device_ids}"
|
724 |
+
}
|
725 |
+
}
|
726 |
+
parameters: {
|
727 |
+
key: "participant_ids"
|
728 |
+
value: {
|
729 |
+
string_value: "${participant_ids}"
|
730 |
+
}
|
731 |
+
}
|
732 |
+
parameters: {
|
733 |
+
key: "lora_cache_optimal_adapter_size"
|
734 |
+
value: {
|
735 |
+
string_value: "${lora_cache_optimal_adapter_size}"
|
736 |
+
}
|
737 |
+
}
|
738 |
+
parameters: {
|
739 |
+
key: "lora_cache_max_adapter_size"
|
740 |
+
value: {
|
741 |
+
string_value: "${lora_cache_max_adapter_size}"
|
742 |
+
}
|
743 |
+
}
|
744 |
+
parameters: {
|
745 |
+
key: "lora_cache_gpu_memory_fraction"
|
746 |
+
value: {
|
747 |
+
string_value: "${lora_cache_gpu_memory_fraction}"
|
748 |
+
}
|
749 |
+
}
|
750 |
+
parameters: {
|
751 |
+
key: "lora_cache_host_memory_bytes"
|
752 |
+
value: {
|
753 |
+
string_value: "${lora_cache_host_memory_bytes}"
|
754 |
+
}
|
755 |
+
}
|
756 |
+
parameters: {
|
757 |
+
key: "lora_prefetch_dir"
|
758 |
+
value: {
|
759 |
+
string_value: "${lora_prefetch_dir}"
|
760 |
+
}
|
761 |
+
}
|
762 |
+
parameters: {
|
763 |
+
key: "decoding_mode"
|
764 |
+
value: {
|
765 |
+
string_value: "${decoding_mode}"
|
766 |
+
}
|
767 |
+
}
|
768 |
+
parameters: {
|
769 |
+
key: "executor_worker_path"
|
770 |
+
value: {
|
771 |
+
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
|
772 |
+
}
|
773 |
+
}
|
774 |
+
parameters: {
|
775 |
+
key: "lookahead_window_size"
|
776 |
+
value: {
|
777 |
+
string_value: "${lookahead_window_size}"
|
778 |
+
}
|
779 |
+
}
|
780 |
+
parameters: {
|
781 |
+
key: "lookahead_ngram_size"
|
782 |
+
value: {
|
783 |
+
string_value: "${lookahead_ngram_size}"
|
784 |
+
}
|
785 |
+
}
|
786 |
+
parameters: {
|
787 |
+
key: "lookahead_verification_set_size"
|
788 |
+
value: {
|
789 |
+
string_value: "${lookahead_verification_set_size}"
|
790 |
+
}
|
791 |
+
}
|
792 |
+
parameters: {
|
793 |
+
key: "medusa_choices"
|
794 |
+
value: {
|
795 |
+
string_value: "${medusa_choices}"
|
796 |
+
}
|
797 |
+
}
|
798 |
+
parameters: {
|
799 |
+
key: "eagle_choices"
|
800 |
+
value: {
|
801 |
+
string_value: "${eagle_choices}"
|
802 |
+
}
|
803 |
+
}
|
804 |
+
parameters: {
|
805 |
+
key: "gpu_weights_percent"
|
806 |
+
value: {
|
807 |
+
string_value: "${gpu_weights_percent}"
|
808 |
+
}
|
809 |
+
}
|
810 |
+
parameters: {
|
811 |
+
key: "enable_context_fmha_fp32_acc"
|
812 |
+
value: {
|
813 |
+
string_value: "${enable_context_fmha_fp32_acc}"
|
814 |
+
}
|
815 |
+
}
|
816 |
+
parameters: {
|
817 |
+
key: "multi_block_mode"
|
818 |
+
value: {
|
819 |
+
string_value: "${multi_block_mode}"
|
820 |
+
}
|
821 |
+
}
|
822 |
+
parameters: {
|
823 |
+
key: "cuda_graph_mode"
|
824 |
+
value: {
|
825 |
+
string_value: "${cuda_graph_mode}"
|
826 |
+
}
|
827 |
+
}
|
828 |
+
parameters: {
|
829 |
+
key: "cuda_graph_cache_size"
|
830 |
+
value: {
|
831 |
+
string_value: "${cuda_graph_cache_size}"
|
832 |
+
}
|
833 |
+
}
|
834 |
+
parameters: {
|
835 |
+
key: "speculative_decoding_fast_logits"
|
836 |
+
value: {
|
837 |
+
string_value: "${speculative_decoding_fast_logits}"
|
838 |
+
}
|
839 |
+
}
|
840 |
+
parameters: {
|
841 |
+
key: "tokenizer_dir"
|
842 |
+
value: {
|
843 |
+
string_value: "${tokenizer_dir}"
|
844 |
+
}
|
845 |
+
}
|
846 |
+
parameters: {
|
847 |
+
key: "guided_decoding_backend"
|
848 |
+
value: {
|
849 |
+
string_value: "${guided_decoding_backend}"
|
850 |
+
}
|
851 |
+
}
|
852 |
+
parameters: {
|
853 |
+
key: "xgrammar_tokenizer_info_path"
|
854 |
+
value: {
|
855 |
+
string_value: "${xgrammar_tokenizer_info_path}"
|
856 |
+
}
|
857 |
+
}
|
runtime/triton_trtllm/model_repo/vocoder/1/model.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
import json
|
28 |
+
import os
|
29 |
+
import logging
|
30 |
+
from typing import List, Dict
|
31 |
+
|
32 |
+
import torch
|
33 |
+
from torch.utils.dlpack import to_dlpack
|
34 |
+
|
35 |
+
import triton_python_backend_utils as pb_utils
|
36 |
+
|
37 |
+
from sparktts.models.bicodec import BiCodec
|
38 |
+
|
39 |
+
# Configure logging
|
40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
class TritonPythonModel:
|
44 |
+
"""Triton Python model for vocoder.
|
45 |
+
|
46 |
+
This model takes global and semantic tokens as input and generates audio waveforms
|
47 |
+
using the BiCodec vocoder.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def initialize(self, args):
|
51 |
+
"""Initialize the model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
args: Dictionary containing model configuration
|
55 |
+
"""
|
56 |
+
# Parse model parameters
|
57 |
+
parameters = json.loads(args['model_config'])['parameters']
|
58 |
+
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
59 |
+
model_dir = model_params["model_dir"]
|
60 |
+
|
61 |
+
# Initialize device and vocoder
|
62 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
+
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
64 |
+
|
65 |
+
self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
|
66 |
+
del self.vocoder.encoder, self.vocoder.postnet
|
67 |
+
self.vocoder.eval().to(self.device) # Set model to evaluation mode
|
68 |
+
|
69 |
+
logger.info("Vocoder initialized successfully")
|
70 |
+
|
71 |
+
|
72 |
+
def execute(self, requests):
|
73 |
+
"""Execute inference on the batched requests.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
requests: List of inference requests
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
List of inference responses containing generated waveforms
|
80 |
+
"""
|
81 |
+
global_tokens_list, semantic_tokens_list = [], []
|
82 |
+
|
83 |
+
# Process each request in batch
|
84 |
+
for request in requests:
|
85 |
+
global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
|
86 |
+
semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
|
87 |
+
global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
|
88 |
+
semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
|
89 |
+
|
90 |
+
# Concatenate tokens for batch processing
|
91 |
+
global_tokens = torch.cat(global_tokens_list, dim=0)
|
92 |
+
semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
|
93 |
+
|
94 |
+
|
95 |
+
# Generate waveforms
|
96 |
+
with torch.no_grad():
|
97 |
+
wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
|
98 |
+
|
99 |
+
# Prepare responses
|
100 |
+
responses = []
|
101 |
+
for i in range(len(requests)):
|
102 |
+
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
|
103 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
104 |
+
responses.append(inference_response)
|
105 |
+
|
106 |
+
return responses
|
runtime/triton_trtllm/model_repo/vocoder/config.pbtxt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "vocoder"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
+
}
|
26 |
+
]
|
27 |
+
|
28 |
+
input [
|
29 |
+
{
|
30 |
+
name: "global_tokens"
|
31 |
+
data_type: TYPE_INT32
|
32 |
+
dims: [-1]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
name: "semantic_tokens"
|
36 |
+
data_type: TYPE_INT32
|
37 |
+
dims: [-1]
|
38 |
+
}
|
39 |
+
]
|
40 |
+
output [
|
41 |
+
{
|
42 |
+
name: "waveform"
|
43 |
+
data_type: TYPE_FP32
|
44 |
+
dims: [ -1 ]
|
45 |
+
}
|
46 |
+
]
|
47 |
+
|
48 |
+
instance_group [
|
49 |
+
{
|
50 |
+
count: 1
|
51 |
+
kind: KIND_CPU
|
52 |
+
}
|
53 |
+
]
|
runtime/triton_trtllm/run.sh
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export PYTHONPATH=../../../Spark-TTS/
|
2 |
+
export CUDA_VISIBLE_DEVICES=0
|
3 |
+
stage=$1
|
4 |
+
stop_stage=$2
|
5 |
+
echo "Start stage: $stage, Stop stage: $stop_stage"
|
6 |
+
|
7 |
+
huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
|
8 |
+
trt_dtype=bfloat16
|
9 |
+
trt_weights_dir=./tllm_checkpoint_${trt_dtype}
|
10 |
+
trt_engines_dir=./trt_engines_${trt_dtype}
|
11 |
+
|
12 |
+
model_repo=./model_repo_test
|
13 |
+
|
14 |
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
15 |
+
echo "Downloading Spark-TTS-0.5B from HuggingFace"
|
16 |
+
huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
|
17 |
+
# pip install -r /workspace_yuekai/spark-tts/Spark-TTS/requirements.txt
|
18 |
+
fi
|
19 |
+
|
20 |
+
|
21 |
+
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
22 |
+
echo "Converting checkpoint to TensorRT weights"
|
23 |
+
python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
|
24 |
+
--output_dir $trt_weights_dir \
|
25 |
+
--dtype $trt_dtype || exit 1
|
26 |
+
|
27 |
+
echo "Building TensorRT engines"
|
28 |
+
trtllm-build --checkpoint_dir $trt_weights_dir \
|
29 |
+
--output_dir $trt_engines_dir \
|
30 |
+
--max_batch_size 16 \
|
31 |
+
--max_num_tokens 32768 \
|
32 |
+
--gemm_plugin $trt_dtype || exit 1
|
33 |
+
fi
|
34 |
+
|
35 |
+
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
36 |
+
echo "Creating model repository"
|
37 |
+
rm -rf $model_repo
|
38 |
+
cp -r ./model_repo $model_repo
|
39 |
+
|
40 |
+
ENGINE_PATH=$trt_engines_dir
|
41 |
+
MAX_QUEUE_DELAY_MICROSECONDS=0
|
42 |
+
MODEL_DIR=$huggingface_model_local_dir
|
43 |
+
LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
|
44 |
+
BLS_INSTANCE_NUM=4
|
45 |
+
TRITON_MAX_BATCH_SIZE=16
|
46 |
+
|
47 |
+
python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
48 |
+
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
49 |
+
python3 scripts/fill_template.py -i ${model_repo}/spark_tts/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
50 |
+
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
51 |
+
|
52 |
+
fi
|
53 |
+
|
54 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
55 |
+
echo "Starting Triton server"
|
56 |
+
tritonserver --model-repository ${model_repo}
|
57 |
+
fi
|
58 |
+
|
59 |
+
|
60 |
+
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
61 |
+
echo "Running client"
|
62 |
+
num_task=2
|
63 |
+
python3 client_grpc.py \
|
64 |
+
--server-addr localhost \
|
65 |
+
--model-name spark_tts \
|
66 |
+
--num-tasks $num_task \
|
67 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
68 |
+
fi
|
runtime/triton_trtllm/scripts/convert_checkpoint.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6 |
+
|
7 |
+
from transformers import AutoConfig
|
8 |
+
|
9 |
+
import tensorrt_llm
|
10 |
+
from tensorrt_llm._utils import release_gc
|
11 |
+
from tensorrt_llm.logger import logger
|
12 |
+
from tensorrt_llm.mapping import Mapping
|
13 |
+
from tensorrt_llm.models import QWenForCausalLM
|
14 |
+
from tensorrt_llm.models.modeling_utils import QuantConfig
|
15 |
+
from tensorrt_llm.quantization import QuantAlgo
|
16 |
+
|
17 |
+
|
18 |
+
def parse_arguments():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--model_dir', type=str, default=None, required=True)
|
21 |
+
parser.add_argument('--tp_size',
|
22 |
+
type=int,
|
23 |
+
default=1,
|
24 |
+
help='N-way tensor parallelism size')
|
25 |
+
parser.add_argument('--pp_size',
|
26 |
+
type=int,
|
27 |
+
default=1,
|
28 |
+
help='N-way pipeline parallelism size')
|
29 |
+
parser.add_argument(
|
30 |
+
'--dtype',
|
31 |
+
type=str,
|
32 |
+
default='auto',
|
33 |
+
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
34 |
+
help=
|
35 |
+
"The data type for the model weights and activations if not quantized. "
|
36 |
+
"If 'auto', the data type is automatically inferred from the source model; "
|
37 |
+
"however, if the source dtype is float32, it is converted to float16.")
|
38 |
+
parser.add_argument(
|
39 |
+
'--use_weight_only',
|
40 |
+
default=False,
|
41 |
+
action="store_true",
|
42 |
+
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
43 |
+
'See --weight_only_precision to set the precision')
|
44 |
+
parser.add_argument(
|
45 |
+
'--disable_weight_only_quant_plugin',
|
46 |
+
default=False,
|
47 |
+
action="store_true",
|
48 |
+
help=
|
49 |
+
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
50 |
+
'You must also use --use_weight_only for that argument to have an impact.'
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
'--weight_only_precision',
|
54 |
+
const='int8',
|
55 |
+
type=str,
|
56 |
+
nargs='?',
|
57 |
+
default='int8',
|
58 |
+
choices=['int8', 'int4', 'int4_gptq'],
|
59 |
+
help=
|
60 |
+
'Define the precision for the weights when using weight-only quantization.'
|
61 |
+
'You must also use --use_weight_only for that argument to have an impact.'
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
'--calib_dataset',
|
65 |
+
type=str,
|
66 |
+
default='ccdv/cnn_dailymail',
|
67 |
+
help=
|
68 |
+
"The huggingface dataset name or the local directory of the dataset for calibration."
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--smoothquant",
|
72 |
+
"-sq",
|
73 |
+
type=float,
|
74 |
+
default=None,
|
75 |
+
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
76 |
+
" to Smoothquant the model, and output int8 weights."
|
77 |
+
" A good first try is 0.5. Must be in [0, 1]")
|
78 |
+
parser.add_argument(
|
79 |
+
'--per_channel',
|
80 |
+
action="store_true",
|
81 |
+
default=False,
|
82 |
+
help=
|
83 |
+
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
84 |
+
'per_channel instead uses a different static scaling factor for each channel. '
|
85 |
+
'The latter is usually more accurate, but a little slower.')
|
86 |
+
parser.add_argument(
|
87 |
+
'--per_token',
|
88 |
+
action="store_true",
|
89 |
+
default=False,
|
90 |
+
help=
|
91 |
+
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
92 |
+
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
93 |
+
'The latter is usually more accurate, but a little slower.')
|
94 |
+
parser.add_argument(
|
95 |
+
'--int8_kv_cache',
|
96 |
+
default=False,
|
97 |
+
action="store_true",
|
98 |
+
help=
|
99 |
+
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
'--per_group',
|
103 |
+
default=False,
|
104 |
+
action="store_true",
|
105 |
+
help=
|
106 |
+
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
107 |
+
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
108 |
+
'The flag is built for GPTQ/AWQ quantization.')
|
109 |
+
|
110 |
+
parser.add_argument('--group_size',
|
111 |
+
type=int,
|
112 |
+
default=128,
|
113 |
+
help='Group size used in GPTQ quantization.')
|
114 |
+
|
115 |
+
parser.add_argument("--load_model_on_cpu", action="store_true")
|
116 |
+
parser.add_argument(
|
117 |
+
'--use_parallel_embedding',
|
118 |
+
action="store_true",
|
119 |
+
default=False,
|
120 |
+
help=
|
121 |
+
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
'--embedding_sharding_dim',
|
125 |
+
type=int,
|
126 |
+
default=0,
|
127 |
+
choices=[0, 1],
|
128 |
+
help=
|
129 |
+
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
130 |
+
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
131 |
+
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
132 |
+
)
|
133 |
+
parser.add_argument('--output_dir',
|
134 |
+
type=str,
|
135 |
+
default='tllm_checkpoint',
|
136 |
+
help='The path to save the TensorRT-LLM checkpoint')
|
137 |
+
parser.add_argument(
|
138 |
+
'--workers',
|
139 |
+
type=int,
|
140 |
+
default=1,
|
141 |
+
help='The number of workers for converting checkpoint in parallel')
|
142 |
+
parser.add_argument(
|
143 |
+
'--moe_tp_size',
|
144 |
+
type=int,
|
145 |
+
default=-1,
|
146 |
+
help=
|
147 |
+
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
'--moe_ep_size',
|
151 |
+
type=int,
|
152 |
+
default=-1,
|
153 |
+
help=
|
154 |
+
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
155 |
+
)
|
156 |
+
args = parser.parse_args()
|
157 |
+
return args
|
158 |
+
|
159 |
+
|
160 |
+
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
|
161 |
+
'''return config dict with quantization info based on the command line args
|
162 |
+
'''
|
163 |
+
quant_config = QuantConfig()
|
164 |
+
if args.use_weight_only:
|
165 |
+
if args.weight_only_precision == 'int8':
|
166 |
+
quant_config.quant_algo = QuantAlgo.W8A16
|
167 |
+
elif args.weight_only_precision == 'int4':
|
168 |
+
quant_config.quant_algo = QuantAlgo.W4A16
|
169 |
+
elif args.smoothquant:
|
170 |
+
quant_config.smoothquant_val = args.smoothquant
|
171 |
+
if args.per_channel:
|
172 |
+
if args.per_token:
|
173 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
174 |
+
else:
|
175 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
|
176 |
+
else:
|
177 |
+
if args.per_token:
|
178 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
|
179 |
+
else:
|
180 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
|
181 |
+
|
182 |
+
if args.int8_kv_cache:
|
183 |
+
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
|
184 |
+
|
185 |
+
if args.weight_only_precision == 'int4_gptq':
|
186 |
+
quant_config.group_size = args.group_size
|
187 |
+
quant_config.has_zero_point = True
|
188 |
+
quant_config.pre_quant_scale = False
|
189 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
190 |
+
|
191 |
+
return quant_config
|
192 |
+
|
193 |
+
|
194 |
+
def update_quant_config_from_hf(quant_config, hf_config,
|
195 |
+
override_fields) -> tuple[QuantConfig, dict]:
|
196 |
+
hf_config_dict = hf_config.to_dict()
|
197 |
+
if hf_config_dict.get('quantization_config'):
|
198 |
+
# update the quant_algo, and clamp_val.
|
199 |
+
if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
|
200 |
+
logger.info(
|
201 |
+
"Load quantization configs from huggingface model_config.")
|
202 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
203 |
+
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
204 |
+
'group_size', 128)
|
205 |
+
quant_config.has_zero_point = hf_config_dict[
|
206 |
+
'quantization_config'].get('zero_point', False)
|
207 |
+
override_fields.update({"use_autoawq": True})
|
208 |
+
elif hf_config_dict['quantization_config'].get(
|
209 |
+
'quant_method') == 'gptq':
|
210 |
+
logger.info(
|
211 |
+
"Load quantization configs from huggingface model_config.")
|
212 |
+
desc_act = hf_config_dict['quantization_config'].get(
|
213 |
+
'desc_act', False)
|
214 |
+
if desc_act:
|
215 |
+
raise ValueError("GPTQ with desc_act=True is not implemented!")
|
216 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
217 |
+
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
218 |
+
'group_size', 128)
|
219 |
+
quant_config.has_zero_point = hf_config_dict[
|
220 |
+
'quantization_config'].get('sym', False)
|
221 |
+
return quant_config, override_fields
|
222 |
+
|
223 |
+
|
224 |
+
def args_to_build_options(args):
|
225 |
+
return {
|
226 |
+
'use_parallel_embedding': args.use_parallel_embedding,
|
227 |
+
'embedding_sharding_dim': args.embedding_sharding_dim,
|
228 |
+
'disable_weight_only_quant_plugin':
|
229 |
+
args.disable_weight_only_quant_plugin
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def convert_and_save_hf(args):
|
234 |
+
model_dir = args.model_dir
|
235 |
+
world_size = args.tp_size * args.pp_size
|
236 |
+
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
|
237 |
+
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
|
238 |
+
# before the refactor is done.
|
239 |
+
override_fields = {}
|
240 |
+
override_fields.update(args_to_build_options(args))
|
241 |
+
quant_config = args_to_quant_config(args)
|
242 |
+
|
243 |
+
try:
|
244 |
+
hf_config = AutoConfig.from_pretrained(model_dir,
|
245 |
+
trust_remote_code=True)
|
246 |
+
quant_config, override_fields = update_quant_config_from_hf(
|
247 |
+
quant_config, hf_config, override_fields)
|
248 |
+
except:
|
249 |
+
logger.warning("AutoConfig cannot load the huggingface config.")
|
250 |
+
|
251 |
+
if args.smoothquant is not None or args.int8_kv_cache:
|
252 |
+
mapping = Mapping(
|
253 |
+
world_size=world_size,
|
254 |
+
tp_size=args.tp_size,
|
255 |
+
pp_size=args.pp_size,
|
256 |
+
moe_tp_size=args.moe_tp_size,
|
257 |
+
moe_ep_size=args.moe_ep_size,
|
258 |
+
)
|
259 |
+
QWenForCausalLM.quantize(args.model_dir,
|
260 |
+
args.output_dir,
|
261 |
+
dtype=args.dtype,
|
262 |
+
mapping=mapping,
|
263 |
+
quant_config=quant_config,
|
264 |
+
calib_dataset=args.calib_dataset,
|
265 |
+
**override_fields)
|
266 |
+
else:
|
267 |
+
|
268 |
+
def convert_and_save_rank(args, rank):
|
269 |
+
mapping = Mapping(world_size=world_size,
|
270 |
+
rank=rank,
|
271 |
+
tp_size=args.tp_size,
|
272 |
+
pp_size=args.pp_size,
|
273 |
+
moe_tp_size=args.moe_tp_size,
|
274 |
+
moe_ep_size=args.moe_ep_size)
|
275 |
+
qwen = QWenForCausalLM.from_hugging_face(model_dir,
|
276 |
+
args.dtype,
|
277 |
+
mapping=mapping,
|
278 |
+
quant_config=quant_config,
|
279 |
+
**override_fields)
|
280 |
+
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
281 |
+
del qwen
|
282 |
+
|
283 |
+
execute(args.workers, [convert_and_save_rank] * world_size, args)
|
284 |
+
release_gc()
|
285 |
+
|
286 |
+
|
287 |
+
def execute(workers, func, args):
|
288 |
+
if workers == 1:
|
289 |
+
for rank, f in enumerate(func):
|
290 |
+
f(args, rank)
|
291 |
+
else:
|
292 |
+
with ThreadPoolExecutor(max_workers=workers) as p:
|
293 |
+
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
294 |
+
exceptions = []
|
295 |
+
for future in as_completed(futures):
|
296 |
+
try:
|
297 |
+
future.result()
|
298 |
+
except Exception as e:
|
299 |
+
traceback.print_exc()
|
300 |
+
exceptions.append(e)
|
301 |
+
assert len(
|
302 |
+
exceptions
|
303 |
+
) == 0, "Checkpoint conversion failed, please check error log."
|
304 |
+
|
305 |
+
|
306 |
+
def main():
|
307 |
+
print(tensorrt_llm.__version__)
|
308 |
+
args = parse_arguments()
|
309 |
+
|
310 |
+
if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
|
311 |
+
# moe default to tp-only
|
312 |
+
args.moe_tp_size = args.tp_size
|
313 |
+
args.moe_ep_size = 1
|
314 |
+
elif (args.moe_tp_size == -1):
|
315 |
+
args.moe_tp_size = args.tp_size // args.moe_ep_size
|
316 |
+
elif (args.moe_ep_size == -1):
|
317 |
+
args.moe_ep_size = args.tp_size // args.moe_tp_size
|
318 |
+
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
|
319 |
+
), "moe_tp_size * moe_ep_size must equal to tp_size"
|
320 |
+
|
321 |
+
tik = time.time()
|
322 |
+
|
323 |
+
if not os.path.exists(args.output_dir):
|
324 |
+
os.makedirs(args.output_dir)
|
325 |
+
|
326 |
+
assert args.model_dir is not None
|
327 |
+
convert_and_save_hf(args)
|
328 |
+
|
329 |
+
tok = time.time()
|
330 |
+
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
331 |
+
print(f'Total time of converting checkpoints: {t}')
|
332 |
+
|
333 |
+
|
334 |
+
if __name__ == '__main__':
|
335 |
+
main()
|
runtime/triton_trtllm/scripts/fill_template.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from string import Template
|
4 |
+
|
5 |
+
|
6 |
+
def split(string, delimiter):
|
7 |
+
"""Split a string using delimiter. Supports escaping.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
string (str): The string to split.
|
11 |
+
delimiter (str): The delimiter to split the string with.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
list: A list of strings.
|
15 |
+
"""
|
16 |
+
result = []
|
17 |
+
current = ""
|
18 |
+
escape = False
|
19 |
+
for char in string:
|
20 |
+
if escape:
|
21 |
+
current += char
|
22 |
+
escape = False
|
23 |
+
elif char == delimiter:
|
24 |
+
result.append(current)
|
25 |
+
current = ""
|
26 |
+
elif char == "\\":
|
27 |
+
escape = True
|
28 |
+
else:
|
29 |
+
current += char
|
30 |
+
result.append(current)
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
def main(file_path, substitutions, in_place):
|
35 |
+
with open(file_path) as f:
|
36 |
+
pbtxt = Template(f.read())
|
37 |
+
|
38 |
+
sub_dict = {
|
39 |
+
"max_queue_size": 0,
|
40 |
+
'max_queue_delay_microseconds': 0,
|
41 |
+
}
|
42 |
+
for sub in split(substitutions, ","):
|
43 |
+
key, value = split(sub, ":")
|
44 |
+
sub_dict[key] = value
|
45 |
+
|
46 |
+
assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
|
47 |
+
|
48 |
+
pbtxt = pbtxt.safe_substitute(sub_dict)
|
49 |
+
|
50 |
+
if in_place:
|
51 |
+
with open(file_path, "w") as f:
|
52 |
+
f.write(pbtxt)
|
53 |
+
else:
|
54 |
+
print(pbtxt)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
parser = ArgumentParser()
|
59 |
+
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
60 |
+
parser.add_argument(
|
61 |
+
"substitutions",
|
62 |
+
help=
|
63 |
+
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
64 |
+
)
|
65 |
+
parser.add_argument("--in_place",
|
66 |
+
"-i",
|
67 |
+
action="store_true",
|
68 |
+
help="do the operation in-place")
|
69 |
+
args = parser.parse_args()
|
70 |
+
main(**vars(args))
|