Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import os | |
import torch.nn as nn | |
from typing import List, Optional, Tuple, Union | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
NUM_QUANTIZERS = 8 # number of quantizers in total, currently assumes first layer AR. | |
START_QUANTIZATION_LAYER = 1 # start quantization layer | |
END_QUANTIZATION_LAYER = 7 # end quantization layer | |
class LlamaAdaptiveRMSNorm(nn.Module): | |
def __init__(self, hidden_size=1024, eps=1e-9, dim_cond=1024): | |
super().__init__() | |
self.to_weight = nn.Linear(dim_cond, hidden_size) | |
nn.init.normal_(self.to_weight.weight, mean=0.0, std=0.02) | |
# nn.init.zeros_(self.to_weight.weight) | |
# nn.init.ones_(self.to_weight.bias) | |
self.variance_epsilon = eps | |
self._is_hf_initialized = True # disable automatic init | |
def forward(self, hidden_states, cond_embedding): | |
input_dtype = hidden_states.dtype | |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
weight = self.to_weight(cond_embedding) | |
return (weight * hidden_states).to(input_dtype) | |
class LlamaNARDecoderLayer(LlamaDecoderLayer): | |
def __init__(self, config: LlamaConfig): | |
"""Override to adaptive layer norm""" | |
super().__init__(config=config, layer_idx=0) # init attention, mlp, etc. | |
self.input_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
self.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
# add `cond` in forward function | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
cond_embedding: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: Optional[bool] = False, | |
use_cache: Optional[bool] = False, | |
) -> Tuple[ | |
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] | |
]: | |
""" | |
Args: | |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
(see `past_key_values`). | |
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
""" | |
residual = hidden_states | |
hidden_states = self.input_layernorm( | |
hidden_states, cond_embedding=cond_embedding | |
) | |
# Self Attention | |
hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = residual + hidden_states | |
# Fully Connected | |
residual = hidden_states | |
hidden_states = self.post_attention_layernorm( | |
hidden_states, cond_embedding=cond_embedding | |
) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (self_attn_weights,) | |
if use_cache: | |
outputs += (present_key_value,) | |
return outputs | |
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast | |
class MultiEmbedding(nn.Module): | |
"""Embedding for multiple quantization layers, summing up the embeddings of each layer.""" | |
def __init__( | |
self, | |
num_embeddings=1034, | |
embedding_dim=1024, | |
num_quantization_layers=NUM_QUANTIZERS, | |
): | |
super().__init__() | |
self.embeddings = nn.ModuleList( | |
[ | |
nn.Embedding(num_embeddings, embedding_dim) | |
for _ in range(num_quantization_layers) | |
] | |
) | |
# initialize embeddings | |
for i in range(num_quantization_layers): | |
self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02) | |
self._is_hf_initialized = True # disable automatic init | |
def forward(self, input_ids): | |
"""Input: [num_quant, B, T] -> Output: [B, T, H]""" | |
num_quant, B, T = input_ids.shape | |
summed_embeddings = torch.zeros( | |
B, T, self.embeddings[0].embedding_dim, device=input_ids.device | |
) | |
for i in range(num_quant): | |
summed_embeddings += self.embeddings[i](input_ids[i]) | |
return summed_embeddings | |
class LlammaNARModel(LlamaModel): | |
def __init__(self, config): | |
"""Adding adaptive layer norm, conditional embeddings, and multi-level input embeddings to the decoder layer""" | |
super().__init__(config) | |
self.layers = nn.ModuleList( | |
[LlamaNARDecoderLayer(config) for _ in range(config.num_hidden_layers)] | |
) | |
self.norm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
self.embed_cond = nn.Embedding( | |
NUM_QUANTIZERS, config.hidden_size | |
) # 7 quantization layers | |
for layer in self.layers: | |
layer.input_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
) | |
self.post_init() | |
def _prepare_decoder_attention_mask( | |
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
): | |
# create noncausal mask | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
combined_attention_mask = None | |
def _expand_mask( | |
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None | |
): | |
""" | |
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
""" | |
bsz, src_len = mask.size() | |
tgt_len = tgt_len if tgt_len is not None else src_len | |
expanded_mask = ( | |
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
) | |
inverted_mask = 1.0 - expanded_mask | |
return inverted_mask.masked_fill( | |
inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
) | |
if attention_mask is not None: | |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
expanded_attn_mask = _expand_mask( | |
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | |
).to(inputs_embeds.device) | |
combined_attention_mask = ( | |
expanded_attn_mask | |
if combined_attention_mask is None | |
else expanded_attn_mask + combined_attention_mask | |
) | |
return combined_attention_mask | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, # [num_quant, B, T] | |
cond: torch.LongTensor = None, # index for conditional embeddings, [B] | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
# retrieve some shape info | |
batch_size, seq_length, _ = input_ids.shape | |
inputs_embeds = input_ids # [B, T, H] | |
# embed cond | |
cond_embedding = self.embed_cond(cond) # [B, H] | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
seq_length_with_past = seq_length | |
past_key_values_length = 0 | |
if past_key_values is not None: | |
past_key_values_length = past_key_values[0][0].shape[2] | |
seq_length_with_past = seq_length_with_past + past_key_values_length | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange( | |
past_key_values_length, | |
seq_length + past_key_values_length, | |
dtype=torch.long, | |
device=device, | |
) | |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
else: | |
position_ids = position_ids.view(-1, seq_length).long() | |
# embed positions | |
if attention_mask is None: | |
attention_mask = torch.ones( | |
(batch_size, seq_length_with_past), | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
attention_mask = self._prepare_decoder_attention_mask( | |
attention_mask, | |
(batch_size, seq_length), | |
inputs_embeds, | |
past_key_values_length, | |
) | |
hidden_states = inputs_embeds | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
use_cache = False | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = () if use_cache else None | |
for idx, decoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
past_key_value = ( | |
past_key_values[idx] if past_key_values is not None else None | |
) | |
if self.gradient_checkpointing and self.training: | |
raise NotImplementedError | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
# None for past_key_value | |
return module(*inputs, output_attentions, None) | |
return custom_forward | |
layer_outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(decoder_layer), | |
hidden_states, | |
attention_mask, | |
position_ids, | |
None, | |
) | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cond_embedding=cond_embedding, # using cond embed | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
hidden_states = self.norm(hidden_states, cond_embedding=cond_embedding) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
if not return_dict: | |
return tuple( | |
v | |
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] | |
if v is not None | |
) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=next_cache, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attns, | |
) | |
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel | |
from transformers.models.llama.modeling_llama import CrossEntropyLoss | |
from easydict import EasyDict as edict | |
class LlamaForNARModeling(LlamaPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = LlammaNARModel(config) | |
self.lm_head = nn.ModuleList( | |
[ | |
nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
for i in range(END_QUANTIZATION_LAYER - START_QUANTIZATION_LAYER + 1) | |
] | |
) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def forward( | |
self, | |
cond: torch.LongTensor, # added | |
prediction_target: torch.LongTensor = None, # added. No shifting. -100 means no loss | |
input_ids: torch.LongTensor = None, # expect an embedding, [B, T, H] | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
# labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
"""Prediction target: [B, T]""" | |
output_attentions = ( | |
output_attentions | |
if output_attentions is not None | |
else self.config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states | |
if output_hidden_states is not None | |
else self.config.output_hidden_states | |
) | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
cond=cond, # added | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
logits = self.lm_head[cond - START_QUANTIZATION_LAYER](hidden_states) | |
loss = None | |
loss_fct = CrossEntropyLoss() | |
if prediction_target is not None: | |
# calculate loss if prediction_target is provided | |
logits_tmp = logits.view(-1, logits.size(-1)) | |
prediction_target = prediction_target.view(-1) | |
loss = loss_fct(logits_tmp, prediction_target) | |
return edict( | |
loss=loss, | |
logits=logits, | |
) | |
class ValleNAR(nn.Module): | |
def __init__( | |
self, | |
phone_vocab_size=256, | |
target_vocab_size=1024, | |
hidden_size=1024, | |
intermediate_size=4096, | |
num_hidden_layers=12, | |
num_attention_heads=16, | |
pad_token_id=1024 + 256, | |
bos_target_id=1282, | |
eos_target_id=1283, | |
bos_phone_id=1284, | |
eos_phone_id=1285, | |
bos_prompt_id=1286, | |
eos_prompt_id=1287, | |
use_input_embeds=False, | |
emb_dim=256, | |
): | |
super(ValleNAR, self).__init__() | |
self.config = LlamaConfig( | |
vocab_size=phone_vocab_size + target_vocab_size + 10, | |
hidden_size=hidden_size, | |
intermediate_size=intermediate_size, | |
num_hidden_layers=num_hidden_layers, | |
num_attention_heads=num_attention_heads, | |
pad_token_id=pad_token_id, | |
bos_token_id=bos_target_id, | |
eos_token_id=eos_target_id, | |
use_cache=False, | |
) | |
self.phone_vocab_size = phone_vocab_size | |
self.target_vocab_size = target_vocab_size | |
self.pad_token_id = pad_token_id | |
self.bos_target_id = bos_target_id | |
self.eos_target_id = eos_target_id | |
self.bos_phone_id = bos_phone_id | |
self.eos_phone_id = eos_phone_id | |
self.bos_prompt_id = bos_prompt_id | |
self.eos_prompt_id = eos_prompt_id | |
self.model = LlamaForNARModeling(self.config) | |
self.use_input_embeds = use_input_embeds | |
self.phone_embedder = nn.Embedding( | |
self.phone_vocab_size + 10, hidden_size | |
) # use phone_embedder to embed all eos, bos tokens | |
self.prompt_embedder = MultiEmbedding( | |
num_embeddings=self.target_vocab_size, | |
embedding_dim=hidden_size, | |
num_quantization_layers=NUM_QUANTIZERS, | |
) | |
self.phone_embedder.weight.data.normal_(mean=0.0, std=0.02) | |
# use linear mask schedule when training | |
# another option is uniform | |
self.mask_layer_schedule = "uniform" | |
# no input embedding is used to provide speaker information | |
if self.use_input_embeds: | |
self.emb_linear = nn.Linear(emb_dim, hidden_size) | |
self.emb_linear.weight.data.normal_(mean=0.0, std=0.01) | |
self.emb_linear.bias.data.zero_() | |
def forward( | |
self, | |
phone_ids, | |
phone_mask, | |
target_ids, | |
target_mask, | |
target_quantization_layer=None, | |
prompt_len=None, | |
dropout=0.0, | |
): | |
""" | |
phone_ids: [B, T] | |
phone_mask: [B, T] | |
target_ids: [8,B,T] | |
target_mask: [B, T] | |
dropout: rate of dropping out the target tokens | |
""" | |
assert (target_ids < 1024).all(), "target_ids should be less than 1024" | |
phone_ids = phone_ids + self.target_vocab_size | |
phone_ids = phone_ids * phone_mask + (1 - phone_mask) * self.pad_token_id | |
# assert (phone_ids >= 1024).all(), "phone_ids should be greater than 1024" | |
# phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label( | |
# phone_ids, | |
# phone_mask, | |
# self.eos_phone_id, | |
# self.bos_phone_id, | |
# self.pad_token_id, | |
# ) | |
phone_label = -100 * (1 - phone_mask) | |
# get phone embedding | |
phone_embedding = self.phone_embedder( | |
phone_ids - self.target_vocab_size | |
) # [B, T, H] | |
if prompt_len is not None: | |
assert not self.training # inference stage fix prompt len to input | |
NUM_PROMPT_TOKENS = prompt_len | |
else: | |
assert self.training | |
# randomly select a prompt length | |
assert self.training # randomize prompt len in training | |
NUM_PROMPT_TOKENS = np.random.randint( | |
min(target_ids.shape[-1] // 4, 5), target_ids.shape[-1] // 2 | |
) | |
# extract 8-level prompts | |
prompt_tokens = target_ids[:, :, :NUM_PROMPT_TOKENS] # [Q, B, T] | |
prompt_mask = torch.ones_like(prompt_tokens[0]) | |
prompt_label = -100 * prompt_mask | |
# get prompt embedding | |
prompt_embedding = self.prompt_embedder(prompt_tokens) # [B, T, H] | |
# randomly select a target qnt layer to predict | |
# total quant layer is 0 to 7 | |
if target_quantization_layer is None: | |
if self.mask_layer_schedule == "linear": | |
weights = torch.tensor( | |
[ | |
NUM_QUANTIZERS - i | |
for i in range( | |
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1 | |
) | |
] | |
) | |
weights = weights / weights.sum() | |
mask_layer = ( | |
torch.multinomial(weights, 1, replacement=True) | |
+ START_QUANTIZATION_LAYER | |
) | |
assert ( | |
mask_layer >= START_QUANTIZATION_LAYER | |
and mask_layer <= END_QUANTIZATION_LAYER | |
) | |
target_quantization_layer = mask_layer.item() | |
elif self.mask_layer_schedule == "cosine": | |
weights = torch.tensor( | |
[ | |
np.cos(i / NUM_QUANTIZERS * np.pi / 2) | |
for i in range( | |
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1 | |
) | |
] | |
) | |
weights = weights / weights.sum() | |
mask_layer = ( | |
torch.multinomial(weights, 1, replacement=True) | |
+ START_QUANTIZATION_LAYER | |
) | |
assert ( | |
mask_layer >= START_QUANTIZATION_LAYER | |
and mask_layer <= END_QUANTIZATION_LAYER | |
) | |
target_quantization_layer = mask_layer.item() | |
breakpoint() | |
elif self.mask_layer_schedule == "uniform": | |
target_quantization_layer = np.random.randint( | |
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1 | |
) | |
# print(f'target layer: {target_quantization_layer}') | |
# prompt of the target part | |
target_prompt_ids = target_ids[ | |
:target_quantization_layer, :, NUM_PROMPT_TOKENS: | |
] | |
def randomly_set_elements(tensor, fraction, value): | |
""" | |
Randomly set a fraction of the elements in a tensor to a specific value. | |
Args: | |
tensor (torch.Tensor): The input tensor. | |
fraction (float): The fraction of elements to set to the specified value (between 0 and 1). | |
value (float or int): The value to set the elements to. | |
Returns: | |
torch.Tensor: The tensor with some elements set to the specified value. | |
""" | |
# Create a mask with the same shape as the tensor | |
mask = torch.rand_like(tensor, dtype=torch.float32) < fraction | |
# Clone the tensor to avoid modifying the original tensor | |
result_tensor = tensor.clone() | |
# Set the elements where the mask is True to the specified value | |
result_tensor[mask] = value | |
return result_tensor | |
if dropout != 0.0: | |
target_prompt_ids = randomly_set_elements( | |
target_prompt_ids, dropout, self.target_vocab_size | |
) | |
target_embedding = self.prompt_embedder(target_prompt_ids) | |
# mask of the target part | |
target_mask = target_mask[:, NUM_PROMPT_TOKENS:] | |
target_labels = target_ids[ | |
target_quantization_layer, :, NUM_PROMPT_TOKENS: | |
] * target_mask + (-100 * (1 - target_mask)) | |
# input embeddings | |
input_embeddings = torch.cat( | |
[phone_embedding, prompt_embedding, target_embedding], dim=1 | |
) | |
input_mask = torch.cat([phone_mask, prompt_mask, target_mask], dim=1) # [B, T] | |
prediction_target = torch.cat( | |
[phone_label, prompt_label, target_labels], dim=1 | |
) # [B, T] | |
out = self.model( | |
cond=torch.tensor( | |
target_quantization_layer, | |
device=prediction_target.device, | |
dtype=torch.long, | |
), | |
input_ids=input_embeddings, | |
prediction_target=prediction_target, | |
attention_mask=input_mask, | |
return_dict=True, | |
) | |
logits = out.logits[:, -target_embedding.shape[1] :, :] | |
targets = prediction_target[..., -target_embedding.shape[1] :] | |
top1_acc = logits.argmax(-1) == targets | |
top1_acc = (top1_acc * target_mask).sum() / target_mask.sum() | |
top5_acc = (logits.topk(5, dim=-1).indices == targets.unsqueeze(-1)).any(-1) | |
top5_acc = (top5_acc * target_mask).sum() / target_mask.sum() | |
top10_acc = (logits.topk(10, dim=-1).indices == targets.unsqueeze(-1)).any(-1) | |
top10_acc = (top10_acc * target_mask).sum() / target_mask.sum() | |
out.target_quantization_layer = target_quantization_layer | |
out.top1_acc = top1_acc | |
out.top5_acc = top5_acc | |
out.top10_acc = top10_acc | |
return out | |
def add_phone_eos_bos_label( | |
self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id | |
): | |
# phone_ids: [B, T] | |
# phone_mask: [B, T] | |
phone_ids = phone_ids + self.target_vocab_size * phone_mask | |
phone_ids = phone_ids * phone_mask | |
phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( | |
1 - phone_mask, (0, 1), value=1 | |
) # make pad token eos token, add eos token at the end | |
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask | |
phone_ids = phone_ids * phone_mask + pad_token_id * ( | |
1 - phone_mask | |
) # restore pad token ids | |
phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token | |
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask | |
phone_label = -100 * torch.ones_like( | |
phone_ids | |
) # loss for entire phone is not computed (passed to llama) | |
return phone_ids, phone_mask, phone_label | |
def sample_hf( | |
self, | |
phone_ids, # [B, T] | |
prompt_ids, # [8, B, T] | |
first_stage_ids, # [B, T] | |
top_k=50, | |
top_p=1, | |
temperature=1.1, | |
first_stage_ids_gt=None, # [Q, B, T] | |
first_stage_ids_gt_end_layer=None, # 2 to 8 | |
): | |
""" | |
phone_ids: [B, T] | |
prompt_ids: [8, B, T] | |
first_stage_ids: [B, T] result from first quant layer. Should be continuation of prompt_ids | |
""" | |
phone_mask = torch.ones_like(phone_ids, dtype=torch.long) | |
assert prompt_ids.shape[-1] >= 5, "prompt_ids should have at least 5 tokens" | |
target_ids = torch.cat( | |
[prompt_ids, first_stage_ids.expand(prompt_ids.shape[0], -1, -1)], dim=-1 | |
) | |
target_mask = torch.ones_like(target_ids[0], dtype=torch.long) | |
if first_stage_ids_gt is not None: | |
target_ids[ | |
:first_stage_ids_gt_end_layer, :, -first_stage_ids_gt.shape[-1] : | |
] = first_stage_ids_gt[:first_stage_ids_gt_end_layer] | |
gen_len = first_stage_ids.shape[-1] | |
start_qnt_layer = 1 | |
if first_stage_ids_gt_end_layer is not None: | |
start_qnt_layer = first_stage_ids_gt_end_layer | |
for qnt_level in range(start_qnt_layer, 8): | |
out = self.forward( | |
phone_ids=phone_ids, | |
phone_mask=phone_mask, | |
target_ids=target_ids, | |
target_mask=target_mask, | |
target_quantization_layer=qnt_level, | |
prompt_len=prompt_ids.shape[-1], | |
) | |
logits = out.logits | |
gen_tokens = torch.argmax(logits, dim=-1).reshape(-1)[ | |
-gen_len: | |
] # [T], generated tokens in this level | |
# overwrite the target_ids with the generated tokens | |
target_ids[qnt_level, :, -gen_len:] = gen_tokens | |
return target_ids[:, :, -gen_len:] | |
def test(): | |
model = ValleNAR().cuda() | |
phone_ids = torch.LongTensor([1, 2, 3, 4, 5]).reshape(1, -1).cuda() | |
phone_mask = torch.LongTensor([1, 1, 1, 1, 1]).reshape(1, -1).cuda() | |
target_ids = torch.randint(high=1024, size=(8, 1, 250), dtype=torch.long).cuda() | |
target_mask = torch.ones(1, 250, dtype=torch.long).cuda() | |
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) | |
for i in range(200): | |
optimizer.zero_grad() | |
out = model( | |
phone_ids=phone_ids, | |
phone_mask=phone_mask, | |
target_ids=target_ids, | |
target_mask=target_mask, | |
# target_quantization_layer=1+i%6, | |
) | |
loss = out.loss | |
loss.backward() | |
optimizer.step() | |
print(f"iter={i}, {loss}.") | |
target_ids_short = target_ids[:, :, :240] | |
model.eval() | |
sampled = model.sample_hf( | |
phone_ids, prompt_ids=target_ids_short, first_stage_ids=target_ids[0, :, 240:] | |
) | |
print(target_ids[:, :, -10:]) | |
print(sampled) | |
print((sampled == target_ids[:, :, -10:]).all()) | |
if __name__ == "__main__": | |
test() | |