from einops import rearrange from einops.layers.torch import Rearrange from torch import nn from transformers import PreTrainedModel import math import torch from .configuration_moonshine import MoonshineConfig class RotaryEmbedding(nn.Module): def __init__(self, dim, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, t): freqs = torch.einsum("i , j -> i j", t.type_as(self.inv_freq), self.inv_freq) freqs = torch.stack((freqs, freqs), dim=-1) return rearrange(freqs, "... d r -> ... (d r)") def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") def apply_rotary_pos_emb(t, freqs): rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype freqs = freqs[-seq_len:, :] # partial rotary embeddings, Wang et al. GPT-J t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] t = t * freqs.cos() + rotate_half(t) * freqs.sin() out = torch.cat((t, t_unrotated), dim=-1) return out.type(orig_dtype) class MultiHeadAttention(nn.Module): def __init__(self, dim, inner_dim, n_head): super().__init__() self.n_head = n_head self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_k = nn.Linear(dim, inner_dim, bias=False) self.to_v = nn.Linear(dim, inner_dim, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) self.softmax = nn.Softmax(dim=-1) # Scaled dot product attention def sdp_attention(self, q, k_t, v, mask=None): d_tensor = v.shape[3] op = (q @ k_t) / math.sqrt(d_tensor) if mask is not None: op = op.masked_fill(mask, -torch.finfo(op.dtype).max) score = self.softmax(op) out = score @ v # concat and pass to linear layer out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) def forward(self, q, k, v, rot_pos_emb=None, mask=None): # dot product with weight matrices q, k, v = self.to_q(q), self.to_k(k), self.to_v(v) q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head) k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head) v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head) # apply RoPE if rot_pos_emb is not None: q = apply_rotary_pos_emb(q, rot_pos_emb) k = apply_rotary_pos_emb(k, rot_pos_emb) k_t = k.transpose(2, 3) return self.sdp_attention(q, k_t, v, mask), k_t, v class MultiHeadCausalSelfAttentionWithKVCache(MultiHeadAttention): def __init__(self, dim, inner_dim, n_head): super().__init__(dim, inner_dim, n_head) def forward(self, q, k, v, k_cache, v_cache, rot_pos_emb, mask): # dot product with weight matrices q, k, v = self.to_q(q), self.to_k(k), self.to_v(v) q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head) k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head) v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head) # apply RoPE q = apply_rotary_pos_emb(q, rot_pos_emb) k = apply_rotary_pos_emb(k, rot_pos_emb) k_t = k.transpose(2, 3) # Append new rows to K and V caches. k_t = torch.concat((k_cache, k_t), dim=3) v = torch.concat((v_cache, v), dim=2) return super().sdp_attention(q, k_t, v, mask=mask), k_t, v class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention): def __init__(self, dim, inner_dim, n_head): super().__init__(dim, inner_dim, n_head) def forward(self, q, k_cache, v_cache, mask): q = self.to_q(q) q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head) return super().sdp_attention(q, k_cache, v_cache, mask=mask) class FFLinearGelu(nn.Module): def __init__(self, dim, ff_mult=4): super().__init__() self.ff = nn.Sequential( nn.Linear(dim, dim * ff_mult, bias=True), nn.GELU(), nn.Linear(dim * ff_mult, dim, bias=True), ) def forward(self, x): return self.ff(x) class FFSwiGLU(nn.Module): def __init__(self, dim, ff_mult=4): super().__init__() self.ff_proj = nn.Linear(dim, dim * ff_mult, bias=True) self.ff_noact = nn.Linear(dim, dim * ff_mult, bias=True) self.ff_act = nn.SiLU() self.ff_out = nn.Linear(dim * ff_mult, dim, bias=True) def forward(self, x): gate = self.ff_act(self.ff_proj(x)) x_noact = self.ff_noact(x) x = x_noact * gate return self.ff_out(x) class EncoderLayer(nn.Module): def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4): super().__init__() self.norm1 = nn.LayerNorm(dim, bias=False) self.attention = MultiHeadAttention(dim, inner_dim=inner_dim, n_head=n_head) self.norm2 = nn.LayerNorm(dim, bias=False) self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult) def forward(self, x, rot_pos_emb, mask): _x = x x = self.norm1(x) x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=mask) x = x + _x _x = x x = self.norm2(x) x = self.ff(x) x = x + _x return x class Encoder(nn.Module): def __init__(self, dim, inner_dim, n_head, n_layers, ff_swiglu): super().__init__() rot_embed_dim = max(inner_dim / n_head / 2, 32) self.rot_pos_emb = RotaryEmbedding(rot_embed_dim) self.layers = nn.ModuleList( [EncoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)] ) self.post_norm = nn.LayerNorm(dim, bias=False) def forward(self, x, mask): pos = torch.arange(x.shape[-2], device=x.device) rot_pos_emb = self.rot_pos_emb(pos) for idx, layer in enumerate(self.layers): x = layer(x, rot_pos_emb=rot_pos_emb, mask=mask) return self.post_norm(x) class DecoderLayer(nn.Module): def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4): super().__init__() self.norm1 = nn.LayerNorm(dim, bias=False) self.self_attention = MultiHeadCausalSelfAttentionWithKVCache( dim, inner_dim=inner_dim, n_head=n_head ) self.norm2 = nn.LayerNorm(dim, bias=False) self.cross_attention = MultiHeadCrossAttentionWithKVCache( dim, inner_dim=inner_dim, n_head=n_head ) self.norm3 = nn.LayerNorm(dim, bias=False) self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult) def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb, input_mask): dim = x.size()[1] causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device) _x = x x = self.norm1(x) x, new_k_cache, new_v_cache = self.self_attention( q=x, k=x, v=x, k_cache=k_cache, v_cache=v_cache, rot_pos_emb=rot_pos_emb, mask=causal_mask, ) x = x + _x _x = x x = self.norm2(x) x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache, mask=input_mask) x = x + _x _x = x x = self.norm3(x) x = self.ff(x) x = x + _x return x, new_k_cache, new_v_cache class Decoder(nn.Module): def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu): super().__init__() self.n_head = n_head self.d_head = inner_dim // n_head rot_embed_dim = max(inner_dim / n_head / 2, 32) self.rot_pos_emb = RotaryEmbedding(rot_embed_dim) self.layers = nn.ModuleList( [DecoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)] ) self.final_norm = nn.LayerNorm(dim, bias=False) self.token_embedding = nn.Embedding(dec_voc_size, dim) def forward(self, x, input_mask, *args): pos = torch.arange(x.shape[1], device=x.device) rot_pos_emb = self.rot_pos_emb(pos) x = self.token_embedding(x) k_cache_new = [] v_cache_new = [] n_layer = len(self.layers) k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [ args[i : i + n_layer] for i in range(0, 4 * n_layer, n_layer) ] for idx, layer in enumerate(self.layers): x, new_k_line, new_v_line = layer( x[:, -1:], k_cache=k_cache[idx], v_cache=v_cache[idx], x_attn_k_cache=x_attn_k_cache[idx], x_attn_v_cache=x_attn_v_cache[idx], rot_pos_emb=rot_pos_emb, input_mask=input_mask, ) k_cache_new.append(new_k_line) v_cache_new.append(new_v_line) x = self.final_norm(x) return x @ self.token_embedding.weight.t(), *k_cache_new, *v_cache_new class InitialDecoderLayer(nn.Module): def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4): super().__init__() self.norm1 = nn.LayerNorm(dim, bias=False) self.self_attention = MultiHeadAttention( dim, inner_dim=inner_dim, n_head=n_head ) self.norm2 = nn.LayerNorm(dim, bias=False) self.cross_attention = MultiHeadAttention( dim, inner_dim=inner_dim, n_head=n_head ) self.norm3 = nn.LayerNorm(dim, bias=False) self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult) def forward(self, x, context, rot_pos_emb, input_mask): dim = x.size()[1] causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device) _x = x x = self.norm1(x) x, new_k_cache, new_v_cache = self.self_attention( q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=causal_mask, ) x = x + _x _x = x x = self.norm2(x) x, x_attn_k_cache, x_attn_v_cache = self.cross_attention( q=x, k=context, v=context, mask=input_mask, ) x = x + _x _x = x x = self.norm3(x) x = self.ff(x) x = x + _x return x, new_k_cache, new_v_cache, x_attn_k_cache, x_attn_v_cache class DecoderInitial(Decoder): def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu): super().__init__(dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu) self.layers = nn.ModuleList( [ InitialDecoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers) ] ) def forward(self, x, enc_src, input_mask): pos = torch.arange(x.shape[1], device=x.device) rot_pos_emb = self.rot_pos_emb(pos) x = self.token_embedding(x) # Shape [n_layers, batch_size, n_head, seq_len, inner_dim]. Cache K transposed. n_layer = len(self.layers) k_cache = [] v_cache = [] x_attn_k_cache = [] x_attn_v_cache = [] for idx, layer in enumerate(self.layers): x, new_k_line, new_v_line, new_x_attn_k_line, new_x_attn_v_line = layer( x, enc_src, rot_pos_emb, input_mask, ) k_cache.append(new_k_line) v_cache.append(new_v_line) x_attn_k_cache.append(new_x_attn_k_line) x_attn_v_cache.append(new_x_attn_v_line) x = self.final_norm(x) return ( x @ self.token_embedding.weight.t(), *k_cache, *v_cache, *x_attn_k_cache, *x_attn_v_cache, ) class AudioPreprocessor(nn.Module): def __init__(self, dim): super().__init__() self.audio_preprocess = nn.Sequential( nn.Conv1d(1, dim, 127, 64, bias=False), nn.Tanh(), nn.GroupNorm(1, dim), nn.Conv1d(dim, 2 * dim, 7, 3), nn.GELU(), nn.Conv1d(2 * dim, dim, 3, 2), nn.GELU(), Rearrange("... c s -> ... s c"), ) def forward(self, src): assert ( src.shape[-1] >= 1023 ), f"src shape[-1] {src.shape[-1]} should be at least 1023" src = src.reshape((-1, 1, src.shape[-1])) return self.audio_preprocess(src) class MoonshineModelTorch(nn.Module): def __init__( self, dim, inner_dim, enc_depth, dec_depth, n_head=8, dec_voc_size=32768, enc_ff_swiglu=False, dec_ff_swiglu=False, ): super().__init__() self.preprocessor = AudioPreprocessor(dim) self.encoder = Encoder( dim, inner_dim, n_head, enc_depth, ff_swiglu=enc_ff_swiglu ) self.decoder_initial = DecoderInitial( dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu ) self.decoder = Decoder( dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu ) self.dec_depth = dec_depth self.n_head = n_head self.d_head = inner_dim // n_head def generate(self, src, mask): preprocessed = self.preprocessor(src) batch_size = preprocessed.shape[0] # Get max sequence length based on number of unmasked inputs for each sample in batch. token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second. if mask is not None: seq_lens = torch.sum(mask, dim=-1, keepdim=True) * token_limit_factor else: token_limit = torch.tensor([src.shape[-1] * token_limit_factor]) seq_lens = torch.stack([token_limit for _ in range(batch_size)]) seq_lens = seq_lens.to(torch.int32).to(src.device).squeeze() # Preprocess mask so that it matches preprocessed audio. if mask is not None: mask = mask[..., :-127:64][..., :-7:3][..., :-3:2].to(torch.bool) mask = ~mask.reshape((batch_size, 1, 1, -1)) mask = torch.nn.functional.pad(mask, (0, preprocessed.shape[-2] - mask.shape[-1])) enc = self.encoder(preprocessed, mask) sot_token = 1 eot_token = 2 sot_array = [[sot_token] for _ in range(batch_size)] seq = torch.as_tensor(sot_array).to(src.device) vals = self.decoder_initial(x=seq, enc_src=enc, input_mask=mask) logits = vals[0] k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [ vals[i : i + self.dec_depth] for i in range(1, 1 + self.dec_depth * 4, self.dec_depth) ] sample = logits[:, -1].argmax(dim=-1, keepdim=True) seq = torch.cat((seq, sample), dim=-1) eot_mask = torch.zeros((batch_size), dtype=torch.bool).to(src.device) while not torch.all(eot_mask): vals = self.decoder( seq, mask, *k_cache, *v_cache, *x_attn_k_cache, *x_attn_v_cache, ) logits = vals[0] k_cache = vals[1 : self.dec_depth + 1] v_cache = vals[self.dec_depth + 1 :] logits = logits[:, -1] # get last token sample = logits.argmax(dim=-1, keepdim=True) # For each sample in batch detect EOT or token limit reached. eot_mask = eot_mask | (sample.squeeze() == eot_token) eot_mask = eot_mask | (seq.shape[-1] >= seq_lens) sample = sample.masked_fill(eot_mask.reshape((-1, 1)), eot_token) seq = torch.cat((seq, sample), dim=-1) return seq class MoonshineModel(PreTrainedModel): config_class = MoonshineConfig def __init__(self, config): super().__init__(config) self.model = MoonshineModelTorch( dim = config.dim, inner_dim = config.inner_dim, enc_depth = config.enc_depth, dec_depth = config.dec_depth, n_head = config.n_head, dec_voc_size = config.dec_voc_size, enc_ff_swiglu = config.enc_ff_swiglu, dec_ff_swiglu = config.dec_ff_swiglu, ) def forward(self, tensor, input_mask=None): return self.model.generate(tensor, input_mask)