salt / main.py
AlexWortega's picture
Update main.py
bff356a verified
import gradio as gr
import torch
import torchaudio
from transformers import AutoTokenizer, AutoModelForCausalLM
from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal
import bitsandbytes as bnb # Import bitsandbytes for INT8 quantization
import numpy as np
from uuid import uuid4
# Load the necessary models and tokenizers
model_path = "Vikhrmodels/salt-116k"
tokenizer = AutoTokenizer.from_pretrained(model_path)
print(tokenizer)
# Специальные токены
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"
# Константы
n_codebooks = 3
max_seq_length = 1024
top_k = 20
from safetensors.torch import load_file
def convert_to_16_bit_wav(data):
if data.dtype == np.float32:
data = data / np.abs(data).max()
data = data * 32767
data = data.astype(np.int16)
elif data.dtype == np.int32:
data = data / 65538
data = data.astype(np.int16)
elif data.dtype == np.int16:
pass
elif data.dtype == np.uint8:
data = data * 257 - 32768
data = data.astype(np.int16)
else:
raise ValueError("Audio data cannot be converted to 16-bit int format.")
return data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model with INT8 quantization
model = AutoModelForCausalLM.from_pretrained(
model_path,
cache_dir=".",
load_in_8bit=False, # Enable loading in INT8
device_map="auto" # Automatically map model to available devices
)
# Configurations for Speech Tokenizer
config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json"
ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
quantizer.eval()
# Freeze layers in the quantizer
def freeze_entire_model(model):
for n, p in model.named_parameters():
p.requires_grad = False
return model
for n, child in quantizer.named_children():
child.to(device)
child = freeze_entire_model(child)
# Create padding tokens for audio
def get_audio_padding_tokens(quantizer):
audio = torch.zeros((1, 1, 1)).to(device)
codes = quantizer.encode(audio)
del audio
torch.cuda.empty_cache()
return {"audio_tokens": codes.squeeze(1)}
# Decode audio from tokens
def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
start = start[0, -1] + 1 if len(start) else 0
end = end[0, -1] if len(end) else tokens.shape[-1]
audio_tokens = tokens[start:end] % n_original_tokens
reminder = audio_tokens.shape[-1] % n_codebooks
if reminder:
audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)
transposed = audio_tokens.view(-1, n_codebooks).t()
codes = transposed.view(n_codebooks, 1, -1).to(device)
audio = quantizer.decode(codes).squeeze(0)
torch.cuda.empty_cache()
xp = str(uuid4())+'.wav'
AudioSignal(audio.detach().cpu().numpy(),quantizer.sample_rate).write(xp)
return xp
# Inference functions
def infer_text_to_audio(text):
max_seq_length=1024
top_k=20
print(type(tokenizer))
print(text)
text_tokenized = tokenizer(str(text), return_tensors="pt")
text_input_tokens = text_tokenized["input_ids"].to(device)
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
text_tokens = torch.cat([text_input_tokens, soa], dim=1)
attention_mask = torch.ones(text_tokens.size(), device=device)
output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device)
audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024)
return audio_signal
def infer_audio_to_text(audio_path):
max_seq_length=1024
top_k=20
audio_data, sample_rate = torchaudio.load(audio_path)
audio = audio_data.view(1, 1, -1).float().to(device)
codes = quantizer.encode(audio)
n_codebooks_a = 1
raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
attention_mask = torch.ones(audio_tokens.size(), device=device)
output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
output_text_tokens = output_text_tokens.cpu()[0]
output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
return decoded_text
# Functions for Gradio Interface
def infer_text_to_audio_gr(text):
audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
return audio_signal
def infer_audio_to_text_gr(audio_path):
generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer)
return generated_text
# Gradio Interface
text_to_audio_interface = gr.Interface(
fn=infer_text_to_audio_gr,
inputs=gr.Textbox(label="Input Text"),
outputs=gr.Audio(label="Audio Answer"),
title="T2S",
description="Model in text to audio mode",
allow_flagging='never',
)
audio_to_text_interface = gr.Interface(
fn=infer_audio_to_text_gr,
inputs=gr.Audio(type="filepath", label="Input Audio"),
outputs=gr.Textbox(label="Text Answer"),
title="S2T",
description="Model in audio to text mode",
allow_flagging='never'
)
# Gradio Demo
#demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Text - Audio", "Audio - Text"])
# Custom CSS for centered links
custom_css = """
<style>
.center {
text-align: center;
}
</style>
"""
# Add Gradio description with centered links
description = f"""
# **Salt: Speech And Language Transformer**
Welcome to the demo of **Salt**, a speech and language model. Vikhr Salt is capable of both **Text-to-Speech (T2S)** and **Speech-to-Text (S2T)** tasks, making it a versatile tool for transforming language into speech and vice versa. Built on a pre-trained large language model, Vikhr Salt incorporates audio tokens using cutting-edge techniques like **Encodec** and **SpeechTokenizer**, enabling robust performance across multiple modalities.
## **🛠 Features**
- **Text-to-Speech (T2S)**: Enter text and generate high-quality audio outputs.
- **Speech-to-Text (S2T)**: Upload an audio file and convert it into accurate text.
## **🚀 Try it out:**
Explore the tabs to try the **Text - Audio** and **Audio - Text** modes!
### **📄 Preprint**
[Read the paper](https://docs.google.com/document/d/1ZvV47W4BCyZM_JfDC1BKj-0ozwPck5t2yNB8jORVshI/edit?usp=sharing)
### **📂 Code**
[Explore the code](https://github.com/VikhrModels/Vikhr4o)
"""
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Text - Audio"):
gr.Markdown("### Text-to-Speech (T2S) Mode")
input_text = gr.Textbox(label="Input Text")
output_audio = gr.Audio(label="Audio Answer")
generate_button = gr.Button("Generate")
generate_button.click(infer_text_to_audio, inputs=input_text, outputs=output_audio)
with gr.TabItem("Audio - Text"):
gr.Markdown("### Speech-to-Text (S2T) Mode")
input_audio = gr.Audio(type="filepath", label="Input Audio")
output_text = gr.Textbox(label="Text Answer")
generate_button = gr.Button("Generate")
generate_button.click(infer_audio_to_text, inputs=input_audio, outputs=output_text)
# Launch the demo
demo.launch(share=True)