OpenSLU / model /decoder /classifier.py
LightChen2333's picture
Upload 34 files
37b9e99
raw
history blame
13.9 kB
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-01-31 20:07:00
Description:
'''
import random
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from model.decoder import decoder_utils
from torchcrf import CRF
from common.utils import HiddenData, OutputData, InputData, ClassifierOutputData, unpack_sequence, pack_sequence, \
instantiate
class BaseClassifier(nn.Module):
"""Base class for all classifier module
"""
def __init__(self, **config):
super().__init__()
self.config = config
if config.get("loss_fn"):
self.loss_fn = instantiate(config.get("loss_fn"))
else:
self.loss_fn = CrossEntropyLoss(ignore_index=self.config.get("ignore_index"))
def forward(self, *args, **kwargs):
raise NotImplementedError("No implemented classifier.")
def decode(self, output: OutputData,
target: InputData = None,
return_list=True,
return_sentence_level=None):
"""decode output logits
Args:
output (OutputData): output logits data
target (InputData, optional): input data with attention mask. Defaults to None.
return_list (bool, optional): if True return list else return torch Tensor.. Defaults to True.
return_sentence_level (_type_, optional): if True decode sentence level intent else decode token level intent. Defaults to None.
Returns:
List or Tensor: decoded sequence ids
"""
if self.config.get("return_sentence_level") is not None and return_sentence_level is None:
return_sentence_level = self.config.get("return_sentence_level")
elif self.config.get("return_sentence_level") is None and return_sentence_level is None:
return_sentence_level = False
return decoder_utils.decode(output, target,
return_list=return_list,
return_sentence_level=return_sentence_level,
pred_type=self.config.get("mode"),
use_multi=self.config.get("use_multi"),
multi_threshold=self.config.get("multi_threshold"))
def compute_loss(self, pred: OutputData, target: InputData):
"""compute loss
Args:
pred (OutputData): output logits data
target (InputData): input golden data
Returns:
Tensor: loss result
"""
_CRF = None
if self.config.get("use_crf"):
_CRF = self.CRF
return decoder_utils.compute_loss(pred, target, criterion_type=self.config["mode"],
use_crf=_CRF is not None,
ignore_index=self.config["ignore_index"],
use_multi=self.config.get("use_multi"),
loss_fn=self.loss_fn,
CRF=_CRF)
class LinearClassifier(BaseClassifier):
"""
Decoder structure based on Linear.
"""
def __init__(self, **config):
"""Construction function for LinearClassifier
Args:
config (dict):
input_dim (int): hidden state dim.
use_slot (bool): whether to classify slot label.
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True.
use_intent (bool): whether to classify intent label.
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True.
use_crf (bool): whether to use crf for slot.
"""
super().__init__(**config)
self.config = config
if config.get("use_slot"):
self.slot_classifier = nn.Linear(config["input_dim"], config["slot_label_num"])
if self.config.get("use_crf"):
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True)
if config.get("use_intent"):
self.intent_classifier = nn.Linear(config["input_dim"], config["intent_label_num"])
def forward(self, hidden: HiddenData):
if self.config.get("use_intent"):
return ClassifierOutputData(self.intent_classifier(hidden.get_intent_hidden_state()))
if self.config.get("use_slot"):
return ClassifierOutputData(self.slot_classifier(hidden.get_slot_hidden_state()))
class AutoregressiveLSTMClassifier(BaseClassifier):
"""
Decoder structure based on unidirectional LSTM.
"""
def __init__(self, **config):
""" Construction function for Decoder.
Args:
config (dict):
input_dim (int): input dimension of Decoder. In fact, it's encoder hidden size.
use_slot (bool): whether to classify slot label.
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True.
use_intent (bool): whether to classify intent label.
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True.
use_crf (bool): whether to use crf for slot.
hidden_dim (int): hidden dimension of iterative LSTM.
embedding_dim (int): if it's not None, the input and output are relevant.
dropout_rate (float): dropout rate of network which is only useful for embedding.
"""
super(AutoregressiveLSTMClassifier, self).__init__(**config)
if config.get("use_slot") and config.get("use_crf"):
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True)
self.input_dim = config["input_dim"]
self.hidden_dim = config["hidden_dim"]
if config.get("use_intent"):
self.output_dim = config["intent_label_num"]
if config.get("use_slot"):
self.output_dim = config["slot_label_num"]
self.dropout_rate = config["dropout_rate"]
self.embedding_dim = config.get("embedding_dim")
self.force_ratio = config.get("force_ratio")
self.config = config
self.ignore_index = config.get("ignore_index") if config.get("ignore_index") is not None else -100
# If embedding_dim is not None, the output and input
# of this structure is relevant.
if self.embedding_dim is not None:
self.embedding_layer = nn.Embedding(self.output_dim, self.embedding_dim)
self.init_tensor = nn.Parameter(
torch.randn(1, self.embedding_dim),
requires_grad=True
)
# Make sure the input dimension of iterative LSTM.
if self.embedding_dim is not None:
lstm_input_dim = self.input_dim + self.embedding_dim
else:
lstm_input_dim = self.input_dim
# Network parameter definition.
self.dropout_layer = nn.Dropout(self.dropout_rate)
self.lstm_layer = nn.LSTM(
input_size=lstm_input_dim,
hidden_size=self.hidden_dim,
batch_first=True,
bidirectional=self.config["bidirectional"],
dropout=self.dropout_rate,
num_layers=self.config["layer_num"]
)
self.linear_layer = nn.Linear(
self.hidden_dim,
self.output_dim
)
# self.loss_fn = CrossEntropyLoss(ignore_index=self.ignore_index)
def forward(self, hidden: HiddenData, internal_interaction=None, **interaction_args):
""" Forward process for decoder.
:param internal_interaction:
:param hidden:
:return: is distribution of prediction labels.
"""
input_tensor = hidden.slot_hidden
seq_lens = hidden.inputs.attention_mask.sum(-1).detach().cpu().tolist()
output_tensor_list, sent_start_pos = [], 0
input_tensor = pack_sequence(input_tensor, seq_lens)
forced_input = None
if self.training:
if random.random() < self.force_ratio:
if self.config["mode"]=="slot":
forced_slot = pack_sequence(hidden.inputs.slot, seq_lens)
temp_slot = []
for index, x in enumerate(forced_slot):
if index == 0:
temp_slot.append(x.reshape(1))
elif x == self.ignore_index:
temp_slot.append(temp_slot[-1])
else:
temp_slot.append(x.reshape(1))
forced_input = torch.cat(temp_slot, 0)
if self.config["mode"]=="token-level-intent":
forced_intent = hidden.inputs.intent.unsqueeze(1).repeat(1, hidden.inputs.slot.shape[1])
forced_input = pack_sequence(forced_intent, seq_lens)
if self.embedding_dim is None or forced_input is not None:
for sent_i in range(0, len(seq_lens)):
sent_end_pos = sent_start_pos + seq_lens[sent_i]
# Segment input hidden tensors.
seg_hiddens = input_tensor[sent_start_pos: sent_end_pos, :]
if self.embedding_dim is not None and forced_input is not None:
if seq_lens[sent_i] > 1:
seg_forced_input = forced_input[sent_start_pos: sent_end_pos]
seg_forced_tensor = self.embedding_layer(seg_forced_input)[:-1]
seg_prev_tensor = torch.cat([self.init_tensor, seg_forced_tensor], dim=0)
else:
seg_prev_tensor = self.init_tensor
# Concatenate forced target tensor.
combined_input = torch.cat([seg_hiddens, seg_prev_tensor], dim=1)
else:
combined_input = seg_hiddens
dropout_input = self.dropout_layer(combined_input)
lstm_out, _ = self.lstm_layer(dropout_input.view(1, seq_lens[sent_i], -1))
if internal_interaction is not None:
interaction_args["sent_id"] = sent_i
lstm_out = internal_interaction(torch.transpose(lstm_out, 0, 1), **interaction_args)[:, 0]
linear_out = self.linear_layer(lstm_out.view(seq_lens[sent_i], -1))
output_tensor_list.append(linear_out)
sent_start_pos = sent_end_pos
else:
for sent_i in range(0, len(seq_lens)):
prev_tensor = self.init_tensor
# It's necessary to remember h and c state
# when output prediction every single step.
last_h, last_c = None, None
sent_end_pos = sent_start_pos + seq_lens[sent_i]
for word_i in range(sent_start_pos, sent_end_pos):
seg_input = input_tensor[[word_i], :]
combined_input = torch.cat([seg_input, prev_tensor], dim=1)
dropout_input = self.dropout_layer(combined_input).view(1, 1, -1)
if last_h is None and last_c is None:
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input)
else:
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input, (last_h, last_c))
if internal_interaction is not None:
interaction_args["sent_id"] = sent_i
lstm_out = internal_interaction(lstm_out, **interaction_args)[:, 0]
lstm_out = self.linear_layer(lstm_out.view(1, -1))
output_tensor_list.append(lstm_out)
_, index = lstm_out.topk(1, dim=1)
prev_tensor = self.embedding_layer(index).view(1, -1)
sent_start_pos = sent_end_pos
seq_unpacked = unpack_sequence(torch.cat(output_tensor_list, dim=0), seq_lens)
# TODO: 都支持softmax
if self.config.get("use_multi"):
pred_output = ClassifierOutputData(seq_unpacked)
else:
pred_output = ClassifierOutputData(F.log_softmax(seq_unpacked, dim=-1))
return pred_output
class MLPClassifier(BaseClassifier):
"""
Decoder structure based on MLP.
"""
def __init__(self, **config):
""" Construction function for Decoder.
Args:
config (dict):
use_slot (bool): whether to classify slot label.
use_intent (bool): whether to classify intent label.
mlp (List):
- _model_target_: torch.nn.Linear
in_features (int): input feature dim
out_features (int): output feature dim
- _model_target_: torch.nn.LeakyReLU
negative_slope: 0.2
- ...
"""
super(MLPClassifier, self).__init__(**config)
self.config = config
for i, x in enumerate(config["mlp"]):
if isinstance(x.get("in_features"), str):
config["mlp"][i]["in_features"] = self.config[x["in_features"][1:-1]]
if isinstance(x.get("out_features"), str):
config["mlp"][i]["out_features"] = self.config[x["out_features"][1:-1]]
mlp = [instantiate(x) for x in config["mlp"]]
self.seq = nn.Sequential(*mlp)
def forward(self, hidden: HiddenData):
if self.config.get("use_intent"):
res = self.seq(hidden.intent_hidden)
else:
res = self.seq(hidden.slot_hidden)
return ClassifierOutputData(res)