adamo1139's picture
Update README.md
dbd39df verified
|
raw
history blame
5.61 kB

Spirit LM Inference Gradio Demo

Copy the github repo, build the spiritlm python package and put models in checkpoints folder before running the script. I would suggest to use conda environment for this.

You need around 15.5GB of VRAM to run the model with short output length and around 19GB to output 800 tokens.

Edit: Audio to audio inference doesn't seem great. Potentially I am tokenizing the audio wrong. Could be also that model doesn't work well with audio IN audio OUT.

import gradio as gr
from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
from transformers import GenerationConfig
import torchaudio
import torch
import tempfile
import os
import numpy as np

# Initialize the Spirit LM model
spirit_lm = Spiritlm("spirit-lm-base-7b")

def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
    )

    if input_type == "text":
        interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
    elif input_type == "audio":
        # Load audio file
        waveform, sample_rate = torchaudio.load(input_content_audio)
        interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
    else:
        raise ValueError("Invalid input type")

    outputs = spirit_lm.generate(
        interleaved_inputs=interleaved_inputs,
        output_modality=OutputModality[output_modality.upper()],
        generation_config=generation_config,
    )

    text_output = ""
    audio_output = None

    for output in outputs:
        if output.content_type == ContentType.TEXT:
            text_output = output.content
        elif output.content_type == ContentType.SPEECH:
            # Ensure output.content is a NumPy array
            if isinstance(output.content, np.ndarray):
                # Debugging: Print shape and dtype of the audio data
                print("Audio data shape:", output.content.shape)
                print("Audio data dtype:", output.content.dtype)

                # Ensure the audio data is in the correct format
                if len(output.content.shape) == 1:
                    # Mono audio data
                    audio_data = torch.from_numpy(output.content).unsqueeze(0)
                else:
                    # Stereo audio data
                    audio_data = torch.from_numpy(output.content)

                # Save the audio content to a temporary file
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
                    torchaudio.save(temp_audio_file.name, audio_data, 16000)
                    audio_output = temp_audio_file.name
            else:
                raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))

    return text_output, audio_output

# Define the Gradio interface
iface = gr.Interface(
    fn=generate_output,
    inputs=[
        gr.Radio(["text", "audio"], label="Input Type"),
        gr.Textbox(label="Input Content (Text)"),
        gr.Audio(label="Input Content (Audio)", type="filepath"),
        gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality"),
        gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
        gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
        gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
        gr.Checkbox(value=True, label="Do Sample"),
    ],
    outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
    title="Spirit LM WebUI Demo",
    description="Demo for generating text or audio using the Spirit LM model.",
)

# Launch the interface
iface.launch()

Spirit LM Checkpoints

Download Checkpoints

Checkpoints are in this repo

Please note that Spirit LM is made available under the FAIR Noncommercial Research License

License is here: https://github.com/facebookresearch/spiritlm/blob/main/LICENSE

Structure

The checkpoints directory should look like this:

checkpoints/
β”œβ”€β”€ README.md
β”œβ”€β”€ speech_tokenizer
β”‚   β”œβ”€β”€ hifigan_spiritlm_base
β”‚   β”‚   β”œβ”€β”€ config.json
β”‚   β”‚   β”œβ”€β”€ generator.pt
β”‚   β”‚   β”œβ”€β”€ speakers.txt
β”‚   β”‚   └── styles.txt
β”‚   β”œβ”€β”€ hifigan_spiritlm_expressive_w2v2
β”‚   β”‚   β”œβ”€β”€ config.json
β”‚   β”‚   β”œβ”€β”€ generator.pt
β”‚   β”‚   └── speakers.txt
β”‚   β”œβ”€β”€ hubert_25hz
β”‚   β”‚   β”œβ”€β”€ L11_quantizer_500.pt
β”‚   β”‚   └── mhubert_base_25hz.pt
β”‚   β”œβ”€β”€ style_encoder_w2v2
β”‚   β”‚   β”œβ”€β”€ config.json
β”‚   β”‚   └── pytorch_model.bin
β”‚   └── vqvae_f0_quantizer
β”‚       β”œβ”€β”€ config.yaml
β”‚       └── model.pt
└── spiritlm_model
    β”œβ”€β”€ spirit-lm-base-7b
    β”‚   β”œβ”€β”€ config.json
    β”‚   β”œβ”€β”€ generation_config.json
    β”‚   β”œβ”€β”€ pytorch_model.bin
    β”‚   β”œβ”€β”€ special_tokens_map.json
    β”‚   β”œβ”€β”€ tokenizer_config.json
    β”‚   └── tokenizer.model
    └── spirit-lm-expressive-7b
        β”œβ”€β”€ config.json
        β”œβ”€β”€ generation_config.json
        β”œβ”€β”€ pytorch_model.bin
        β”œβ”€β”€ special_tokens_map.json
        β”œβ”€β”€ tokenizer_config.json
        └── tokenizer.model