import torch from torch import nn import torch.nn.functional as F from dataclasses import dataclass import copy from transformers.modeling_outputs import BaseModelOutput, ModelOutput, MaskedLMOutput, TokenClassifierOutput, SequenceClassifierOutput from transformers.modeling_utils import PreTrainedModel from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, AutoModelForMaskedLM, AutoTokenizer, AutoModelForSequenceClassification from .configuration_hlm import HLMConfig, HLMEncoderConfig from .tokenization_hlm import HLMTokenizer from typing import Tuple, Optional, Union @dataclass class HLMBaseModelOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None hidden_states: Tuple[torch.FloatTensor] = None attentions: Tuple[torch.FloatTensor] = None # Not currently supported initial_embeds: torch.FloatTensor = None initial_word_embeds: torch.FloatTensor = None intra_word_mask: torch.LongTensor = None char_embeds: torch.LongTensor = None input_shape: Tuple[int, int, int, int] = None class HLMEncoder(nn.Module): def __init__(self, config) -> None: super().__init__() if config.sandwich_size > 0: sandwich_start_index = config.num_hidden_layers // 2 - config.sandwich_size sandwich_indices = [sandwich_start_index + i*2 + 1 for i in range(config.sandwich_size)] #print('Sandwich indices:', sandwich_indices) self.layers = nn.ModuleList([ TransformerBlock(config, bias=i in sandwich_indices) for i in range(config.num_hidden_layers)]) for i in range(config.sandwich_size): self.layers[sandwich_start_index + i*2+1].make_sandwich(self.layers[sandwich_start_index + i*2]) else: self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) def _get_attention_mask(self, attn_mask, dtype): if attn_mask.dim() <= 2: extended_mask = attn_mask.unsqueeze(1).unsqueeze(2) extended_mask = extended_mask*extended_mask.squeeze(-2).unsqueeze(-1) elif attn_mask.dim() == 3: extended_mask = attn_mask.unsqueeze(1) else: extended_mask = attn_mask # Convert to float to avoid zero in denominator of softmax in SDPA, resulting in NaNs min_dtype = torch.finfo(dtype).min extended_mask = ((1.0 - extended_mask.float()) * min_dtype) # SDPA returns NaNs for fully masked rows, so attend to all tokens instead extended_mask = extended_mask.mul(~torch.all(extended_mask==min_dtype, dim=-1, keepdim=True)) return extended_mask def forward(self, hidden_states, attention_mask, freqs_cos, freqs_sin, return_dict=True, output_hidden_states=False): all_hidden_states = [] attn_mask = self._get_attention_mask(attention_mask, hidden_states.dtype) for layer in self.layers: hidden_states = layer(hidden_states, attn_mask, freqs_cos, freqs_sin) all_hidden_states.append(hidden_states) if return_dict: return BaseModelOutput( last_hidden_state=all_hidden_states[-1], hidden_states=all_hidden_states if output_hidden_states else None, attentions=None, ) else: return (all_hidden_states[-1], all_hidden_states) if output_hidden_states else all_hidden_states class HLMPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = HLMConfig base_model_prefix = "hlm" _keys_to_ignore_on_load_unexpected = [] supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class HLMModel(HLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.char_embeddings = nn.Embedding(config.vocab_size, config.intra_word_encoder.hidden_size, padding_idx=0) self.char_embedding_dropout = nn.Dropout(config.intra_word_encoder.dropout_prob) if self.config.embedding_size != -1 and self.config.embedding_size != self.config.intra_word_encoder.hidden_size: self.char_embedding_project = nn.Linear(self.config.embedding_size, self.config.intra_word_encoder.hidden_size, bias=False) freqs_cos, freqs_sin = precompute_freqs_cis(config.intra_word_encoder.hidden_size // config.intra_word_encoder.num_attention_heads, config.max_seq_length) self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) self.word_type_embeddings = nn.Embedding(config.type_vocab_size, config.intra_word_encoder.hidden_size) self.intra_word_encoder = HLMEncoder(config.intra_word_encoder) if self.config.intra_word_encoder.hidden_size != self.config.inter_word_encoder.hidden_size: self.intra_word_project = nn.Linear(self.config.intra_word_encoder.hidden_size, self.config.inter_word_encoder.hidden_size, bias=False) self.inter_word_encoder = HLMEncoder(config.inter_word_encoder) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.char_embeddings def set_input_embeddings(self, new_embeddings): self.char_embeddings = new_embeddings def forward(self, input_ids, char_input_mask, word_input_mask, word_type_ids=None, combined_word_embeddings: Optional[bool]=False, output_hidden_states: Optional[bool]=False, return_dict: Optional[bool]=True): input_embeds = self.char_embeddings(input_ids) input_embeds = self.char_embedding_dropout(input_embeds) if hasattr(self, "char_embedding_project"): input_embeds = self.char_embedding_project(input_embeds) batch_size, num_word, _, _ = input_embeds.shape num_char = self.config.max_word_length # reshape to attend to intra-word tokens rather than full sequence input_embeds = input_embeds.view(batch_size * num_word, num_char, self.config.intra_word_encoder.hidden_size) intra_word_mask = char_input_mask.view(batch_size * num_word, num_char) intra_word_output = self.intra_word_encoder( input_embeds, intra_word_mask, self.freqs_cos[:num_char], self.freqs_sin[:num_char], output_hidden_states=False, return_dict=True, ) initial_embeds = intra_word_output.last_hidden_state # extract [WORD_CLS] embeddings, which are always at the beginning of each word initial_word_embeds = initial_embeds[:,0,:] if word_type_ids is not None: word_type_embeds = self.word_type_embeddings(word_type_ids) word_type_embeds = word_type_embeds.view(batch_size * num_word, self.config.intra_word_encoder.hidden_size) initial_word_embeds = initial_word_embeds + word_type_embeds if hasattr(self, "intra_word_project"): initial_embeds = self.intra_word_project(initial_embeds) # reshape and extract contextualized inter-word representation word_embeds = initial_word_embeds.view(batch_size, num_word, self.config.inter_word_encoder.hidden_size) inter_word_output = self.inter_word_encoder( word_embeds, word_input_mask, self.freqs_cos[:num_word], self.freqs_sin[:num_word], output_hidden_states=output_hidden_states, return_dict=True, ) if combined_word_embeddings: initial_word_embeds = initial_word_embeds.view(batch_size, num_word, self.config.inter_word_encoder.hidden_size) contextual_word_embeds = inter_word_output.last_hidden_state combined_word_embeds = torch.cat([initial_word_embeds, contextual_word_embeds], dim=2) last_hidden_state = combined_word_embeds else: last_hidden_state = inter_word_output.last_hidden_state if return_dict: return HLMBaseModelOutput( last_hidden_state=last_hidden_state, hidden_states=inter_word_output.hidden_states if output_hidden_states else None, initial_embeds=initial_embeds, initial_word_embeds=initial_word_embeds, intra_word_mask=intra_word_mask, char_embeds=input_embeds, input_shape=(batch_size, num_word, num_char, self.config.inter_word_encoder.hidden_size), ) else: return ( last_hidden_state, inter_word_output.hidden_states if output_hidden_states else None, initial_embeds, initial_word_embeds, intra_word_mask, input_embeds, (batch_size, num_word, num_char, self.config.inter_word_encoder.hidden_size), ) def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # reshape xq and xk to match the complex representation xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1) xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) # apply rotation using real numbers xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos # flatten last two dimensions xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cos = torch.cos(freqs) # real part freqs_sin = torch.sin(freqs) # imaginary part return freqs_cos, freqs_sin class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight class TransformerBlock(nn.Module): def __init__(self, config: HLMEncoderConfig, bias: bool = False): super().__init__() self.pad_id = config.pad_token_id self.drop_p = config.dropout_prob self.n_heads = config.num_attention_heads self.d_head = config.hidden_size // config.num_attention_heads self.has_bias = bias dim = config.hidden_size # Attention self.q = nn.Linear(in_features=dim, out_features=dim, bias=bias) self.k = nn.Linear(in_features=dim, out_features=dim, bias=bias) self.v = nn.Linear(in_features=dim, out_features=dim, bias=bias) self.att_proj_linear = nn.Linear(in_features=dim, out_features=dim, bias=bias) self.resid_dropout = nn.Dropout(self.drop_p) # Feedforward layer self.ff_dropout = nn.Dropout(self.drop_p) self.ff_linear_1 = nn.Linear(in_features=dim, out_features=config.intermediate_size, bias=bias) self.ff_linear_2 = nn.Linear(in_features=config.intermediate_size, out_features=dim, bias=bias) self.ff_linear_3 = nn.Linear(in_features=dim, out_features=config.intermediate_size, bias=bias) # Pre-layer norms self.attn_norm = RMSNorm(dim, eps=config.layer_norm_eps) self.ff_norm = RMSNorm(dim, eps=config.layer_norm_eps) def make_sandwich(self, other): assert self.has_bias assert not other.has_bias self.q.weight = other.q.weight self.k.weight = other.k.weight self.v.weight = other.v.weight self.att_proj_linear.weight = other.att_proj_linear.weight self.ff_linear_1.weight = other.ff_linear_1.weight self.ff_linear_2.weight = other.ff_linear_2.weight self.ff_linear_3.weight = other.ff_linear_3.weight def forward(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor): x = x + self._attention_block(self.attn_norm(x), pad_mask, freqs_cos, freqs_sin) x = x + self._feedforward_block(self.ff_norm(x)) return x def _attention_block(self, x: torch.Tensor, attn_mask: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor): batch_size, seq_len, _ = x.shape xq, xk, xv = self.q(x), self.k(x), self.v(x) # Reshape for rotary embeddings xq = xq.view(batch_size, seq_len, self.n_heads, self.d_head) xk = xk.view(batch_size, seq_len, self.n_heads, self.d_head) xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head) xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # Reshape for attention calculation: (b_sz, n_head, s_len, d_head) xq = xq.transpose(1, 2) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) att = F.scaled_dot_product_attention( query=xq, key=xk, value=xv, attn_mask=attn_mask, dropout_p=self.drop_p if self.training else 0.0, is_causal=False, ) # Shape (b_sz, s_len, n_head, d_head) out = att.transpose(1, 2).contiguous() out = out.view(batch_size, seq_len, self.n_heads * self.d_head) return self.resid_dropout(self.att_proj_linear(out)) def _feedforward_block(self, x: torch.Tensor): # SWiGLU activation x = self.ff_linear_2(F.silu(self.ff_linear_1(x)) * self.ff_linear_3(x)) x = self.ff_dropout(x) return x class HLMForMaskedLM(HLMPreTrainedModel): _tied_weights_keys = ["cls.decoder.weight", "cls.decoder.bias"] def __init__(self, config): super().__init__(config) # NOTE: This property name must match "base_model_prefix" in the base class self.hlm = HLMModel(config) self.cls = HLMLMPredictionHead(config) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.cls.decoder def set_output_embeddings(self, new_embeddings): self.cls.decoder = new_embeddings def forward( self, input_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, char_input_mask: Optional[torch.Tensor] = None, word_input_mask: Optional[torch.Tensor] = None, word_type_ids: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = True, ) -> Union[Tuple, MaskedLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, num_words, max_chars_per_word)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ outputs = self.hlm( input_ids, char_input_mask=char_input_mask, word_input_mask=word_input_mask, word_type_ids=word_type_ids, output_hidden_states=output_hidden_states, return_dict=return_dict, combined_word_embeddings=False, ) prediction_scores = self.cls(outputs, freqs_cos=self.hlm.freqs_cos[:self.config.max_word_length], freqs_sin=self.hlm.freqs_sin[:self.config.max_word_length]) masked_lm_loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output else: return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, ) class HLMLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() intra_word_encoder_config = copy.copy(config.intra_word_encoder) intra_word_encoder_config.num_hidden_layers = 1 intra_word_encoder_config.sandwich_size = 0 self.intra_word_encoder = HLMEncoder(intra_word_encoder_config) self.residual_word_embedding = getattr(config, 'residual_word_embedding', False) self.config = config if self.config.intra_word_encoder.hidden_size != self.config.inter_word_encoder.hidden_size: self.inter_word_project = nn.Linear(config.inter_word_encoder.hidden_size, self.config.intra_word_encoder.hidden_size, bias=False) if getattr(config, "tie_word_embeddings", True): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.intra_word_encoder.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias else: self.decoder = nn.Linear(config.intra_word_encoder.hidden_size, config.vocab_size) def forward(self, base_model_output: HLMBaseModelOutput, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor): batch_size, num_word, _, _ = base_model_output.input_shape word_embeds = base_model_output.last_hidden_state.reshape(batch_size * num_word, 1, self.config.inter_word_encoder.hidden_size) if self.residual_word_embedding: # residual connection between initial word embeddings and contextual word embeddings as mentioned in the paper (section A.3) word_embeds += base_model_output.initial_word_embeds.unsqueeze(1) if hasattr(self, "inter_word_project"): word_embeds = self.inter_word_project(word_embeds) # concatenate to restore the character-level token sequence char_embeds = torch.cat([word_embeds, base_model_output.initial_embeds[:,1:,:]], dim=1) intra_word_output = self.intra_word_encoder( char_embeds, base_model_output.intra_word_mask, freqs_cos, freqs_sin, output_hidden_states=False, return_dict=True, ) char_logits = self.decoder(intra_word_output.last_hidden_state) batch_size, num_word, num_char, _ = base_model_output.input_shape char_logits = char_logits.reshape(batch_size, num_word * num_char, -1) return char_logits class HLMForTokenClassification(HLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.hlm = HLMModel(config) self.cls = nn.Linear(config.inter_word_encoder.hidden_size*2, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, char_input_mask: Optional[torch.Tensor] = None, word_input_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.hlm( input_ids, char_input_mask=char_input_mask, word_input_mask=word_input_mask, output_hidden_states=output_hidden_states, combined_word_embeddings=True, ) logits = self.cls(outputs.last_hidden_state) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) class HLMForSequenceClassification(HLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.num_labels = getattr(config, 'num_labels', 2) self.hlm = HLMModel(config) self.dense = nn.Linear(config.inter_word_encoder.hidden_size, config.inter_word_encoder.hidden_size) self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(config.inter_word_encoder.hidden_size, config.num_labels) #self.activation = SwiGLU() self.activation = nn.GELU() # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, char_input_mask: Optional[torch.Tensor] = None, word_input_mask: Optional[torch.Tensor] = None, word_type_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.hlm( input_ids, char_input_mask=char_input_mask, word_input_mask=word_input_mask, word_type_ids=word_type_ids, output_hidden_states=output_hidden_states, combined_word_embeddings=False, ) emb = outputs.last_hidden_state[:, 0] emb = self.dense(emb) emb = self.activation(emb) emb = self.dropout(emb) logits = self.classifier(emb) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: # regression task loss_fn = nn.MSELoss() logits = logits.view(-1).to(labels.dtype) loss = loss_fn(logits, labels.view(-1)) elif labels.dim() == 1 or labels.size(-1) == 1: label_index = (labels >= 0).nonzero() labels = labels.long() if label_index.size(0) > 0: labeled_logits = torch.gather( logits, 0, label_index.expand(label_index.size(0), logits.size(1)) ) labels = torch.gather(labels, 0, label_index.view(-1)) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) else: loss = torch.tensor(0).to(logits) else: log_softmax = nn.LogSoftmax(-1) loss = -((log_softmax(logits) * labels).sum(-1)).mean() elif self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states) AutoConfig.register("hlm", HLMConfig) AutoModel.register(HLMConfig, HLMModel) AutoModelForTokenClassification.register(HLMConfig, HLMForTokenClassification) AutoModelForSequenceClassification.register(HLMConfig, HLMForSequenceClassification) AutoModelForMaskedLM.register(HLMConfig, HLMForMaskedLM) AutoTokenizer.register(HLMConfig, HLMTokenizer) HLMConfig.register_for_auto_class() HLMModel.register_for_auto_class("AutoModel") HLMForMaskedLM.register_for_auto_class("AutoModelForMaskedLM") HLMForSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification") HLMForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")