fasttext-classification / fasttext_fsc.py
Taizo Kaneko
commit files to HF hub
97c46f0
raw
history blame
2.44 kB
from __future__ import annotations
from transformers import PretrainedConfig
from torch import nn
import torch
from torchtyping import TensorType
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig
from transformers.modeling_outputs import SequenceClassifierOutput
class FastTextForSeuqenceClassification(FastTextJpModel):
"""FastTextのベクトルをベースとした分類を行います。
"""
def __init__(self, config: FastTextJpConfig):
super().__init__(config)
def forward(self, **inputs) -> SequenceClassifierOutput:
"""embeddingを行います。
Returns:
TensorType["batch", "word", "vectors"]: 単語ごとにベクトルを返します。
"""
input_ids = inputs["input_ids"]
outputs = self.word_embeddings(input_ids)
sentence = outputs[torch.logical_and(inputs["attention_mask"] == 1,
inputs["token_type_ids"] == 0)]
candidate_label = outputs[torch.logical_and(
inputs["attention_mask"] == 1, inputs["token_type_ids"] == 1)]
sentence_mean = torch.mean(sentence, dim=-2, keepdim=True)
candidate_label_mean = torch.mean(candidate_label,
dim=-2,
keepdim=True)
if sentence_mean.dim() == 2:
p = torch.nn.functional.cosine_similarity(sentence_mean,
candidate_label_mean,
dim=1)
logits = [[torch.log(p), -torch.inf, torch.log(1 - p)]]
else:
logits = []
# batch
for sm, clm in zip(sentence_mean, candidate_label_mean):
p = torch.nn.functional.cosine_similarity(sm, clm, dim=1)
logits.append([[torch.log(p), -torch.inf, torch.log(1 - p)]])
logits = torch.FloatTensor(logits)
return SequenceClassifierOutput(
loss=None,
logits=logits, # type: ignore
hidden_states=None,
attentions=None,
)
# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextForSeuqenceClassification.register_for_auto_class("AutoModel")