Spaces:
Running
Running
File size: 29,700 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
import torch
import torch.nn.functional as F
import numpy as np
import os
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
NUM_QUANTIZERS = 8 # number of quantizers in total, currently assumes first layer AR.
START_QUANTIZATION_LAYER = 1 # start quantization layer
END_QUANTIZATION_LAYER = 7 # end quantization layer
class LlamaAdaptiveRMSNorm(nn.Module):
def __init__(self, hidden_size=1024, eps=1e-9, dim_cond=1024):
super().__init__()
self.to_weight = nn.Linear(dim_cond, hidden_size)
nn.init.normal_(self.to_weight.weight, mean=0.0, std=0.02)
# nn.init.zeros_(self.to_weight.weight)
# nn.init.ones_(self.to_weight.bias)
self.variance_epsilon = eps
self._is_hf_initialized = True # disable automatic init
def forward(self, hidden_states, cond_embedding):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
weight = self.to_weight(cond_embedding)
return (weight * hidden_states).to(input_dtype)
class LlamaNARDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
"""Override to adaptive layer norm"""
super().__init__(config=config, layer_idx=0) # init attention, mlp, etc.
self.input_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
# add `cond` in forward function
def forward(
self,
hidden_states: torch.Tensor,
cond_embedding: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states, cond_embedding=cond_embedding
)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states, cond_embedding=cond_embedding
)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
class MultiEmbedding(nn.Module):
"""Embedding for multiple quantization layers, summing up the embeddings of each layer."""
def __init__(
self,
num_embeddings=1034,
embedding_dim=1024,
num_quantization_layers=NUM_QUANTIZERS,
):
super().__init__()
self.embeddings = nn.ModuleList(
[
nn.Embedding(num_embeddings, embedding_dim)
for _ in range(num_quantization_layers)
]
)
# initialize embeddings
for i in range(num_quantization_layers):
self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02)
self._is_hf_initialized = True # disable automatic init
def forward(self, input_ids):
"""Input: [num_quant, B, T] -> Output: [B, T, H]"""
num_quant, B, T = input_ids.shape
summed_embeddings = torch.zeros(
B, T, self.embeddings[0].embedding_dim, device=input_ids.device
)
for i in range(num_quant):
summed_embeddings += self.embeddings[i](input_ids[i])
return summed_embeddings
class LlammaNARModel(LlamaModel):
def __init__(self, config):
"""Adding adaptive layer norm, conditional embeddings, and multi-level input embeddings to the decoder layer"""
super().__init__(config)
self.layers = nn.ModuleList(
[LlamaNARDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.embed_cond = nn.Embedding(
NUM_QUANTIZERS, config.hidden_size
) # 7 quantization layers
for layer in self.layers:
layer.input_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
)
self.post_init()
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None, # [num_quant, B, T]
cond: torch.LongTensor = None, # index for conditional embeddings, [B]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# retrieve some shape info
batch_size, seq_length, _ = input_ids.shape
inputs_embeds = input_ids # [B, T, H]
# embed cond
cond_embedding = self.embed_cond(cond) # [B, H]
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training:
raise NotImplementedError
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cond_embedding=cond_embedding, # using cond embed
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states, cond_embedding=cond_embedding)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
from transformers.models.llama.modeling_llama import CrossEntropyLoss
from easydict import EasyDict as edict
class LlamaForNARModeling(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = LlammaNARModel(config)
self.lm_head = nn.ModuleList(
[
nn.Linear(config.hidden_size, config.vocab_size, bias=False)
for i in range(END_QUANTIZATION_LAYER - START_QUANTIZATION_LAYER + 1)
]
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
cond: torch.LongTensor, # added
prediction_target: torch.LongTensor = None, # added. No shifting. -100 means no loss
input_ids: torch.LongTensor = None, # expect an embedding, [B, T, H]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""Prediction target: [B, T]"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
cond=cond, # added
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head[cond - START_QUANTIZATION_LAYER](hidden_states)
loss = None
loss_fct = CrossEntropyLoss()
if prediction_target is not None:
# calculate loss if prediction_target is provided
logits_tmp = logits.view(-1, logits.size(-1))
prediction_target = prediction_target.view(-1)
loss = loss_fct(logits_tmp, prediction_target)
return edict(
loss=loss,
logits=logits,
)
class ValleNAR(nn.Module):
def __init__(
self,
phone_vocab_size=256,
target_vocab_size=1024,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=12,
num_attention_heads=16,
pad_token_id=1024 + 256,
bos_target_id=1282,
eos_target_id=1283,
bos_phone_id=1284,
eos_phone_id=1285,
bos_prompt_id=1286,
eos_prompt_id=1287,
use_input_embeds=False,
emb_dim=256,
):
super(ValleNAR, self).__init__()
self.config = LlamaConfig(
vocab_size=phone_vocab_size + target_vocab_size + 10,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
pad_token_id=pad_token_id,
bos_token_id=bos_target_id,
eos_token_id=eos_target_id,
use_cache=False,
)
self.phone_vocab_size = phone_vocab_size
self.target_vocab_size = target_vocab_size
self.pad_token_id = pad_token_id
self.bos_target_id = bos_target_id
self.eos_target_id = eos_target_id
self.bos_phone_id = bos_phone_id
self.eos_phone_id = eos_phone_id
self.bos_prompt_id = bos_prompt_id
self.eos_prompt_id = eos_prompt_id
self.model = LlamaForNARModeling(self.config)
self.use_input_embeds = use_input_embeds
self.phone_embedder = nn.Embedding(
self.phone_vocab_size + 10, hidden_size
) # use phone_embedder to embed all eos, bos tokens
self.prompt_embedder = MultiEmbedding(
num_embeddings=self.target_vocab_size,
embedding_dim=hidden_size,
num_quantization_layers=NUM_QUANTIZERS,
)
self.phone_embedder.weight.data.normal_(mean=0.0, std=0.02)
# use linear mask schedule when training
# another option is uniform
self.mask_layer_schedule = "uniform"
# no input embedding is used to provide speaker information
if self.use_input_embeds:
self.emb_linear = nn.Linear(emb_dim, hidden_size)
self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
self.emb_linear.bias.data.zero_()
def forward(
self,
phone_ids,
phone_mask,
target_ids,
target_mask,
target_quantization_layer=None,
prompt_len=None,
dropout=0.0,
):
"""
phone_ids: [B, T]
phone_mask: [B, T]
target_ids: [8,B,T]
target_mask: [B, T]
dropout: rate of dropping out the target tokens
"""
assert (target_ids < 1024).all(), "target_ids should be less than 1024"
phone_ids = phone_ids + self.target_vocab_size
phone_ids = phone_ids * phone_mask + (1 - phone_mask) * self.pad_token_id
# assert (phone_ids >= 1024).all(), "phone_ids should be greater than 1024"
# phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
# phone_ids,
# phone_mask,
# self.eos_phone_id,
# self.bos_phone_id,
# self.pad_token_id,
# )
phone_label = -100 * (1 - phone_mask)
# get phone embedding
phone_embedding = self.phone_embedder(
phone_ids - self.target_vocab_size
) # [B, T, H]
if prompt_len is not None:
assert not self.training # inference stage fix prompt len to input
NUM_PROMPT_TOKENS = prompt_len
else:
assert self.training
# randomly select a prompt length
assert self.training # randomize prompt len in training
NUM_PROMPT_TOKENS = np.random.randint(
min(target_ids.shape[-1] // 4, 5), target_ids.shape[-1] // 2
)
# extract 8-level prompts
prompt_tokens = target_ids[:, :, :NUM_PROMPT_TOKENS] # [Q, B, T]
prompt_mask = torch.ones_like(prompt_tokens[0])
prompt_label = -100 * prompt_mask
# get prompt embedding
prompt_embedding = self.prompt_embedder(prompt_tokens) # [B, T, H]
# randomly select a target qnt layer to predict
# total quant layer is 0 to 7
if target_quantization_layer is None:
if self.mask_layer_schedule == "linear":
weights = torch.tensor(
[
NUM_QUANTIZERS - i
for i in range(
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
)
]
)
weights = weights / weights.sum()
mask_layer = (
torch.multinomial(weights, 1, replacement=True)
+ START_QUANTIZATION_LAYER
)
assert (
mask_layer >= START_QUANTIZATION_LAYER
and mask_layer <= END_QUANTIZATION_LAYER
)
target_quantization_layer = mask_layer.item()
elif self.mask_layer_schedule == "cosine":
weights = torch.tensor(
[
np.cos(i / NUM_QUANTIZERS * np.pi / 2)
for i in range(
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
)
]
)
weights = weights / weights.sum()
mask_layer = (
torch.multinomial(weights, 1, replacement=True)
+ START_QUANTIZATION_LAYER
)
assert (
mask_layer >= START_QUANTIZATION_LAYER
and mask_layer <= END_QUANTIZATION_LAYER
)
target_quantization_layer = mask_layer.item()
breakpoint()
elif self.mask_layer_schedule == "uniform":
target_quantization_layer = np.random.randint(
START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
)
# print(f'target layer: {target_quantization_layer}')
# prompt of the target part
target_prompt_ids = target_ids[
:target_quantization_layer, :, NUM_PROMPT_TOKENS:
]
def randomly_set_elements(tensor, fraction, value):
"""
Randomly set a fraction of the elements in a tensor to a specific value.
Args:
tensor (torch.Tensor): The input tensor.
fraction (float): The fraction of elements to set to the specified value (between 0 and 1).
value (float or int): The value to set the elements to.
Returns:
torch.Tensor: The tensor with some elements set to the specified value.
"""
# Create a mask with the same shape as the tensor
mask = torch.rand_like(tensor, dtype=torch.float32) < fraction
# Clone the tensor to avoid modifying the original tensor
result_tensor = tensor.clone()
# Set the elements where the mask is True to the specified value
result_tensor[mask] = value
return result_tensor
if dropout != 0.0:
target_prompt_ids = randomly_set_elements(
target_prompt_ids, dropout, self.target_vocab_size
)
target_embedding = self.prompt_embedder(target_prompt_ids)
# mask of the target part
target_mask = target_mask[:, NUM_PROMPT_TOKENS:]
target_labels = target_ids[
target_quantization_layer, :, NUM_PROMPT_TOKENS:
] * target_mask + (-100 * (1 - target_mask))
# input embeddings
input_embeddings = torch.cat(
[phone_embedding, prompt_embedding, target_embedding], dim=1
)
input_mask = torch.cat([phone_mask, prompt_mask, target_mask], dim=1) # [B, T]
prediction_target = torch.cat(
[phone_label, prompt_label, target_labels], dim=1
) # [B, T]
out = self.model(
cond=torch.tensor(
target_quantization_layer,
device=prediction_target.device,
dtype=torch.long,
),
input_ids=input_embeddings,
prediction_target=prediction_target,
attention_mask=input_mask,
return_dict=True,
)
logits = out.logits[:, -target_embedding.shape[1] :, :]
targets = prediction_target[..., -target_embedding.shape[1] :]
top1_acc = logits.argmax(-1) == targets
top1_acc = (top1_acc * target_mask).sum() / target_mask.sum()
top5_acc = (logits.topk(5, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
top5_acc = (top5_acc * target_mask).sum() / target_mask.sum()
top10_acc = (logits.topk(10, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
top10_acc = (top10_acc * target_mask).sum() / target_mask.sum()
out.target_quantization_layer = target_quantization_layer
out.top1_acc = top1_acc
out.top5_acc = top5_acc
out.top10_acc = top10_acc
return out
def add_phone_eos_bos_label(
self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
):
# phone_ids: [B, T]
# phone_mask: [B, T]
phone_ids = phone_ids + self.target_vocab_size * phone_mask
phone_ids = phone_ids * phone_mask
phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
1 - phone_mask, (0, 1), value=1
) # make pad token eos token, add eos token at the end
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
phone_ids = phone_ids * phone_mask + pad_token_id * (
1 - phone_mask
) # restore pad token ids
phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
phone_label = -100 * torch.ones_like(
phone_ids
) # loss for entire phone is not computed (passed to llama)
return phone_ids, phone_mask, phone_label
@torch.no_grad()
def sample_hf(
self,
phone_ids, # [B, T]
prompt_ids, # [8, B, T]
first_stage_ids, # [B, T]
top_k=50,
top_p=1,
temperature=1.1,
first_stage_ids_gt=None, # [Q, B, T]
first_stage_ids_gt_end_layer=None, # 2 to 8
):
"""
phone_ids: [B, T]
prompt_ids: [8, B, T]
first_stage_ids: [B, T] result from first quant layer. Should be continuation of prompt_ids
"""
phone_mask = torch.ones_like(phone_ids, dtype=torch.long)
assert prompt_ids.shape[-1] >= 5, "prompt_ids should have at least 5 tokens"
target_ids = torch.cat(
[prompt_ids, first_stage_ids.expand(prompt_ids.shape[0], -1, -1)], dim=-1
)
target_mask = torch.ones_like(target_ids[0], dtype=torch.long)
if first_stage_ids_gt is not None:
target_ids[
:first_stage_ids_gt_end_layer, :, -first_stage_ids_gt.shape[-1] :
] = first_stage_ids_gt[:first_stage_ids_gt_end_layer]
gen_len = first_stage_ids.shape[-1]
start_qnt_layer = 1
if first_stage_ids_gt_end_layer is not None:
start_qnt_layer = first_stage_ids_gt_end_layer
for qnt_level in range(start_qnt_layer, 8):
out = self.forward(
phone_ids=phone_ids,
phone_mask=phone_mask,
target_ids=target_ids,
target_mask=target_mask,
target_quantization_layer=qnt_level,
prompt_len=prompt_ids.shape[-1],
)
logits = out.logits
gen_tokens = torch.argmax(logits, dim=-1).reshape(-1)[
-gen_len:
] # [T], generated tokens in this level
# overwrite the target_ids with the generated tokens
target_ids[qnt_level, :, -gen_len:] = gen_tokens
return target_ids[:, :, -gen_len:]
def test():
model = ValleNAR().cuda()
phone_ids = torch.LongTensor([1, 2, 3, 4, 5]).reshape(1, -1).cuda()
phone_mask = torch.LongTensor([1, 1, 1, 1, 1]).reshape(1, -1).cuda()
target_ids = torch.randint(high=1024, size=(8, 1, 250), dtype=torch.long).cuda()
target_mask = torch.ones(1, 250, dtype=torch.long).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for i in range(200):
optimizer.zero_grad()
out = model(
phone_ids=phone_ids,
phone_mask=phone_mask,
target_ids=target_ids,
target_mask=target_mask,
# target_quantization_layer=1+i%6,
)
loss = out.loss
loss.backward()
optimizer.step()
print(f"iter={i}, {loss}.")
target_ids_short = target_ids[:, :, :240]
model.eval()
sampled = model.sample_hf(
phone_ids, prompt_ids=target_ids_short, first_stage_ids=target_ids[0, :, 240:]
)
print(target_ids[:, :, -10:])
print(sampled)
print((sampled == target_ids[:, :, -10:]).all())
if __name__ == "__main__":
test()
|