Xinsheng Wang commited on
Commit
1c17251
·
unverified ·
2 Parent(s): 882c25f b228154

Merge pull request #92 from yuekaizhang/triton

Browse files
runtime/triton_trtllm/Dockerfile.server ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
2
+ RUN apt-get update && apt-get install -y cmake
3
+ RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
4
+ RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa
5
+ WORKDIR /workspace
runtime/triton_trtllm/README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Nvidia Triton Inference Serving Best Practice for Spark TTS
2
+
3
+ ### Quick Start
4
+ Directly launch the service using docker compose.
5
+ ```sh
6
+ docker compose up
7
+ ```
8
+
9
+ ### Build Image
10
+ Build the docker image from scratch.
11
+ ```sh
12
+ docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
13
+ ```
14
+
15
+ ### Create Docker Container
16
+ ```sh
17
+ your_mount_dir=/mnt:/mnt
18
+ docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
19
+ ```
20
+
21
+ ### Export Models to TensorRT-LLM and Launch Server
22
+ Inside docker container, we would follow the official guide of TensorRT-LLM to build TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen).
23
+
24
+ ```sh
25
+ bash run.sh 0 3
26
+ ```
27
+ ### Simple HTTP client
28
+ ```sh
29
+ python3 client_http.py
30
+ ```
31
+
32
+ ### Benchmark using Dataset
33
+ ```sh
34
+ num_task=2
35
+ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
36
+ ```
37
+
38
+ ### Benchmark Results
39
+ Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs, total audio duration 169 secs.
40
+
41
+ | Model | Note | Concurrency | Avg Latency | RTF |
42
+ |-------|-----------|-----------------------|---------|--|
43
+ | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms | 0.1362|
44
+ | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms | 0.0737|
45
+ | Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms | 0.0704|
runtime/triton_trtllm/client_grpc.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
3
+ # 2023 Nvidia (authors: Yuekai Zhang)
4
+ # 2023 Recurrent.ai (authors: Songtao Shi)
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ This script supports to load dataset from huggingface and sends it to the server
20
+ for decoding, in parallel.
21
+
22
+ Usage:
23
+ # For offline Spark-TTS-0.5B
24
+ # huggingface dataset
25
+ num_task=2
26
+ python3 client_grpc.py \
27
+ --server-addr localhost \
28
+ --model-name spark_tts \
29
+ --num-tasks $num_task \
30
+ --huggingface-dataset yuekai/seed_tts \
31
+ --split-name wenetspeech4tts \
32
+ --log-dir ./log_concurrent_tasks_${num_task}
33
+ """
34
+
35
+ import argparse
36
+ import asyncio
37
+ import json
38
+
39
+ import os
40
+ import time
41
+ import types
42
+ from pathlib import Path
43
+
44
+ import numpy as np
45
+ import soundfile as sf
46
+ import tritonclient
47
+ import tritonclient.grpc.aio as grpcclient
48
+ from tritonclient.utils import np_to_triton_dtype
49
+
50
+
51
+
52
+ def write_triton_stats(stats, summary_file):
53
+ with open(summary_file, "w") as summary_f:
54
+ model_stats = stats["model_stats"]
55
+ # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
56
+ summary_f.write(
57
+ "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
58
+ )
59
+ summary_f.write("To learn more about the log, please refer to: \n")
60
+ summary_f.write(
61
+ "1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n"
62
+ )
63
+ summary_f.write(
64
+ "2. https://github.com/triton-inference-server/server/issues/5374 \n\n"
65
+ )
66
+ summary_f.write(
67
+ "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
68
+ )
69
+ summary_f.write(
70
+ "However, there is a trade-off between the increased queue time and the increased batch size. \n"
71
+ )
72
+ summary_f.write(
73
+ "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
74
+ )
75
+ summary_f.write(
76
+ "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
77
+ )
78
+ for model_state in model_stats:
79
+ if "last_inference" not in model_state:
80
+ continue
81
+ summary_f.write(f"model name is {model_state['name']} \n")
82
+ model_inference_stats = model_state["inference_stats"]
83
+ total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
84
+ total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
85
+ total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
86
+ total_output_time_s = (
87
+ int(model_inference_stats["compute_output"]["ns"]) / 1e9
88
+ )
89
+ summary_f.write(
90
+ f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
91
+ )
92
+ model_batch_stats = model_state["batch_stats"]
93
+ for batch in model_batch_stats:
94
+ batch_size = int(batch["batch_size"])
95
+ compute_input = batch["compute_input"]
96
+ compute_output = batch["compute_output"]
97
+ compute_infer = batch["compute_infer"]
98
+ batch_count = int(compute_infer["count"])
99
+ assert (
100
+ compute_infer["count"]
101
+ == compute_output["count"]
102
+ == compute_input["count"]
103
+ )
104
+ compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
105
+ compute_input_time_ms = int(compute_input["ns"]) / 1e6
106
+ compute_output_time_ms = int(compute_output["ns"]) / 1e6
107
+ summary_f.write(
108
+ f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms/batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms/batch_count/batch_size:.2f} ms \n" # noqa
109
+ )
110
+ # summary_f.write(
111
+ # f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa
112
+ # )
113
+ # summary_f.write(
114
+ # f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa
115
+ # )
116
+
117
+
118
+
119
+ def get_args():
120
+ parser = argparse.ArgumentParser(
121
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--server-addr",
126
+ type=str,
127
+ default="localhost",
128
+ help="Address of the server",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--server-port",
133
+ type=int,
134
+ default=8001,
135
+ help="Grpc port of the triton server, default is 8001",
136
+ )
137
+
138
+ parser.add_argument(
139
+ "--reference-audio",
140
+ type=str,
141
+ default=None,
142
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
143
+ )
144
+
145
+ parser.add_argument(
146
+ "--reference-text",
147
+ type=str,
148
+ default="",
149
+ help="",
150
+ )
151
+
152
+ parser.add_argument(
153
+ "--target-text",
154
+ type=str,
155
+ default="",
156
+ help="",
157
+ )
158
+
159
+ parser.add_argument(
160
+ "--huggingface-dataset",
161
+ type=str,
162
+ default="yuekai/seed_tts",
163
+ help="dataset name in huggingface dataset hub",
164
+ )
165
+
166
+ parser.add_argument(
167
+ "--split-name",
168
+ type=str,
169
+ default="wenetspeech4tts",
170
+ choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
171
+ help="dataset split name, default is 'test'",
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--manifest-path",
176
+ type=str,
177
+ default=None,
178
+ help="Path to the manifest dir which includes wav.scp trans.txt files.",
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--model-name",
183
+ type=str,
184
+ default="f5_tts",
185
+ choices=[
186
+ "f5_tts", "spark_tts"
187
+ ],
188
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--num-tasks",
193
+ type=int,
194
+ default=1,
195
+ help="Number of concurrent tasks for sending",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--log-interval",
200
+ type=int,
201
+ default=5,
202
+ help="Controls how frequently we print the log.",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--compute-wer",
207
+ action="store_true",
208
+ default=False,
209
+ help="""True to compute WER.
210
+ """,
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--log-dir",
215
+ type=str,
216
+ required=False,
217
+ default="./tmp",
218
+ help="log directory",
219
+ )
220
+
221
+ parser.add_argument(
222
+ "--batch-size",
223
+ type=int,
224
+ default=1,
225
+ help="Inference batch_size per request for offline mode.",
226
+ )
227
+
228
+ return parser.parse_args()
229
+
230
+
231
+ def load_audio(wav_path, target_sample_rate=16000):
232
+ assert target_sample_rate == 16000, "hard coding in server"
233
+ if isinstance(wav_path, dict):
234
+ waveform = wav_path["array"]
235
+ sample_rate = wav_path["sampling_rate"]
236
+ else:
237
+ waveform, sample_rate = sf.read(wav_path)
238
+ if sample_rate != target_sample_rate:
239
+ from scipy.signal import resample
240
+ num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
241
+ waveform = resample(waveform, num_samples)
242
+ return waveform, target_sample_rate
243
+
244
+ async def send(
245
+ manifest_item_list: list,
246
+ name: str,
247
+ triton_client: tritonclient.grpc.aio.InferenceServerClient,
248
+ protocol_client: types.ModuleType,
249
+ log_interval: int,
250
+ model_name: str,
251
+ padding_duration: int = None,
252
+ audio_save_dir: str = "./",
253
+ ):
254
+ total_duration = 0.0
255
+ results = []
256
+ latency_data = []
257
+ task_id = int(name[5:])
258
+
259
+ print(f"manifest_item_list: {manifest_item_list}")
260
+ for i, item in enumerate(manifest_item_list):
261
+ if i % log_interval == 0:
262
+ print(f"{name}: {i}/{len(manifest_item_list)}")
263
+ waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
264
+ duration = len(waveform) / sample_rate
265
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
266
+
267
+ reference_text, target_text = item["reference_text"], item["target_text"]
268
+
269
+ estimated_target_duration = duration / len(reference_text) * len(target_text)
270
+
271
+ if padding_duration:
272
+ # padding to nearset 10 seconds
273
+ samples = np.zeros(
274
+ (
275
+ 1,
276
+ padding_duration
277
+ * sample_rate
278
+ * ((int(duration) // padding_duration) + 1),
279
+ ),
280
+ dtype=np.float32,
281
+ )
282
+
283
+ samples[0, : len(waveform)] = waveform
284
+ else:
285
+ samples = waveform
286
+
287
+ samples = samples.reshape(1, -1).astype(np.float32)
288
+
289
+ inputs = [
290
+ protocol_client.InferInput(
291
+ "reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)
292
+ ),
293
+ protocol_client.InferInput(
294
+ "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
295
+ ),
296
+ protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
297
+ protocol_client.InferInput("target_text", [1, 1], "BYTES")
298
+ ]
299
+ inputs[0].set_data_from_numpy(samples)
300
+ inputs[1].set_data_from_numpy(lengths)
301
+
302
+ input_data_numpy = np.array([reference_text], dtype=object)
303
+ input_data_numpy = input_data_numpy.reshape((1, 1))
304
+ inputs[2].set_data_from_numpy(input_data_numpy)
305
+
306
+ input_data_numpy = np.array([target_text], dtype=object)
307
+ input_data_numpy = input_data_numpy.reshape((1, 1))
308
+ inputs[3].set_data_from_numpy(input_data_numpy)
309
+
310
+ outputs = [protocol_client.InferRequestedOutput("waveform")]
311
+
312
+ sequence_id = 100000000 + i + task_id * 10
313
+ start = time.time()
314
+ response = await triton_client.infer(
315
+ model_name, inputs, request_id=str(sequence_id), outputs=outputs
316
+ )
317
+
318
+ audio = response.as_numpy("waveform").reshape(-1)
319
+
320
+ end = time.time() - start
321
+
322
+ audio_save_path = os.path.join(
323
+ audio_save_dir, f"{item['target_audio_path']}.wav"
324
+ )
325
+ sf.write(audio_save_path, audio, 16000, "PCM_16")
326
+
327
+ latency_data.append((end, estimated_target_duration))
328
+ total_duration += estimated_target_duration
329
+
330
+ return total_duration, latency_data
331
+
332
+ def load_manifests(manifest_path):
333
+ with open(manifest_path, "r") as f:
334
+ manifest_list = []
335
+ for line in f:
336
+ assert len(line.strip().split("|")) == 4
337
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
338
+ utt = Path(utt).stem
339
+ # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
340
+ if not os.path.isabs(prompt_wav):
341
+ prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
342
+ manifest_list.append(
343
+ {
344
+ "audio_filepath": prompt_wav,
345
+ "reference_text": prompt_text,
346
+ "target_text": gt_text,
347
+ "target_audio_path": utt
348
+ }
349
+ )
350
+ return manifest_list
351
+
352
+
353
+ def split_data(data, k):
354
+ n = len(data)
355
+ if n < k:
356
+ print(
357
+ f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}."
358
+ )
359
+ k = n
360
+
361
+ quotient = n // k
362
+ remainder = n % k
363
+
364
+ result = []
365
+ start = 0
366
+ for i in range(k):
367
+ if i < remainder:
368
+ end = start + quotient + 1
369
+ else:
370
+ end = start + quotient
371
+
372
+ result.append(data[start:end])
373
+ start = end
374
+
375
+ return result
376
+
377
+
378
+ async def main():
379
+ args = get_args()
380
+ url = f"{args.server_addr}:{args.server_port}"
381
+
382
+ triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
383
+ protocol_client = grpcclient
384
+
385
+ if args.reference_audio:
386
+ args.num_tasks = 1
387
+ args.log_interval = 1
388
+ manifest_item_list = [
389
+ {
390
+ "reference_text": args.reference_text,
391
+ "target_text": args.target_text,
392
+ "audio_filepath": args.reference_audio,
393
+ "target_audio_path": "test",
394
+ }
395
+ ]
396
+ elif args.huggingface_dataset:
397
+ import datasets
398
+
399
+ dataset = datasets.load_dataset(
400
+ args.huggingface_dataset,
401
+ split=args.split_name,
402
+ trust_remote_code=True,
403
+ )
404
+ manifest_item_list = []
405
+ for i in range(len(dataset)):
406
+ manifest_item_list.append(
407
+ {
408
+ "audio_filepath": dataset[i]["prompt_audio"],
409
+ "reference_text": dataset[i]["prompt_text"],
410
+ "target_audio_path": dataset[i]["id"],
411
+ "target_text": dataset[i]["target_text"],
412
+ }
413
+ )
414
+ else:
415
+ manifest_item_list = load_manifests(args.manifest_path)
416
+
417
+ args.num_tasks = min(args.num_tasks, len(manifest_item_list))
418
+ manifest_item_list = split_data(manifest_item_list, args.num_tasks)
419
+
420
+ os.makedirs(args.log_dir, exist_ok=True)
421
+ tasks = []
422
+ start_time = time.time()
423
+ for i in range(args.num_tasks):
424
+ task = asyncio.create_task(
425
+ send(
426
+ manifest_item_list[i],
427
+ name=f"task-{i}",
428
+ triton_client=triton_client,
429
+ protocol_client=protocol_client,
430
+ log_interval=args.log_interval,
431
+ model_name=args.model_name,
432
+ audio_save_dir=args.log_dir,
433
+ padding_duration=None,
434
+ )
435
+ )
436
+ tasks.append(task)
437
+
438
+ ans_list = await asyncio.gather(*tasks)
439
+
440
+ end_time = time.time()
441
+ elapsed = end_time - start_time
442
+
443
+
444
+ total_duration = 0.0
445
+ latency_data = []
446
+ for ans in ans_list:
447
+ total_duration += ans[0]
448
+ latency_data += ans[1]
449
+
450
+ rtf = elapsed / total_duration
451
+
452
+ s = f"RTF: {rtf:.4f}\n"
453
+ s += f"total_duration: {total_duration:.3f} seconds\n"
454
+ s += f"({total_duration/3600:.2f} hours)\n"
455
+ s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
456
+
457
+ latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
458
+ latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
459
+ latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
460
+ s += f"latency_variance: {latency_variance:.2f}\n"
461
+ s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
462
+ s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
463
+ s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
464
+ s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
465
+ s += f"average_latency_ms: {latency_ms:.2f}\n"
466
+
467
+ print(s)
468
+ if args.manifest_path:
469
+ name = Path(args.manifest_path).stem
470
+ elif args.split_name:
471
+ name = args.split_name
472
+ with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
473
+ f.write(s)
474
+
475
+ stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
476
+ write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
477
+
478
+ metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
479
+ with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
480
+ json.dump(metadata, f, indent=4)
481
+ if __name__ == "__main__":
482
+ asyncio.run(main())
runtime/triton_trtllm/client_http.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import requests
27
+ import soundfile as sf
28
+ import json
29
+ import numpy as np
30
+ import argparse
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser(
34
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--server-url",
39
+ type=str,
40
+ default="localhost:8000",
41
+ help="Address of the server",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--reference-audio",
46
+ type=str,
47
+ default="../../example/prompt_audio.wav",
48
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--reference-text",
53
+ type=str,
54
+ default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
55
+ help="",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--target-text",
60
+ type=str,
61
+ default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
62
+ help="",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--model-name",
67
+ type=str,
68
+ default="spark_tts",
69
+ choices=[
70
+ "f5_tts", "spark_tts"
71
+ ],
72
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--output-audio",
77
+ type=str,
78
+ default="output.wav",
79
+ help="Path to save the output audio",
80
+ )
81
+ return parser.parse_args()
82
+
83
+ def prepare_request(
84
+ waveform,
85
+ reference_text,
86
+ target_text,
87
+ sample_rate=16000,
88
+ padding_duration: int = None,
89
+ audio_save_dir: str = "./",
90
+ ):
91
+ assert len(waveform.shape) == 1, "waveform should be 1D"
92
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
93
+ if padding_duration:
94
+ # padding to nearset 10 seconds
95
+ samples = np.zeros(
96
+ (
97
+ 1,
98
+ padding_duration
99
+ * sample_rate
100
+ * ((int(duration) // padding_duration) + 1),
101
+ ),
102
+ dtype=np.float32,
103
+ )
104
+
105
+ samples[0, : len(waveform)] = waveform
106
+ else:
107
+ samples = waveform
108
+
109
+ samples = samples.reshape(1, -1).astype(np.float32)
110
+
111
+ data = {
112
+ "inputs":[
113
+ {
114
+ "name": "reference_wav",
115
+ "shape": samples.shape,
116
+ "datatype": "FP32",
117
+ "data": samples.tolist()
118
+ },
119
+ {
120
+ "name": "reference_wav_len",
121
+ "shape": lengths.shape,
122
+ "datatype": "INT32",
123
+ "data": lengths.tolist(),
124
+ },
125
+ {
126
+ "name": "reference_text",
127
+ "shape": [1, 1],
128
+ "datatype": "BYTES",
129
+ "data": [reference_text]
130
+ },
131
+ {
132
+ "name": "target_text",
133
+ "shape": [1, 1],
134
+ "datatype": "BYTES",
135
+ "data": [target_text]
136
+ }
137
+ ]
138
+ }
139
+
140
+ return data
141
+
142
+ if __name__ == "__main__":
143
+ args = get_args()
144
+ server_url = args.server_url
145
+ if not server_url.startswith(("http://", "https://")):
146
+ server_url = f"http://{server_url}"
147
+
148
+ url = f"{server_url}/v2/models/{args.model_name}/infer"
149
+ waveform, sr = sf.read(args.reference_audio)
150
+ assert sr == 16000, "sample rate hardcoded in server"
151
+
152
+ samples = np.array(waveform, dtype=np.float32)
153
+ data = prepare_request(samples, args.reference_text, args.target_text)
154
+
155
+ rsp = requests.post(
156
+ url,
157
+ headers={"Content-Type": "application/json"},
158
+ json=data,
159
+ verify=False,
160
+ params={"request_id": '0'}
161
+ )
162
+ result = rsp.json()
163
+ audio = result["outputs"][0]["data"]
164
+ audio = np.array(audio, dtype=np.float32)
165
+ sf.write(args.output_audio, audio, 16000, "PCM_16")
runtime/triton_trtllm/docker-compose.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ tts:
3
+ image: soar97/triton-spark-tts:25.02
4
+ shm_size: '1gb'
5
+ ports:
6
+ - "8000:8000"
7
+ - "8001:8001"
8
+ - "8002:8002"
9
+ environment:
10
+ - PYTHONIOENCODING=utf-8
11
+ - MODEL_ID=${MODEL_ID}
12
+ deploy:
13
+ resources:
14
+ reservations:
15
+ devices:
16
+ - driver: nvidia
17
+ device_ids: ['0']
18
+ capabilities: [gpu]
19
+ command: >
20
+ /bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3"
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import json
27
+ import torch
28
+ from torch.utils.dlpack import to_dlpack
29
+
30
+ import triton_python_backend_utils as pb_utils
31
+
32
+ import os
33
+ import numpy as np
34
+
35
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
36
+
37
+ class TritonPythonModel:
38
+ """Triton Python model for audio tokenization.
39
+
40
+ This model takes reference audio input and extracts semantic and global tokens
41
+ using BiCodec tokenizer.
42
+ """
43
+
44
+ def initialize(self, args):
45
+ """Initialize the model.
46
+
47
+ Args:
48
+ args: Dictionary containing model configuration
49
+ """
50
+ # Parse model parameters
51
+ parameters = json.loads(args['model_config'])['parameters']
52
+ model_params = {k: v["string_value"] for k, v in parameters.items()}
53
+
54
+ # Initialize tokenizer
55
+ self.device = torch.device("cuda")
56
+ self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"],
57
+ device=self.device)
58
+
59
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
60
+ """Extract reference audio clip for speaker embedding.
61
+
62
+ Args:
63
+ wav: Input waveform array
64
+
65
+ Returns:
66
+ Reference clip of fixed duration
67
+ """
68
+ SAMPLE_RATE = 16000
69
+ REF_SEGMENT_DURATION = 6 # seconds
70
+ LATENT_HOP_LENGTH = 320
71
+
72
+ ref_segment_length = (
73
+ int(SAMPLE_RATE * REF_SEGMENT_DURATION)
74
+ // LATENT_HOP_LENGTH
75
+ * LATENT_HOP_LENGTH
76
+ )
77
+ wav_length = len(wav)
78
+
79
+ if ref_segment_length > wav_length:
80
+ # Repeat and truncate if input is too short
81
+ repeat_times = ref_segment_length // wav_length + 1
82
+ wav = np.tile(wav, repeat_times)
83
+
84
+ return wav[:ref_segment_length]
85
+
86
+ def execute(self, requests):
87
+ """Execute inference on the batched requests.
88
+
89
+ Args:
90
+ requests: List of inference requests
91
+
92
+ Returns:
93
+ List of inference responses containing tokenized outputs
94
+ """
95
+ reference_wav_list = []
96
+ reference_wav_ref_clip_list = []
97
+
98
+ # Process each request in batch
99
+ for request in requests:
100
+ # Extract input tensors
101
+ wav_array = pb_utils.get_input_tensor_by_name(
102
+ request, "reference_wav").as_numpy()
103
+ wav_len = pb_utils.get_input_tensor_by_name(
104
+ request, "reference_wav_len").as_numpy().item()
105
+
106
+ # Prepare inputs
107
+ wav = wav_array[:, :wav_len].squeeze(0)
108
+ reference_wav_list.append(wav)
109
+
110
+ wav_ref_clip = self.get_ref_clip(wav)
111
+ reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
112
+
113
+ # Batch process through tokenizer
114
+ ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
115
+ wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
116
+ reference_wav_list)
117
+
118
+ audio_tokenizer_input = {
119
+ "ref_wav": ref_wav_clip_tensor.to(self.device),
120
+ "feat": wav2vec2_features.to(self.device),
121
+ }
122
+ semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
123
+ audio_tokenizer_input)
124
+
125
+ # Prepare responses
126
+ responses = []
127
+ for i in range(len(requests)):
128
+ global_tokens_tensor = pb_utils.Tensor.from_dlpack(
129
+ "global_tokens", to_dlpack(global_tokens[i]))
130
+ semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
131
+ "semantic_tokens", to_dlpack(semantic_tokens[i]))
132
+
133
+ inference_response = pb_utils.InferenceResponse(
134
+ output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
135
+ responses.append(inference_response)
136
+
137
+ return responses
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "audio_tokenizer"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ parameters [
22
+ {
23
+ key: "model_dir",
24
+ value: {string_value:"${model_dir}"}
25
+ }
26
+ ]
27
+
28
+ input [
29
+ {
30
+ name: "reference_wav"
31
+ data_type: TYPE_FP32
32
+ dims: [-1]
33
+ },
34
+ {
35
+ name: "reference_wav_len"
36
+ data_type: TYPE_INT32
37
+ dims: [1]
38
+ }
39
+ ]
40
+ output [
41
+ {
42
+ name: "global_tokens"
43
+ data_type: TYPE_INT32
44
+ dims: [-1]
45
+ },
46
+ {
47
+ name: "semantic_tokens"
48
+ data_type: TYPE_INT32
49
+ dims: [-1]
50
+ }
51
+ ]
52
+
53
+ instance_group [
54
+ {
55
+ count: 1
56
+ kind: KIND_CPU
57
+ }
58
+ ]
runtime/triton_trtllm/model_repo/spark_tts/1/model.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ import json
28
+ import os
29
+ import re
30
+ from typing import Dict, List, Tuple, Optional, Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from torch.utils.dlpack import from_dlpack, to_dlpack
35
+ import triton_python_backend_utils as pb_utils
36
+ from transformers import AutoTokenizer
37
+
38
+ from sparktts.utils.token_parser import TASK_TOKEN_MAP
39
+
40
+ def process_prompt(
41
+ text: str,
42
+ prompt_text: Optional[str] = None,
43
+ global_token_ids: torch.Tensor = None,
44
+ semantic_token_ids: torch.Tensor = None,
45
+ ) -> Tuple[str, torch.Tensor]:
46
+ """
47
+ Process input for voice cloning.
48
+
49
+ Args:
50
+ text: The text input to be converted to speech.
51
+ prompt_text: Transcript of the prompt audio.
52
+ global_token_ids: Global token IDs extracted from reference audio.
53
+ semantic_token_ids: Semantic token IDs extracted from reference audio.
54
+
55
+ Returns:
56
+ Tuple containing the formatted input prompt and global token IDs.
57
+ """
58
+ # Convert global tokens to string format
59
+ global_tokens = "".join(
60
+ [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
61
+ )
62
+
63
+
64
+ # Prepare the input tokens for the model
65
+ if prompt_text is not None:
66
+ # Include semantic tokens when prompt text is provided
67
+ semantic_tokens = "".join(
68
+ [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
69
+ )
70
+
71
+ inputs = [
72
+ TASK_TOKEN_MAP["tts"],
73
+ "<|start_content|>",
74
+ prompt_text,
75
+ text,
76
+ "<|end_content|>",
77
+ "<|start_global_token|>",
78
+ global_tokens,
79
+ "<|end_global_token|>",
80
+ "<|start_semantic_token|>",
81
+ semantic_tokens,
82
+ ]
83
+ else:
84
+ # Without prompt text, exclude semantic tokens
85
+ inputs = [
86
+ TASK_TOKEN_MAP["tts"],
87
+ "<|start_content|>",
88
+ text,
89
+ "<|end_content|>",
90
+ "<|start_global_token|>",
91
+ global_tokens,
92
+ "<|end_global_token|>",
93
+ ]
94
+
95
+ # Join all input components into a single string
96
+ inputs = "".join(inputs)
97
+ return inputs, global_token_ids
98
+
99
+
100
+ class TritonPythonModel:
101
+ """Triton Python model for Spark TTS.
102
+
103
+ This model orchestrates the end-to-end TTS pipeline by coordinating
104
+ between audio tokenizer, LLM, and vocoder components.
105
+ """
106
+
107
+ def initialize(self, args):
108
+ """Initialize the model.
109
+
110
+ Args:
111
+ args: Dictionary containing model configuration
112
+ """
113
+ # Parse model parameters
114
+ parameters = json.loads(args['model_config'])['parameters']
115
+ model_params = {k: v["string_value"] for k, v in parameters.items()}
116
+
117
+ # Initialize tokenizer
118
+ llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
119
+ self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
120
+ self.device = torch.device("cuda")
121
+ self.decoupled = False
122
+
123
+ def forward_llm(self, input_ids):
124
+ """
125
+ Prepares the response from the language model based on the provided
126
+ inputs. Creates a `pb_utils.InferenceRequest` object with passed
127
+ `llm_request_inputs` to send to a decoupled TensorRTLLM model.
128
+ For each response from the language model:
129
+ - Checks for errors and raise an exception if any are found.
130
+ - Extracts the "output_ids" tensor from the response.
131
+ - Determines the finish reason based on the presence of the
132
+ end-of-sequence token or reaching the maximum length.
133
+ - Appends the generated token IDs to `output_ids`.
134
+ - If the finish reason is determined, decodes the output IDs to text
135
+ and prepares the final response.
136
+
137
+ The final response includes the generated text, finish reason,
138
+ completion tokens, prompt tokens, and total tokens.
139
+
140
+ Parameters
141
+ ----------
142
+ - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
143
+
144
+ Returns
145
+ -------
146
+ - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
147
+ """
148
+ # convert input_ids to numpy, with shape [1, sequence_length]
149
+ input_ids = input_ids.cpu().numpy()
150
+ max_tokens = 512
151
+ input_dict = {
152
+ "request_output_len": np.array([[max_tokens]], dtype=np.int32),
153
+ "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
154
+ "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
155
+ "streaming": np.array([[self.decoupled]], dtype=np.bool_),
156
+ "runtime_top_p": np.array([[0.95]], dtype=np.float32),
157
+ "runtime_top_k": np.array([[50]], dtype=np.int32),
158
+ "temperature": np.array([[0.8]], dtype=np.float32),
159
+ "input_ids": input_ids,
160
+ "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
161
+ }
162
+
163
+ # Convert inputs to Triton tensors
164
+ input_tensor_list = [
165
+ pb_utils.Tensor(k, v) for k, v in input_dict.items()
166
+ ]
167
+
168
+ # Create and execute inference request
169
+ llm_request = pb_utils.InferenceRequest(
170
+ model_name="tensorrt_llm",
171
+ requested_output_names=["output_ids", "sequence_length"],
172
+ inputs=input_tensor_list,
173
+ )
174
+
175
+ llm_response = llm_request.exec(decoupled=self.decoupled)
176
+ if llm_response.has_error():
177
+ raise pb_utils.TritonModelException(llm_response.error().message())
178
+
179
+ # Extract and process output
180
+ output_ids = pb_utils.get_output_tensor_by_name(
181
+ llm_response, "output_ids").as_numpy()
182
+ seq_lens = pb_utils.get_output_tensor_by_name(
183
+ llm_response, "sequence_length").as_numpy()
184
+
185
+ # Get actual output IDs up to the sequence length
186
+ actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
187
+
188
+ return actual_output_ids
189
+
190
+ def forward_audio_tokenizer(self, wav, wav_len):
191
+ """Forward pass through the audio tokenizer component.
192
+
193
+ Args:
194
+ wav: Input waveform tensor
195
+ wav_len: Waveform length tensor
196
+
197
+ Returns:
198
+ Tuple of global and semantic tokens
199
+ """
200
+ inference_request = pb_utils.InferenceRequest(
201
+ model_name='audio_tokenizer',
202
+ requested_output_names=['global_tokens', 'semantic_tokens'],
203
+ inputs=[wav, wav_len]
204
+ )
205
+
206
+ inference_response = inference_request.exec()
207
+ if inference_response.has_error():
208
+ raise pb_utils.TritonModelException(inference_response.error().message())
209
+
210
+ # Extract and convert output tensors
211
+ global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
212
+ global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
213
+
214
+ semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
215
+ semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
216
+
217
+ return global_tokens, semantic_tokens
218
+
219
+ def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
220
+ """Forward pass through the vocoder component.
221
+
222
+ Args:
223
+ global_token_ids: Global token IDs tensor
224
+ pred_semantic_ids: Predicted semantic token IDs tensor
225
+
226
+ Returns:
227
+ Generated waveform tensor
228
+ """
229
+ # Convert tensors to Triton format
230
+ global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
231
+ pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
232
+
233
+ # Create and execute inference request
234
+ inference_request = pb_utils.InferenceRequest(
235
+ model_name='vocoder',
236
+ requested_output_names=['waveform'],
237
+ inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
238
+ )
239
+
240
+ inference_response = inference_request.exec()
241
+ if inference_response.has_error():
242
+ raise pb_utils.TritonModelException(inference_response.error().message())
243
+
244
+ # Extract and convert output waveform
245
+ waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
246
+ waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
247
+
248
+ return waveform
249
+
250
+ def execute(self, requests):
251
+ """Execute inference on the batched requests.
252
+
253
+ Args:
254
+ requests: List of inference requests
255
+
256
+ Returns:
257
+ List of inference responses containing generated audio
258
+ """
259
+ responses = []
260
+
261
+ for request in requests:
262
+ # Extract input tensors
263
+ wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
264
+ wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
265
+
266
+ # Process reference audio through audio tokenizer
267
+ global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
268
+
269
+ # Extract text inputs
270
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
271
+ reference_text = reference_text[0][0].decode('utf-8')
272
+
273
+ target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
274
+ target_text = target_text[0][0].decode('utf-8')
275
+
276
+ # Prepare prompt for LLM
277
+ prompt, global_token_ids = process_prompt(
278
+ text=target_text,
279
+ prompt_text=reference_text,
280
+ global_token_ids=global_tokens,
281
+ semantic_token_ids=semantic_tokens,
282
+ )
283
+
284
+
285
+ # Tokenize prompt for LLM
286
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
287
+ input_ids = model_inputs.input_ids.to(torch.int32)
288
+
289
+ # Generate semantic tokens with LLM
290
+ generated_ids = self.forward_llm(input_ids)
291
+
292
+ # Decode and extract semantic token IDs from generated text
293
+ predicted_text = self.tokenizer.batch_decode([generated_ids], skip_special_tokens=True)[0]
294
+ pred_semantic_ids = (
295
+ torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)])
296
+ .unsqueeze(0).to(torch.int32)
297
+ )
298
+
299
+
300
+ # Generate audio with vocoder
301
+ audio = self.forward_vocoder(
302
+ global_token_ids.to(self.device),
303
+ pred_semantic_ids.to(self.device),
304
+ )
305
+
306
+ # Prepare response
307
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
308
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
309
+ responses.append(inference_response)
310
+
311
+ return responses
runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "spark_tts"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ parameters [
22
+ {
23
+ key: "llm_tokenizer_dir",
24
+ value: {string_value:"${llm_tokenizer_dir}"}
25
+ }
26
+ ]
27
+
28
+ input [
29
+ {
30
+ name: "reference_wav"
31
+ data_type: TYPE_FP32
32
+ dims: [-1]
33
+ optional: True
34
+ },
35
+ {
36
+ name: "reference_wav_len"
37
+ data_type: TYPE_INT32
38
+ dims: [1]
39
+ optional: True
40
+ },
41
+ {
42
+ name: "reference_text"
43
+ data_type: TYPE_STRING
44
+ dims: [1]
45
+ },
46
+ {
47
+ name: "target_text"
48
+ data_type: TYPE_STRING
49
+ dims: [1]
50
+ }
51
+ ]
52
+ output [
53
+ {
54
+ name: "waveform"
55
+ data_type: TYPE_FP32
56
+ dims: [ -1 ]
57
+ }
58
+ ]
59
+
60
+ instance_group [
61
+ {
62
+ count: ${bls_instance_num}
63
+ kind: KIND_CPU
64
+ }
65
+ ]
runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep ADDED
File without changes
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "tensorrt_llm"
28
+ backend: "${triton_backend}"
29
+ max_batch_size: ${triton_max_batch_size}
30
+
31
+ model_transaction_policy {
32
+ decoupled: ${decoupled_mode}
33
+ }
34
+
35
+ dynamic_batching {
36
+ preferred_batch_size: [ ${triton_max_batch_size} ]
37
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
38
+ default_queue_policy: { max_queue_size: ${max_queue_size} }
39
+ }
40
+
41
+ input [
42
+ {
43
+ name: "input_ids"
44
+ data_type: TYPE_INT32
45
+ dims: [ -1 ]
46
+ allow_ragged_batch: true
47
+ optional: true
48
+ },
49
+ {
50
+ name: "encoder_input_features"
51
+ data_type: ${encoder_input_features_data_type}
52
+ dims: [ -1, -1 ]
53
+ allow_ragged_batch: true
54
+ optional: true
55
+ },
56
+ {
57
+ name: "encoder_output_lengths"
58
+ data_type: TYPE_INT32
59
+ dims: [ 1 ]
60
+ reshape: { shape: [ ] }
61
+ optional: true
62
+ },
63
+ {
64
+ name: "input_lengths"
65
+ data_type: TYPE_INT32
66
+ dims: [ 1 ]
67
+ reshape: { shape: [ ] }
68
+ },
69
+ {
70
+ name: "request_output_len"
71
+ data_type: TYPE_INT32
72
+ dims: [ 1 ]
73
+ reshape: { shape: [ ] }
74
+ },
75
+ {
76
+ name: "num_return_sequences"
77
+ data_type: TYPE_INT32
78
+ dims: [ 1 ]
79
+ reshape: { shape: [ ] }
80
+ optional: true
81
+ },
82
+ {
83
+ name: "draft_input_ids"
84
+ data_type: TYPE_INT32
85
+ dims: [ -1 ]
86
+ optional: true
87
+ allow_ragged_batch: true
88
+ },
89
+ {
90
+ name: "decoder_input_ids"
91
+ data_type: TYPE_INT32
92
+ dims: [ -1 ]
93
+ optional: true
94
+ allow_ragged_batch: true
95
+ },
96
+ {
97
+ name: "decoder_input_lengths"
98
+ data_type: TYPE_INT32
99
+ dims: [ 1 ]
100
+ optional: true
101
+ reshape: { shape: [ ] }
102
+ },
103
+ {
104
+ name: "draft_logits"
105
+ data_type: ${logits_datatype}
106
+ dims: [ -1, -1 ]
107
+ optional: true
108
+ allow_ragged_batch: true
109
+ },
110
+ {
111
+ name: "draft_acceptance_threshold"
112
+ data_type: TYPE_FP32
113
+ dims: [ 1 ]
114
+ reshape: { shape: [ ] }
115
+ optional: true
116
+ },
117
+ {
118
+ name: "end_id"
119
+ data_type: TYPE_INT32
120
+ dims: [ 1 ]
121
+ reshape: { shape: [ ] }
122
+ optional: true
123
+ },
124
+ {
125
+ name: "pad_id"
126
+ data_type: TYPE_INT32
127
+ dims: [ 1 ]
128
+ reshape: { shape: [ ] }
129
+ optional: true
130
+ },
131
+ {
132
+ name: "stop_words_list"
133
+ data_type: TYPE_INT32
134
+ dims: [ 2, -1 ]
135
+ optional: true
136
+ allow_ragged_batch: true
137
+ },
138
+ {
139
+ name: "bad_words_list"
140
+ data_type: TYPE_INT32
141
+ dims: [ 2, -1 ]
142
+ optional: true
143
+ allow_ragged_batch: true
144
+ },
145
+ {
146
+ name: "embedding_bias"
147
+ data_type: TYPE_FP32
148
+ dims: [ -1 ]
149
+ optional: true
150
+ allow_ragged_batch: true
151
+ },
152
+ {
153
+ name: "beam_width"
154
+ data_type: TYPE_INT32
155
+ dims: [ 1 ]
156
+ reshape: { shape: [ ] }
157
+ optional: true
158
+ },
159
+ {
160
+ name: "temperature"
161
+ data_type: TYPE_FP32
162
+ dims: [ 1 ]
163
+ reshape: { shape: [ ] }
164
+ optional: true
165
+ },
166
+ {
167
+ name: "runtime_top_k"
168
+ data_type: TYPE_INT32
169
+ dims: [ 1 ]
170
+ reshape: { shape: [ ] }
171
+ optional: true
172
+ },
173
+ {
174
+ name: "runtime_top_p"
175
+ data_type: TYPE_FP32
176
+ dims: [ 1 ]
177
+ reshape: { shape: [ ] }
178
+ optional: true
179
+ },
180
+ {
181
+ name: "runtime_top_p_min"
182
+ data_type: TYPE_FP32
183
+ dims: [ 1 ]
184
+ reshape: { shape: [ ] }
185
+ optional: true
186
+ },
187
+ {
188
+ name: "runtime_top_p_decay"
189
+ data_type: TYPE_FP32
190
+ dims: [ 1 ]
191
+ reshape: { shape: [ ] }
192
+ optional: true
193
+ },
194
+ {
195
+ name: "runtime_top_p_reset_ids"
196
+ data_type: TYPE_INT32
197
+ dims: [ 1 ]
198
+ reshape: { shape: [ ] }
199
+ optional: true
200
+ },
201
+ {
202
+ name: "len_penalty"
203
+ data_type: TYPE_FP32
204
+ dims: [ 1 ]
205
+ reshape: { shape: [ ] }
206
+ optional: true
207
+ },
208
+ {
209
+ name: "early_stopping"
210
+ data_type: TYPE_BOOL
211
+ dims: [ 1 ]
212
+ reshape: { shape: [ ] }
213
+ optional: true
214
+ },
215
+ {
216
+ name: "repetition_penalty"
217
+ data_type: TYPE_FP32
218
+ dims: [ 1 ]
219
+ reshape: { shape: [ ] }
220
+ optional: true
221
+ },
222
+ {
223
+ name: "min_length"
224
+ data_type: TYPE_INT32
225
+ dims: [ 1 ]
226
+ reshape: { shape: [ ] }
227
+ optional: true
228
+ },
229
+ {
230
+ name: "beam_search_diversity_rate"
231
+ data_type: TYPE_FP32
232
+ dims: [ 1 ]
233
+ reshape: { shape: [ ] }
234
+ optional: true
235
+ },
236
+ {
237
+ name: "presence_penalty"
238
+ data_type: TYPE_FP32
239
+ dims: [ 1 ]
240
+ reshape: { shape: [ ] }
241
+ optional: true
242
+ },
243
+ {
244
+ name: "frequency_penalty"
245
+ data_type: TYPE_FP32
246
+ dims: [ 1 ]
247
+ reshape: { shape: [ ] }
248
+ optional: true
249
+ },
250
+ {
251
+ name: "random_seed"
252
+ data_type: TYPE_UINT64
253
+ dims: [ 1 ]
254
+ reshape: { shape: [ ] }
255
+ optional: true
256
+ },
257
+ {
258
+ name: "return_log_probs"
259
+ data_type: TYPE_BOOL
260
+ dims: [ 1 ]
261
+ reshape: { shape: [ ] }
262
+ optional: true
263
+ },
264
+ {
265
+ name: "return_context_logits"
266
+ data_type: TYPE_BOOL
267
+ dims: [ 1 ]
268
+ reshape: { shape: [ ] }
269
+ optional: true
270
+ },
271
+ {
272
+ name: "return_generation_logits"
273
+ data_type: TYPE_BOOL
274
+ dims: [ 1 ]
275
+ reshape: { shape: [ ] }
276
+ optional: true
277
+ },
278
+ {
279
+ name: "return_perf_metrics"
280
+ data_type: TYPE_BOOL
281
+ dims: [ 1 ]
282
+ reshape: { shape: [ ] }
283
+ optional: true
284
+ },
285
+ {
286
+ name: "exclude_input_in_output"
287
+ data_type: TYPE_BOOL
288
+ dims: [ 1 ]
289
+ reshape: { shape: [ ] }
290
+ optional: true
291
+ },
292
+ {
293
+ name: "stop"
294
+ data_type: TYPE_BOOL
295
+ dims: [ 1 ]
296
+ reshape: { shape: [ ] }
297
+ optional: true
298
+ },
299
+ {
300
+ name: "streaming"
301
+ data_type: TYPE_BOOL
302
+ dims: [ 1 ]
303
+ reshape: { shape: [ ] }
304
+ optional: true
305
+ },
306
+ {
307
+ name: "prompt_embedding_table"
308
+ data_type: TYPE_FP16
309
+ dims: [ -1, -1 ]
310
+ optional: true
311
+ allow_ragged_batch: true
312
+ },
313
+ {
314
+ name: "prompt_table_extra_ids"
315
+ data_type: TYPE_UINT64
316
+ dims: [ -1 ]
317
+ optional: true
318
+ allow_ragged_batch: true
319
+ },
320
+ {
321
+ name: "prompt_vocab_size"
322
+ data_type: TYPE_INT32
323
+ dims: [ 1 ]
324
+ reshape: { shape: [ ] }
325
+ optional: true
326
+ },
327
+ # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
328
+ {
329
+ name: "cross_attention_mask"
330
+ data_type: TYPE_BOOL
331
+ dims: [ -1, -1 ]
332
+ optional: true
333
+ allow_ragged_batch: true
334
+ },
335
+ # Mrope param when mrope is used
336
+ {
337
+ name: "mrope_rotary_cos_sin"
338
+ data_type: TYPE_FP32
339
+ dims: [ -1 ]
340
+ optional: true
341
+ },
342
+ {
343
+ name: "mrope_position_deltas"
344
+ data_type: TYPE_INT64
345
+ dims: [ 1 ]
346
+ optional: true
347
+ },
348
+ # the unique task ID for the given LoRA.
349
+ # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
350
+ # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
351
+ # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
352
+ {
353
+ name: "lora_task_id"
354
+ data_type: TYPE_UINT64
355
+ dims: [ 1 ]
356
+ reshape: { shape: [ ] }
357
+ optional: true
358
+ },
359
+ # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
360
+ # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
361
+ # each of the in / out tensors are first flattened and then concatenated together in the format above.
362
+ # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
363
+ {
364
+ name: "lora_weights"
365
+ data_type: TYPE_FP16
366
+ dims: [ -1, -1 ]
367
+ optional: true
368
+ allow_ragged_batch: true
369
+ },
370
+ # module identifier (same size a first dimension of lora_weights)
371
+ # See LoraModule::ModuleType for model id mapping
372
+ #
373
+ # "attn_qkv": 0 # compbined qkv adapter
374
+ # "attn_q": 1 # q adapter
375
+ # "attn_k": 2 # k adapter
376
+ # "attn_v": 3 # v adapter
377
+ # "attn_dense": 4 # adapter for the dense layer in attention
378
+ # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
379
+ # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
380
+ # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
381
+ #
382
+ # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
383
+ {
384
+ name: "lora_config"
385
+ data_type: TYPE_INT32
386
+ dims: [ -1, 3 ]
387
+ optional: true
388
+ allow_ragged_batch: true
389
+ },
390
+ {
391
+ name: "context_phase_params"
392
+ data_type: TYPE_UINT8
393
+ dims: [ -1 ]
394
+ optional: true
395
+ allow_ragged_batch: true
396
+ },
397
+ # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
398
+ {
399
+ name: "skip_cross_attn_blocks"
400
+ data_type: TYPE_BOOL
401
+ dims: [ 1 ]
402
+ optional: true
403
+ allow_ragged_batch: true
404
+ },
405
+ {
406
+ name: "retention_token_range_starts"
407
+ data_type: TYPE_INT32
408
+ dims: [ -1 ]
409
+ optional: true
410
+ allow_ragged_batch: true
411
+ },
412
+ {
413
+ name: "retention_token_range_ends"
414
+ data_type: TYPE_INT32
415
+ dims: [ -1 ]
416
+ optional: true
417
+ allow_ragged_batch: true
418
+ },
419
+ {
420
+ name: "retention_token_range_priorities"
421
+ data_type: TYPE_INT32
422
+ dims: [ -1 ]
423
+ optional: true
424
+ allow_ragged_batch: true
425
+ },
426
+ {
427
+ name: "retention_token_range_durations_ms"
428
+ data_type: TYPE_INT32
429
+ dims: [ -1 ]
430
+ optional: true
431
+ allow_ragged_batch: true
432
+ },
433
+ {
434
+ name: "retention_decode_priority"
435
+ data_type: TYPE_INT32
436
+ dims: [ 1 ]
437
+ optional: true
438
+ allow_ragged_batch: true
439
+ },
440
+ {
441
+ name: "retention_decode_duration_ms"
442
+ data_type: TYPE_INT32
443
+ dims: [ 1 ]
444
+ optional: true
445
+ allow_ragged_batch: true
446
+ },
447
+ {
448
+ name: "guided_decoding_guide_type"
449
+ data_type: TYPE_STRING
450
+ dims: [ 1 ]
451
+ optional: true
452
+ allow_ragged_batch: true
453
+ },
454
+ {
455
+ name: "guided_decoding_guide"
456
+ data_type: TYPE_STRING
457
+ dims: [ 1 ]
458
+ optional: true
459
+ allow_ragged_batch: true
460
+ },
461
+ {
462
+ name: "lookahead_window_size"
463
+ data_type: TYPE_INT32
464
+ dims: [ 1 ]
465
+ optional: true
466
+ allow_ragged_batch: true
467
+ },
468
+ {
469
+ name: "lookahead_ngram_size"
470
+ data_type: TYPE_INT32
471
+ dims: [ 1 ]
472
+ optional: true
473
+ allow_ragged_batch: true
474
+ },
475
+ {
476
+ name: "lookahead_verification_set_size"
477
+ data_type: TYPE_INT32
478
+ dims: [ 1 ]
479
+ optional: true
480
+ allow_ragged_batch: true
481
+ }
482
+ ]
483
+ output [
484
+ {
485
+ name: "output_ids"
486
+ data_type: TYPE_INT32
487
+ dims: [ -1, -1 ]
488
+ },
489
+ {
490
+ name: "sequence_length"
491
+ data_type: TYPE_INT32
492
+ dims: [ -1 ]
493
+ },
494
+ {
495
+ name: "cum_log_probs"
496
+ data_type: TYPE_FP32
497
+ dims: [ -1 ]
498
+ },
499
+ {
500
+ name: "output_log_probs"
501
+ data_type: TYPE_FP32
502
+ dims: [ -1, -1 ]
503
+ },
504
+ {
505
+ name: "context_logits"
506
+ data_type: ${logits_datatype}
507
+ dims: [ -1, -1 ]
508
+ },
509
+ {
510
+ name: "generation_logits"
511
+ data_type: ${logits_datatype}
512
+ dims: [ -1, -1, -1 ]
513
+ },
514
+ {
515
+ name: "batch_index"
516
+ data_type: TYPE_INT32
517
+ dims: [ 1 ]
518
+ },
519
+ {
520
+ name: "sequence_index"
521
+ data_type: TYPE_INT32
522
+ dims: [ 1 ]
523
+ },
524
+ {
525
+ name: "context_phase_params"
526
+ data_type: TYPE_UINT8
527
+ dims: [ -1 ]
528
+ },
529
+ {
530
+ name: "kv_cache_alloc_new_blocks"
531
+ data_type: TYPE_INT32
532
+ dims: [ 1 ]
533
+ },
534
+ {
535
+ name: "kv_cache_reused_blocks"
536
+ data_type: TYPE_INT32
537
+ dims: [ 1 ]
538
+ },
539
+ {
540
+ name: "kv_cache_alloc_total_blocks"
541
+ data_type: TYPE_INT32
542
+ dims: [ 1 ]
543
+ },
544
+ {
545
+ name: "arrival_time_ns"
546
+ data_type: TYPE_INT64
547
+ dims: [ 1 ]
548
+ },
549
+ {
550
+ name: "first_scheduled_time_ns"
551
+ data_type: TYPE_INT64
552
+ dims: [ 1 ]
553
+ },
554
+ {
555
+ name: "first_token_time_ns"
556
+ data_type: TYPE_INT64
557
+ dims: [ 1 ]
558
+ },
559
+ {
560
+ name: "last_token_time_ns"
561
+ data_type: TYPE_INT64
562
+ dims: [ 1 ]
563
+ },
564
+ {
565
+ name: "acceptance_rate"
566
+ data_type: TYPE_FP32
567
+ dims: [ 1 ]
568
+ },
569
+ {
570
+ name: "total_accepted_draft_tokens"
571
+ data_type: TYPE_INT32
572
+ dims: [ 1 ]
573
+ },
574
+ {
575
+ name: "total_draft_tokens"
576
+ data_type: TYPE_INT32
577
+ dims: [ 1 ]
578
+ }
579
+ ]
580
+ instance_group [
581
+ {
582
+ count: 1
583
+ kind : KIND_CPU
584
+ }
585
+ ]
586
+ parameters: {
587
+ key: "max_beam_width"
588
+ value: {
589
+ string_value: "${max_beam_width}"
590
+ }
591
+ }
592
+ parameters: {
593
+ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
594
+ value: {
595
+ string_value: "no"
596
+ }
597
+ }
598
+ parameters: {
599
+ key: "gpt_model_type"
600
+ value: {
601
+ string_value: "${batching_strategy}"
602
+ }
603
+ }
604
+ parameters: {
605
+ key: "gpt_model_path"
606
+ value: {
607
+ string_value: "${engine_dir}"
608
+ }
609
+ }
610
+ parameters: {
611
+ key: "encoder_model_path"
612
+ value: {
613
+ string_value: "${encoder_engine_dir}"
614
+ }
615
+ }
616
+ parameters: {
617
+ key: "max_tokens_in_paged_kv_cache"
618
+ value: {
619
+ string_value: "${max_tokens_in_paged_kv_cache}"
620
+ }
621
+ }
622
+ parameters: {
623
+ key: "max_attention_window_size"
624
+ value: {
625
+ string_value: "${max_attention_window_size}"
626
+ }
627
+ }
628
+ parameters: {
629
+ key: "sink_token_length"
630
+ value: {
631
+ string_value: "${sink_token_length}"
632
+ }
633
+ }
634
+ parameters: {
635
+ key: "batch_scheduler_policy"
636
+ value: {
637
+ string_value: "${batch_scheduler_policy}"
638
+ }
639
+ }
640
+ parameters: {
641
+ key: "kv_cache_free_gpu_mem_fraction"
642
+ value: {
643
+ string_value: "${kv_cache_free_gpu_mem_fraction}"
644
+ }
645
+ }
646
+ parameters: {
647
+ key: "cross_kv_cache_fraction"
648
+ value: {
649
+ string_value: "${cross_kv_cache_fraction}"
650
+ }
651
+ }
652
+ parameters: {
653
+ key: "kv_cache_host_memory_bytes"
654
+ value: {
655
+ string_value: "${kv_cache_host_memory_bytes}"
656
+ }
657
+ }
658
+ # kv_cache_onboard_blocks is for internal implementation.
659
+ parameters: {
660
+ key: "kv_cache_onboard_blocks"
661
+ value: {
662
+ string_value: "${kv_cache_onboard_blocks}"
663
+ }
664
+ }
665
+ # enable_trt_overlap is deprecated and doesn't have any effect on the runtime
666
+ # parameters: {
667
+ # key: "enable_trt_overlap"
668
+ # value: {
669
+ # string_value: "${enable_trt_overlap}"
670
+ # }
671
+ # }
672
+ parameters: {
673
+ key: "exclude_input_in_output"
674
+ value: {
675
+ string_value: "${exclude_input_in_output}"
676
+ }
677
+ }
678
+ parameters: {
679
+ key: "cancellation_check_period_ms"
680
+ value: {
681
+ string_value: "${cancellation_check_period_ms}"
682
+ }
683
+ }
684
+ parameters: {
685
+ key: "stats_check_period_ms"
686
+ value: {
687
+ string_value: "${stats_check_period_ms}"
688
+ }
689
+ }
690
+ parameters: {
691
+ key: "iter_stats_max_iterations"
692
+ value: {
693
+ string_value: "${iter_stats_max_iterations}"
694
+ }
695
+ }
696
+ parameters: {
697
+ key: "request_stats_max_iterations"
698
+ value: {
699
+ string_value: "${request_stats_max_iterations}"
700
+ }
701
+ }
702
+ parameters: {
703
+ key: "enable_kv_cache_reuse"
704
+ value: {
705
+ string_value: "${enable_kv_cache_reuse}"
706
+ }
707
+ }
708
+ parameters: {
709
+ key: "normalize_log_probs"
710
+ value: {
711
+ string_value: "${normalize_log_probs}"
712
+ }
713
+ }
714
+ parameters: {
715
+ key: "enable_chunked_context"
716
+ value: {
717
+ string_value: "${enable_chunked_context}"
718
+ }
719
+ }
720
+ parameters: {
721
+ key: "gpu_device_ids"
722
+ value: {
723
+ string_value: "${gpu_device_ids}"
724
+ }
725
+ }
726
+ parameters: {
727
+ key: "participant_ids"
728
+ value: {
729
+ string_value: "${participant_ids}"
730
+ }
731
+ }
732
+ parameters: {
733
+ key: "lora_cache_optimal_adapter_size"
734
+ value: {
735
+ string_value: "${lora_cache_optimal_adapter_size}"
736
+ }
737
+ }
738
+ parameters: {
739
+ key: "lora_cache_max_adapter_size"
740
+ value: {
741
+ string_value: "${lora_cache_max_adapter_size}"
742
+ }
743
+ }
744
+ parameters: {
745
+ key: "lora_cache_gpu_memory_fraction"
746
+ value: {
747
+ string_value: "${lora_cache_gpu_memory_fraction}"
748
+ }
749
+ }
750
+ parameters: {
751
+ key: "lora_cache_host_memory_bytes"
752
+ value: {
753
+ string_value: "${lora_cache_host_memory_bytes}"
754
+ }
755
+ }
756
+ parameters: {
757
+ key: "lora_prefetch_dir"
758
+ value: {
759
+ string_value: "${lora_prefetch_dir}"
760
+ }
761
+ }
762
+ parameters: {
763
+ key: "decoding_mode"
764
+ value: {
765
+ string_value: "${decoding_mode}"
766
+ }
767
+ }
768
+ parameters: {
769
+ key: "executor_worker_path"
770
+ value: {
771
+ string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
772
+ }
773
+ }
774
+ parameters: {
775
+ key: "lookahead_window_size"
776
+ value: {
777
+ string_value: "${lookahead_window_size}"
778
+ }
779
+ }
780
+ parameters: {
781
+ key: "lookahead_ngram_size"
782
+ value: {
783
+ string_value: "${lookahead_ngram_size}"
784
+ }
785
+ }
786
+ parameters: {
787
+ key: "lookahead_verification_set_size"
788
+ value: {
789
+ string_value: "${lookahead_verification_set_size}"
790
+ }
791
+ }
792
+ parameters: {
793
+ key: "medusa_choices"
794
+ value: {
795
+ string_value: "${medusa_choices}"
796
+ }
797
+ }
798
+ parameters: {
799
+ key: "eagle_choices"
800
+ value: {
801
+ string_value: "${eagle_choices}"
802
+ }
803
+ }
804
+ parameters: {
805
+ key: "gpu_weights_percent"
806
+ value: {
807
+ string_value: "${gpu_weights_percent}"
808
+ }
809
+ }
810
+ parameters: {
811
+ key: "enable_context_fmha_fp32_acc"
812
+ value: {
813
+ string_value: "${enable_context_fmha_fp32_acc}"
814
+ }
815
+ }
816
+ parameters: {
817
+ key: "multi_block_mode"
818
+ value: {
819
+ string_value: "${multi_block_mode}"
820
+ }
821
+ }
822
+ parameters: {
823
+ key: "cuda_graph_mode"
824
+ value: {
825
+ string_value: "${cuda_graph_mode}"
826
+ }
827
+ }
828
+ parameters: {
829
+ key: "cuda_graph_cache_size"
830
+ value: {
831
+ string_value: "${cuda_graph_cache_size}"
832
+ }
833
+ }
834
+ parameters: {
835
+ key: "speculative_decoding_fast_logits"
836
+ value: {
837
+ string_value: "${speculative_decoding_fast_logits}"
838
+ }
839
+ }
840
+ parameters: {
841
+ key: "tokenizer_dir"
842
+ value: {
843
+ string_value: "${tokenizer_dir}"
844
+ }
845
+ }
846
+ parameters: {
847
+ key: "guided_decoding_backend"
848
+ value: {
849
+ string_value: "${guided_decoding_backend}"
850
+ }
851
+ }
852
+ parameters: {
853
+ key: "xgrammar_tokenizer_info_path"
854
+ value: {
855
+ string_value: "${xgrammar_tokenizer_info_path}"
856
+ }
857
+ }
runtime/triton_trtllm/model_repo/vocoder/1/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ import json
28
+ import os
29
+ import logging
30
+ from typing import List, Dict
31
+
32
+ import torch
33
+ from torch.utils.dlpack import to_dlpack
34
+
35
+ import triton_python_backend_utils as pb_utils
36
+
37
+ from sparktts.models.bicodec import BiCodec
38
+
39
+ # Configure logging
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
41
+ logger = logging.getLogger(__name__)
42
+
43
+ class TritonPythonModel:
44
+ """Triton Python model for vocoder.
45
+
46
+ This model takes global and semantic tokens as input and generates audio waveforms
47
+ using the BiCodec vocoder.
48
+ """
49
+
50
+ def initialize(self, args):
51
+ """Initialize the model.
52
+
53
+ Args:
54
+ args: Dictionary containing model configuration
55
+ """
56
+ # Parse model parameters
57
+ parameters = json.loads(args['model_config'])['parameters']
58
+ model_params = {key: value["string_value"] for key, value in parameters.items()}
59
+ model_dir = model_params["model_dir"]
60
+
61
+ # Initialize device and vocoder
62
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
64
+
65
+ self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
66
+ del self.vocoder.encoder, self.vocoder.postnet
67
+ self.vocoder.eval().to(self.device) # Set model to evaluation mode
68
+
69
+ logger.info("Vocoder initialized successfully")
70
+
71
+
72
+ def execute(self, requests):
73
+ """Execute inference on the batched requests.
74
+
75
+ Args:
76
+ requests: List of inference requests
77
+
78
+ Returns:
79
+ List of inference responses containing generated waveforms
80
+ """
81
+ global_tokens_list, semantic_tokens_list = [], []
82
+
83
+ # Process each request in batch
84
+ for request in requests:
85
+ global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
86
+ semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
87
+ global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
88
+ semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
89
+
90
+ # Concatenate tokens for batch processing
91
+ global_tokens = torch.cat(global_tokens_list, dim=0)
92
+ semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
93
+
94
+
95
+ # Generate waveforms
96
+ with torch.no_grad():
97
+ wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
98
+
99
+ # Prepare responses
100
+ responses = []
101
+ for i in range(len(requests)):
102
+ wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
103
+ inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
104
+ responses.append(inference_response)
105
+
106
+ return responses
runtime/triton_trtllm/model_repo/vocoder/config.pbtxt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "vocoder"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ parameters [
22
+ {
23
+ key: "model_dir",
24
+ value: {string_value:"${model_dir}"}
25
+ }
26
+ ]
27
+
28
+ input [
29
+ {
30
+ name: "global_tokens"
31
+ data_type: TYPE_INT32
32
+ dims: [-1]
33
+ },
34
+ {
35
+ name: "semantic_tokens"
36
+ data_type: TYPE_INT32
37
+ dims: [-1]
38
+ }
39
+ ]
40
+ output [
41
+ {
42
+ name: "waveform"
43
+ data_type: TYPE_FP32
44
+ dims: [ -1 ]
45
+ }
46
+ ]
47
+
48
+ instance_group [
49
+ {
50
+ count: 1
51
+ kind: KIND_CPU
52
+ }
53
+ ]
runtime/triton_trtllm/run.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PYTHONPATH=../../../Spark-TTS/
2
+ export CUDA_VISIBLE_DEVICES=0
3
+ stage=$1
4
+ stop_stage=$2
5
+ echo "Start stage: $stage, Stop stage: $stop_stage"
6
+
7
+ huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
8
+ trt_dtype=bfloat16
9
+ trt_weights_dir=./tllm_checkpoint_${trt_dtype}
10
+ trt_engines_dir=./trt_engines_${trt_dtype}
11
+
12
+ model_repo=./model_repo_test
13
+
14
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
15
+ echo "Downloading Spark-TTS-0.5B from HuggingFace"
16
+ huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
17
+ # pip install -r /workspace_yuekai/spark-tts/Spark-TTS/requirements.txt
18
+ fi
19
+
20
+
21
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
22
+ echo "Converting checkpoint to TensorRT weights"
23
+ python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
24
+ --output_dir $trt_weights_dir \
25
+ --dtype $trt_dtype || exit 1
26
+
27
+ echo "Building TensorRT engines"
28
+ trtllm-build --checkpoint_dir $trt_weights_dir \
29
+ --output_dir $trt_engines_dir \
30
+ --max_batch_size 16 \
31
+ --max_num_tokens 32768 \
32
+ --gemm_plugin $trt_dtype || exit 1
33
+ fi
34
+
35
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
36
+ echo "Creating model repository"
37
+ rm -rf $model_repo
38
+ cp -r ./model_repo $model_repo
39
+
40
+ ENGINE_PATH=$trt_engines_dir
41
+ MAX_QUEUE_DELAY_MICROSECONDS=0
42
+ MODEL_DIR=$huggingface_model_local_dir
43
+ LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
44
+ BLS_INSTANCE_NUM=4
45
+ TRITON_MAX_BATCH_SIZE=16
46
+
47
+ python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
48
+ python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
49
+ python3 scripts/fill_template.py -i ${model_repo}/spark_tts/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
50
+ python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
51
+
52
+ fi
53
+
54
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
55
+ echo "Starting Triton server"
56
+ tritonserver --model-repository ${model_repo}
57
+ fi
58
+
59
+
60
+ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
61
+ echo "Running client"
62
+ num_task=2
63
+ python3 client_grpc.py \
64
+ --server-addr localhost \
65
+ --model-name spark_tts \
66
+ --num-tasks $num_task \
67
+ --log-dir ./log_concurrent_tasks_${num_task}
68
+ fi
runtime/triton_trtllm/scripts/convert_checkpoint.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ import traceback
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ from transformers import AutoConfig
8
+
9
+ import tensorrt_llm
10
+ from tensorrt_llm._utils import release_gc
11
+ from tensorrt_llm.logger import logger
12
+ from tensorrt_llm.mapping import Mapping
13
+ from tensorrt_llm.models import QWenForCausalLM
14
+ from tensorrt_llm.models.modeling_utils import QuantConfig
15
+ from tensorrt_llm.quantization import QuantAlgo
16
+
17
+
18
+ def parse_arguments():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--model_dir', type=str, default=None, required=True)
21
+ parser.add_argument('--tp_size',
22
+ type=int,
23
+ default=1,
24
+ help='N-way tensor parallelism size')
25
+ parser.add_argument('--pp_size',
26
+ type=int,
27
+ default=1,
28
+ help='N-way pipeline parallelism size')
29
+ parser.add_argument(
30
+ '--dtype',
31
+ type=str,
32
+ default='auto',
33
+ choices=['auto', 'float16', 'bfloat16', 'float32'],
34
+ help=
35
+ "The data type for the model weights and activations if not quantized. "
36
+ "If 'auto', the data type is automatically inferred from the source model; "
37
+ "however, if the source dtype is float32, it is converted to float16.")
38
+ parser.add_argument(
39
+ '--use_weight_only',
40
+ default=False,
41
+ action="store_true",
42
+ help='Quantize weights for the various GEMMs to INT4/INT8.'
43
+ 'See --weight_only_precision to set the precision')
44
+ parser.add_argument(
45
+ '--disable_weight_only_quant_plugin',
46
+ default=False,
47
+ action="store_true",
48
+ help=
49
+ 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
50
+ 'You must also use --use_weight_only for that argument to have an impact.'
51
+ )
52
+ parser.add_argument(
53
+ '--weight_only_precision',
54
+ const='int8',
55
+ type=str,
56
+ nargs='?',
57
+ default='int8',
58
+ choices=['int8', 'int4', 'int4_gptq'],
59
+ help=
60
+ 'Define the precision for the weights when using weight-only quantization.'
61
+ 'You must also use --use_weight_only for that argument to have an impact.'
62
+ )
63
+ parser.add_argument(
64
+ '--calib_dataset',
65
+ type=str,
66
+ default='ccdv/cnn_dailymail',
67
+ help=
68
+ "The huggingface dataset name or the local directory of the dataset for calibration."
69
+ )
70
+ parser.add_argument(
71
+ "--smoothquant",
72
+ "-sq",
73
+ type=float,
74
+ default=None,
75
+ help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
76
+ " to Smoothquant the model, and output int8 weights."
77
+ " A good first try is 0.5. Must be in [0, 1]")
78
+ parser.add_argument(
79
+ '--per_channel',
80
+ action="store_true",
81
+ default=False,
82
+ help=
83
+ 'By default, we use a single static scaling factor for the GEMM\'s result. '
84
+ 'per_channel instead uses a different static scaling factor for each channel. '
85
+ 'The latter is usually more accurate, but a little slower.')
86
+ parser.add_argument(
87
+ '--per_token',
88
+ action="store_true",
89
+ default=False,
90
+ help=
91
+ 'By default, we use a single static scaling factor to scale activations in the int8 range. '
92
+ 'per_token chooses at run time, and for each token, a custom scaling factor. '
93
+ 'The latter is usually more accurate, but a little slower.')
94
+ parser.add_argument(
95
+ '--int8_kv_cache',
96
+ default=False,
97
+ action="store_true",
98
+ help=
99
+ 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
100
+ )
101
+ parser.add_argument(
102
+ '--per_group',
103
+ default=False,
104
+ action="store_true",
105
+ help=
106
+ 'By default, we use a single static scaling factor to scale weights in the int4 range. '
107
+ 'per_group chooses at run time, and for each group, a custom scaling factor. '
108
+ 'The flag is built for GPTQ/AWQ quantization.')
109
+
110
+ parser.add_argument('--group_size',
111
+ type=int,
112
+ default=128,
113
+ help='Group size used in GPTQ quantization.')
114
+
115
+ parser.add_argument("--load_model_on_cpu", action="store_true")
116
+ parser.add_argument(
117
+ '--use_parallel_embedding',
118
+ action="store_true",
119
+ default=False,
120
+ help=
121
+ 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
122
+ )
123
+ parser.add_argument(
124
+ '--embedding_sharding_dim',
125
+ type=int,
126
+ default=0,
127
+ choices=[0, 1],
128
+ help=
129
+ 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
130
+ 'To shard it along hidden dimension, set embedding_sharding_dim=1'
131
+ 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
132
+ )
133
+ parser.add_argument('--output_dir',
134
+ type=str,
135
+ default='tllm_checkpoint',
136
+ help='The path to save the TensorRT-LLM checkpoint')
137
+ parser.add_argument(
138
+ '--workers',
139
+ type=int,
140
+ default=1,
141
+ help='The number of workers for converting checkpoint in parallel')
142
+ parser.add_argument(
143
+ '--moe_tp_size',
144
+ type=int,
145
+ default=-1,
146
+ help=
147
+ 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
148
+ )
149
+ parser.add_argument(
150
+ '--moe_ep_size',
151
+ type=int,
152
+ default=-1,
153
+ help=
154
+ 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
155
+ )
156
+ args = parser.parse_args()
157
+ return args
158
+
159
+
160
+ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
161
+ '''return config dict with quantization info based on the command line args
162
+ '''
163
+ quant_config = QuantConfig()
164
+ if args.use_weight_only:
165
+ if args.weight_only_precision == 'int8':
166
+ quant_config.quant_algo = QuantAlgo.W8A16
167
+ elif args.weight_only_precision == 'int4':
168
+ quant_config.quant_algo = QuantAlgo.W4A16
169
+ elif args.smoothquant:
170
+ quant_config.smoothquant_val = args.smoothquant
171
+ if args.per_channel:
172
+ if args.per_token:
173
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
174
+ else:
175
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
176
+ else:
177
+ if args.per_token:
178
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
179
+ else:
180
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
181
+
182
+ if args.int8_kv_cache:
183
+ quant_config.kv_cache_quant_algo = QuantAlgo.INT8
184
+
185
+ if args.weight_only_precision == 'int4_gptq':
186
+ quant_config.group_size = args.group_size
187
+ quant_config.has_zero_point = True
188
+ quant_config.pre_quant_scale = False
189
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
190
+
191
+ return quant_config
192
+
193
+
194
+ def update_quant_config_from_hf(quant_config, hf_config,
195
+ override_fields) -> tuple[QuantConfig, dict]:
196
+ hf_config_dict = hf_config.to_dict()
197
+ if hf_config_dict.get('quantization_config'):
198
+ # update the quant_algo, and clamp_val.
199
+ if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
200
+ logger.info(
201
+ "Load quantization configs from huggingface model_config.")
202
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
203
+ quant_config.group_size = hf_config_dict['quantization_config'].get(
204
+ 'group_size', 128)
205
+ quant_config.has_zero_point = hf_config_dict[
206
+ 'quantization_config'].get('zero_point', False)
207
+ override_fields.update({"use_autoawq": True})
208
+ elif hf_config_dict['quantization_config'].get(
209
+ 'quant_method') == 'gptq':
210
+ logger.info(
211
+ "Load quantization configs from huggingface model_config.")
212
+ desc_act = hf_config_dict['quantization_config'].get(
213
+ 'desc_act', False)
214
+ if desc_act:
215
+ raise ValueError("GPTQ with desc_act=True is not implemented!")
216
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
217
+ quant_config.group_size = hf_config_dict['quantization_config'].get(
218
+ 'group_size', 128)
219
+ quant_config.has_zero_point = hf_config_dict[
220
+ 'quantization_config'].get('sym', False)
221
+ return quant_config, override_fields
222
+
223
+
224
+ def args_to_build_options(args):
225
+ return {
226
+ 'use_parallel_embedding': args.use_parallel_embedding,
227
+ 'embedding_sharding_dim': args.embedding_sharding_dim,
228
+ 'disable_weight_only_quant_plugin':
229
+ args.disable_weight_only_quant_plugin
230
+ }
231
+
232
+
233
+ def convert_and_save_hf(args):
234
+ model_dir = args.model_dir
235
+ world_size = args.tp_size * args.pp_size
236
+ # Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
237
+ # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
238
+ # before the refactor is done.
239
+ override_fields = {}
240
+ override_fields.update(args_to_build_options(args))
241
+ quant_config = args_to_quant_config(args)
242
+
243
+ try:
244
+ hf_config = AutoConfig.from_pretrained(model_dir,
245
+ trust_remote_code=True)
246
+ quant_config, override_fields = update_quant_config_from_hf(
247
+ quant_config, hf_config, override_fields)
248
+ except:
249
+ logger.warning("AutoConfig cannot load the huggingface config.")
250
+
251
+ if args.smoothquant is not None or args.int8_kv_cache:
252
+ mapping = Mapping(
253
+ world_size=world_size,
254
+ tp_size=args.tp_size,
255
+ pp_size=args.pp_size,
256
+ moe_tp_size=args.moe_tp_size,
257
+ moe_ep_size=args.moe_ep_size,
258
+ )
259
+ QWenForCausalLM.quantize(args.model_dir,
260
+ args.output_dir,
261
+ dtype=args.dtype,
262
+ mapping=mapping,
263
+ quant_config=quant_config,
264
+ calib_dataset=args.calib_dataset,
265
+ **override_fields)
266
+ else:
267
+
268
+ def convert_and_save_rank(args, rank):
269
+ mapping = Mapping(world_size=world_size,
270
+ rank=rank,
271
+ tp_size=args.tp_size,
272
+ pp_size=args.pp_size,
273
+ moe_tp_size=args.moe_tp_size,
274
+ moe_ep_size=args.moe_ep_size)
275
+ qwen = QWenForCausalLM.from_hugging_face(model_dir,
276
+ args.dtype,
277
+ mapping=mapping,
278
+ quant_config=quant_config,
279
+ **override_fields)
280
+ qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
281
+ del qwen
282
+
283
+ execute(args.workers, [convert_and_save_rank] * world_size, args)
284
+ release_gc()
285
+
286
+
287
+ def execute(workers, func, args):
288
+ if workers == 1:
289
+ for rank, f in enumerate(func):
290
+ f(args, rank)
291
+ else:
292
+ with ThreadPoolExecutor(max_workers=workers) as p:
293
+ futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
294
+ exceptions = []
295
+ for future in as_completed(futures):
296
+ try:
297
+ future.result()
298
+ except Exception as e:
299
+ traceback.print_exc()
300
+ exceptions.append(e)
301
+ assert len(
302
+ exceptions
303
+ ) == 0, "Checkpoint conversion failed, please check error log."
304
+
305
+
306
+ def main():
307
+ print(tensorrt_llm.__version__)
308
+ args = parse_arguments()
309
+
310
+ if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
311
+ # moe default to tp-only
312
+ args.moe_tp_size = args.tp_size
313
+ args.moe_ep_size = 1
314
+ elif (args.moe_tp_size == -1):
315
+ args.moe_tp_size = args.tp_size // args.moe_ep_size
316
+ elif (args.moe_ep_size == -1):
317
+ args.moe_ep_size = args.tp_size // args.moe_tp_size
318
+ assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
319
+ ), "moe_tp_size * moe_ep_size must equal to tp_size"
320
+
321
+ tik = time.time()
322
+
323
+ if not os.path.exists(args.output_dir):
324
+ os.makedirs(args.output_dir)
325
+
326
+ assert args.model_dir is not None
327
+ convert_and_save_hf(args)
328
+
329
+ tok = time.time()
330
+ t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
331
+ print(f'Total time of converting checkpoints: {t}')
332
+
333
+
334
+ if __name__ == '__main__':
335
+ main()
runtime/triton_trtllm/scripts/fill_template.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from argparse import ArgumentParser
3
+ from string import Template
4
+
5
+
6
+ def split(string, delimiter):
7
+ """Split a string using delimiter. Supports escaping.
8
+
9
+ Args:
10
+ string (str): The string to split.
11
+ delimiter (str): The delimiter to split the string with.
12
+
13
+ Returns:
14
+ list: A list of strings.
15
+ """
16
+ result = []
17
+ current = ""
18
+ escape = False
19
+ for char in string:
20
+ if escape:
21
+ current += char
22
+ escape = False
23
+ elif char == delimiter:
24
+ result.append(current)
25
+ current = ""
26
+ elif char == "\\":
27
+ escape = True
28
+ else:
29
+ current += char
30
+ result.append(current)
31
+ return result
32
+
33
+
34
+ def main(file_path, substitutions, in_place):
35
+ with open(file_path) as f:
36
+ pbtxt = Template(f.read())
37
+
38
+ sub_dict = {
39
+ "max_queue_size": 0,
40
+ 'max_queue_delay_microseconds': 0,
41
+ }
42
+ for sub in split(substitutions, ","):
43
+ key, value = split(sub, ":")
44
+ sub_dict[key] = value
45
+
46
+ assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
47
+
48
+ pbtxt = pbtxt.safe_substitute(sub_dict)
49
+
50
+ if in_place:
51
+ with open(file_path, "w") as f:
52
+ f.write(pbtxt)
53
+ else:
54
+ print(pbtxt)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ parser = ArgumentParser()
59
+ parser.add_argument("file_path", help="path of the .pbtxt to modify")
60
+ parser.add_argument(
61
+ "substitutions",
62
+ help=
63
+ "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
64
+ )
65
+ parser.add_argument("--in_place",
66
+ "-i",
67
+ action="store_true",
68
+ help="do the operation in-place")
69
+ args = parser.parse_args()
70
+ main(**vars(args))