File size: 2,437 Bytes
97c46f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import annotations
from transformers import PretrainedConfig
from torch import nn
import torch
from torchtyping import TensorType
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig
from transformers.modeling_outputs import SequenceClassifierOutput


class FastTextForSeuqenceClassification(FastTextJpModel):
    """FastTextのベクトルをベースとした分類を行います。
    """

    def __init__(self, config: FastTextJpConfig):
        super().__init__(config)

    def forward(self, **inputs) -> SequenceClassifierOutput:
        """embeddingを行います。

        Returns:
            TensorType["batch", "word", "vectors"]: 単語ごとにベクトルを返します。
        """
        input_ids = inputs["input_ids"]
        outputs = self.word_embeddings(input_ids)
        sentence = outputs[torch.logical_and(inputs["attention_mask"] == 1,
                                             inputs["token_type_ids"] == 0)]
        candidate_label = outputs[torch.logical_and(
            inputs["attention_mask"] == 1, inputs["token_type_ids"] == 1)]

        sentence_mean = torch.mean(sentence, dim=-2, keepdim=True)
        candidate_label_mean = torch.mean(candidate_label,
                                          dim=-2,
                                          keepdim=True)
        if sentence_mean.dim() == 2:
            p = torch.nn.functional.cosine_similarity(sentence_mean,
                                                      candidate_label_mean,
                                                      dim=1)
            logits = [[torch.log(p), -torch.inf, torch.log(1 - p)]]
        else:
            logits = []
            # batch
            for sm, clm in zip(sentence_mean, candidate_label_mean):
                p = torch.nn.functional.cosine_similarity(sm, clm, dim=1)
                logits.append([[torch.log(p), -torch.inf, torch.log(1 - p)]])
        logits = torch.FloatTensor(logits)
        return SequenceClassifierOutput(
            loss=None,
            logits=logits,  # type: ignore
            hidden_states=None,
            attentions=None,
        )


# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextForSeuqenceClassification.register_for_auto_class("AutoModel")