Spirit LM Inference Gradio Demo
Copy the github repo, build the spiritlm python package and put models in checkpoints
folder before running the script. I would suggest to use conda environment for this.
import gradio as gr
from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
from transformers import GenerationConfig
import torchaudio
import torch
import tempfile
import os
import numpy as np
# Initialize the Spirit LM model with the modified class
spirit_lm = Spiritlm("spirit-lm-base-7b")
def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
)
if input_type == "text":
interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
elif input_type == "audio":
# Load audio file
waveform, sample_rate = torchaudio.load(input_content_audio)
interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
else:
raise ValueError("Invalid input type")
outputs = spirit_lm.generate(
interleaved_inputs=interleaved_inputs,
output_modality=OutputModality[output_modality.upper()],
generation_config=generation_config,
)
text_output = ""
audio_output = None
for output in outputs:
if output.content_type == ContentType.TEXT:
text_output = output.content
elif output.content_type == ContentType.SPEECH:
# Ensure output.content is a NumPy array
if isinstance(output.content, np.ndarray):
# Debugging: Print shape and dtype of the audio data
print("Audio data shape:", output.content.shape)
print("Audio data dtype:", output.content.dtype)
# Ensure the audio data is in the correct format
if len(output.content.shape) == 1:
# Mono audio data
audio_data = torch.from_numpy(output.content).unsqueeze(0)
else:
# Stereo audio data
audio_data = torch.from_numpy(output.content)
# Save the audio content to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
torchaudio.save(temp_audio_file.name, audio_data, 16000)
audio_output = temp_audio_file.name
else:
raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))
return text_output, audio_output
# Define the Gradio interface
iface = gr.Interface(
fn=generate_output,
inputs=[
gr.Radio(["text", "audio"], label="Input Type"),
gr.Textbox(label="Input Content (Text)"),
gr.Audio(label="Input Content (Audio)", type="filepath"),
gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality"),
gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
gr.Checkbox(value=True, label="Do Sample"),
],
outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
title="Spirit LM WebUI Demo",
description="Demo for generating text or audio using the Spirit LM model.",
)
# Launch the interface
iface.launch()
Spirit LM Checkpoints
Download Checkpoints
Checkpoints are in this repo
Please note that Spirit LM is made available under the FAIR Noncommercial Research License
License is here: https://github.com/facebookresearch/spiritlm/blob/main/LICENSE
Structure
The checkpoints directory should look like this:
checkpoints/
βββ README.md
βββ speech_tokenizer
β βββ hifigan_spiritlm_base
β β βββ config.json
β β βββ generator.pt
β β βββ speakers.txt
β β βββ styles.txt
β βββ hifigan_spiritlm_expressive_w2v2
β β βββ config.json
β β βββ generator.pt
β β βββ speakers.txt
β βββ hubert_25hz
β β βββ L11_quantizer_500.pt
β β βββ mhubert_base_25hz.pt
β βββ style_encoder_w2v2
β β βββ config.json
β β βββ pytorch_model.bin
β βββ vqvae_f0_quantizer
β βββ config.yaml
β βββ model.pt
βββ spiritlm_model
βββ spirit-lm-base-7b
β βββ config.json
β βββ generation_config.json
β βββ pytorch_model.bin
β βββ special_tokens_map.json
β βββ tokenizer_config.json
β βββ tokenizer.model
βββ spirit-lm-expressive-7b
βββ config.json
βββ generation_config.json
βββ pytorch_model.bin
βββ special_tokens_map.json
βββ tokenizer_config.json
βββ tokenizer.model