Spaces:
Runtime error
Runtime error
File size: 1,994 Bytes
37b9e99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
'''
Author: Qiguang Chen
Date: 2023-01-11 10:39:26
LastEditors: Qiguang Chen
LastEditTime: 2023-01-26 17:18:22
Description: Root Model Module
'''
from torch import nn
from common.utils import OutputData, InputData
from model.decoder.base_decoder import BaseDecoder
from model.encoder.base_encoder import BaseEncoder
class OpenSLUModel(nn.Module):
def __init__(self, encoder: BaseEncoder, decoder:BaseDecoder, **config):
"""Create model automatedly
Args:
encoder (BaseEncoder): encoder created by config
decoder (BaseDecoder): decoder created by config
config (dict): any other args
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.config = config
def forward(self, inp: InputData) -> OutputData:
""" model forward
Args:
inp (InputData): input ids and other information
Returns:
OutputData: pred logits
"""
return self.decoder(self.encoder(inp))
def decode(self, output: OutputData, target: InputData=None):
""" decode output
Args:
pred (OutputData): pred logits data
target (InputData): golden data
Returns: decoded ids
"""
return self.decoder.decode(output, target)
def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
""" compute loss
Args:
pred (OutputData): pred logits data
target (InputData): golden data
compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
compute_slot_loss (bool, optional): whether to compute slot loss. Defaults to True.
Returns: loss value
"""
return self.decoder.compute_loss(pred, target, compute_intent_loss=compute_intent_loss,
compute_slot_loss=compute_slot_loss)
|