adamo1139 commited on
Commit
a2a9bf4
1 Parent(s): 4e5accf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -0
README.md CHANGED
@@ -1,3 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Spirit LM Checkpoints
2
 
3
  ## Download Checkpoints
 
1
+ # Spirit LM Inference Gradio Demo
2
+
3
+ Copy the github repo, build the spiritlm python package and put models in `checkpoints` folder before running the script. I would suggest to use conda environment for this.
4
+
5
+ ```python
6
+ import gradio as gr
7
+ from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
8
+ from transformers import GenerationConfig
9
+ import torchaudio
10
+ import torch
11
+ import tempfile
12
+ import os
13
+ import numpy as np
14
+
15
+ # Initialize the Spirit LM model with the modified class
16
+ spirit_lm = Spiritlm("spirit-lm-base-7b")
17
+
18
+ def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample):
19
+ generation_config = GenerationConfig(
20
+ temperature=temperature,
21
+ top_p=top_p,
22
+ max_new_tokens=max_new_tokens,
23
+ do_sample=do_sample,
24
+ )
25
+
26
+ if input_type == "text":
27
+ interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
28
+ elif input_type == "audio":
29
+ # Load audio file
30
+ waveform, sample_rate = torchaudio.load(input_content_audio)
31
+ interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
32
+ else:
33
+ raise ValueError("Invalid input type")
34
+
35
+ outputs = spirit_lm.generate(
36
+ interleaved_inputs=interleaved_inputs,
37
+ output_modality=OutputModality[output_modality.upper()],
38
+ generation_config=generation_config,
39
+ )
40
+
41
+ text_output = ""
42
+ audio_output = None
43
+
44
+ for output in outputs:
45
+ if output.content_type == ContentType.TEXT:
46
+ text_output = output.content
47
+ elif output.content_type == ContentType.SPEECH:
48
+ # Ensure output.content is a NumPy array
49
+ if isinstance(output.content, np.ndarray):
50
+ # Debugging: Print shape and dtype of the audio data
51
+ print("Audio data shape:", output.content.shape)
52
+ print("Audio data dtype:", output.content.dtype)
53
+
54
+ # Ensure the audio data is in the correct format
55
+ if len(output.content.shape) == 1:
56
+ # Mono audio data
57
+ audio_data = torch.from_numpy(output.content).unsqueeze(0)
58
+ else:
59
+ # Stereo audio data
60
+ audio_data = torch.from_numpy(output.content)
61
+
62
+ # Save the audio content to a temporary file
63
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
64
+ torchaudio.save(temp_audio_file.name, audio_data, 16000)
65
+ audio_output = temp_audio_file.name
66
+ else:
67
+ raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))
68
+
69
+ return text_output, audio_output
70
+
71
+ # Define the Gradio interface
72
+ iface = gr.Interface(
73
+ fn=generate_output,
74
+ inputs=[
75
+ gr.Radio(["text", "audio"], label="Input Type"),
76
+ gr.Textbox(label="Input Content (Text)"),
77
+ gr.Audio(label="Input Content (Audio)", type="filepath"),
78
+ gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality"),
79
+ gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
80
+ gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
81
+ gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
82
+ gr.Checkbox(value=True, label="Do Sample"),
83
+ ],
84
+ outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
85
+ title="Spirit LM WebUI Demo",
86
+ description="Demo for generating text or audio using the Spirit LM model.",
87
+ )
88
+
89
+ # Launch the interface
90
+ iface.launch()
91
+
92
+ ```
93
+
94
+
95
  # Spirit LM Checkpoints
96
 
97
  ## Download Checkpoints