File size: 2,889 Bytes
9e4713f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743cb9b
 
 
 
 
 
 
 
9e4713f
743cb9b
9e4713f
 
 
 
743cb9b
9e4713f
 
 
 
743cb9b
 
 
9e4713f
 
 
 
 
 
 
 
743cb9b
9e4713f
 
 
 
743cb9b
9e4713f
 
743cb9b
9e4713f
 
743cb9b
 
 
 
 
 
 
 
 
9e4713f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from collections import OrderedDict
from pathlib import Path

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

from shad_mlops_transformers.config import config

# example = ["Nader Jokhadar had given Syria the lead with a well-struck header in the seventh minute."]
# model_name = "bert-base-uncased"
# model_name = "Davlan/distilbert-base-multilingual-cased-ner-hrl"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModel.from_pretrained(model_name)
# nlp = pipeline("ner", model=model, tokenizer=tokenizer)
# toks = tokenizer(example, padding=True, truncation=True, return_tensors="pt")
# with torch.no_grad():
#     p = model(**toks)
# print(p)


class extract_tensor(nn.Module):
    def forward(self, x):
        # Output shape (batch, features, hidden)
        tensor, _ = x
        # Reshape shape (batch, hidden)
        return tensor[:, :]


class DocumentClassifier(nn.Module):
    def __init__(self, n_classes: int = 2, device: torch.device = torch.device("cpu")):
        super().__init__()
        self.model_name = "bert-base-uncased"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.encoder = AutoModel.from_pretrained(self.model_name)
        self.device = device
        self.n_classes = n_classes
        self.model = nn.Sequential(
            OrderedDict(
                [
                    # ("fc", nn.Linear(in_features=self.encoder.pooler.dense.out_features, out_features=n_classes)),
                    ("lstm", nn.LSTM(input_size=self.encoder.pooler.dense.out_features, hidden_size=n_classes)),
                    ("extract", extract_tensor()),
                    ("sm", nn.Softmax()),
                ]
            )
        )
        self.trainable_params = self.model.parameters()

    def forward(self, text):
        with torch.no_grad():
            tok_info = self.tokenize(text)
            embeddings = self.encoder(**tok_info)["pooler_output"]
        return self.model(embeddings)

    def tokenize(self, x: str) -> dict:
        return self.ensure_device(self.tokenizer(x, padding=True, truncation=True, return_tensors="pt"))

    def from_file(self, path: Path = config.weights_path) -> "DocumentClassifier":
        self.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
        return self

    def ensure_device(self, tok_output):
        tokens_tensor = tok_output["input_ids"].to(self.device)
        token_type_ids = tok_output["token_type_ids"].to(self.device)
        attention_mask = tok_output["attention_mask"].to(self.device)

        output = {"input_ids": tokens_tensor, "token_type_ids": token_type_ids, "attention_mask": attention_mask}

        return output


if __name__ == "__main__":
    data = ["This article describes machine learning"]
    model = DocumentClassifier(n_classes=61).from_file()
    model(data)