Google TPUs documentation

Advanced TGI Server Configuration

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Advanced TGI Server Configuration

Jetstream Pytorch and Pytorch XLA backends

Jetstream Pytorch is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available.

We recommend using Jetstream with TGI for the best performance. If for some reason you want to use the Pytorch/XLA backend instead, you can set the JETSTREAM_PT_DISABLE=1 environment variable.

For more information, see our discussion on the difference between jetstream and pytorch XLA

Quantization

When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the QUANTIZATION=1 environment variable. For instance, on a 2x4 TPU v5e (16GB per chip * 8 = 128 GB per pod), you can serve models up to 70B parameters, such as Llama 3.3-70B. The quantization is done in int8 on the fly as the weight loads. As with any quantization option, you can expect a small drop in the model accuracy. Without the quantization option enabled, the model is served in bf16.

How to solve memory requirements

If you encounter Backend(NotEnoughMemory(2048)), here are some solutions that could help with reducing memory usage in TGI:

Optimum-TPU specific arguments:

  • -e QUANTIZATION=1: To enable quantization. This should reduce memory requirements by almost half
  • -e MAX_BATCH_SIZE=n: You can manually reduce the size of the batch size

TGI specific arguments:

  • --max-input-length: Maximum input sequence length
  • --max-total-tokens: Maximum combined input and output tokens
  • --max-batch-prefill-tokens: Maximum tokens for batch processing
  • --max-batch-total-tokens: Maximum total tokens in a batch

To reduce memory usage, you can try smaller values for --max-input-length, --max-total-tokens, --max-batch-prefill-tokens, and --max-batch-total-tokens.

`max-batch-prefill-tokens ≤ max-input-length * max_batch_size`. Otherwise, you will have an error as the configuration does not make sense. If the max-batch-prefill-tokens were bigger, then you would not be able to process any request

Sharding

Sharding is done automatically by the TGI server, so your model uses all the TPUs that are available. We do tensor parallelism, so the layers are automatically split in all available TPUs. However, the TGI router will only see one shard.

More information on tensor parralelsim can be found here https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism.

Understanding the configuration

Key parameters explained:

Required parameters

  • --shm-size 16GB: Increase default shared memory allocation.
  • --privileged: Required for TPU access.
  • --net host: Uses host network mode. Those are needed to run a TPU container so that the container can properly access the TPU hardware.

Optional parameters

  • -v ~/hf_data:/data: Volume mount for model storage, this allows you to not have to re-download the models weights on each startup. You can use any folder you would like as long as it maps back to /data.
  • -e SKIP_WARMUP=1: Disables warmup for quick testing (not recommended for production). Those are parameters used by TGI and optimum-TPU to configure the server behavior.
`--privileged --shm-size 16GB --net host` is required as specify in https://github.com/pytorch/xla

Next steps

Please check the TGI docs for more TGI server configuration options.