fish-audio-t / fish_speech /conversation.py
kiylu's picture
add project files
b128c76
from dataclasses import dataclass, field
from typing import Literal
import torch
from .tokenizer import MODALITY_TOKENS, FishTokenizer
CODEBOOK_PAD_TOKEN_ID = 0
@dataclass(kw_only=True)
class BasePart:
pass
@dataclass(kw_only=True)
class VQPart(BasePart):
codes: torch.Tensor
@dataclass(kw_only=True)
class TextPart(BasePart):
text: str
@dataclass(kw_only=True)
class EncodedMessage:
tokens: torch.Tensor
labels: torch.Tensor
vq_mask_tokens: torch.Tensor | None = None
vq_mask_labels: torch.Tensor | None = None
vq_parts: list[torch.Tensor]
vq_require_losses: torch.Tensor | None = None
@dataclass(kw_only=True)
class Message:
role: Literal["system", "user", "assistant"]
parts: list[VQPart | TextPart] = field(default_factory=list)
add_im_start: bool = True
add_im_end: bool = True
cal_loss: bool = False
modality: Literal["text", "voice", "interleave"] | None = None
# By default, ignore the loss of the auto-generated im_start token
ignore_im_start_loss: bool = True
def encode(
self: "Message",
tokenizer: FishTokenizer,
) -> EncodedMessage:
all_tokens = []
all_labels = []
# Multi-modal tokens
vq_parts = []
vq_masks = []
parts = self.parts.copy()
if self.add_im_start:
modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
if self.add_im_end:
parts.append(TextPart(text="<|im_end|>"))
for part in parts:
if isinstance(part, TextPart):
tokens = torch.tensor(
tokenizer.encode(part.text),
dtype=torch.int,
)
elif isinstance(part, VQPart):
curr_codes = part.codes.clone()
tokens = torch.tensor(
[
tokenizer.semantic_id_to_token_id[i.item()]
for i in curr_codes[0].int()
],
dtype=torch.int,
)
vq_parts.append(curr_codes)
else:
raise ValueError(f"Unsupported part type: {type(part)}")
all_tokens.append(tokens)
if isinstance(part, VQPart):
vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
else:
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
if self.cal_loss:
all_labels.append(tokens.clone())
else:
all_labels.append(torch.full_like(tokens, -100))
tokens = torch.cat(all_tokens, dim=0)
labels = torch.cat(all_labels, dim=0)
vq_masks = torch.cat(vq_masks, dim=0)
assert tokens.shape == labels.shape == vq_masks.shape
if self.ignore_im_start_loss and self.add_im_start:
labels[: len(all_tokens[0])] = -100
return EncodedMessage(
tokens=tokens,
labels=labels,
vq_parts=vq_parts,
vq_mask_tokens=vq_masks,
vq_mask_labels=vq_masks,
)
@dataclass
class Conversation:
messages: list[Message]
def __init__(self: "Conversation", messages: list[Message] | None = None):
self.messages = messages or []
def encode(
self: "Conversation",
tokenizer: FishTokenizer,
add_shift: bool = True,
ignore_loss_tokens: list[str] = [],
) -> EncodedMessage:
# Build the input_ids and labels
tokens = []
labels = []
vq_parts = []
vq_mask_tokens = []
vq_mask_labels = []
vq_require_losses = []
ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
for message in self.messages:
encoded = message.encode(
tokenizer,
)
tokens.append(encoded.tokens)
labels.append(encoded.labels)
vq_parts.extend(encoded.vq_parts)
vq_mask_tokens.append(encoded.vq_mask_tokens)
vq_mask_labels.append(encoded.vq_mask_labels)
vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
tokens = torch.cat(tokens, dim=0)
labels = torch.cat(labels, dim=0)
vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
if add_shift:
tokens = tokens[:-1]
labels = labels[1:]
vq_mask_tokens = vq_mask_tokens[:-1]
vq_mask_labels = vq_mask_labels[1:]
for i in ignore_loss_token_ids:
assert i != -100 and i is not None
labels[labels == i] = -100
assert tokens.dtype in [
torch.int,
torch.long,
], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
return EncodedMessage(
tokens=tokens,
labels=labels,
vq_parts=vq_parts,
vq_mask_tokens=vq_mask_tokens,
vq_mask_labels=vq_mask_labels,
vq_require_losses=vq_require_losses,
)
def encode_for_inference(
self: "Conversation",
tokenizer: FishTokenizer,
num_codebooks: int,
) -> EncodedMessage:
# self.visualize(tokenizer)
encoded = self.encode(tokenizer, add_shift=False)
tokens = encoded.tokens
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
values[0] = tokens
if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
return values
vq_parts = encoded.vq_parts
vq_parts = [part.to(values.device) for part in vq_parts]
vq_parts = torch.cat(vq_parts, dim=1)
values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
values[1:, encoded.vq_mask_tokens] = vq_parts
return values
def visualize(
self: "Conversation",
tokenizer: FishTokenizer,
ignore_loss_tokens: list[str] = [],
):
encoded = self.encode(
tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
)
# Colors for alternating tokens
colors = {
"blue": "\033[94m", # Light blue
"cyan": "\033[96m", # Cyan
"green": "\033[92m", # Light green
"dark_green": "\033[32m", # Dark green
}
blue_idx = 0
green_idx = 0
def print_in_blue(x):
nonlocal blue_idx
color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
print(f"{color}{x}\033[0m", end="")
blue_idx += 1
def print_in_green(x):
nonlocal green_idx
color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
print(f"{color}{x}\033[0m", end="")
green_idx += 1
for tok, lab in zip(encoded.tokens, encoded.labels):
val = tokenizer.decode([tok])
if lab == -100:
print_in_green(val)
else:
print_in_blue(val)
print()
def append(self: "Conversation", message: Message):
self.messages.append(message)
if __name__ == "__main__":
message0 = Message(
role="user",
parts=[
TextPart(text="Hello, how are you?"),
VQPart(codes=torch.zeros((4, 10))),
],
cal_loss=False,
)
message1 = Message(
role="assistant",
parts=[TextPart(text="I'm fine, thank you.")],
cal_loss=True,
)
conversation = Conversation([message0, message1])
tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
conversation.visualize(tokenizer)
encoded = conversation.encode(tokenizer)
print(encoded)
print(tokenizer.batch_decode(encoded.tokens))