Spaces:
Runtime error
Runtime error
''' | |
Author: Qiguang Chen | |
Date: 2023-01-11 10:39:26 | |
LastEditors: Qiguang Chen | |
LastEditTime: 2023-01-31 20:07:00 | |
Description: | |
''' | |
import random | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.nn import CrossEntropyLoss | |
from model.decoder import decoder_utils | |
from torchcrf import CRF | |
from common.utils import HiddenData, OutputData, InputData, ClassifierOutputData, unpack_sequence, pack_sequence, \ | |
instantiate | |
class BaseClassifier(nn.Module): | |
"""Base class for all classifier module | |
""" | |
def __init__(self, **config): | |
super().__init__() | |
self.config = config | |
if config.get("loss_fn"): | |
self.loss_fn = instantiate(config.get("loss_fn")) | |
else: | |
self.loss_fn = CrossEntropyLoss(ignore_index=self.config.get("ignore_index")) | |
def forward(self, *args, **kwargs): | |
raise NotImplementedError("No implemented classifier.") | |
def decode(self, output: OutputData, | |
target: InputData = None, | |
return_list=True, | |
return_sentence_level=None): | |
"""decode output logits | |
Args: | |
output (OutputData): output logits data | |
target (InputData, optional): input data with attention mask. Defaults to None. | |
return_list (bool, optional): if True return list else return torch Tensor.. Defaults to True. | |
return_sentence_level (_type_, optional): if True decode sentence level intent else decode token level intent. Defaults to None. | |
Returns: | |
List or Tensor: decoded sequence ids | |
""" | |
if self.config.get("return_sentence_level") is not None and return_sentence_level is None: | |
return_sentence_level = self.config.get("return_sentence_level") | |
elif self.config.get("return_sentence_level") is None and return_sentence_level is None: | |
return_sentence_level = False | |
return decoder_utils.decode(output, target, | |
return_list=return_list, | |
return_sentence_level=return_sentence_level, | |
pred_type=self.config.get("mode"), | |
use_multi=self.config.get("use_multi"), | |
multi_threshold=self.config.get("multi_threshold")) | |
def compute_loss(self, pred: OutputData, target: InputData): | |
"""compute loss | |
Args: | |
pred (OutputData): output logits data | |
target (InputData): input golden data | |
Returns: | |
Tensor: loss result | |
""" | |
_CRF = None | |
if self.config.get("use_crf"): | |
_CRF = self.CRF | |
return decoder_utils.compute_loss(pred, target, criterion_type=self.config["mode"], | |
use_crf=_CRF is not None, | |
ignore_index=self.config["ignore_index"], | |
use_multi=self.config.get("use_multi"), | |
loss_fn=self.loss_fn, | |
CRF=_CRF) | |
class LinearClassifier(BaseClassifier): | |
""" | |
Decoder structure based on Linear. | |
""" | |
def __init__(self, **config): | |
"""Construction function for LinearClassifier | |
Args: | |
config (dict): | |
input_dim (int): hidden state dim. | |
use_slot (bool): whether to classify slot label. | |
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. | |
use_intent (bool): whether to classify intent label. | |
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. | |
use_crf (bool): whether to use crf for slot. | |
""" | |
super().__init__(**config) | |
self.config = config | |
if config.get("use_slot"): | |
self.slot_classifier = nn.Linear(config["input_dim"], config["slot_label_num"]) | |
if self.config.get("use_crf"): | |
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) | |
if config.get("use_intent"): | |
self.intent_classifier = nn.Linear(config["input_dim"], config["intent_label_num"]) | |
def forward(self, hidden: HiddenData): | |
if self.config.get("use_intent"): | |
return ClassifierOutputData(self.intent_classifier(hidden.get_intent_hidden_state())) | |
if self.config.get("use_slot"): | |
return ClassifierOutputData(self.slot_classifier(hidden.get_slot_hidden_state())) | |
class AutoregressiveLSTMClassifier(BaseClassifier): | |
""" | |
Decoder structure based on unidirectional LSTM. | |
""" | |
def __init__(self, **config): | |
""" Construction function for Decoder. | |
Args: | |
config (dict): | |
input_dim (int): input dimension of Decoder. In fact, it's encoder hidden size. | |
use_slot (bool): whether to classify slot label. | |
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. | |
use_intent (bool): whether to classify intent label. | |
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. | |
use_crf (bool): whether to use crf for slot. | |
hidden_dim (int): hidden dimension of iterative LSTM. | |
embedding_dim (int): if it's not None, the input and output are relevant. | |
dropout_rate (float): dropout rate of network which is only useful for embedding. | |
""" | |
super(AutoregressiveLSTMClassifier, self).__init__(**config) | |
if config.get("use_slot") and config.get("use_crf"): | |
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) | |
self.input_dim = config["input_dim"] | |
self.hidden_dim = config["hidden_dim"] | |
if config.get("use_intent"): | |
self.output_dim = config["intent_label_num"] | |
if config.get("use_slot"): | |
self.output_dim = config["slot_label_num"] | |
self.dropout_rate = config["dropout_rate"] | |
self.embedding_dim = config.get("embedding_dim") | |
self.force_ratio = config.get("force_ratio") | |
self.config = config | |
self.ignore_index = config.get("ignore_index") if config.get("ignore_index") is not None else -100 | |
# If embedding_dim is not None, the output and input | |
# of this structure is relevant. | |
if self.embedding_dim is not None: | |
self.embedding_layer = nn.Embedding(self.output_dim, self.embedding_dim) | |
self.init_tensor = nn.Parameter( | |
torch.randn(1, self.embedding_dim), | |
requires_grad=True | |
) | |
# Make sure the input dimension of iterative LSTM. | |
if self.embedding_dim is not None: | |
lstm_input_dim = self.input_dim + self.embedding_dim | |
else: | |
lstm_input_dim = self.input_dim | |
# Network parameter definition. | |
self.dropout_layer = nn.Dropout(self.dropout_rate) | |
self.lstm_layer = nn.LSTM( | |
input_size=lstm_input_dim, | |
hidden_size=self.hidden_dim, | |
batch_first=True, | |
bidirectional=self.config["bidirectional"], | |
dropout=self.dropout_rate, | |
num_layers=self.config["layer_num"] | |
) | |
self.linear_layer = nn.Linear( | |
self.hidden_dim, | |
self.output_dim | |
) | |
# self.loss_fn = CrossEntropyLoss(ignore_index=self.ignore_index) | |
def forward(self, hidden: HiddenData, internal_interaction=None, **interaction_args): | |
""" Forward process for decoder. | |
:param internal_interaction: | |
:param hidden: | |
:return: is distribution of prediction labels. | |
""" | |
input_tensor = hidden.slot_hidden | |
seq_lens = hidden.inputs.attention_mask.sum(-1).detach().cpu().tolist() | |
output_tensor_list, sent_start_pos = [], 0 | |
input_tensor = pack_sequence(input_tensor, seq_lens) | |
forced_input = None | |
if self.training: | |
if random.random() < self.force_ratio: | |
if self.config["mode"]=="slot": | |
forced_slot = pack_sequence(hidden.inputs.slot, seq_lens) | |
temp_slot = [] | |
for index, x in enumerate(forced_slot): | |
if index == 0: | |
temp_slot.append(x.reshape(1)) | |
elif x == self.ignore_index: | |
temp_slot.append(temp_slot[-1]) | |
else: | |
temp_slot.append(x.reshape(1)) | |
forced_input = torch.cat(temp_slot, 0) | |
if self.config["mode"]=="token-level-intent": | |
forced_intent = hidden.inputs.intent.unsqueeze(1).repeat(1, hidden.inputs.slot.shape[1]) | |
forced_input = pack_sequence(forced_intent, seq_lens) | |
if self.embedding_dim is None or forced_input is not None: | |
for sent_i in range(0, len(seq_lens)): | |
sent_end_pos = sent_start_pos + seq_lens[sent_i] | |
# Segment input hidden tensors. | |
seg_hiddens = input_tensor[sent_start_pos: sent_end_pos, :] | |
if self.embedding_dim is not None and forced_input is not None: | |
if seq_lens[sent_i] > 1: | |
seg_forced_input = forced_input[sent_start_pos: sent_end_pos] | |
seg_forced_tensor = self.embedding_layer(seg_forced_input)[:-1] | |
seg_prev_tensor = torch.cat([self.init_tensor, seg_forced_tensor], dim=0) | |
else: | |
seg_prev_tensor = self.init_tensor | |
# Concatenate forced target tensor. | |
combined_input = torch.cat([seg_hiddens, seg_prev_tensor], dim=1) | |
else: | |
combined_input = seg_hiddens | |
dropout_input = self.dropout_layer(combined_input) | |
lstm_out, _ = self.lstm_layer(dropout_input.view(1, seq_lens[sent_i], -1)) | |
if internal_interaction is not None: | |
interaction_args["sent_id"] = sent_i | |
lstm_out = internal_interaction(torch.transpose(lstm_out, 0, 1), **interaction_args)[:, 0] | |
linear_out = self.linear_layer(lstm_out.view(seq_lens[sent_i], -1)) | |
output_tensor_list.append(linear_out) | |
sent_start_pos = sent_end_pos | |
else: | |
for sent_i in range(0, len(seq_lens)): | |
prev_tensor = self.init_tensor | |
# It's necessary to remember h and c state | |
# when output prediction every single step. | |
last_h, last_c = None, None | |
sent_end_pos = sent_start_pos + seq_lens[sent_i] | |
for word_i in range(sent_start_pos, sent_end_pos): | |
seg_input = input_tensor[[word_i], :] | |
combined_input = torch.cat([seg_input, prev_tensor], dim=1) | |
dropout_input = self.dropout_layer(combined_input).view(1, 1, -1) | |
if last_h is None and last_c is None: | |
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input) | |
else: | |
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input, (last_h, last_c)) | |
if internal_interaction is not None: | |
interaction_args["sent_id"] = sent_i | |
lstm_out = internal_interaction(lstm_out, **interaction_args)[:, 0] | |
lstm_out = self.linear_layer(lstm_out.view(1, -1)) | |
output_tensor_list.append(lstm_out) | |
_, index = lstm_out.topk(1, dim=1) | |
prev_tensor = self.embedding_layer(index).view(1, -1) | |
sent_start_pos = sent_end_pos | |
seq_unpacked = unpack_sequence(torch.cat(output_tensor_list, dim=0), seq_lens) | |
# TODO: 都支持softmax | |
if self.config.get("use_multi"): | |
pred_output = ClassifierOutputData(seq_unpacked) | |
else: | |
pred_output = ClassifierOutputData(F.log_softmax(seq_unpacked, dim=-1)) | |
return pred_output | |
class MLPClassifier(BaseClassifier): | |
""" | |
Decoder structure based on MLP. | |
""" | |
def __init__(self, **config): | |
""" Construction function for Decoder. | |
Args: | |
config (dict): | |
use_slot (bool): whether to classify slot label. | |
use_intent (bool): whether to classify intent label. | |
mlp (List): | |
- _model_target_: torch.nn.Linear | |
in_features (int): input feature dim | |
out_features (int): output feature dim | |
- _model_target_: torch.nn.LeakyReLU | |
negative_slope: 0.2 | |
- ... | |
""" | |
super(MLPClassifier, self).__init__(**config) | |
self.config = config | |
for i, x in enumerate(config["mlp"]): | |
if isinstance(x.get("in_features"), str): | |
config["mlp"][i]["in_features"] = self.config[x["in_features"][1:-1]] | |
if isinstance(x.get("out_features"), str): | |
config["mlp"][i]["out_features"] = self.config[x["out_features"][1:-1]] | |
mlp = [instantiate(x) for x in config["mlp"]] | |
self.seq = nn.Sequential(*mlp) | |
def forward(self, hidden: HiddenData): | |
if self.config.get("use_intent"): | |
res = self.seq(hidden.intent_hidden) | |
else: | |
res = self.seq(hidden.slot_hidden) | |
return ClassifierOutputData(res) | |