Spaces:
Running
Running
File size: 7,341 Bytes
14e7fb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import gradio as gr
import torch
import torchaudio
from transformers import AutoTokenizer, AutoModelForCausalLM
from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal
import bitsandbytes as bnb # Import bitsandbytes for INT8 quantization
import numpy as np
from uuid import uuid4
# Load the necessary models and tokenizers
model_path = "Vikhrmodels/llama_asr_tts_24000"
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".")
# Специальные токены
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"
# Константы
n_codebooks = 3
max_seq_length = 1024
top_k = 20
from safetensors.torch import load_file
def convert_to_16_bit_wav(data):
# Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
# breakpoint()
if data.dtype == np.float32:
# warnings.warn(
# "Audio data is not in 16-bit integer format."
# "Trying to convert to 16-bit int format."
# )
data = data / np.abs(data).max()
data = data * 32767
data = data.astype(np.int16)
elif data.dtype == np.int32:
# warnings.warn(
# "Audio data is not in 16-bit integer format."
# "Trying to convert to 16-bit int format."
# )
data = data / 65538
data = data.astype(np.int16)
elif data.dtype == np.int16:
pass
elif data.dtype == np.uint8:
# warnings.warn(
# "Audio data is not in 16-bit integer format."
# "Trying to convert to 16-bit int format."
# )
data = data * 257 - 32768
data = data.astype(np.int16)
else:
raise ValueError("Audio data cannot be converted to " "16-bit int format.")
return data
# Load the model with INT8 quantization
model = AutoModelForCausalLM.from_pretrained(
model_path,
cache_dir=".",
load_in_8bit=True, # Enable loading in INT8
device_map="auto" # Automatically map model to available devices
)
# Configurations for Speech Tokenizer
config_path = "audiotokenizer/speechtokenizer_hubert_avg_config.json"
ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
quantizer.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Перемещение всех слоев квантизатора на устройство и их заморозка
def freeze_entire_model(model):
for n, p in model.named_parameters():
p.requires_grad = False
return model
for n, child in quantizer.named_children():
child.to(device)
child = freeze_entire_model(child)
# Функция для создания токенов заполнения для аудио
def get_audio_padding_tokens(quantizer):
audio = torch.zeros((1, 1, 1)).to(device)
codes = quantizer.encode(audio)
del audio
torch.cuda.empty_cache()
return {"audio_tokens": codes.squeeze(1)}
# Функция для декодирования аудио из токенов
def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
start = start[0, -1] + 1 if len(start) else 0
end = end[0, -1] if len(end) else tokens.shape[-1]
audio_tokens = tokens[start:end] % n_original_tokens
reminder = audio_tokens.shape[-1] % n_codebooks
if reminder:
audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)
transposed = audio_tokens.view(-1, n_codebooks).t()
codes = transposed.view(n_codebooks, 1, -1).to(device)
audio = quantizer.decode(codes).squeeze(0)
torch.cuda.empty_cache()
xp = str(uuid4())+'.wav'
AudioSignal(audio.detach().cpu().numpy(),quantizer.sample_rate).write(xp)
return xp
# Пример использования
# Функция инференса для текста на входе и аудио на выходе
def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
text_tokenized = tokenizer(text, return_tensors="pt")
text_input_tokens = text_tokenized["input_ids"].to(device)
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
text_tokens = torch.cat([text_input_tokens, soa], dim=1)
attention_mask = torch.ones(text_tokens.size(), device=device)
output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device)
audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024)
return audio_signal
# Функция инференса для аудио на входе и текста на выходе
def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
audio_data, sample_rate = torchaudio.load(audio_path)
audio = audio_data.view(1, 1, -1).float().to(device)
codes = quantizer.encode(audio)
n_codebooks_a = 1
raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024
soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
attention_mask = torch.ones(audio_tokens.size(), device=device)
output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
output_text_tokens = output_text_tokens.cpu()[0]
output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
return decoded_text
# Functions for inference
def infer_text_to_audio_gr(text):
audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
return audio_signal
def infer_audio_to_text_gr(audio_path):
generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer)
return generated_text
# Gradio Interface
text_to_audio_interface = gr.Interface(
fn=infer_text_to_audio_gr,
inputs=gr.Textbox(label="Input Text"),
outputs=gr.Audio(label="Аудио Ответ"),
title="T2S",
description="Модель в режиме ответа в аудио",
allow_flagging='never',
)
audio_to_text_interface = gr.Interface(
fn=infer_audio_to_text_gr,
inputs=gr.Audio(type="filepath", label="Input Audio"),
outputs=gr.Textbox(label="Текстовый ответ"),
title="S2T",
description="Модель в режиме ответа в тексте",
allow_flagging='never'
)
# Launch Gradio App
demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Текст - Аудио", "Аудио - Текст"])
demo.launch(share=True) |