# Spirit LM Inference Gradio Demo Copy the github repo, build the spiritlm python package and put models in `checkpoints` folder before running the script. I would suggest to use conda environment for this. You need around 15.5GB of VRAM to run the model with short output length and around 19GB to output 800 tokens. Edit: Audio to audio inference doesn't seem great. Potentially I am tokenizing the audio wrong. Could be also that model doesn't work well with audio IN audio OUT. ```python import gradio as gr from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType from transformers import GenerationConfig import torchaudio import torch import tempfile import os import numpy as np # Initialize the Spirit LM model spirit_lm = Spiritlm("spirit-lm-base-7b") def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample): generation_config = GenerationConfig( temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, do_sample=do_sample, ) if input_type == "text": interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)] elif input_type == "audio": # Load audio file waveform, sample_rate = torchaudio.load(input_content_audio) interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)] else: raise ValueError("Invalid input type") outputs = spirit_lm.generate( interleaved_inputs=interleaved_inputs, output_modality=OutputModality[output_modality.upper()], generation_config=generation_config, ) text_output = "" audio_output = None for output in outputs: if output.content_type == ContentType.TEXT: text_output = output.content elif output.content_type == ContentType.SPEECH: # Ensure output.content is a NumPy array if isinstance(output.content, np.ndarray): # Debugging: Print shape and dtype of the audio data print("Audio data shape:", output.content.shape) print("Audio data dtype:", output.content.dtype) # Ensure the audio data is in the correct format if len(output.content.shape) == 1: # Mono audio data audio_data = torch.from_numpy(output.content).unsqueeze(0) else: # Stereo audio data audio_data = torch.from_numpy(output.content) # Save the audio content to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: torchaudio.save(temp_audio_file.name, audio_data, 16000) audio_output = temp_audio_file.name else: raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content))) return text_output, audio_output # Define the Gradio interface iface = gr.Interface( fn=generate_output, inputs=[ gr.Radio(["text", "audio"], label="Input Type"), gr.Textbox(label="Input Content (Text)"), gr.Audio(label="Input Content (Audio)", type="filepath"), gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality"), gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"), gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"), gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"), gr.Checkbox(value=True, label="Do Sample"), ], outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")], title="Spirit LM WebUI Demo", description="Demo for generating text or audio using the Spirit LM model.", ) # Launch the interface iface.launch() ``` # Spirit LM Checkpoints ## Download Checkpoints Checkpoints are in this repo Please note that Spirit LM is made available under the **FAIR Noncommercial Research License** License is here: https://github.com/facebookresearch/spiritlm/blob/main/LICENSE ## Structure The checkpoints directory should look like this: ``` checkpoints/ ├── README.md ├── speech_tokenizer │   ├── hifigan_spiritlm_base │   │   ├── config.json │   │   ├── generator.pt │   │   ├── speakers.txt │   │   └── styles.txt │   ├── hifigan_spiritlm_expressive_w2v2 │   │   ├── config.json │   │   ├── generator.pt │   │   └── speakers.txt │   ├── hubert_25hz │   │   ├── L11_quantizer_500.pt │   │   └── mhubert_base_25hz.pt │   ├── style_encoder_w2v2 │   │   ├── config.json │   │   └── pytorch_model.bin │   └── vqvae_f0_quantizer │   ├── config.yaml │   └── model.pt └── spiritlm_model ├── spirit-lm-base-7b │   ├── config.json │   ├── generation_config.json │   ├── pytorch_model.bin │   ├── special_tokens_map.json │   ├── tokenizer_config.json │   └── tokenizer.model └── spirit-lm-expressive-7b ├── config.json ├── generation_config.json ├── pytorch_model.bin ├── special_tokens_map.json ├── tokenizer_config.json └── tokenizer.model ```