Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Spirit LM Checkpoints
|
2 |
|
3 |
## Download Checkpoints
|
|
|
1 |
+
# Spirit LM Inference Gradio Demo
|
2 |
+
|
3 |
+
Copy the github repo, build the spiritlm python package and put models in `checkpoints` folder before running the script. I would suggest to use conda environment for this.
|
4 |
+
|
5 |
+
```python
|
6 |
+
import gradio as gr
|
7 |
+
from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
|
8 |
+
from transformers import GenerationConfig
|
9 |
+
import torchaudio
|
10 |
+
import torch
|
11 |
+
import tempfile
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
# Initialize the Spirit LM model with the modified class
|
16 |
+
spirit_lm = Spiritlm("spirit-lm-base-7b")
|
17 |
+
|
18 |
+
def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample):
|
19 |
+
generation_config = GenerationConfig(
|
20 |
+
temperature=temperature,
|
21 |
+
top_p=top_p,
|
22 |
+
max_new_tokens=max_new_tokens,
|
23 |
+
do_sample=do_sample,
|
24 |
+
)
|
25 |
+
|
26 |
+
if input_type == "text":
|
27 |
+
interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
|
28 |
+
elif input_type == "audio":
|
29 |
+
# Load audio file
|
30 |
+
waveform, sample_rate = torchaudio.load(input_content_audio)
|
31 |
+
interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
|
32 |
+
else:
|
33 |
+
raise ValueError("Invalid input type")
|
34 |
+
|
35 |
+
outputs = spirit_lm.generate(
|
36 |
+
interleaved_inputs=interleaved_inputs,
|
37 |
+
output_modality=OutputModality[output_modality.upper()],
|
38 |
+
generation_config=generation_config,
|
39 |
+
)
|
40 |
+
|
41 |
+
text_output = ""
|
42 |
+
audio_output = None
|
43 |
+
|
44 |
+
for output in outputs:
|
45 |
+
if output.content_type == ContentType.TEXT:
|
46 |
+
text_output = output.content
|
47 |
+
elif output.content_type == ContentType.SPEECH:
|
48 |
+
# Ensure output.content is a NumPy array
|
49 |
+
if isinstance(output.content, np.ndarray):
|
50 |
+
# Debugging: Print shape and dtype of the audio data
|
51 |
+
print("Audio data shape:", output.content.shape)
|
52 |
+
print("Audio data dtype:", output.content.dtype)
|
53 |
+
|
54 |
+
# Ensure the audio data is in the correct format
|
55 |
+
if len(output.content.shape) == 1:
|
56 |
+
# Mono audio data
|
57 |
+
audio_data = torch.from_numpy(output.content).unsqueeze(0)
|
58 |
+
else:
|
59 |
+
# Stereo audio data
|
60 |
+
audio_data = torch.from_numpy(output.content)
|
61 |
+
|
62 |
+
# Save the audio content to a temporary file
|
63 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
|
64 |
+
torchaudio.save(temp_audio_file.name, audio_data, 16000)
|
65 |
+
audio_output = temp_audio_file.name
|
66 |
+
else:
|
67 |
+
raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))
|
68 |
+
|
69 |
+
return text_output, audio_output
|
70 |
+
|
71 |
+
# Define the Gradio interface
|
72 |
+
iface = gr.Interface(
|
73 |
+
fn=generate_output,
|
74 |
+
inputs=[
|
75 |
+
gr.Radio(["text", "audio"], label="Input Type"),
|
76 |
+
gr.Textbox(label="Input Content (Text)"),
|
77 |
+
gr.Audio(label="Input Content (Audio)", type="filepath"),
|
78 |
+
gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality"),
|
79 |
+
gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
|
80 |
+
gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
|
81 |
+
gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
|
82 |
+
gr.Checkbox(value=True, label="Do Sample"),
|
83 |
+
],
|
84 |
+
outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
|
85 |
+
title="Spirit LM WebUI Demo",
|
86 |
+
description="Demo for generating text or audio using the Spirit LM model.",
|
87 |
+
)
|
88 |
+
|
89 |
+
# Launch the interface
|
90 |
+
iface.launch()
|
91 |
+
|
92 |
+
```
|
93 |
+
|
94 |
+
|
95 |
# Spirit LM Checkpoints
|
96 |
|
97 |
## Download Checkpoints
|