import torch import torch.nn as nn import numpy as np from torch.autograd import Function from transformers import PreTrainedModel from transformers.models.roberta.modeling_roberta import ( RobertaModel, RobertaClassificationHead, ) from typing import Union, Tuple, Optional from transformers.modeling_outputs import ( SequenceClassifierOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput ) from transformers.utils import ModelOutput from .configuration_pure_roberta import PureRobertaConfig class CovarianceFunction(Function): @staticmethod def forward(ctx, inputs): x = inputs b, c, h, w = x.data.shape m = h * w x = x.view(b, c, m) I_hat = (-1.0 / m / m) * torch.ones(m, m, device=x.device) + ( 1.0 / m ) * torch.eye(m, m, device=x.device) I_hat = I_hat.view(1, m, m).repeat(b, 1, 1).type(x.dtype) y = x @ I_hat @ x.transpose(-1, -2) ctx.save_for_backward(inputs, I_hat) return y @staticmethod def backward(ctx, grad_output): inputs, I_hat = ctx.saved_tensors x = inputs b, c, h, w = x.data.shape m = h * w x = x.view(b, c, m) grad_input = grad_output + grad_output.transpose(1, 2) grad_input = grad_input @ x @ I_hat grad_input = grad_input.reshape(b, c, h, w) return grad_input class Covariance(nn.Module): def __init__(self): super(Covariance, self).__init__() def _covariance(self, x): return CovarianceFunction.apply(x) def forward(self, x): # x should be [batch_size, seq_len, embed_dim] if x.dim() == 2: x = x.transpose(-1, -2) C = self._covariance(x[None, :, :, None]) C = C.squeeze(dim=0) return C class PFSA(torch.nn.Module): """ https://openreview.net/pdf?id=isodM5jTA7h """ def __init__(self, input_dim, alpha=1): super(PFSA, self).__init__() self.input_dim = input_dim self.alpha = alpha def forward_one_sample(self, x): x = x.transpose(1, 2)[..., None] k = torch.mean(x, dim=[-1, -2], keepdim=True) kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1] qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1] C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd) A = (1 - torch.sigmoid(C_qk)) ** self.alpha out = x * A out = out.squeeze(dim=-1).transpose(1, 2) return out def forward(self, input_values, attention_mask=None): """ x: [B, T, F] """ out = [] b, t, f = input_values.shape for x, mask in zip(input_values, attention_mask): x = x.view(1, t, f) # x_in = x[:, :sum(mask), :] x_in = x[:, :int(mask.sum().item()), :] x_out = self.forward_one_sample(x_in) x_expanded = torch.zeros_like(x, device=x.device) x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out out.append(x_expanded) out = torch.vstack(out) out = out.view(b, t, f) return out class PURE(torch.nn.Module): def __init__( self, in_dim, svd_rank=16, num_pc_to_remove=1, center=False, num_iters=2, alpha=1, disable_pcr=False, disable_pfsa=False, disable_covariance=True, *args, **kwargs ): super().__init__() self.in_dim = in_dim self.svd_rank = svd_rank self.num_pc_to_remove = num_pc_to_remove self.center = center self.num_iters = num_iters self.do_pcr = not disable_pcr self.do_pfsa = not disable_pfsa self.do_covariance = not disable_covariance self.attention = PFSA(in_dim, alpha=alpha) def _compute_pc(self, X, attention_mask): """ x: (B, T, F) """ pcs = [] bs, seqlen, dim = X.shape for x, mask in zip(X, attention_mask): rank = int(mask.sum().item()) x = x[:rank, :] if self.do_covariance: x = Covariance()(x) q = self.svd_rank else: q = min(self.svd_rank, rank) _, _, V = torch.pca_lowrank(x, q=q, center=self.center, niter=self.num_iters) # _, _, Vh = torch.linalg.svd(x_, full_matrices=False) # V = Vh.mH pc = V.transpose(0, 1)[:self.num_pc_to_remove, :] # pc: [K, F] pcs.append(pc) # pcs = torch.vstack(pcs) # pcs = pcs.view(bs, self.num_pc_to_remove, dim) return pcs def _remove_pc(self, X, pcs): """ [B, T, F], [B, ..., F] """ b, t, f = X.shape out = [] for i, (x, pc) in enumerate(zip(X, pcs)): # v = [] # for j, t in enumerate(x): # t_ = t # for c_ in c: # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1) # v.append(t_.transpose(-1, -2)) # v = torch.vstack(v) v = x - x @ pc.transpose(0, 1) @ pc out.append(v[None, ...]) out = torch.vstack(out) return out def forward(self, input_values, attention_mask=None, *args, **kwargs): """ PCR -> Attention x: (B, T, F) """ x = input_values if self.do_pcr: pc = self._compute_pc(x, attention_mask) # pc: [B, K, F] xx = self._remove_pc(x, pc) # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F] else: xx = x if self.do_pfsa: xx = self.attention(xx, attention_mask) return xx class StatisticsPooling(torch.nn.Module): def __init__(self, return_mean=True, return_std=True): super().__init__() # Small value for GaussNoise self.eps = 1e-5 self.return_mean = return_mean self.return_std = return_std if not (self.return_mean or self.return_std): raise ValueError( "both of statistics are equal to False \n" "consider enabling mean and/or std statistic pooling" ) def forward(self, input_values, attention_mask=None): """Calculates mean and std for a batch (input tensor). Arguments --------- x : torch.Tensor It represents a tensor for a mini-batch. """ x = input_values if attention_mask is None: if self.return_mean: mean = x.mean(dim=1) if self.return_std: std = x.std(dim=1) else: mean = [] std = [] for snt_id in range(x.shape[0]): # Avoiding padded time steps lengths = torch.sum(attention_mask, dim=1) relative_lengths = lengths / torch.max(lengths) actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int() # actual_size = int(torch.round(lengths[snt_id] * x.shape[1])) # computing statistics if self.return_mean: mean.append( torch.mean(x[snt_id, 0:actual_size, ...], dim=0) ) if self.return_std: std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0)) if self.return_mean: mean = torch.stack(mean) if self.return_std: std = torch.stack(std) if self.return_mean: gnoise = self._get_gauss_noise(mean.size(), device=mean.device) gnoise = gnoise mean += gnoise if self.return_std: std = std + self.eps # Append mean and std of the batch if self.return_mean and self.return_std: pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(1) elif self.return_mean: pooled_stats = mean.unsqueeze(1) elif self.return_std: pooled_stats = std.unsqueeze(1) return pooled_stats def _get_gauss_noise(self, shape_of_tensor, device="cpu"): """Returns a tensor of epsilon Gaussian noise. Arguments --------- shape_of_tensor : tensor It represents the size of tensor for generating Gaussian noise. """ gnoise = torch.randn(shape_of_tensor, device=device) gnoise -= torch.min(gnoise) gnoise /= torch.max(gnoise) gnoise = self.eps * ((1 - 9) * gnoise + 9) return gnoise class PureRobertaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = PureRobertaConfig base_model_prefix = "pure_roberta" supports_gradient_checkpointing = True _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"] _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class PureRobertaForSequenceClassification(PureRobertaPreTrainedModel): def __init__( self, config, label_smoothing=0.0, ): super().__init__(config) self.label_smoothing = label_smoothing self.num_labels = config.num_labels self.config = config self.roberta = RobertaModel(config, add_pooling_layer=False) self.pure = PURE( in_dim=config.hidden_size, svd_rank=config.svd_rank, num_pc_to_remove=config.num_pc_to_remove, center=config.center, num_iters=config.num_iters, alpha=config.alpha, disable_pcr=config.disable_pcr, disable_pfsa=config.disable_pfsa, disable_covariance=config.disable_covariance ) self.mean = StatisticsPooling(return_mean=True, return_std=False) self.classifier = RobertaClassificationHead(config) # Initialize weights and apply final processing self.post_init() def forward_pure_embeddings( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) token_embeddings = outputs.last_hidden_state token_embeddings = self.pure(token_embeddings, attention_mask) return ModelOutput( last_hidden_state=token_embeddings, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) token_embeddings = outputs.last_hidden_state token_embeddings = self.pure(token_embeddings, attention_mask) pooled_output = self.mean(token_embeddings).squeeze(1) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )