import torch import torch.nn as nn from transformers import T5EncoderModel from transformers.modeling_outputs import SequenceClassifierOutput class IntentPredictModel(nn.Module): def __init__(self, pretrained_model_name_or_path, num_classes): super().__init__() self.backbone = T5EncoderModel.from_pretrained(pretrained_model_name_or_path) out_features = self.backbone.encoder.block[-1].layer[-1].DenseReluDense.wo.out_features # t5-base 为 768,t5-large 为 1024 self.fc1 = nn.Linear(out_features, 128) self.fc2 = nn.Linear(128, num_classes) def forward(self, input_ids, attention_mask): out_backbone = self.backbone(input_ids=input_ids, attention_mask=attention_mask) out_backbone = out_backbone.last_hidden_state # [batch_size, seq_len, out_features] out_backbone = torch.mean(out_backbone, dim=1) # [batch_size, out_features] out_fc1 = self.fc1(out_backbone) # [batch_size, 128] out_fc2 = self.fc2(out_fc1) # [batch_size, num_classes] return SequenceClassifierOutput(logits=out_fc2)