adamo1139's picture
Update README.md
a2a9bf4 verified
|
raw
history blame
5.35 kB
# Spirit LM Inference Gradio Demo
Copy the github repo, build the spiritlm python package and put models in `checkpoints` folder before running the script. I would suggest to use conda environment for this.
```python
import gradio as gr
from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
from transformers import GenerationConfig
import torchaudio
import torch
import tempfile
import os
import numpy as np
# Initialize the Spirit LM model with the modified class
spirit_lm = Spiritlm("spirit-lm-base-7b")
def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
)
if input_type == "text":
interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
elif input_type == "audio":
# Load audio file
waveform, sample_rate = torchaudio.load(input_content_audio)
interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
else:
raise ValueError("Invalid input type")
outputs = spirit_lm.generate(
interleaved_inputs=interleaved_inputs,
output_modality=OutputModality[output_modality.upper()],
generation_config=generation_config,
)
text_output = ""
audio_output = None
for output in outputs:
if output.content_type == ContentType.TEXT:
text_output = output.content
elif output.content_type == ContentType.SPEECH:
# Ensure output.content is a NumPy array
if isinstance(output.content, np.ndarray):
# Debugging: Print shape and dtype of the audio data
print("Audio data shape:", output.content.shape)
print("Audio data dtype:", output.content.dtype)
# Ensure the audio data is in the correct format
if len(output.content.shape) == 1:
# Mono audio data
audio_data = torch.from_numpy(output.content).unsqueeze(0)
else:
# Stereo audio data
audio_data = torch.from_numpy(output.content)
# Save the audio content to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
torchaudio.save(temp_audio_file.name, audio_data, 16000)
audio_output = temp_audio_file.name
else:
raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))
return text_output, audio_output
# Define the Gradio interface
iface = gr.Interface(
fn=generate_output,
inputs=[
gr.Radio(["text", "audio"], label="Input Type"),
gr.Textbox(label="Input Content (Text)"),
gr.Audio(label="Input Content (Audio)", type="filepath"),
gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality"),
gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
gr.Checkbox(value=True, label="Do Sample"),
],
outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
title="Spirit LM WebUI Demo",
description="Demo for generating text or audio using the Spirit LM model.",
)
# Launch the interface
iface.launch()
```
# Spirit LM Checkpoints
## Download Checkpoints
Checkpoints are in this repo
Please note that Spirit LM is made available under the **FAIR Noncommercial Research License**
License is here: https://github.com/facebookresearch/spiritlm/blob/main/LICENSE
## Structure
The checkpoints directory should look like this:
```
checkpoints/
β”œβ”€β”€ README.md
β”œβ”€β”€ speech_tokenizer
β”‚Β Β  β”œβ”€β”€ hifigan_spiritlm_base
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ config.json
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ generator.pt
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ speakers.txt
β”‚Β Β  β”‚Β Β  └── styles.txt
β”‚Β Β  β”œβ”€β”€ hifigan_spiritlm_expressive_w2v2
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ config.json
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ generator.pt
β”‚Β Β  β”‚Β Β  └── speakers.txt
β”‚Β Β  β”œβ”€β”€ hubert_25hz
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ L11_quantizer_500.pt
β”‚Β Β  β”‚Β Β  └── mhubert_base_25hz.pt
β”‚Β Β  β”œβ”€β”€ style_encoder_w2v2
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ config.json
β”‚Β Β  β”‚Β Β  └── pytorch_model.bin
β”‚Β Β  └── vqvae_f0_quantizer
β”‚Β Β  β”œβ”€β”€ config.yaml
β”‚Β Β  └── model.pt
└── spiritlm_model
β”œβ”€β”€ spirit-lm-base-7b
β”‚Β Β  β”œβ”€β”€ config.json
β”‚Β Β  β”œβ”€β”€ generation_config.json
β”‚Β Β  β”œβ”€β”€ pytorch_model.bin
β”‚Β Β  β”œβ”€β”€ special_tokens_map.json
β”‚Β Β  β”œβ”€β”€ tokenizer_config.json
β”‚Β Β  └── tokenizer.model
└── spirit-lm-expressive-7b
β”œβ”€β”€ config.json
β”œβ”€β”€ generation_config.json
β”œβ”€β”€ pytorch_model.bin
β”œβ”€β”€ special_tokens_map.json
β”œβ”€β”€ tokenizer_config.json
└── tokenizer.model
```