File size: 1,368 Bytes
550665c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from typing import Optional, Tuple
import torch
from dataclasses import dataclass
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import QuestionAnsweringModelOutput
@dataclass
class QuestionAnsweringNaModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
loss (:obj:`torch.FloatTensor`, `optional`):
Loss of the output.
start_logits (:obj:`torch.FloatTensor`):
Span start logits.
end_logits (:obj:`torch.FloatTensor`):
Span end logits.
has_logits (:obj:`torch.FloatTensor`):
Has logits tensor.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`):
Hidden states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
has_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
|