diff --git "a/modeling_minicpmo.py" "b/modeling_minicpmo.py" --- "a/modeling_minicpmo.py" +++ "b/modeling_minicpmo.py" @@ -1,91 +1,40 @@ -# coding=utf-8 -# Copyright 2025 The OpenBMB Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from typing import Dict, List, Optional, Tuple, Union, Literal +from dataclasses import dataclass import json -import logging import math -import os -import types -from collections.abc import Iterator -from copy import deepcopy -from dataclasses import dataclass -from threading import Thread -from typing import List -from typing import Literal -from typing import Optional -from typing import Tuple -from typing import Union +import logging import numpy as np +from tqdm import tqdm +from threading import Thread +from PIL import Image import soundfile as sf +from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.utils.parametrize as P -from huggingface_hub import hf_hub_download -from PIL import Image from torch.nn.utils.parametrizations import weight_norm -from tqdm import tqdm -from transformers import AutoProcessor -from transformers import BertTokenizerFast -from transformers import LlamaConfig -from transformers import LlamaModel -from transformers import LogitsWarper -from transformers import PreTrainedModel -from transformers import Qwen2ForCausalLM -from transformers import Qwen2PreTrainedModel -from transformers import TextIteratorStreamer -from transformers import TopKLogitsWarper -from transformers import TopPLogitsWarper -from transformers.cache_utils import Cache -from transformers.cache_utils import DynamicCache -from transformers.cache_utils import EncoderDecoderCache -from transformers.cache_utils import StaticCache -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.modeling_outputs import ModelOutput -from transformers.models.whisper.modeling_whisper import ACT2FN -from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES -from transformers.models.whisper.modeling_whisper import WhisperConfig -from transformers.models.whisper.modeling_whisper import WhisperEncoder - -try: - from vector_quantize_pytorch import GroupedResidualFSQ - from vocos import Vocos - from vocos.pretrained import instantiate_class - - _tts_deps = True -except: - _tts_deps = False - -from .configuration_minicpm import ConditionalChatTTSConfig -from .configuration_minicpm import MiniCPMOConfig +from torch.nn.utils.rnn import pad_sequence +from vector_quantize_pytorch import GroupedResidualFSQ +from vocos import Vocos +from vocos.pretrained import instantiate_class + +from transformers import AutoProcessor, TextIteratorStreamer, PreTrainedModel, LogitsWarper, BertTokenizerFast, \ + TopPLogitsWarper, TopKLogitsWarper, Qwen2PreTrainedModel, Qwen2ForCausalLM +from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPast +from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperConfig, WHISPER_ATTENTION_CLASSES, ACT2FN +from transformers.cache_utils import EncoderDecoderCache, DynamicCache +from transformers import LlamaConfig, LlamaModel + +from .configuration_minicpm import MiniCPMOConfig, ConditionalChatTTSConfig from .modeling_navit_siglip import SiglipVisionTransformer from .resampler import Resampler -from .utils import NumberToTextConverter -from .utils import sentence_end -from .utils import VoiceChecker - -logger = logging.getLogger(__name__) -@dataclass -class OmniOutput(ModelOutput): - text: Optional[Union[str, List[str], Iterator]] = None - spk_embeds: Optional[torch.FloatTensor] = None - audio_wav: Optional[np.ndarray] = None - sampling_rate: Optional[int] = None +logger = logging.getLogger(__name__) +padding_logged = False class MiniCPMOPreTrainedModel(Qwen2PreTrainedModel): @@ -96,109 +45,73 @@ class MiniCPMO(MiniCPMOPreTrainedModel): def __init__(self, config): super().__init__(config) self.llm = Qwen2ForCausalLM(config) - self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) # patch llm - + self.vpm = self.init_vision_module() + self.apm = self.init_audio_module() + self.tts = self.init_tts_module() + self.vision_dim = self.vpm.embed_dim self.embed_dim = self.llm.config.hidden_size + self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) + + audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) + embed_dim = self.llm.config.hidden_size - # init vision module - if self.config.init_vision: - self.vpm = self.init_vision_module() - self.vision_dim = self.vpm.embed_dim - self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) - - # init audio module - if self.config.init_audio: - self.apm = self.init_audio_module() - audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) - self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step) - self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim) - self.audio_encoder_layer = -1 - - # init tts module - if self.config.init_tts: - assert _tts_deps, "please make sure vector_quantize_pytorch and vocos are installed." - self.tts = self.init_tts_module() + self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step) + self.audio_projection_layer = MultiModalProjector( + in_dim=audio_output_dim, + out_dim=embed_dim + ) + self.audio_encoder_layer = -1 self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) - self.terminators = ["<|im_end|>", "<|endoftext|>"] + self.terminators = ['<|im_end|>', ''] self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" - self.force_no_stop = False - - # for stream api - self.reset_session() - - def reset_session(self): - self.session_id = None - self.new_user_msg = True - self.llm_generated = False - self.llm_generate_completed = False - self.llm_past_key_values = None - self.audio_past_key_values = None # apm kv cache - def init_tts( - self, - tts_text_tokenizer_path=None, - vocos_ckpt_path=None, - ): - """ - load tts tokenizer and vocos - 1. try load form local 2. try load from huggingface - """ + # todo: merge in omni processor + tts_text_tokenizer = BertTokenizerFast.from_pretrained("/mnt/data/user/tc_agi/xubokai/ChatTTS/asset/tokenizer") from .processing_minicpmo import ChatTTSProcessor - - if tts_text_tokenizer_path is None: - tts_text_tokenizer_path = os.path.join(self.config._name_or_path, "assets/chattts_tokenizer") - if not os.path.exists(tts_text_tokenizer_path): - # try from hf model_id - tts_text_tokenizer_path = "openbmb/chattts_tokenizer" - - tts_text_tokenizer = BertTokenizerFast.from_pretrained(tts_text_tokenizer_path) self.tts_processor = ChatTTSProcessor(text_tokenizer=tts_text_tokenizer) - if vocos_ckpt_path is None: - vocos_ckpt_path = os.path.join(self.config._name_or_path, "assets/Vocos.pt") - if not os.path.exists(vocos_ckpt_path): - vocos_ckpt_path = hf_hub_download(repo_id="openbmb/MiniCPM-o-2_6", subfolder="assets", filename="Vocos.pt") + # todo: merge to omni model + self.vocos = None - assert os.path.exists(vocos_ckpt_path) - self.vocos = self.initialize_vocos(vocos_ckpt_path) + self.streaming_text_chunk_size = 11 + self.force_no_stop=False + self._generate = self.generate - def initialize_vocos(self, ckpt_path): + def initialize_vocos(self): feature_extractor = instantiate_class( - args=(), - init={ - "class_path": "vocos.feature_extractors.MelSpectrogramFeatures", - "init_args": {"sample_rate": 24000, "n_fft": 1024, "hop_length": 256, "n_mels": 100}, - }, + args=(), init={'class_path': 'vocos.feature_extractors.MelSpectrogramFeatures', + 'init_args': {'sample_rate': 24000, 'n_fft': 1024, 'hop_length': 256, 'n_mels': 100}} ) backbone = instantiate_class( - args=(), - init={ - "class_path": "vocos.models.VocosBackbone", - "init_args": {"input_channels": 100, "dim": 512, "intermediate_dim": 1536, "num_layers": 8}, - }, + args=(), init={'class_path': 'vocos.models.VocosBackbone', + 'init_args': {'input_channels': 100, 'dim': 512, 'intermediate_dim': 1536, + 'num_layers': 8}} ) head = instantiate_class( - args=(), - init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}}, + args=(), init={'class_path': 'vocos.heads.ISTFTHead', + 'init_args': {'dim': 512, 'n_fft': 1024, 'hop_length': 256}} ) vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32) - vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True)) + vocos.load_state_dict( + torch.load('/mnt/data/user/tc_agi/xubokai/ChatTTS/asset/Vocos.pt', weights_only=True, mmap=True)) return vocos def init_vision_module(self): - if self.config._attn_implementation == "flash_attention_2": - self.config.vision_config._attn_implementation = "flash_attention_2" + # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes + if self.config._attn_implementation == 'flash_attention_2': + self.config.vision_config._attn_implementation = 'flash_attention_2' else: - self.config.vision_config._attn_implementation = "eager" + # not suport sdpa + self.config.vision_config._attn_implementation = 'eager' model = SiglipVisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] - setattr(model, "embed_dim", model.embeddings.embed_dim) - setattr(model, "patch_size", model.embeddings.patch_size) + setattr(model, 'embed_dim', model.embeddings.embed_dim) + setattr(model, 'patch_size', model.embeddings.patch_size) return model @@ -208,13 +121,15 @@ class MiniCPMO(MiniCPMOPreTrainedModel): embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, - adaptive=True, + adaptive=True ) + def init_audio_module(self): model = MiniCPMWhisperEncoder(self.config.audio_config) return model + def init_tts_module(self): model = ConditionalChatTTS(self.config.tts_config) return model @@ -238,12 +153,12 @@ class MiniCPMO(MiniCPMOPreTrainedModel): return self.llm def subsequent_chunk_mask( - self, - size: int, - chunk_size: int, - num_left_chunks: int = -1, - device: torch.device = torch.device("cpu"), - num_lookhead: int = 0, + self, + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), + num_lookhead: int = 0 ) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder @@ -281,30 +196,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel): Computes the output length of the convolutional layers and the output length of the audio encoder """ input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 - input_lengths_after_pooling = ( - input_lengths_after_cnn - self.config.audio_pool_step - ) // self.config.audio_pool_step + 1 + input_lengths_after_pooling = (input_lengths_after_cnn - self.config.audio_pool_step) // self.config.audio_pool_step + 1 input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) return input_lengths_after_cnn, input_lengths_after_pooling def get_vllm_embedding(self, data): - """ - Compute all visual embeddings, and set into llm embeddings. - Args: - data: Dict - tgt_sizes: image size after patch embedding - pixel_values: image features - image_bound: position of each picture corresponding to input_ids - input_ids: full input_ids, include placeholder - Returns: - embedding with vision, vision_hidden_states - """ - if "vision_hidden_states" not in data: + if 'vision_hidden_states' not in data: dtype = self.llm.model.embed_tokens.weight.dtype device = self.llm.model.embed_tokens.weight.device - tgt_sizes = data["tgt_sizes"] - pixel_values_list = data["pixel_values"] + tgt_sizes = data['tgt_sizes'] + pixel_values_list = data['pixel_values'] vision_hidden_states = [] all_pixel_values = [] img_cnt = [] @@ -319,15 +221,14 @@ class MiniCPMO(MiniCPMOPreTrainedModel): max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence( - all_pixel_values, batch_first=True, padding_value=0.0 - ) + all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, + padding_value=0.0) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) for i in range(B): - patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True + patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_batch_size = self.config.vision_batch_size all_pixel_values = all_pixel_values.type(dtype) @@ -336,33 +237,28 @@ class MiniCPMO(MiniCPMOPreTrainedModel): for i in range(0, B, vision_batch_size): start_idx = i end_idx = i + vision_batch_size - tmp_hs = self.vpm( - all_pixel_values[start_idx:end_idx], - patch_attention_mask=patch_attn_mask[start_idx:end_idx], - tgt_sizes=tgt_sizes[start_idx:end_idx], - ).last_hidden_state + tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state hs.append(tmp_hs) vision_embedding = torch.cat(hs, dim=0) else: - vision_embedding = self.vpm( - all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes - ).last_hidden_state + vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) start = 0 for pixel_values in pixel_values_list: img_cnt = len(pixel_values) if img_cnt > 0: - vision_hidden_states.append(vision_embedding[start : start + img_cnt]) + vision_hidden_states.append(vision_embedding[start: start + img_cnt]) start += img_cnt else: vision_hidden_states.append([]) - else: # no image + else: # no image if self.training: - dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype) - tgt_sizes = torch.Tensor( - [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]] - ).type(torch.int32) + dummy_image = torch.zeros( + (1, 3, 224, 224), + device=device, dtype=dtype + ) + tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32) dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) else: dummy_feature = [] @@ -370,126 +266,48 @@ class MiniCPMO(MiniCPMOPreTrainedModel): vision_hidden_states.append(dummy_feature) else: - vision_hidden_states = data["vision_hidden_states"] + vision_hidden_states = data['vision_hidden_states'] - if hasattr(self.llm.config, "scale_emb"): - vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb + if hasattr(self.llm.config, 'scale_emb'): + vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb else: - vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) + vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) - vision_hidden_states = [ - i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states - ] + vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance( + i, torch.Tensor) else i for i in vision_hidden_states] - bs = len(data["input_ids"]) + bs = len(data['input_ids']) for i in range(bs): cur_vs_hs = vision_hidden_states[i] if len(cur_vs_hs) > 0: cur_vllm_emb = vllm_embedding[i] - cur_image_bound = data["image_bound"][i] + cur_image_bound = data['image_bound'][i] if len(cur_image_bound) > 0: image_indices = torch.stack( [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] ).to(vllm_embedding.device) - cur_vllm_emb.scatter_( - 0, - image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), - cur_vs_hs.view(-1, cur_vs_hs.shape[-1]), - ) + cur_vllm_emb = cur_vllm_emb.scatter(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), + cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) + # cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), + # cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) elif self.training: cur_vllm_emb += cur_vs_hs[0].mean() * 0 return vllm_embedding, vision_hidden_states - def get_audio_embedding_streaming(self, data): - r""" - Extract audio embeddings in a streaming manner using cached key-value pairs. - - This method processes incoming audio features incrementally and stores/updates `past_key_values` - for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended - for streaming scenarios. + def get_audio_embedding(self, data, chunk_length=-1, dummy=True): + dtype = self.apm.embed_positions.weight.dtype + device = self.apm.embed_positions.weight.device - Args: - data (dict): - - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. - - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. - - Returns: - List[List[torch.Tensor]]: audio embeddings - """ - wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance - audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]] + wavforms = data.get('audio_features', []) # (bs, 80, frames) or [], 多条数据多个音频是提前 padding 好 + audio_feature_lens_raw = data.get('audio_feature_lens', []) # list, [[x1, x2], [y1], [z1]] # exist audio if len(wavforms) > 0: audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavforms.shape - assert batch_size == 1 - max_seq_len = (max_mel_seq_len - 1) // 2 + 1 - - if self.audio_past_key_values is not None: - cache_length = self.audio_past_key_values[0][0].shape[2] - apm_max_len = self.apm.embed_positions.weight.shape[0] - if cache_length + max_seq_len >= apm_max_len: - logger.warning( - f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset." - ) - self.audio_past_key_values = None - - audio_outputs = self.apm(wavforms, past_key_values=self.audio_past_key_values, use_cache=True) - audio_states = audio_outputs.last_hidden_state # [:, :audio_feat_lengths, :] - self.audio_past_key_values = audio_outputs.past_key_values - - audio_embeds = self.audio_projection_layer(audio_states) - - audio_embeds = audio_embeds.transpose(1, 2) - audio_embeds = self.audio_avg_pooler(audio_embeds) - audio_embeds = audio_embeds.transpose(1, 2) - - _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) - - num_audio_tokens = feature_lens_after_pooling - - final_audio_embeds = [] - idx = 0 - for i in range(len(audio_feature_lens_raw)): - target_audio_embeds = [] - for _ in range(len(audio_feature_lens_raw[i])): - target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) - idx += 1 - final_audio_embeds.append(target_audio_embeds) - return final_audio_embeds - else: - return [] - - def get_audio_embedding(self, data, chunk_length=-1): - r""" - Extract full audio embeddings with optional chunk-based attention. - - This method computes embeddings for all audio frames at once, either using full attention (when - `chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does - not use key-value caching and is suitable for non-streaming inference. - - Args: - data (dict): - - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. - - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. - chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based - attention (>0) during embedding computation. - - Returns: - List[List[torch.Tensor]]: audio embeddings - """ - - wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance - audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]] - - # exist audio - if len(wavforms) > 0: - audio_feature_lens = torch.hstack(audio_feature_lens_raw) - batch_size, _, max_mel_seq_len = wavforms.shape - max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # 原本代码是(max_mel_seq_len - 2) // 2 + 1 如果输入长度是奇数的话就会差1 # Create a sequence tensor of shape (batch_size, max_seq_len) seq_range = ( @@ -505,7 +323,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel): batch_size, 1, max_seq_len, max_seq_len ) audio_attention_mask = audio_attention_mask_.to( - dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device + dtype=self.apm.conv1.weight.dtype, + device=self.apm.conv1.weight.device ) if chunk_length > 0: @@ -514,14 +333,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel): size=max_seq_len, chunk_size=chunk_num_frame, num_left_chunks=-1, - device=audio_attention_mask_.device, + device=audio_attention_mask_.device ) audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) audio_attention_mask[audio_attention_mask_] = float("-inf") audio_states = self.apm( - wavforms, output_hidden_states=True, attention_mask=audio_attention_mask - ).hidden_states[self.audio_encoder_layer] + wavforms, + output_hidden_states=True, + attention_mask=audio_attention_mask).hidden_states[self.audio_encoder_layer] + audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) @@ -537,45 +358,42 @@ class MiniCPMO(MiniCPMOPreTrainedModel): for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): - target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) + target_audio_embeds.append(audio_embeds[idx, :num_audio_tokens[idx], :]) idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds + elif self.training and dummy: + dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype) + audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer] + + audio_embeds = self.audio_projection_layer(audio_states) + + audio_embeds = audio_embeds.transpose(1, 2) + audio_embeds = self.audio_avg_pooler(audio_embeds) + audio_embeds = audio_embeds.transpose(1, 2) + return [audio_embeds] + else: return [] - def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False): - """ - Args: - data: - input_embeddings: - chunk_length: whisper use full attention or chunk attention - stream_input: use streaming audio embedding - Returns: - final embeddings with audio feature - """ - if stream_input: - audio_embeddings = self.get_audio_embedding_streaming(data) - else: - audio_embeddings = self.get_audio_embedding(data, chunk_length) + def get_omni_embedding(self, data, input_embeddings, chunk_length=-1): + audio_embeddings = self.get_audio_embedding(data, chunk_length) bs = len(input_embeddings) - if len(data.get("audio_features", [])) > 0: + if len(data.get('audio_features', [])) > 0: assert len(audio_embeddings) == len(input_embeddings) if len(audio_embeddings) > 0: - audio_bounds = data["audio_bounds"] + audio_bounds = data['audio_bounds'] - if self.config.chunk_input: + if self.config.stream_input: for i in range(bs): - audio_embs = torch.cat(audio_embeddings[i], dim=0).to( - device=input_embeddings.device, dtype=input_embeddings.dtype - ) + audio_embs = torch.cat(audio_embeddings[i], dim=0).to(device=input_embeddings.device, + dtype=input_embeddings.dtype) audio_start_pos = 0 for bound in audio_bounds[i]: audio_len = bound[1] - bound[0] - input_embeddings[0, bound[0] : bound[1]] = audio_embs[ - audio_start_pos : audio_start_pos + audio_len, : - ] + input_embeddings[0, bound[0]:bound[1]] = audio_embs[ + audio_start_pos:audio_start_pos + audio_len, :] audio_start_pos += audio_len else: for i in range(bs): @@ -583,10 +401,13 @@ class MiniCPMO(MiniCPMOPreTrainedModel): bounds = audio_bounds[i] for embs, bound in zip(audio_embs, bounds): audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( - input_embeddings.device - ) + input_embeddings.device) if embs.shape[0] != len(audio_indices): + print(f"Sample {i}:") + print(f" Bounds: {bound}, Indices Length: {len(audio_indices)}") + print(f" Embeddings Shape: {embs.shape}") + print(f" Input Embedding Shape at Indices: {input_embeddings[i, audio_indices].shape}") raise ValueError( f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " f"to input indices of length {len(audio_indices)}" @@ -594,30 +415,30 @@ class MiniCPMO(MiniCPMOPreTrainedModel): input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype) elif self.training: for i in range(bs): - # dummy audio_embeddings - input_embeddings += audio_embeddings[0].mean() * 0 + # dummy audio_embedings + input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0 return input_embeddings def forward(self, data, **kwargs): vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) - - if self.config.init_audio: - vllm_embedding = self.get_omni_embedding( - data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length - ) + vllm_embedding = self.get_omni_embedding(data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length) position_ids = data["position_ids"] if position_ids.dtype != torch.int64: position_ids = position_ids.long() - - # compatible with llama factory - for key in ["input_ids", "inputs_embeds", "position_ids"]: + + for key in ['input_ids', 'inputs_embeds', 'position_ids']: if key in kwargs: del kwargs[key] - return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs) - + return self.llm( + input_ids=None, + position_ids=position_ids, + inputs_embeds=vllm_embedding, + **kwargs + ) + def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] outputs = self.llm.generate( @@ -627,7 +448,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): attention_mask=attention_mask, output_hidden_states=True, return_dict_in_generate=True, - **kwargs, + **kwargs ) return outputs @@ -635,16 +456,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_kwargs = { - "inputs_embeds": inputs_embeds, - "pad_token_id": 0, - "eos_token_id": terminators, - "streamer": streamer, + 'inputs_embeds': inputs_embeds, + 'pad_token_id': 0, + 'eos_token_id': terminators, + 'streamer': streamer } generation_kwargs.update(kwargs) thread = Thread(target=self.llm.generate, kwargs=generation_kwargs) thread.start() - + return streamer def _decode_text(self, result_ids, tokenizer): @@ -656,102 +477,15 @@ class MiniCPMO(MiniCPMOPreTrainedModel): result = result[1:] if result[-1] in terminators: result = result[:-1] - result_text.append(tokenizer.decode(result)) + result_text.append(tokenizer.decode(result).strip()) return result_text - def get_sys_prompt(self, ref_audio=None, mode="default", language="zh"): - """ - Choose different system prompts according to different tasks - Args: - ref_audio: if ref_audio is not None, will use the voice cloning prompts, and the voice - generated by the model will refer to the timbre of ref audio - mode: - "default": default system prompt and not refer to any task - "omni": input video and audio simultaneously - "audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user's question as a helpful assistant. - "audio_roleplay": Roleplay voice-only mode, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt. - "voice_cloning": TTS mode, the model will clone the voice of ref_audio. - language: prompts language, the model has the ability to automatically select the response language - based on the question language - Returns: - - """ - if ref_audio is not None: - assert isinstance(ref_audio, np.ndarray), "ref_audio error" - if mode == "omni": - if language == "zh": - sys_prompt = "你是一个AI助手。你能接受视频,音频和文本输入并输出语音和文本。" - vc_prompt_prefix = sys_prompt + "模仿输入音频中的声音特征。" - vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。" - else: - sys_prompt = "You are a helpful assistant. You can accept video, audio and text input and output voice and text. " - vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt." - vc_prompt_suffix = "As an assistant, you will speak using this voice style." - - if ref_audio is not None: - sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} - - else: - sys_msgs = {"role": "user", "content": [sys_prompt]} - - return sys_msgs - elif mode == "audio_assistant": - if language == "zh": - vc_prompt_prefix = "模仿输入音频中的声音特征。" - vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。" - else: - vc_prompt_prefix = "Clone the voice in the provided audio prompt." - vc_prompt_suffix = "As an assistant, you will speak using this voice style." - - if ref_audio is not None: - sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} - - else: - logger.warning( - "Warning: ref_audio is None, speech generation will be performed based on the default voice." - ) - sys_msgs = {"role": "user", "content": ["Use the voice.", vc_prompt_suffix]} - - return sys_msgs - elif mode == "audio_roleplay": - if language == "zh": - vc_prompt_prefix = "模仿输入音频中的声音特征。" - vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。" - else: - vc_prompt_prefix = "Clone the voice in the provided audio prompt." - vc_prompt_suffix = "Try to role-play the character based on the audio prompt above." - - if ref_audio is not None: - sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} - else: - print("Warning: ref_audio is None, speech generation will be performed based on the default voice.") - sys_msgs = {"role": "user", "content": ["Use the voice.", vc_prompt_suffix]} - - return sys_msgs - elif mode == "voice_cloning": - if language == "zh": - vc_prompt_prefix = "模仿输入音频中的声音特征。" - else: - vc_prompt_prefix = "Clone the voice in the provided audio prompt." - - if ref_audio is not None: - sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio]} - else: - raise ValueError("ref_audio con't be None in voice_cloning mode.") - - return sys_msgs - else: - sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text." - sys_msgs = {"role": "user", "content": [sys_prompt]} - - return sys_msgs - def generate( self, input_ids=None, pixel_values=None, tgt_sizes=None, - audio_features=None, + audio_features=[], audio_feature_lens=None, image_bound=None, audio_bounds=None, @@ -760,7 +494,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): tokenizer=None, vision_hidden_states=None, stream=False, - **kwargs, + **kwargs ): assert input_ids is not None assert len(input_ids) == len(pixel_values) @@ -776,7 +510,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): if vision_hidden_states is None: model_inputs["pixel_values"] = pixel_values - model_inputs["tgt_sizes"] = tgt_sizes + model_inputs['tgt_sizes'] = tgt_sizes else: model_inputs["vision_hidden_states"] = vision_hidden_states @@ -784,9 +518,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): with torch.inference_mode(): model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs) model_inputs["inputs_embeds"] = self.get_omni_embedding( - model_inputs, - input_embeddings=model_inputs["inputs_embeds"], - chunk_length=self.config.audio_chunk_length, + model_inputs, input_embeddings=model_inputs["inputs_embeds"], chunk_length=self.config.audio_chunk_length ) if stream: @@ -795,69 +527,38 @@ class MiniCPMO(MiniCPMOPreTrainedModel): outputs = {} else: outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs) - result = self._decode_text(outputs.sequences, tokenizer) return result, outputs def chat( self, - image=None, - msgs=None, - tokenizer=None, + image, + msgs, + tokenizer, processor=None, vision_hidden_states=None, max_new_tokens=2048, min_new_tokens=0, sampling=True, - max_inp_length=32768, + max_inp_length=8192, stream=False, - chunk_input=True, + stream_input=True, omni_input=False, max_slice_nums=None, use_image_id=None, - use_tts_template=False, - generate_audio=False, - return_spk_embed=False, - return_dict=False, + use_tts=False, output_audio_path=None, - **kwargs, + return_spk_embed=False, + **kwargs ): - """ - Unified chat function - - Args: - image: use for batch_size=1 vqa, It is not recommended to continue to use this parameter - msgs: the input chat msgs, support text: (string) / image: (PIL.Image) / audio (numpy.ndarray) - tokenizer: tokenizer for llm - processor: if None, use the default processor - max_new_tokens: the maximum length of the generation - min_new_tokens: the minimum length of the generation - sampling: whether to use sampling decoding or beam search decoding - max_inp_length: the maximum length of input - stream: whether to return generator, only used when tts is not required - chunk_input: whether to split audio into 1s chunks - omni_input: determine whether it is omni mode - max_slice_nums: control the maximum number of image slices - use_image_id: for video understanding or omni understanding, use_image_id should be False - use_tts_template: if the msgs contain audio, use_tts_template should be True - generate_audio: whether to generate audio output, only used when return_dict=True - return_spk_embed: whether to return spk embedding, only used when return_dict=True - return_dict: whether to return dict - output_audio_path: audio save path when generate_audio - **kwargs: - """ if isinstance(msgs[0], list): batched = True else: batched = False - - if generate_audio or return_spk_embed: - return_dict = True - msgs_list = msgs images_list = image - + if batched is False: images_list, msgs_list = [images_list], [msgs_list] else: @@ -869,22 +570,12 @@ class MiniCPMO(MiniCPMOPreTrainedModel): if self.processor is None: self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) processor = self.processor - - assert ( - self.config.query_num == processor.image_processor.image_feature_size - ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert ( - self.config.patch_size == processor.image_processor.patch_size - ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert ( - self.config.use_image_id == processor.image_processor.use_image_id - ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert ( - self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums - ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." - assert ( - self.config.slice_mode == processor.image_processor.slice_mode - ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + + assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." + assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`." prompts_lists = [] input_images_list = [] @@ -917,12 +608,11 @@ class MiniCPMO(MiniCPMOPreTrainedModel): for c in content: if isinstance(c, Image.Image): images.append(c) - cur_msgs.append("(./)") - elif isinstance(c, np.ndarray): # audio + cur_msgs.append("./") + elif isinstance(c, np.ndarray): # audio audios.append(c) audio_parts.append(i) - cur_msgs.append("()") - use_tts_template = True + cur_msgs.append("") elif isinstance(c, str): cur_msgs.append(c) if omni_input: @@ -935,23 +625,24 @@ class MiniCPMO(MiniCPMOPreTrainedModel): copy_msgs, tokenize=False, add_generation_prompt=True, - chat_template=self.default_tts_chat_template if use_tts_template else None, + chat_template=self.default_tts_chat_template if use_tts else None ) ) input_images_list.append(images) input_audios_list.append(audios) audio_parts_list.append(audio_parts) + inputs = processor( - prompts_lists, + prompts_lists, input_images_list, input_audios_list, audio_parts_list, max_slice_nums=max_slice_nums, use_image_id=use_image_id, - chunk_input=chunk_input, - return_tensors="pt", - max_length=max_inp_length, + stream_input=stream_input, + return_tensors="pt", + max_length=max_inp_length ).to(self.device) if sampling: @@ -960,18 +651,20 @@ class MiniCPMO(MiniCPMOPreTrainedModel): "top_k": 100, "temperature": 0.7, "do_sample": True, - "repetition_penalty": 1.01, + "repetition_penalty": 1.05 } else: generation_config = { "num_beams": 3, "repetition_penalty": 1.2, } - + if min_new_tokens > 0: - generation_config["min_new_tokens"] = min_new_tokens + generation_config['min_new_tokens'] = min_new_tokens - generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) + generation_config.update( + (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() + ) inputs.pop("image_sizes") with torch.inference_mode(): @@ -981,290 +674,33 @@ class MiniCPMO(MiniCPMOPreTrainedModel): max_new_tokens=max_new_tokens, vision_hidden_states=vision_hidden_states, stream=stream, - **generation_config, + **generation_config ) - + if stream: - def stream_gen(): for text in res: for term in self.terminators: - text = text.replace(term, "") + text = text.replace(term, '') yield text - - if return_dict: - return OmniOutput(text=stream_gen()) - else: - return stream_gen() + return stream_gen() else: - spk_embeds = wav_numpy = sr = None - if batched: answer = res else: answer = res[0] - if use_tts_template and generate_audio: + if use_tts and output_audio_path: mel_spec = self._generate_mel_spec(inputs, outputs, answer) - wav_numpy, sr = self.decode_mel_to_audio(mel_spec, output_audio_path) + self.decode_mel_to_audio(mel_spec, output_audio_path) if return_spk_embed: spk_embeds = self._get_last_spk_embeds(inputs, outputs) - - if isinstance(answer, list): - answer = [i.replace(tokenizer.tts_end, "") for i in answer] - else: - answer = answer.replace(tokenizer.tts_end, "") - - if return_dict: - return OmniOutput(text=answer, spk_embeds=spk_embeds, audio_wav=wav_numpy, sampling_rate=sr) + return answer, spk_embeds else: return answer - @torch.inference_mode() - def streaming_prefill( - self, - session_id, - msgs, - tokenizer, - omni_input=True, - max_slice_nums=None, - ls_temperature=1.0, - **kwargs, - ): - """ - Streaming video/audio input and output audio stream, Only support batch_size=1 - Args: - session_id: Note: new connection should use a new session_id - """ - assert session_id is not None - if self.session_id is None or session_id != self.session_id: # new session - self.is_first = True - else: - self.is_first = False - - images = [] - audios = [] - - assert len(msgs) == 1 - copy_msgs = deepcopy(msgs) - msg = copy_msgs[0] - - assert msg["role"] in ["system", "user", "assistant"] - - content = msg["content"] - cur_msgs = [] - for j, c in enumerate(content): - if isinstance(c, Image.Image): - images.append(c) - cur_msgs.append("(./)") - elif isinstance(c, np.ndarray): # audio - audios.append(c) - cur_msgs.append("()") - elif isinstance(c, str): - cur_msgs.append(c) - else: - logger.error("Invalid content type:", c) - - cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input) - if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start - if self.llm_generated: - if self.llm_generate_completed: - msg["content"] = "<|im_end|>\n<|im_start|>user\n" + cur_contents - else: # break llm gen, add tts_eos - msg["content"] = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents - else: - msg["content"] = "<|im_start|>user\n" + cur_contents - self.new_user_msg = False - else: - msg["content"] = cur_contents - - if msg["role"] in ["system", "assistant"]: - self.new_user_msg = True - self.audio_past_key_values = None # apm kv cache - - if self.is_first: - # init pask_key_values - logger.info(f"new session_id: {session_id}, reset kv cache") - self.reset_session() - self.session_id = session_id - - prompt = tokenizer.apply_chat_template( - copy_msgs, tokenize=False, add_generation_prompt=False, chat_template=self.default_tts_chat_template - ) - add_special_tokens = True # add bos - else: - prompt = copy_msgs[0]["content"] - add_special_tokens = False - - model_inputs = self.processor( - [prompt], - [images], - [audios], - max_slice_nums=1 if max_slice_nums is None else max_slice_nums, - use_image_id=False, - chunk_input=True, - return_tensors="pt", - max_length=None, - sampling_rate=16000, - add_special_tokens=add_special_tokens, - ).to(self.device) - - # 1. prepare input embeddings - model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs) - # get audio embedding with audio_past_key_values - inputs_embeds = self.get_omni_embedding( - model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=True - ) - - if self.is_first: - # clean audio_past_key_values after first prefill - self.audio_past_key_values = None - - if self.llm_past_key_values is not None: - cache_length = self.llm_past_key_values[0][0].shape[2] - else: - cache_length = 0 - - attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device) - - # 2. do prefill and predict listen/speak label - outputs = self.llm( - past_key_values=self.llm_past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=None, # position_ids, - use_cache=True, - return_dict=True, - ) - self.llm_past_key_values = outputs["past_key_values"] - return - - @torch.inference_mode() - def streaming_generate( - self, - session_id, - tokenizer, - max_new_tokens=512, - min_new_tokens=0, - sampling=True, - generate_audio=True, - enable_regenerate=False, - **kwargs, - ): - """ - Streaming video/audio input and output audio stream - Args: - """ - if sampling: - generation_config = { - "top_p": 0.8, - "top_k": 100, - "temperature": 0.7, - "do_sample": True, - "repetition_penalty": 1.01, - } - else: - generation_config = { - "num_beams": 3, - "repetition_penalty": 1.2, - } - generation_config["min_new_tokens"] = min_new_tokens - generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) - - # do generate - # reset buffer - self.new_user_msg = True - self.llm_generated = True - self.llm_generate_completed = False - self.audio_past_key_values = None # apm kv cache - - terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] - generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>" - input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda() - - spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0] - spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0] - spk_bounds = [ - torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) - ] # List[Tensor], (1,2) - - cache_length = past_length = self.llm_past_key_values[0][0].shape[2] - attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device) - - generation_config["max_new_tokens"] = max_new_tokens - streamer = self.llm_generate_chunk(input_ids, attention_mask, tokenizer, terminators, generation_config) - - if generate_audio: - result = self._generate_mel_spec_audio_streaming( - spk_bounds, streamer, output_chunk_size=25, enable_regenerate=enable_regenerate - ) - return result - else: - return streamer - - def llm_generate_chunk(self, input_ids, attention_mask, tokenizer, terminators, generation_config): - def check_uncompleted_token(ids): - cur_text = tokenizer.decode(ids) - end = len(ids) - while cur_text[-1] == "�": - end -= 1 - if end == 0: - break - cur_text = tokenizer.decode(ids[:end]) - return end - - max_new_tokens = int(generation_config.pop("max_new_tokens", 2048)) - new_len = 0 - first_chunk = True - eos = False - left_ids = None - - while True: - outputs = self.llm.generate( - input_ids=input_ids, - past_key_values=self.llm_past_key_values, - attention_mask=attention_mask, - use_cache=True, - max_new_tokens=3, # reduce first token delay - pad_token_id=0, - output_hidden_states=True if first_chunk else False, - return_dict_in_generate=True, - eos_token_id=terminators, - **generation_config, - ) - if outputs.sequences[0, -1] in terminators: - eos = True - input_len = input_ids.shape[1] - cur_ids = outputs.sequences[:, input_len:] - new_len += cur_ids.shape[1] - - if left_ids is not None and left_ids.shape[1] > 0: - cur_ids = torch.cat([left_ids, cur_ids], dim=1) - end = check_uncompleted_token(cur_ids[0]) - left_ids = cur_ids[:, end:] - cur_ids = cur_ids[:, :end] - text = self._decode_text(cur_ids, tokenizer)[0] if end > 0 else "" - - self.llm_past_key_values = outputs.past_key_values - input_ids = outputs.sequences[:, -1:] - cache_length = past_length = self.llm_past_key_values[0][0].shape[2] - attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device) - - res = {"text": text} - if first_chunk: - res["hidden_states"] = outputs.hidden_states - first_chunk = False - yield res - - if eos: - self.llm_generate_completed = True - break - if new_len >= max_new_tokens: - logger.debug(f"LLM generation {new_len} exceeds max_new_tokens({max_new_tokens}), break.") - break - def prepare_tts_text(self, text): tts_tokens = self.tts_processor.text_tokenizer.encode(text, add_special_tokens=False) tts_tokens_len = len(tts_tokens) @@ -1273,7 +709,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): pad_str = "[Etts]" + "[PAD]" * (num_pad_tokens - 1) else: - tts_tokens = tts_tokens[0 : self.tts.streaming_text_reserved_len] + tts_tokens = tts_tokens[0: self.tts.streaming_text_reserved_len] tts_tokens_len = len(tts_tokens) text = self.tts_processor.text_tokenizer.decode(tts_tokens, add_special_tokens=False) pad_str = "" @@ -1282,22 +718,13 @@ class MiniCPMO(MiniCPMOPreTrainedModel): new_text_tts = f"[Stts]{spk_emb_placeholder_tts}{text}{pad_str}[Ptts]" return new_text_tts, tts_tokens_len - def get_tts_text_start_token_ids(self): - text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs - tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[ - "input_ids" - ].cuda() - return tts_input_ids - def _build_streaming_mask(self, tts_tokens_len): - tts_sequence_full_length = ( - 1 + self.tts.num_spk_embs * self.tts.use_speaker_embedding + self.tts.streaming_text_reserved_len + 1 - ) + tts_sequence_full_length = 1 + self.tts.num_spk_embs * self.tts.use_speaker_embedding + self.tts.streaming_text_reserved_len + 1 streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8) - streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1 + streaming_attention_mask[0: 1 + 1 + tts_tokens_len + 1] = 1 streaming_attention_mask[-1] = 1 return streaming_attention_mask - + def _get_last_spk_embeds(self, inputs, outputs): last_hidden_states = [hs[-1] for hs in outputs.hidden_states] @@ -1305,562 +732,130 @@ class MiniCPMO(MiniCPMOPreTrainedModel): last_hidden_states = torch.vstack([i[0] for i in last_hidden_states]) # last spk - spk_bound = inputs["spk_bounds"][0][-1] - - spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]] - return spk_embeds - - def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048): - spk_embeds = self._get_last_spk_embeds(inputs, outputs) - - text = text.split("<|tts_bos|>")[-1] - gen_text = text.split("<|tts_eos|>")[0] - tts_text, tts_token_lens = self.prepare_tts_text(gen_text) - tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) - tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long) - streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) - - logits_warpers, logits_processors = gen_logits( - num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty - ) - - condition_length = ( - 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 - ) - - dtype = self.tts.emb_text.weight.dtype - emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device) - past_key_values = [ - ( - torch.zeros( - 1, - self.tts.config.num_attention_heads, - condition_length - 1, - self.tts.config.hidden_size // self.tts.config.num_attention_heads, - dtype=emb.dtype, - device=self.tts.device, - ), - torch.zeros( - 1, - self.tts.config.num_attention_heads, - condition_length - 1, - self.tts.config.hidden_size // self.tts.config.num_attention_heads, - dtype=emb.dtype, - device=self.tts.device, - ), - ) - for _ in range(self.tts.config.num_hidden_layers) - ] - - audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) - - eos_lab = False - for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)): - if chunk_idx == 0: - begin = chunk_idx * self.tts.streaming_text_chunk_size + 0 - end = ( - (chunk_idx + 1) * self.tts.streaming_text_chunk_size - + 1 - + self.tts.use_speaker_embedding * self.tts.num_spk_embs - ) - else: - begin = ( - chunk_idx * self.tts.streaming_text_chunk_size - + 1 - + self.tts.use_speaker_embedding * self.tts.num_spk_embs - ) - end = min( - (chunk_idx + 1) * self.tts.streaming_text_chunk_size - + 1 - + self.tts.use_speaker_embedding * self.tts.num_spk_embs, - condition_length - 1, - ) - - if end - begin > 0: - text_input_ids = tts_input_ids[:, begin:end] - position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) - - if begin == 0: - past_key_values = self.tts.prefill_text( - input_ids=text_input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - lm_spk_emb_last_hidden_states=spk_embeds, - ) - else: - past_key_values = self.tts.prefill_text( - input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values - ) - - outputs = self.tts.generate( - input_ids=audio_input_ids, - past_key_values=past_key_values, - streaming_tts_text_mask=streaming_tts_text_mask, - max_new_token=output_chunk_size, - force_no_stop=self.force_no_stop, - temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), - eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), - logits_warpers=logits_warpers, - logits_processors=logits_processors, - ) - audio_input_ids = outputs.audio_input_ids - past_key_values = outputs.past_key_values - - if outputs.finished: - logger.debug("Generation finished.") - eos_lab = True - break - - if not eos_lab: - logger.debug("eos_lab False, Generation continue.") - while True: - outputs = self.tts.generate( - input_ids=audio_input_ids, - past_key_values=past_key_values, - streaming_tts_text_mask=streaming_tts_text_mask, - max_new_token=output_chunk_size, - force_no_stop=self.force_no_stop, - temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), - eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), - logits_warpers=logits_warpers, - logits_processors=logits_processors, - ) - - audio_input_ids = outputs.audio_input_ids - past_key_values = outputs.past_key_values - - if outputs.finished: - logger.debug("Generation finished.") - break - if outputs.new_ids.shape[1] > tts_max_new_tokens: - logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.") - break - - mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids) - return mel_spec - - def _linear_overlap_add2_wav(self, frames: List[torch.Tensor], overlap: int): - """ - Merge two audio waveforms with smooth in streaming audio generation. - Borrowed some codes from `https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py` - """ - assert len(frames) == 2 - device = frames[0].device - dtype = frames[0].dtype - # shape = frames[0].shape[:-1] - - frame0_length = frames[0].shape[-1] - frame1_length = frames[1].shape[-1] - total_size = frame0_length + frame1_length - overlap - weight_len = max(frame0_length, frame1_length) + overlap - t = torch.linspace(0, 1, weight_len + 2, device=device, dtype=dtype)[1:-1] - weight = 0.5 - (t - 0.5).abs() - - sum_weight = torch.zeros(total_size, device=device, dtype=dtype) - out = torch.zeros(total_size, device=device, dtype=dtype) - offset: int = 0 - - out[offset : offset + frame0_length] += weight[-frame0_length:] * frames[0] - sum_weight[offset : offset + frame0_length] += weight[-frame0_length:] - offset += frame0_length - overlap - out[offset : offset + frame1_length] += weight[:frame1_length] * frames[1] - sum_weight[offset : offset + frame1_length] += weight[:frame1_length] - - assert sum_weight.min() > 0 - out = out / sum_weight - return out[:frame0_length], out[frame0_length:] - - def _generate_mel_spec_audio_streaming( - self, - spk_bounds, - streamer, - output_chunk_size=25, - spk_embeds=None, - prev_seg_text_ids=None, - prev_seg_text_left="", - prev_seg_audio_ids=None, - enable_regenerate=False, - ): - # get spk_embedding - gen_text = "" - tts_text = "" - new_segment_gen = False - if spk_embeds is None: - spk_bound = spk_bounds[0][-1] - r = next(streamer) - txt = r["text"] - gen_text += txt.split("<|tts_eos|>")[0] - tts_text, tts_token_lens = self.prepare_tts_text(gen_text) - last_hidden_states = r["hidden_states"][0][-1][0] # output: (input_seq_len, dim) - spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]] - - # init past_key_values - logits_warpers, logits_processors = gen_logits( - num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty - ) - condition_length = ( - 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 - ) - tts_start_token_len = 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs - dtype = self.tts.emb_text.weight.dtype - past_key_values = [ - ( - torch.zeros( - 1, - self.tts.config.num_attention_heads, - condition_length - 1, - self.tts.config.hidden_size // self.tts.config.num_attention_heads, - dtype=dtype, - device=self.tts.device, - ), - torch.zeros( - 1, - self.tts.config.num_attention_heads, - condition_length - 1, - self.tts.config.hidden_size // self.tts.config.num_attention_heads, - dtype=dtype, - device=self.tts.device, - ), - ) - for _ in range(self.tts.config.num_hidden_layers) - ] - audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) - - # prefill prev segment for smooth - chunk_idx = 0 - new_ids_len = 0 - prev_text_len = 0 - if prev_seg_text_ids is not None and prev_seg_audio_ids is not None: - tts_token_lens = prev_seg_text_ids.shape[1] - # assert tts_token_lens % self.tts.streaming_text_chunk_size == 0 - streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) - position_ids = torch.arange( - 0, tts_token_lens + tts_start_token_len, dtype=torch.long, device=self.tts.device - ).unsqueeze(0) - - text_input_ids = self.get_tts_text_start_token_ids() - text_input_ids = torch.cat([text_input_ids, prev_seg_text_ids], dim=1) - past_key_values = self.tts.prefill_text( - input_ids=text_input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - lm_spk_emb_last_hidden_states=spk_embeds, - ) - past_key_values = self.tts.prefill_audio_ids( - input_ids=prev_seg_audio_ids[:, :-1, :], - # not prefill last id, which will be input_id of next generation - past_key_values=past_key_values, - streaming_tts_text_mask=streaming_tts_text_mask, - ) - - # update init - chunk_idx += int(tts_token_lens / self.tts.streaming_text_chunk_size) - audio_input_ids = torch.cat([audio_input_ids, prev_seg_audio_ids], dim=1) - text = self.tts_processor.text_tokenizer.decode(prev_seg_text_ids[0].tolist(), add_special_tokens=False) - - gen_text += text - gen_text += prev_seg_text_left - prev_text_len = len(gen_text) # takecare the position - new_ids_len += prev_seg_audio_ids.shape[1] - - prev_wav = None - eos_lab = False - stop = False - shift_len = 180 - voice_checker = VoiceChecker() - number_converter = NumberToTextConverter() - lang = None - gen_text_raw = gen_text - for t, r in enumerate(streamer): - t += 1 - txt = r["text"] - txt = txt.split("<|tts_eos|>")[0] - gen_text_raw += txt - if t == 1 and txt == "" and prev_seg_text_ids is not None: - logger.warning("New segment is empty, generation finished.") - return - if t <= 2: # do just one time, more token greater certainty - lang = number_converter.detect_language(gen_text_raw) - gen_text += number_converter.replace_numbers_with_text(txt, lang).replace("*", "") # markdown ** - - # TODO speed up - tts_text, tts_token_lens = self.prepare_tts_text(gen_text) - - if tts_token_lens >= self.tts.streaming_text_reserved_len - shift_len: - end_c = sentence_end(txt) - if end_c: - end_c_idx = gen_text.rfind(end_c) - assert end_c_idx != -1 - text_left = gen_text[end_c_idx + 1 :] - gen_text = gen_text[: end_c_idx + 1] - tts_text, tts_token_lens = self.prepare_tts_text(gen_text) - new_segment_gen = True - logger.debug( - f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, starting a new segment generation" - ) - break - - if tts_token_lens >= (chunk_idx + 1) * self.tts.streaming_text_chunk_size: - - # do prefill and generate - if chunk_idx == 0: - begin = 0 - end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len - else: - begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len - end = min( - (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len, condition_length - 1 - ) - - tts_input_ids = self.tts_processor.text_tokenizer( - tts_text, return_tensors="pt", add_special_tokens=False - )["input_ids"].cuda() - text_input_ids = tts_input_ids[:, begin:end] - streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) - position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) - - past_key_values = self.tts.prefill_text( - input_ids=text_input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None, - ) - outputs = self.tts.generate( - input_ids=audio_input_ids, - past_key_values=past_key_values, - streaming_tts_text_mask=streaming_tts_text_mask, - max_new_token=output_chunk_size, - force_no_stop=self.force_no_stop, - temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), - eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), - logits_warpers=logits_warpers, - logits_processors=logits_processors, - ) - audio_input_ids = ( - outputs.audio_input_ids - ) # [1,seq_len,4] seq_len=tts.streaming_text_reserved_len + 3 + len(new_ids) - past_key_values = outputs.past_key_values - chunk_idx += 1 - - mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :]) - new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4] - - wav_np, sr = self.decode_mel_to_audio(mel_spec) # [1,100,50] -> [50*256] - - if enable_regenerate: - if prev_wav is not None: - check_wav_np = wav_np[2048:].cpu().numpy() # 2*4*256(hop) - check_mel = mel_spec[0, :, 8:].cpu().numpy() # 2*4 - else: - check_wav_np = wav_np.cpu().numpy() - check_mel = mel_spec[0].cpu().numpy() - if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560): - voice_checker.reset() - # regenerate - N = output_chunk_size if prev_wav is None else output_chunk_size * 2 - past_kv = [] - for i in range(len(past_key_values)): - past_kv.append( - ( - past_key_values[i][0][:, :, :-N, :], # .clone(), - past_key_values[i][1][:, :, :-N, :], # .clone(), - ) - ) - outputs = self.tts.generate( - input_ids=audio_input_ids[:, :-N, :], - past_key_values=past_kv, - streaming_tts_text_mask=streaming_tts_text_mask, - max_new_token=N, - force_no_stop=self.force_no_stop, - temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), - eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), - logits_warpers=logits_warpers, - logits_processors=logits_processors, - ) - audio_input_ids = outputs.audio_input_ids - past_key_values = outputs.past_key_values + spk_bound = inputs['spk_bounds'][0][-1] - new_ids_len -= N - mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :]) - new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4] - wav_np, sr = self.decode_mel_to_audio(mel_spec) + spk_embeds = last_hidden_states[spk_bound[0]: spk_bound[1]] + return spk_embeds - if prev_wav is not None: - wav_y = wav_np[: len(prev_wav)] - prev_wav = wav_np[len(prev_wav) :] - cur_text = gen_text_raw[prev_text_len:] - prev_text_len = len(gen_text_raw) - yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr) + def _generate_mel_spec(self, inputs, outputs, text): + spk_embeds = self._get_last_spk_embeds(inputs, outputs) - else: - prev_wav = wav_np - else: - # smooth wav - if prev_wav is not None: - wav_np, prev_wav = self._linear_overlap_add2_wav( - [prev_wav, wav_np], overlap=512 * 4 - ) # tts_hop256*2 - cur_text = gen_text_raw[prev_text_len:] - prev_text_len = len(gen_text_raw) - yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr) + gen_text = text.replace('<|tts_eos|>', '') + tts_text, tts_token_lens = self.prepare_tts_text(gen_text) + tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) + tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long) + streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) - else: - prev_wav = wav_np + logits_warpers, logits_processors = gen_logits( + num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty + ) - if outputs.finished: - logger.debug("Generation finished.") - eos_lab = True - break + condition_length = 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1 + + dtype = self.tts.emb_text.weight.dtype + emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device) + past_key_values = [ + ( + torch.zeros(1, self.tts.config.num_attention_heads, condition_length - 1, + self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, + device=self.tts.device), + torch.zeros(1, self.tts.config.num_attention_heads, condition_length - 1, + self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, + device=self.tts.device) + ) + for _ in range(self.tts.config.num_hidden_layers) + ] - if not eos_lab and tts_text: - logger.debug("eos_lab False, Generation continue.") + audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device) + eos_lab = False + for chunk_idx in range(math.ceil(emb.shape[1] / self.streaming_text_chunk_size)): if chunk_idx == 0: - begin = 0 + begin = chunk_idx * self.streaming_text_chunk_size + 0 + end = (chunk_idx + 1) * self.streaming_text_chunk_size + 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs else: - begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len - end = tts_token_lens + tts_start_token_len + 1 # 1 for [Etts] - if end > begin: - tts_input_ids = self.tts_processor.text_tokenizer( - tts_text, return_tensors="pt", add_special_tokens=False - )["input_ids"].cuda() - text_input_ids = tts_input_ids[:, begin:end] - streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) + begin = chunk_idx * self.streaming_text_chunk_size + 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + end = min((chunk_idx + 1) * self.streaming_text_chunk_size + 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs, + condition_length - 1) + if end - begin < 1: + print(f"BKing has break by the end of {end} and begin of {begin}") + else: + text_input_ids = tts_input_ids[:, begin: end] position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) + # print("预填充块:", begin, end) + if begin == 0: + past_key_values = self.tts.prefill_text( + input_ids=text_input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + lm_spk_emb_last_hidden_states=spk_embeds + ) + else: + past_key_values = self.tts.prefill_text( + input_ids=text_input_ids, + position_ids=position_ids, + past_key_values=past_key_values + ) - past_key_values = self.tts.prefill_text( - input_ids=text_input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None, - ) + outputs = self.tts.generate( + input_ids=audio_input_ids, + past_key_values=past_key_values, + streaming_tts_text_mask=streaming_tts_text_mask, + max_new_token=25, + force_no_stop=self.force_no_stop, + temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), + eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), + logits_warpers=logits_warpers, + logits_processors=logits_processors, + ) + audio_input_ids = outputs.audio_input_ids + past_key_values = outputs.past_key_values + + if outputs.finished: + print("Generation finished.") + eos_lab = True + break + + if not eos_lab: + print("Generation not finished.") while True: - # temp = [0.1, 0.3, 0.1, 0.3] if chunk_idx < 21 else [0.1] * self.tts.num_vq outputs = self.tts.generate( input_ids=audio_input_ids, past_key_values=past_key_values, streaming_tts_text_mask=streaming_tts_text_mask, - max_new_token=output_chunk_size, + max_new_token=25, force_no_stop=self.force_no_stop, - # temperature=torch.tensor([0.1] * self.tts.num_vq, dtype=torch.float, device=self.tts.device), temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), logits_warpers=logits_warpers, logits_processors=logits_processors, ) + audio_input_ids = outputs.audio_input_ids past_key_values = outputs.past_key_values - chunk_idx += 1 - - mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :]) - new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4] - - wav_np, sr = self.decode_mel_to_audio(mel_spec) - - if enable_regenerate: - if prev_wav is not None: - check_wav_np = wav_np[2048:].cpu().numpy() # 2*4*256(hop) - check_mel = mel_spec[0, :, 8:].cpu().numpy() # 2*4 - else: - check_wav_np = wav_np.cpu().numpy() - check_mel = mel_spec[0].cpu().numpy() - if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560): - voice_checker.reset() - # regenerate - N = output_chunk_size if prev_wav is None else output_chunk_size * 2 - past_kv = [] - for i in range(len(past_key_values)): - past_kv.append( - ( - past_key_values[i][0][:, :, :-N, :], # .clone(), - past_key_values[i][1][:, :, :-N, :], # .clone(), - ) - ) - outputs = self.tts.generate( - input_ids=audio_input_ids[:, :-N, :], - past_key_values=past_kv, - streaming_tts_text_mask=streaming_tts_text_mask, - max_new_token=N, - force_no_stop=self.force_no_stop, - temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), - eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), - logits_warpers=logits_warpers, - logits_processors=logits_processors, - ) - audio_input_ids = outputs.audio_input_ids - past_key_values = outputs.past_key_values - - new_ids_len -= N - mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :]) - new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4] - wav_np, sr = self.decode_mel_to_audio(mel_spec) - - if prev_wav is not None: - wav_y = wav_np[: len(prev_wav)] - prev_wav = wav_np[len(prev_wav) :] - cur_text = gen_text_raw[prev_text_len:] - prev_text_len = len(gen_text_raw) - yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr) - else: - prev_wav = wav_np - else: - # smooth wav - if prev_wav is not None: - wav_np, prev_wav = self._linear_overlap_add2_wav( - [prev_wav, wav_np], overlap=512 * 4 - ) # tts_hop256*2 - cur_text = gen_text_raw[prev_text_len:] - prev_text_len = len(gen_text_raw) - yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr) - else: - prev_wav = wav_np if outputs.finished: - logger.debug("Generation finished.") + print("Generation finished.") break if outputs.new_ids.shape[1] > 2048: - stop = True - logger.debug("Generation length > 2048, stopped.") + print("Generation not finished but break.") break - if prev_wav is not None: - cur_text = gen_text_raw[prev_text_len:] - yield OmniOutput(text=cur_text, audio_wav=prev_wav, sampling_rate=sr) # yield last chunk wav without smooth + mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids) + print("Mel spectrogram generated.") + return mel_spec - if new_segment_gen and not stop: - logger.debug( - f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, start a new segment generation" - ) - tid_len = 5 # self.tts.streaming_text_chunk_size - prev_seg_text_ids = tts_input_ids[:, end - 1 - tid_len : end - 1] # exclude last Etts - aid_len = 50 # int(tid_len * new_ids_len / tts_token_lens) - prev_seg_audio_ids = outputs.new_ids[:, -aid_len:, :] - - result = self._generate_mel_spec_audio_streaming( - spk_bounds, - streamer, - output_chunk_size, - spk_embeds, - prev_seg_text_ids, - text_left, - prev_seg_audio_ids, - enable_regenerate=enable_regenerate, - ) - for res in result: - yield res + def decode_mel_to_audio(self, mel_spec, output_path="test.wav"): + if self.vocos is None: + self.vocos = self.initialize_vocos() - def decode_mel_to_audio(self, mel_spec, output_path=""): with torch.inference_mode(): - wav_numpy = self.vocos.decode(mel_spec.float()).cpu().squeeze() - sr = 24000 - if output_path: - sf.write(output_path, wav_numpy.numpy(), samplerate=sr) - logger.info(f"Audio saved to {output_path}") - return wav_numpy, sr + wav_numpy = self.vocos.decode(mel_spec.float()).cpu().numpy().squeeze() + sf.write(output_path, wav_numpy, samplerate=24000) + print(f"Audio saved to {output_path}.") -# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference class MiniCPMWhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig, layer_idx: int = None): super().__init__() @@ -1870,7 +865,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, - layer_idx=layer_idx, + layer_idx=layer_idx ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -1881,32 +876,14 @@ class MiniCPMWhisperEncoderLayer(nn.Module): self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - layer_head_mask: torch.Tensor, - output_attentions: bool = False, - past_key_values: Optional[EncoderDecoderCache] = None, - use_cache: Optional[bool] = False, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = False, ) -> torch.Tensor: - r""" - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`): - Hidden states to be fed into the encoder layer. - attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`): - Attention mask where padding elements are indicated by large negative values. - layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`): - Mask to nullify selected heads of the attention modules. - output_attentions (`bool`, *optional*): - Whether or not to return the attention weights. - past_key_values (`EncoderDecoderCache`, *optional*): - Past key-value pairs used for incremental decoding. - use_cache (`bool`, *optional*): - Whether or not to return updated `past_key_values` for caching. - - Returns: - A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`. - """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, past_key_values = self.self_attn( @@ -1914,7 +891,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module): attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, - past_key_value=past_key_values, + past_key_value=past_key_values ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1928,7 +905,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module): hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) @@ -1943,128 +920,25 @@ class MiniCPMWhisperEncoderLayer(nn.Module): return outputs - -# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference class MiniCPMWhisperEncoder(WhisperEncoder): def __init__(self, config: WhisperConfig): super().__init__(config) - self.layers = nn.ModuleList( - [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)] - ) + self.layers = nn.ModuleList([ + MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers) + ]) def forward( - self, - input_features, - attention_mask=None, - head_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - past_key_values: Optional[EncoderDecoderCache] = None, - use_cache: Optional[bool] = None, + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, ): - r""" - Forward pass of the Whisper encoder. - - Args: - input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): - Float values of log-mel features extracted from the raw audio waveform. Typically generated - by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav` - files into padded 2D mel spectrogram frames. These features are projected via convolution layers - (`conv1` and `conv2`) and then transformed into embeddings for the encoder. - - attention_mask (`torch.Tensor`, *optional*): - Not used by Whisper for masking `input_features`, but included for API compatibility with - other models. If provided, it is simply ignored within the model. By default, Whisper - effectively ignores silence in the input log-mel spectrogram. - - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected attention heads. The elements should be either 1 or 0, where: - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked** (i.e., the attention head is dropped). - - output_attentions (`bool`, *optional*): - Whether or not to return the attention tensors of all encoder layers. If set to `True`, the - returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with - attention weights for each encoder layer. - - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. If set to `True`, the returned - tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the - initial embedding output as well as the outputs of each layer. - - return_dict (`bool`, *optional*): - Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead - of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object, - otherwise it will be a tuple. - - past_key_values (`EncoderDecoderCache`, *optional*): - When using caching for faster inference, this is an object that stores the key-value pairs - for attention states. If provided, the model will append new states to the existing cache - and return the updated cache. This speeds up sequential decoding or chunked inference. - - - If `past_key_values` is `None`, no past states are used or returned. - - If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided - cache and return the updated cache (as `next_encoder_cache`). - - use_cache (`bool`, *optional*): - Whether or not the model should use caching (`past_key_values`) to speed up processing - during inference. When set to `True`, the model will: - - Inspect and use `past_key_values` if provided. - - Return updated `past_key_values` (under the name `next_encoder_cache` in - `BaseModelOutputWithPast`). - - Returns: - `BaseModelOutputWithPast` or `tuple` (depending on `return_dict`): - If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains: - - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - The output of the final encoder layer. - - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`): - Hidden states of the model at each layer (including the initial projection). - - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`): - Attention weights from each encoder layer. - - **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*): - Updated cache of key-value pairs if `use_cache=True`. - - If `return_dict=False`, a tuple is returned, where the format is: - `(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions` - only present if their respective `output_*` arguments are set to `True`. - - Example: - >>> from transformers import AutoFeatureExtractor, WhisperConfig, WhisperForConditionalGeneration - >>> import torch - - >>> # Load a feature extractor and a Whisper model - >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - - >>> # Assume you have audio (list of floats or numpy array) loaded from a file - >>> # Then extract the mel features: - >>> input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features - - >>> # Forward pass - >>> outputs = model.encoder( - ... input_features=input_features, - ... output_hidden_states=True, - ... output_attentions=True, - ... use_cache=True - ... ) - - >>> # Retrieve the last hidden state - >>> last_hidden_state = outputs.last_hidden_state - >>> print(last_hidden_state.shape) - torch.Size([batch_size, seq_length, hidden_size]) - - >>> # Retrieve the intermediate hidden states if output_hidden_states=True - >>> all_encoder_hidden_states = outputs.hidden_states - - >>> # Retrieve attention weights if output_attentions=True - >>> all_encoder_attentions = outputs.attentions - - >>> # Retrieve updated past key values if use_cache=True - >>> encoder_cache = outputs.past_key_values - """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -2085,29 +959,30 @@ class MiniCPMWhisperEncoder(WhisperEncoder): if past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) elif isinstance(past_key_values, list): - past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) + past_key_values = EncoderDecoderCache( + DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) elif isinstance(past_key_values, DynamicCache): past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) else: pass past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: - logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") + if not padding_logged: + padding_logged = True + logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") embed_pos_front = embed_pos[past_key_values_length:, :] - embed_pos = torch.cat( - ( - embed_pos_front, - torch.repeat_interleave( - embed_pos[-1, :].unsqueeze(0), - inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, - dim=0, - ), + embed_pos = torch.cat(( + embed_pos_front, + torch.repeat_interleave( + embed_pos[-1, :].unsqueeze(0), + inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, + dim=0 ) - ) + )) else: - embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :] + embed_pos = embed_pos[past_key_values_length:inputs_embeds.shape[1] + past_key_values_length, :] else: - embed_pos = embed_pos[: inputs_embeds.shape[1], :] + embed_pos = embed_pos[:inputs_embeds.shape[1], :] hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -2143,7 +1018,7 @@ class MiniCPMWhisperEncoder(WhisperEncoder): (head_mask[idx] if head_mask is not None else None), output_attentions, past_key_values, - use_cache, + use_cache ) else: layer_outputs = encoder_layer( @@ -2152,7 +1027,7 @@ class MiniCPMWhisperEncoder(WhisperEncoder): layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_values=past_key_values, - use_cache=use_cache, + use_cache=use_cache ) hidden_states = layer_outputs[0] @@ -2175,19 +1050,18 @@ class MiniCPMWhisperEncoder(WhisperEncoder): last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, - past_key_values=next_encoder_cache, + past_key_values=next_encoder_cache ) - -# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` +# dvae module class ConvNeXtBlock(nn.Module): def __init__( - self, - dim: int, - intermediate_dim: int, - kernel: int, - dilation: int, - layer_scale_init_value: float = 1e-6, + self, + dim: int, + intermediate_dim: int, + kernel: int, + dilation: int, + layer_scale_init_value: float = 1e-6, ): # ConvNeXt Block copied from Vocos. super().__init__() @@ -2201,7 +1075,9 @@ class ConvNeXtBlock(nn.Module): ) self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, intermediate_dim) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.coef = ( @@ -2233,16 +1109,15 @@ class ConvNeXtBlock(nn.Module): return x -# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` class GFSQ(nn.Module): def __init__( - self, - dim: int, - levels: List[int], - G: int, - R: int, - eps=1e-5, - transpose=True, + self, + dim: int, + levels: List[int], + G: int, + R: int, + eps=1e-5, + transpose=True, ): super(GFSQ, self).__init__() self.quantizer = GroupedResidualFSQ( @@ -2276,18 +1151,17 @@ class GFSQ(nn.Module): return ind.transpose_(1, 2) if self.transpose else ind -# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` class DVAEDecoder(nn.Module): def __init__( - self, - idim: int, - odim: int, - n_layer=12, - bn_dim=64, - hidden=256, - kernel=7, - dilation=2, - up=False, + self, + idim: int, + odim: int, + n_layer=12, + bn_dim=64, + hidden=256, + kernel=7, + dilation=2, + up=False, ): super().__init__() self.up = up @@ -2321,10 +1195,9 @@ class DVAEDecoder(nn.Module): return x -# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py` class DVAE(nn.Module): def __init__( - self, + self, ): super().__init__() @@ -2364,7 +1237,11 @@ class DVAE(nn.Module): ) @torch.inference_mode() - def forward(self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode") -> torch.Tensor: + def forward( + self, + inp: torch.Tensor, + mode: Literal["encode", "decode"] = "decode" + ) -> torch.Tensor: if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: mel = inp.clone() x: torch.Tensor = self.downsample_conv( @@ -2400,15 +1277,16 @@ class DVAE(nn.Module): return torch.mul(dec_out, self.coef, out=dec_out) +# tts module def apply_spk_emb( - input_ids: torch.Tensor = None, - spk_emb: torch.Tensor = None, - input_embeds: torch.Tensor = None, - spk_emb_token_id: int = 0, - num_spk_embs: int = 1, + input_ids: torch.Tensor = None, + spk_emb: torch.Tensor = None, + input_embeds: torch.Tensor = None, + spk_emb_token_id: int = 0, + num_spk_embs: int = 1, ): """ - Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned. + Replace consecutive speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned. Args: input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max] @@ -2431,23 +1309,143 @@ def apply_spk_emb( assert nonzero_position_idx.shape[0] == num_spk_embs begin_idx = nonzero_position_idx.min() end_idx = nonzero_position_idx.max() - input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_ + input_embeds[idx, begin_idx: end_idx + 1, :] = spk_emb_ return +def make_streaming_chunk_mask( + input_embeds: torch.Tensor, + tts_text_scopes: List[List[int]], + tts_audio_scopes: List[List[int]], + tts_text_masks: List[torch.Tensor], + min_chunk_num_token: int = 5, + max_chunk_num_token: int = 7, + streaming_audio_chunk_size: int = 50, +): + """ + Create a look-ahead chunked attention mask that allows the TTS transformer to see only the first M tokens when generating each N to N+1 seconds of audio, enabling streaming TTS. + + Args: + input_embeds (torch.Tensor): Input embeddings combining text and audio, shape [batch_size, seq_len, hidden_dim] + tts_text_scopes (List[List[int]]): Range of text tokens for each sample + tts_audio_scopes (List[List[int]]): Range of audio tokens for each sample + tts_text_masks (List[torch.Tensor]): Text masks for each sample + min_chunk_num_token (int): Minimum number of new text tokens the model can see per audio chunk + max_chunk_num_token (int): Maximum number of new text tokens the model can see per audio chunk + streaming_audio_chunk_size (int): Size of audio chunk, 50 corresponds to approximately 1 second of audio + + Returns: + torch.Tensor: 4D causal mask with shape [batch_size, 1, seq_len, seq_len] + + Example: + Input sequence: + [t1, t2, t3, t4, t5, [Ptts], a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...] + Output 4D causal mask: + ------- text positions ------- + [0] <- here is [Stts] + [0, 0] <- here is [spk_emb] * N + [0, 0, 0] + [0, 0, 0, 0] + [0, 0, 0, 0, 0] + ------- audio positions -------- + [0, 0, -inf, -inf, -inf, 0] <- here is [Ptts], [Ptts]'s last hidden state should predict the first audio token + v- here is [Ptts] + [0, 0, -inf, -inf, -inf, 0, 0] + [0, 0, -inf, -inf, -inf, 0, 0, 0] + [0, 0, -inf, -inf, -inf, 0, 0, 0, 0] + [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0] + [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0, 0] # end of first 1s audio chunk + [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + """ + + # Create a complete attention mask for input embeds [batch_size, seq_len], without considering audio mask as audio is always at the end + batch_size = input_embeds.shape[0] + input_embeds_attention_mask = torch.ones(input_embeds.shape[0], input_embeds.shape[1], dtype=torch.int8, + device=input_embeds.device) + + for idx in range(batch_size): + input_embeds_attention_mask[idx, tts_text_scopes[idx][0]: tts_text_scopes[idx][1]] = tts_text_masks[idx] + + # Initialize a standard upper triangular causal mask + dtype = input_embeds.dtype + device = input_embeds.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_embeds.shape[1] + causal_mask = torch.full((sequence_length, sequence_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + raise ValueError("sequence_length of tts could not be 1.") + causal_mask = causal_mask.unsqueeze(0).repeat(input_embeds.shape[0], 1, 1) + + # For each data sample + for idx in range(input_embeds.shape[0]): + tts_audio_scope = tts_audio_scopes[idx] + tts_text_scope = tts_text_scopes[idx] + + audio_token_start = tts_audio_scope[0] + audio_duration = tts_audio_scope[1] - tts_audio_scope[0] + + # Record which text chunk the current audio chunk can see up to + text_pivot = 0 + num_valid_text_tokens = torch.sum(tts_text_masks[idx]).item() - 1 # [Ptts] excluded + # How many audio chunks are in total, the num of buckets should be smaller as possible + num_buckets = max(1, math.floor(audio_duration / streaming_audio_chunk_size)) + # print("num_buckets", num_buckets) + + num_text_tokens_per_audio_chunk = math.ceil( + num_valid_text_tokens / num_buckets) # 这里 10 是超参数 比如每个audio chunk最多说10个文本token,再多就不正常了。 + if num_text_tokens_per_audio_chunk > 10: + num_text_tokens_per_audio_chunk = 10 + elif num_text_tokens_per_audio_chunk < 4: + num_text_tokens_per_audio_chunk = 4 + else: + pass + + # print("num_text_tokens_per_audio_chunk", num_text_tokens_per_audio_chunk) + + # For each chunk of audio + for chunk_idx in range(math.ceil(audio_duration / streaming_audio_chunk_size)): + audio_chunk_start = audio_token_start + chunk_idx * streaming_audio_chunk_size + audio_chunk_end = audio_token_start + (chunk_idx + 1) * streaming_audio_chunk_size + # New text seen by this new audio chunk + new_text_this_chunk = num_text_tokens_per_audio_chunk + # The right bound of visible text tokens + text_pivot = min(new_text_this_chunk + text_pivot, num_valid_text_tokens) + # Mask all text chunks after the visible ones + # -> [text_pivot, len(tts_text_scope)-1] excluding [Ptts] + causal_mask[ + idx, + audio_chunk_start - 1: audio_chunk_end - 1, + tts_text_scope[0] + text_pivot: tts_text_scope[1] - 1 + ] = min_dtype + + # Mask the padding parts in tts_text_masks (no position will attend to it) + causal_mask[idx, :, input_embeds_attention_mask[idx] == 0] = min_dtype + + # Add extra dimensions, [batch_size, seq_len, seq_len] -> [batch_size, 1, seq_len, seq_len] + causal_mask = causal_mask.unsqueeze(1) + + return causal_mask + + def make_streaming_chunk_mask_generation( - inputs_embeds: torch.Tensor, - past_seen_tokens: int, - streaming_tts_text_mask: torch.Tensor, - streaming_reserved_length: int = 300, - streaming_audio_chunk_size: int = 50, - streaming_text_chunk_size: int = 10, - num_spk_emb: int = 1, - use_spk_emb: bool = True, + inputs_embeds: torch.Tensor, + past_seen_tokens: int, + streaming_tts_text_mask: torch.Tensor, + streaming_reserved_length: int = 300, + streaming_audio_chunk_size: int = 50, + streaming_text_chunk_size: int = 10, + num_spk_emb: int = 1, + use_spk_emb: bool = True, ) -> torch.Tensor: """ - In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens. + Determine which `text` tokens the model can attend to when generating each chunk of `audio` tokens. This function creates a mask that allows the model to attend to a specific chunk of text tokens when generating each chunk of audio tokens, enabling streaming TTS generation. @@ -2473,30 +1471,24 @@ def make_streaming_chunk_mask_generation( min_dtype = torch.finfo(dtype).min # Add `1` to the past seen tokens to account for new `tokens` during `generate` - causal_mask = torch.full((1, past_seen_tokens + inputs_embeds.shape[1]), fill_value=0, dtype=dtype, device=device) + causal_mask = torch.full((1, past_seen_tokens + 1), fill_value=0, dtype=dtype, device=device) # Calculate the start of invisible text tokens - invisible_text_tokens_start = ( - min( - math.ceil((past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size) - * streaming_text_chunk_size, - streaming_reserved_length, - ) - + 1 - + num_spk_emb * use_spk_emb - ) # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True + invisible_text_tokens_start = min( + math.ceil( + (past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size + ) * streaming_text_chunk_size, + streaming_reserved_length + ) + 1 + num_spk_emb * use_spk_emb # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True - invisible_text_tokens_end = ( - streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1 - ) # Add 1 for [Ptts] (aka `audio_bos_token_id`) + invisible_text_tokens_end = streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1 # Add 1 for [Ptts] (aka `audio_bos_token_id`) # Set invisible text tokens to min_dtype (effectively -inf) - causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype + causal_mask[0, invisible_text_tokens_start: invisible_text_tokens_end] = min_dtype # Mask padding positions in the text mask - causal_mask[0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1].masked_fill_( - streaming_tts_text_mask == 0, min_dtype - ) + causal_mask[0, 0: 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1].masked_fill_( + streaming_tts_text_mask == 0, min_dtype) # Add extra dimensions for batch and heads causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) @@ -2504,22 +1496,27 @@ def make_streaming_chunk_mask_generation( return causal_mask -# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py` class CustomRepetitionPenaltyLogitsProcessorRepeat: def __init__(self, penalty: float, max_input_ids: int, past_window: int): if not isinstance(penalty, float) or not (penalty > 0): - raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + raise ValueError( + f"`penalty` has to be a strictly positive float, but is {penalty}" + ) self.penalty = penalty self.max_input_ids = max_input_ids self.past_window = past_window - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: if input_ids.size(1) > self.past_window: input_ids = input_ids.narrow(1, -self.past_window, self.past_window) freq = F.one_hot(input_ids, scores.size(1)).sum(1) if freq.size(0) > self.max_input_ids: - freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_() + freq.narrow( + 0, self.max_input_ids, freq.size(0) - self.max_input_ids + ).zero_() alpha = torch.pow(self.penalty, freq) scores = scores.contiguous() inp = scores.multiply(alpha) @@ -2563,100 +1560,13 @@ class MultiModalProjector(nn.Module): class ConditionalChatTTS(PreTrainedModel): - """A conditional text-to-speech model that can generate speech from text with speaker conditioning. - - This model extends PreTrainedModel to provide text-to-speech capabilities with: - - LLM hidden state conditioning - - Streaming generation - - The model uses a transformer architecture with LLM hidden states and can operate in both - streaming and non-streaming modes for flexible deployment. - - The model process sequence in the following format: - | text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token | - - The format is designed to support LLM-conditioned streaming audio generation. - - Usage: - To support streaming generation, two global variables should be maintained outside of the model. - 1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq]. - 2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads] - - where `num_vq` is the number of audio codebooks, in default setting, it is `4`. - - 1. Create an empty `past_key_values` with - ```python - initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token - dtype = model.emb_text.weight.dtype - device = model.emb_text.weight.device - past_key_values = [ - ( - torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device), - torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device) - ) - for _ in range(model.config.num_hidden_layers) - ] - - 2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder. - - ```python - initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1 - # [bos token, speaker embeddings, text tokens, audio bos token] - audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq) - ``` - - 2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method. - - ```python - outputs = llm.generate(**kwargs) - llm_tokens = some_function_to_extract_llm_tokens(outputs) - lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs) - tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens)) - # here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens. - begin = 0 - end = 9+1 - position_ids = torch.arange(begin, end, dtype=torch.long, device=device) - - past_key_values = model.prefill_text( - input_ids=tts_text_input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, - ) - ``` - - 3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention. - - ```python - streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length) - streaming_tts_text_mask[0:end] = 1 # denotes these post - ``` - - 3. Generate audio codes using `generate` method. - - ```python - outputs = model.generate( - input_ids=audio_input_ids, - past_key_values=past_key_values, - streaming_tts_text_mask=streaming_tts_text_mask, - max_new_token=50, - ) - - # update past_key_values and input_ids - past_key_values = outputs.past_key_values - audio_input_ids = outputs.input_ids - ``` - - The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling. - - 4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response. - - 5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above. - """ - config_class = ConditionalChatTTSConfig - - def __init__(self, config: ConditionalChatTTSConfig): + _no_split_modules = [] + + def __init__( + self, + config: ConditionalChatTTSConfig + ): super().__init__(config) self.use_speaker_embedding = config.use_speaker_embedding @@ -2666,6 +1576,8 @@ class ConditionalChatTTS(PreTrainedModel): self.use_text = config.use_text self.streaming = config.streaming + self.streaming_text_chunk_min = config.streaming_text_chunk_min + self.streaming_text_chunk_max = config.streaming_text_chunk_max self.streaming_text_chunk_size = config.streaming_text_chunk_size self.streaming_audio_chunk_size = config.streaming_audio_chunk_size self.streaming_text_reserved_len = config.streaming_text_reserved_len @@ -2683,16 +1595,19 @@ class ConditionalChatTTS(PreTrainedModel): else: self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False) self.emb_code = nn.ModuleList( - [nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)] + [ + nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq) + ] + ) + self.emb_text = nn.Embedding( + config.num_text_tokens, config.hidden_size ) - self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) self.head_code = nn.ModuleList( [ weight_norm( nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), name="weight", - ) - for _ in range(config.num_vq) + ) for _ in range(config.num_vq) ] ) dvae = DVAE() @@ -2710,17 +1625,369 @@ class ConditionalChatTTS(PreTrainedModel): model = LlamaModel(model_config) self.model = model + return + + def forward( + self, + input_ids, + lm_spk_emb_last_hidden_states=None, + lm_last_hidden_states=None, + target_audio_features=None, + streaming_tts_text_masks=None, + **kwargs, + ): + """ + Calculate TTS modeling loss. Only used in training. + + Process: + - LLM last hidden states (obtained from LLM, with gradients) + - Text ground truth (without gradients) + - Target audio features (without gradients) + + Updates: + - 2024/10/3: Support empty input (dummy train) for tasks without audio, preventing training stalls due to unused parameters. + - 2024/10/11: Support EOS token + + Args: + input_ids (List[Tensor[seq_len]]): Text ground truth input_ids for each model's speech area. Each element is a variable-length Tensor. + lm_spk_emb_last_hidden_states (List[Tensor[gpt_dim]], optional): Speaker embedding last hidden states from the language model. + lm_last_hidden_states (List[Tensor[seq_len, gpt_dim]], optional): LLM last hidden states for each model's speech area. Each element is a variable-length Tensor. + target_audio_features (List[Tensor[num_channels, num_samples]], optional): Mel spectrogram ground truth for each model's speech area. Each element is a variable-length Tensor. + streaming_tts_text_masks (List[Tensor[seq_len_max]], optional): Masks used to pad text to fixed length in streaming training. Shape is Tensor[seq_len_max]. + """ + + # consider the case of dummy training + dummy = False + if self.train: + if len(input_ids) == 0: + dummy = True + dummy_seq_len = 100 + input_ids = [ + torch.full( + (dummy_seq_len,), + fill_value=1, + device=self.model.embed_tokens.weight.device, + dtype=torch.int64 + ) + ] + input_ids[0][0: self.num_spk_embs] = self.spk_emb_token_id + + if self.config.use_speaker_embedding: + lm_spk_emb_last_hidden_states = [ + torch.full( + (self.num_spk_embs, self.config.llm_dim), + fill_value=0, + device=self.model.embed_tokens.weight.device, + dtype=self.model.embed_tokens.weight.dtype + ) + ] + else: + lm_last_hidden_states = [ + torch.full( + (dummy_seq_len, self.config.llm_dim), + fill_value=0, + device=self.model.embed_tokens.weight.device, + dtype=self.model.embed_tokens.weight.dtype + ) + ] + + target_audio_features = [ + torch.full( + (self.num_mel_bins, dummy_seq_len), + fill_value=0, + device=self.model.embed_tokens.weight.device, + dtype=self.model.embed_tokens.weight.dtype + ) + ] + streaming_tts_text_masks = None + + if lm_last_hidden_states is not None: + assert not self.use_speaker_embedding + # Project llm last hidden states (QwenAudio, Qwen2) to tts gpt decoder hidden size (as tts condition) first + # Keep track of the length of each tts condition + assert len(lm_last_hidden_states) != 0 + all_tts_condition_seq_len = [i.shape[0] for i in lm_last_hidden_states] + + # Pad hidden states to be a big tensor for high efficiency ---- [batch_size, seq_len_max, lm_hidden_size] + input_data = pad_sequence(lm_last_hidden_states, batch_first=True) + + # all_lm_last_hidden_states -> all_tts_conditions + all_tts_condition = self.projector(input_data) + + # Perform L2 norm # [batch_size, seq_len_max, gpt_hidden_size] + all_tts_condition = F.normalize(all_tts_condition, p=2, dim=2) + + # Split whole tensor into list[Tensor] and remove padding positions + all_tts_condition_varlen = [] + for idx in range(all_tts_condition.shape[0]): + all_tts_condition_varlen.append(all_tts_condition[idx, 0:all_tts_condition_seq_len[idx]]) + + else: + all_tts_condition_varlen = None + + if lm_spk_emb_last_hidden_states is not None: # List[Tensor[num_spk_emb, lm_hidden_dim]] + assert self.use_speaker_embedding + if len(lm_spk_emb_last_hidden_states) == 0: + raise ValueError("lm_spk_emb_last_hidden_states is empty.") + # [bs, num_spk_emb, lm_hidden_dim] This will raise an error if spk_emb is not equal for each data + stacked_lm_spk_emb_last_hidden_states = torch.stack(lm_spk_emb_last_hidden_states, dim=0) + + # Check if the number of num_spk_embs matches the expectation + assert stacked_lm_spk_emb_last_hidden_states.shape[1] == self.num_spk_embs + + # Project to tts decoder dimension uniformly + gpt_spk_emb_last_hidden_states = self.projector( + stacked_lm_spk_emb_last_hidden_states) # [bs, num_spk_emb, gpt_dim] + + # Normalize + gpt_spk_emb_last_hidden_states = F.normalize(gpt_spk_emb_last_hidden_states, p=2, dim=-1) + + else: + gpt_spk_emb_last_hidden_states = None + + # means training, encoding audio features to audio tokens using dVAE on the fly + if target_audio_features is not None: + assert self.dvae.coef.requires_grad == False + with torch.inference_mode(): + eos_token_id = int(self.emb_code[0].num_embeddings - 1) + all_audio_codes = [] + # For speech, it might be necessary to keep float32 encoding, even if it's slower + with torch.cuda.amp.autocast(dtype=torch.float): + for audio_waveform in target_audio_features: + audio_codes = self.dvae(audio_waveform, mode="encode") # Tensor[1, num_vq, audio_seq_len] + # Add eos token + audio_codes_with_eos = torch.cat( + ( + audio_codes.squeeze(0), # [num_vq, seq_len] + torch.ones(self.num_vq, 1, device=audio_codes.device, + dtype=audio_codes.dtype) * eos_token_id # [num_vq, 1] + ), dim=-1 + ) + all_audio_codes.append(audio_codes_with_eos) # Tensor[4, audio_seq_len] + + all_audio_codes_seq_len = [i.shape[1] for i in all_audio_codes] + + # Encode 4 layers of codes to audio embedding by layer + audio_embed_all_layers = [] + for i in range(self.num_vq): + audio_codes_layer_i = [] + for codes in all_audio_codes: + audio_codes_layer_i.append( + codes[i, :].squeeze(0), + ) + # Pad each layer of audio codes to fixed length + audio_codes_layer_i = pad_sequence(audio_codes_layer_i, batch_first=True) + # Encode each layer of audio codes into embedding (parallelized) + audio_embed_layer_i = self.emb_code[i](audio_codes_layer_i) # [batch_size, seq_len, gpt_hidden_dim] + audio_embed_all_layers.append(audio_embed_layer_i) + + # Here we need to calculate the audio_embed of four layers and add them up + # According to the official implementation of ChatTTS https://github.com/2noise/ChatTTS/blob/51ec0c784c2795b257d7a6b64274e7a36186b731/ChatTTS/model/gpt.py#L451 + audio_embed_all_layers = torch.stack(audio_embed_all_layers, dim=0) # [num_vq, seq_len, gpt_hidden_dim] + audio_embed_all_layers = torch.sum(audio_embed_all_layers, dim=0, + keepdim=False) # [seq_len, gpt_hidden_dim] + + # Convert back to variable-length sequences based on the original lengths of stored audio codes + audio_embed_all_layers_varlen = [] + for idx in range(audio_embed_all_layers.shape[0]): + audio_embed_all_layers_varlen.append( + audio_embed_all_layers[idx, 0:all_audio_codes_seq_len[idx]] + ) + + # Encode the text into embeds + all_input_ids_seq_len = [i.shape[0] for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True) + all_text_embeds = self.emb_text(input_ids) # [batch_size, seq_len] -> [batch_size, seq_len, gpt_hidden_dim] + + # Merge spk_emb: If spk_emb is provided, it needs to be replaced in the embeds + if lm_spk_emb_last_hidden_states is not None: + # This is an in-place replacement of some positions in all_text_embeds with spk emb + apply_spk_emb( + input_ids=input_ids, + spk_emb=gpt_spk_emb_last_hidden_states, + input_embeds=all_text_embeds, + spk_emb_token_id=self.spk_emb_token_id, + num_spk_embs=self.num_spk_embs, + ) + + all_text_embeds_varlen = [] + # Convert back to variable-length sequences for easier fusion of different tokens later + for idx in range(all_text_embeds.shape[0]): + all_text_embeds_varlen.append( + all_text_embeds[idx, 0:all_input_ids_seq_len[idx], :] + ) # List[ Tensor[seq_len, gpt_hidden_dim] ] + + # Merge tts condition, audio embeds, and text token embeds. + # Final concatenation format: llm last hidden state | text_embeds embeds | audio embeds + + # Merge embeds from multiple sources + embeds_to_merge = [] + + # Add lm condition + if lm_last_hidden_states is not None: + embeds_to_merge.append(all_tts_condition_varlen) + + # Add text + if self.use_text: + embeds_to_merge.append(all_text_embeds_varlen) + + # If audio feature is provided, add audio embeds + if target_audio_features is not None: + embeds_to_merge.append(audio_embed_all_layers_varlen) + + # Merge embeds + all_merged_embeds_ = [] + for item_tuple in zip(*embeds_to_merge): + # [seq_len_tts_condition+seq_len_text+seq_len_audio, gpt_hidden_dim] + merged_embed = torch.cat(item_tuple, dim=0) + all_merged_embeds_.append(merged_embed) + + input_embeds_seqlen = [] + for i in all_merged_embeds_: + input_embeds_seqlen.append(i.shape[0]) + + # This will pad the embeds of each sequence to form a neat tensor, as we're about to feed it into the transformer + # We don't generate an attention mask here because we use right padding + input_embeds = pad_sequence(all_merged_embeds_, + batch_first=True) # List[ Tensor[seq_len_i, gpt_hidden_dim] ] -> Tensor[batch_size, seq_len_max, gpt_hidden_dim] + + # Determine the position of text in each data + text_ranges = [] + batch_size = input_embeds.shape[0] + for idx in range(batch_size): + start_idx = 0 + + # If hidden state is provided, we need to consider the length of the hidden state + if lm_last_hidden_states is not None: + start_idx += all_tts_condition_seq_len[idx] + + end_idx = start_idx + all_input_ids_seq_len[idx] + text_ranges.append((start_idx, end_idx)) + + if target_audio_features is not None: + # Make labels for audio codes + batch_size = input_embeds.shape[0] + seq_len_max = input_embeds.shape[1] + + # Here we construct a labels, only the positions of audio codes will be learned. [batch_size, seq_len, num_vqs] + labels = torch.zeros(batch_size, seq_len_max, self.num_vq, device=input_embeds.device, dtype=torch.long) + labels[:, :, :] = -100 + + # Determine the position of audio codes in each data + audio_codes_ranges = [] + for idx in range(batch_size): + start_idx = 0 + + # If hidden state is provided, we need to consider the length of the hidden state + if lm_last_hidden_states is not None: + start_idx += all_tts_condition_seq_len[idx] + + if self.use_text: + start_idx += all_input_ids_seq_len[idx] + + end_idx = start_idx + all_audio_codes_seq_len[idx] + audio_codes_ranges.append((start_idx, end_idx)) + + # Replace audio labels into labels + for idx, audio_codes_range in zip(range(batch_size), audio_codes_ranges): + start_idx = audio_codes_range[0] + end_idx = audio_codes_range[1] + labels[ + idx, start_idx: end_idx, : + ] = all_audio_codes[idx].permute(1, 0) + + # For REAL streaming ChatTTS setting, a simple way is to create a self-defined 4D attention mask to the model, then we can control which kv can be attended by which q. + # https://github.com/huggingface/transformers/blob/65bb28444849976f853063edb958b3ef3dd59d12/src/transformers/models/llama/modeling_llama.py#L59 + # It says, `Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.` + + if self.streaming and not dummy: + tts_attention_mask_4d = make_streaming_chunk_mask( + input_embeds=input_embeds, # input_embeds after merging text and audio + tts_text_scopes=text_ranges, # List[Tuple[int, int]] + tts_audio_scopes=audio_codes_ranges, # List[Tuple[int, int]] + tts_text_masks=streaming_tts_text_masks, # List[Tensor[seq_len_max]] + min_chunk_num_token=self.streaming_text_chunk_min, + max_chunk_num_token=self.streaming_text_chunk_max, + streaming_audio_chunk_size=self.streaming_audio_chunk_size, + ) # [batch_size, 1, seq_len, seq_len] + else: + tts_attention_mask_4d = None + + # invoke gpt forward AND get last hidden states AND predict audio codes + # here we don't use attention mask because we use right padding, and we have manually made labels know where should learn + + outputs = self.model( # self.decoder.gpt is a Llama model, not LlamaForCausalLM + inputs_embeds=input_embeds, + attention_mask=tts_attention_mask_4d, + ) + + tts_last_hidden_state = outputs.last_hidden_state # [batch, seq_len_max, gpt_hidden_dim] + + # predict audio codes using last_hidden_state by gpt TTS decoder + logits_all_vq_layers = [] + for num_vq_iter in range(self.num_vq): + logits_i = self.head_code[num_vq_iter]( + tts_last_hidden_state) # [batch, seq_len_max, audio_codebook_vocab] + logits_all_vq_layers.append(logits_i) + logits_all_vq_layers = torch.stack(logits_all_vq_layers, + dim=0) # [num_vq, batch_size, seq_len_max, audio_codebook_vocab], stack, insert one extra dimension + logits_all_vq_layers = logits_all_vq_layers.permute(1, 2, 0, + 3) # [batch_size, seq_len_max, num_vq, audio_codebook_vocab] + + # compute model predictions + shift_logits = logits_all_vq_layers[:, :-1, :, + :].contiguous() # [batch_size, seq_len_max-1, num_vq, audio_codebook_vocab] + shift_labels = labels[:, 1:, :].contiguous() # [batch_size, seq_len_max-1, num_vq] + + # compute CE loss + if not self.aug_loss_weight: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + else: + loss_fct = nn.CrossEntropyLoss(reduction='none') + losses = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device) + ).view(shift_labels.size()) # [batch_size, seq_len_max-1, num_vq] + + valid_label_count = (shift_labels != -100).sum() + + eos_token_id = int(self.dvae.emb_code[0].num_embeddings - 1) + eos_positions = (shift_labels == eos_token_id).nonzero() + for pos in eos_positions: + seq_len = pos[1] + 1 # 包含eos_token_id的序列长度 + if seq_len < 400: # shorter than 5s (150text+50audio*5) + losses[pos[0], pos[1], pos[2]] *= 0.2 + elif seq_len > 650: # longer than 15s (150text+50audio*15) + losses[pos[0], pos[1], pos[2]] *= 2 + + loss = losses.sum() / valid_label_count + + if dummy: + print("dummy loss", loss) + loss = loss * 0 # Avoid bringing invalid gradients + + else: + loss = None + + return loss + @torch.inference_mode() - def merge_inputs_embeds( - self, - input_ids: torch.Tensor, - lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, + def prepare_inputs_embeds( + self, + input_ids: torch.Tensor, + lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, + lm_last_hidden_states: Optional[torch.Tensor] = None ): - """Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`. + """Prepare inputs_embeds for the model in inference mode, + encode input_ids to embeddings, then merge lm_spk_emb_last_hidden_states, and lm_last_hidden_states. Args: input_ids (torch.Tensor): Input token IDs. lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None. + lm_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states from the language model. Defaults to None. Raises: NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented. @@ -2747,23 +2014,26 @@ class ConditionalChatTTS(PreTrainedModel): spk_emb=projected_spk_emb, input_embeds=inputs_embeds, spk_emb_token_id=self.spk_emb_token_id, - num_spk_embs=self.num_spk_embs, + num_spk_embs=self.num_spk_embs ) else: + assert lm_last_hidden_states is not None + # TODO: Add projected language model hidden states to tts embedding space raise NotImplementedError return inputs_embeds @torch.inference_mode() def prefill_text( - self, - input_ids: torch.Tensor, - position_ids: torch.LongTensor, - past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], - lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.LongTensor, + past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], + lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None, + lm_last_hidden_states: Optional[torch.Tensor] = None ): """Prefill a chunk of new text tokens in streaming setting. - Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens. + Specifically speaking, update `past_key_values` using new text tokens. Args: input_ids (Tensor): Tensor of shape [batch_size, seq_len] @@ -2777,10 +2047,11 @@ class ConditionalChatTTS(PreTrainedModel): assert input_ids.shape[0] == 1 assert past_key_values is not None - # Merge text and LLM embeddings - inputs_embeds = self.merge_inputs_embeds( + # Merge text and embeddings from language model + inputs_embeds = self.prepare_inputs_embeds( input_ids=input_ids, lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, + lm_last_hidden_states=lm_last_hidden_states, ) # Clone KV Cache @@ -2788,8 +2059,8 @@ class ConditionalChatTTS(PreTrainedModel): for i in range(len(past_key_values)): past_key_values_for_prefill.append( ( - past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(), - past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(), + past_key_values[i][0][:, :, :position_ids[:, 0], :].clone(), + past_key_values[i][1][:, :, :position_ids[:, 0], :].clone(), ) ) @@ -2807,20 +2078,14 @@ class ConditionalChatTTS(PreTrainedModel): # Get model updated KV Cache past_key_values_for_prefill_updated = outputs_prefill.past_key_values - # Update generated KV Cache to input `past_key_values` + # Update generated KV Cache to input past_key_values for layer_idx in range(len(past_key_values)): # Update keys - past_key_values[layer_idx][0][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = ( - past_key_values_for_prefill_updated[layer_idx][0][ - :, :, position_ids[:, 0] : position_ids[:, -1] + 1 - ].clone() - ) + past_key_values[layer_idx][0][:, :, position_ids[:, 0]:position_ids[:, -1] + 1, :] = \ + past_key_values_for_prefill_updated[layer_idx][0][:, :, position_ids[:, 0]:position_ids[:, -1] + 1].clone() # Update values - past_key_values[layer_idx][1][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = ( - past_key_values_for_prefill_updated[layer_idx][1][ - :, :, position_ids[:, 0] : position_ids[:, -1] + 1 - ].clone() - ) + past_key_values[layer_idx][1][:, :, position_ids[:, 0]:position_ids[:, -1] + 1, :] = \ + past_key_values_for_prefill_updated[layer_idx][1][:, :, position_ids[:, 0]:position_ids[:, -1] + 1].clone() # TODO: del past_key_values_for_prefill_updated recursively # TODO: del outputs_prefill recursively @@ -2828,81 +2093,71 @@ class ConditionalChatTTS(PreTrainedModel): return past_key_values @torch.inference_mode() - def prefill_audio_ids( - self, - input_ids: torch.Tensor, - past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], - streaming_tts_text_mask=None, - add_audio_bos: bool = True, + def generate( + self, + input_ids: torch.Tensor, + past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], + temperature: torch.Tensor, + eos_token: Union[int, torch.Tensor], + streaming_tts_text_mask=None, + force_no_stop=False, + min_new_token=10, + max_new_token=50, + logits_warpers: List[LogitsWarper] = [], + logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], + show_tqdm=False, ): - """Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation. - Specifically, prefill many audio ids (typically from last window) to the model in the new window. - - Args: - input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids. - past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism. - """ - assert input_ids.shape[0] == 1 - assert past_key_values is not None - - code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)] - inputs_embeds = torch.stack(code_emb, 3).sum(3) # [1,seq_len,768] - input_len = input_ids.shape[1] + """Generate audio codes in streaming setting. + Specifically speaking, generate audio codes when not all text tokens are prefilled. - if add_audio_bos: - narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device) - bos_inputs_embeds = self.emb_text(narrowed_input_ids) - inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1) - input_len += 1 + Usage: + Always pass an non-empty `past_key_values` to the function. The function does not do `prefill` by itself. It relies on `prefill_text` method to provide a valid `past_key_values`. - past_key_values_length = past_key_values[0][0].shape[2] - position_ids = torch.arange( - past_key_values_length, past_key_values_length + input_len, dtype=torch.long, device=self.device - ).unsqueeze(0) + 1. Create an empty `past_key_values` with + ```python + initial_kv_cache_length = 1 + self.num_spk_embs + self.streaming_text_reserved_len + dtype = model.emb_text.weight.dtype + device = model.emb_text.weight.device + past_key_values = [ + ( + torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device), + torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device) + ) + for _ in range(model.config.num_hidden_layers) + ] - cache_position = position_ids.clone() - causal_mask = make_streaming_chunk_mask_generation( - inputs_embeds=inputs_embeds, - past_seen_tokens=past_key_values[0][0].shape[2], - streaming_tts_text_mask=streaming_tts_text_mask, - streaming_reserved_length=self.streaming_text_reserved_len, - streaming_text_chunk_size=self.streaming_text_chunk_size, - ) # [1, 1, 1, past_key_values_length + input_len] + 2. Prefill some text tokens using `prefill_text` method. + ```python + outputs = llm.generate(**kwargs) + lm_spk_emb_last_hidden_states or lm_last_hidden_states = extract(outputs.last_hidden_states) + input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens)) + position_ids = torch.arange(begin, end, dtype=torch.long, device=device) + past_key_values = self.prefill_text( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states, + lm_last_hidden_states=lm_last_hidden_states, + ) + ``` - # Model forward - outputs: BaseModelOutputWithPast = self.model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=True, - output_attentions=False, - cache_position=cache_position, - ) - past_key_values = outputs.past_key_values - return past_key_values + 3. Generate audio codes using `generate` method. + ```python + # initialize input_ids, this should be only done `once` + condition_length = 1 + model.num_spk_embs * model.use_speaker_embedding + model.streaming_text_reserved_len + 1 + input_ids = torch.zeros(batch_size=1, condition_length, self.num_vq) - @torch.inference_mode() - def generate( - self, - input_ids: torch.Tensor, - past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], - temperature: torch.Tensor, - eos_token: Union[int, torch.Tensor], - streaming_tts_text_mask=None, - force_no_stop=False, - min_new_token=10, - max_new_token=50, - logits_warpers: List[LogitsWarper] = [], - logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], - show_tqdm=False, - ): - """Generate audio codes in streaming setting or non-streaming setting. - Specifically speaking, generate audio codes when not all text tokens are prefilled. + outputs = self.generate( + input_ids=input_ids, + past_key_values=past_key_values, + ) - Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details. + # update past_key_values and input_ids + past_key_values = outputs.past_key_values + input_ids = outputs.input_ids + ``` - In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`. + 4. Repeat step 2 and 3. Args: input_ids (torch.Tensor): Input token ids. @@ -2914,7 +2169,8 @@ class ConditionalChatTTS(PreTrainedModel): logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to []. logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to []. show_tqdm (bool, optional): Whether to show progress bar. Defaults to True. - + Raises: + NotImplementedError: _description_ Returns: GenerationOutputs: Generation outputs. """ @@ -2929,7 +2185,12 @@ class ConditionalChatTTS(PreTrainedModel): finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool() - temperature = temperature.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous().view(-1, 1) + temperature = ( + temperature.unsqueeze(0) + .expand(input_ids.shape[0], -1) + .contiguous() + .view(-1, 1) + ) progress = input_ids.shape[1] @@ -2942,7 +2203,7 @@ class ConditionalChatTTS(PreTrainedModel): device=input_ids.device, ) - # Copy existing `input_ids` to `input_ids_buf` + # Copy existing input_ids to input_ids_buf input_ids_buf.narrow(1, 0, progress).copy_(input_ids) del input_ids @@ -2961,41 +2222,51 @@ class ConditionalChatTTS(PreTrainedModel): for i in range(max_new_token): # Prepare generation inputs audio_bos = False - - # If this is the first audio token, the case is SPECIAL + # If this is the first audio token, the case is special if progress == condition_length: audio_bos = True - assert progress == ( - past_key_values[0][0].shape[2] + 1 - ) # If you are using according to the guidelines, this should be passed. - if audio_bos: - # Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict a new audio token. This is a special case because without the `audio bos token`, it is impossible to generate the first audio token in our streaming setting. + # Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict a new audio token. + assert progress == (past_key_values[0][0].shape[2] + 1) narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device) inputs_embeds = self.emb_text(narrowed_input_ids) del narrowed_input_ids else: - # Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate`. + # Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate` + assert progress == (past_key_values[0][0].shape[2] + 1) narrowed_input_ids = input_ids.narrow(dim=1, start=input_ids.shape[1] - 1, length=1) - code_emb = [self.emb_code[i](narrowed_input_ids[:, :, i]) for i in range(self.num_vq)] + code_emb = [ + self.emb_code[i](narrowed_input_ids[:, :, i]) + for i in range(self.num_vq) + ] inputs_embeds = torch.stack(code_emb, 3).sum(3) position_ids = torch.tensor( - [past_key_values[0][0].shape[2] + 1], dtype=torch.long, device=self.device + [past_key_values[0][0].shape[2] + 1], + dtype=torch.long, + device=self.device ).unsqueeze(0) cache_position = position_ids.clone() - - # Make causal mask causal_mask = make_streaming_chunk_mask_generation( inputs_embeds=inputs_embeds, past_seen_tokens=past_key_values[0][0].shape[2], streaming_tts_text_mask=streaming_tts_text_mask, streaming_reserved_length=self.streaming_text_reserved_len, - streaming_text_chunk_size=self.streaming_text_chunk_size, + streaming_text_chunk_size=self.streaming_text_chunk_size ) + # debug = False + # if debug: + # print(f"generation step {i}") + # print(f" position_ids {position_ids}") + # if past_key_values is not None: + # print(f" past_key_values {past_key_values[0][0].shape}") + # print(f" inputs_embeds {inputs_embeds.shape}") + # print(f" cache_position {cache_position}") + # print(f" causal_mask {causal_mask.shape}") + # Model forward outputs: BaseModelOutputWithPast = self.model( attention_mask=causal_mask, @@ -3069,7 +2340,8 @@ class ConditionalChatTTS(PreTrainedModel): scores = F.softmax(logits, dim=-1) del logits - idx_next = torch.multinomial(scores, num_samples=1) # .to(finish.device) + + idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) del scores @@ -3079,7 +2351,7 @@ class ConditionalChatTTS(PreTrainedModel): finish.logical_or_(finish_or) del finish_or - # Store new `token` into `input_ids_buf` + # 新的 `token` 存入 `input_ids_buf` input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1)) if i == 0 and finish.any(): @@ -3101,7 +2373,9 @@ class ConditionalChatTTS(PreTrainedModel): if not finish.all(): if show_tqdm: - logger.info(f"incomplete result. hit max_new_token: {max_new_token}") + print( + f"incomplete result. hit max_new_token: {max_new_token}" + ) del input_ids_buf @@ -3121,20 +2395,10 @@ class ConditionalChatTTS(PreTrainedModel): @torch.inference_mode() def decode_to_mel_specs( - self, - result_list: List[torch.Tensor], + self, + result_list: List[torch.Tensor], + use_decoder: bool = False, ): - """Decode discrete audio codes to mel spectrograms. - - Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py` - - Args: - result_list (List[torch.Tensor]): Audio codes output from `generate`. - - Returns: - torch.Tensor: Mel spectrograms. - """ - decoder = self.dvae max_x_len = -1 if len(result_list) == 0: @@ -3157,7 +2421,6 @@ class ConditionalChatTTS(PreTrainedModel): return mel_specs -# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py` def gen_logits( num_code: int, top_P=0.7, @@ -3172,88 +2435,10 @@ def gen_logits( logits_processors = [] if repetition_penalty is not None and repetition_penalty != 1: - logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16)) - - return logits_warpers, logits_processors - - -# Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation -def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - **kwargs, -): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for positionidspositionidsposition_ids. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - dtype = self.lm_head.weight.dtype - min_dtype = torch.finfo(dtype).min - - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_length(), - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=batch_size, + logits_processors.append( + CustomRepetitionPenaltyLogitsProcessorRepeat( + repetition_penalty, num_code, 16 + ) ) - model_inputs.update( - { - "position_ids": position_ids, - # "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) - return model_inputs + return logits_warpers, logits_processors