from collections import OrderedDict from pathlib import Path import torch import torch.nn as nn from transformers import AutoModel, AutoTokenizer from shad_mlops_transformers.config import config # example = ["Nader Jokhadar had given Syria the lead with a well-struck header in the seventh minute."] # model_name = "bert-base-uncased" # model_name = "Davlan/distilbert-base-multilingual-cased-ner-hrl" # tokenizer = AutoTokenizer.from_pretrained(model_name) # model = AutoModel.from_pretrained(model_name) # nlp = pipeline("ner", model=model, tokenizer=tokenizer) # toks = tokenizer(example, padding=True, truncation=True, return_tensors="pt") # with torch.no_grad(): # p = model(**toks) # print(p) class extract_tensor(nn.Module): def forward(self, x): # Output shape (batch, features, hidden) tensor, _ = x # Reshape shape (batch, hidden) return tensor[:, :] class DocumentClassifier(nn.Module): def __init__(self, n_classes: int = 2, device: torch.device = torch.device("cpu")): super().__init__() self.model_name = "bert-base-uncased" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.encoder = AutoModel.from_pretrained(self.model_name) self.device = device self.n_classes = n_classes self.model = nn.Sequential( OrderedDict( [ # ("fc", nn.Linear(in_features=self.encoder.pooler.dense.out_features, out_features=n_classes)), ("lstm", nn.LSTM(input_size=self.encoder.pooler.dense.out_features, hidden_size=n_classes)), ("extract", extract_tensor()), ("sm", nn.Softmax()), ] ) ) self.trainable_params = self.model.parameters() def forward(self, text): with torch.no_grad(): tok_info = self.tokenize(text) embeddings = self.encoder(**tok_info)["pooler_output"] return self.model(embeddings) def tokenize(self, x: str) -> dict: return self.ensure_device(self.tokenizer(x, padding=True, truncation=True, return_tensors="pt")) def from_file(self, path: Path = config.weights_path) -> "DocumentClassifier": self.load_state_dict(torch.load(path, map_location=torch.device("cpu"))) return self def ensure_device(self, tok_output): tokens_tensor = tok_output["input_ids"].to(self.device) token_type_ids = tok_output["token_type_ids"].to(self.device) attention_mask = tok_output["attention_mask"].to(self.device) output = {"input_ids": tokens_tensor, "token_type_ids": token_type_ids, "attention_mask": attention_mask} return output if __name__ == "__main__": data = ["This article describes machine learning"] model = DocumentClassifier(n_classes=61).from_file() model(data)