|
--- |
|
license: other |
|
pipeline_tag: text-to-audio |
|
library_name: transformers |
|
--- |
|
# Spirit LM Inference Gradio Demo |
|
|
|
Copy the github repo, build the [spiritlm](https://github.com/facebookresearch/spiritlm) python package and put models in `checkpoints` folder before running the script. I would suggest to use conda environment for this. |
|
|
|
You need around 15.5GB of VRAM to run the model with 200 tokens output length and around 19GB to output 800 tokens. |
|
|
|
If you're concerned about pickles from unknown uploader - grab them from a repo maintained by HF staffer - [https://huggingface.co/spirit-lm/Meta-spirit-lm](https://huggingface.co/spirit-lm/Meta-spirit-lm) |
|
|
|
Audio to audio inference doesn't seem good at all. Potentially I am tokenizing the audio wrong. Could be also that model doesn't work well with audio IN audio OUT. |
|
|
|
Script here works with just single speaker - if you know how to get other speakers let me know and I'll update it. |
|
|
|
to install requirements for the sample Gradio demo provided, please run: |
|
|
|
``` |
|
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121 |
|
pip install gradio tempfile transformers numpy |
|
``` |
|
|
|
Remember that you also need to install the [spiritlm](https://github.com/facebookresearch/spiritlm) module. |
|
|
|
```python |
|
import gradio as gr |
|
from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType |
|
from transformers import GenerationConfig |
|
import torchaudio |
|
import torch |
|
import tempfile |
|
import os |
|
import numpy as np |
|
|
|
# Initialize the Spirit LM model with the modified class |
|
spirit_lm = Spiritlm("spirit-lm-base-7b") |
|
|
|
def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample, speaker_id): |
|
generation_config = GenerationConfig( |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
) |
|
|
|
if input_type == "text": |
|
interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)] |
|
elif input_type == "audio": |
|
# Load audio file |
|
waveform, sample_rate = torchaudio.load(input_content_audio) |
|
interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)] |
|
else: |
|
raise ValueError("Invalid input type") |
|
|
|
outputs = spirit_lm.generate( |
|
interleaved_inputs=interleaved_inputs, |
|
output_modality=OutputModality[output_modality.upper()], |
|
generation_config=generation_config, |
|
speaker_id=speaker_id, # Pass the selected speaker ID |
|
) |
|
|
|
text_output = "" |
|
audio_output = None |
|
|
|
for output in outputs: |
|
if output.content_type == ContentType.TEXT: |
|
text_output = output.content |
|
elif output.content_type == ContentType.SPEECH: |
|
# Ensure output.content is a NumPy array |
|
if isinstance(output.content, np.ndarray): |
|
# Debugging: Print shape and dtype of the audio data |
|
print("Audio data shape:", output.content.shape) |
|
print("Audio data dtype:", output.content.dtype) |
|
|
|
# Ensure the audio data is in the correct format |
|
if len(output.content.shape) == 1: |
|
# Mono audio data |
|
audio_data = torch.from_numpy(output.content).unsqueeze(0) |
|
else: |
|
# Stereo audio data |
|
audio_data = torch.from_numpy(output.content) |
|
|
|
# Save the audio content to a temporary file |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: |
|
torchaudio.save(temp_audio_file.name, audio_data, 16000) |
|
audio_output = temp_audio_file.name |
|
else: |
|
raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content))) |
|
|
|
return text_output, audio_output |
|
|
|
# Define the Gradio interface |
|
iface = gr.Interface( |
|
fn=generate_output, |
|
inputs=[ |
|
gr.Radio(["text", "audio"], label="Input Type", value="text"), |
|
gr.Textbox(label="Input Content (Text)"), |
|
gr.Audio(label="Input Content (Audio)", type="filepath"), |
|
gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality", value="SPEECH"), |
|
gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"), |
|
gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"), |
|
gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"), |
|
gr.Checkbox(value=True, label="Do Sample"), |
|
gr.Dropdown(choices=[0, 1, 2, 3], value=0, label="Speaker ID"), |
|
], |
|
outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")], |
|
title="Spirit LM WebUI Demo", |
|
description="Demo for generating text or audio using the Spirit LM model.", |
|
flagging_mode="never", |
|
) |
|
|
|
# Launch the interface |
|
iface.launch() |
|
|
|
|
|
``` |
|
|
|
|
|
# Spirit LM Checkpoints |
|
|
|
## Download Checkpoints |
|
Checkpoints are in this repo |
|
|
|
Please note that Spirit LM is made available under the **FAIR Noncommercial Research License** |
|
|
|
License is here: https://github.com/facebookresearch/spiritlm/blob/main/LICENSE |
|
|
|
## Structure |
|
The checkpoints directory should look like this: |
|
``` |
|
checkpoints/ |
|
βββ README.md |
|
βββ speech_tokenizer |
|
βΒ Β βββ hifigan_spiritlm_base |
|
βΒ Β βΒ Β βββ config.json |
|
βΒ Β βΒ Β βββ generator.pt |
|
βΒ Β βΒ Β βββ speakers.txt |
|
βΒ Β βΒ Β βββ styles.txt |
|
βΒ Β βββ hifigan_spiritlm_expressive_w2v2 |
|
βΒ Β βΒ Β βββ config.json |
|
βΒ Β βΒ Β βββ generator.pt |
|
βΒ Β βΒ Β βββ speakers.txt |
|
βΒ Β βββ hubert_25hz |
|
βΒ Β βΒ Β βββ L11_quantizer_500.pt |
|
βΒ Β βΒ Β βββ mhubert_base_25hz.pt |
|
βΒ Β βββ style_encoder_w2v2 |
|
βΒ Β βΒ Β βββ config.json |
|
βΒ Β βΒ Β βββ pytorch_model.bin |
|
βΒ Β βββ vqvae_f0_quantizer |
|
βΒ Β βββ config.yaml |
|
βΒ Β βββ model.pt |
|
βββ spiritlm_model |
|
βββ spirit-lm-base-7b |
|
βΒ Β βββ config.json |
|
βΒ Β βββ generation_config.json |
|
βΒ Β βββ pytorch_model.bin |
|
βΒ Β βββ special_tokens_map.json |
|
βΒ Β βββ tokenizer_config.json |
|
βΒ Β βββ tokenizer.model |
|
βββ spirit-lm-expressive-7b |
|
βββ config.json |
|
βββ generation_config.json |
|
βββ pytorch_model.bin |
|
βββ special_tokens_map.json |
|
βββ tokenizer_config.json |
|
βββ tokenizer.model |
|
``` |