Spaces:
Sleeping
Sleeping
from collections import OrderedDict | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
from transformers import AutoModel, AutoTokenizer | |
from shad_mlops_transformers.config import config | |
# example = ["Nader Jokhadar had given Syria the lead with a well-struck header in the seventh minute."] | |
# model_name = "bert-base-uncased" | |
# model_name = "Davlan/distilbert-base-multilingual-cased-ner-hrl" | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# model = AutoModel.from_pretrained(model_name) | |
# nlp = pipeline("ner", model=model, tokenizer=tokenizer) | |
# toks = tokenizer(example, padding=True, truncation=True, return_tensors="pt") | |
# with torch.no_grad(): | |
# p = model(**toks) | |
# print(p) | |
class extract_tensor(nn.Module): | |
def forward(self, x): | |
# Output shape (batch, features, hidden) | |
tensor, _ = x | |
# Reshape shape (batch, hidden) | |
return tensor[:, :] | |
class DocumentClassifier(nn.Module): | |
def __init__(self, n_classes: int = 2, device: torch.device = torch.device("cpu")): | |
super().__init__() | |
self.model_name = "bert-base-uncased" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.encoder = AutoModel.from_pretrained(self.model_name) | |
self.device = device | |
self.n_classes = n_classes | |
self.model = nn.Sequential( | |
OrderedDict( | |
[ | |
# ("fc", nn.Linear(in_features=self.encoder.pooler.dense.out_features, out_features=n_classes)), | |
("lstm", nn.LSTM(input_size=self.encoder.pooler.dense.out_features, hidden_size=n_classes)), | |
("extract", extract_tensor()), | |
("sm", nn.Softmax()), | |
] | |
) | |
) | |
self.trainable_params = self.model.parameters() | |
def forward(self, text): | |
with torch.no_grad(): | |
tok_info = self.tokenize(text) | |
embeddings = self.encoder(**tok_info)["pooler_output"] | |
return self.model(embeddings) | |
def tokenize(self, x: str) -> dict: | |
return self.ensure_device(self.tokenizer(x, padding=True, truncation=True, return_tensors="pt")) | |
def from_file(self, path: Path = config.weights_path) -> "DocumentClassifier": | |
self.load_state_dict(torch.load(path, map_location=torch.device("cpu"))) | |
return self | |
def ensure_device(self, tok_output): | |
tokens_tensor = tok_output["input_ids"].to(self.device) | |
token_type_ids = tok_output["token_type_ids"].to(self.device) | |
attention_mask = tok_output["attention_mask"].to(self.device) | |
output = {"input_ids": tokens_tensor, "token_type_ids": token_type_ids, "attention_mask": attention_mask} | |
return output | |
if __name__ == "__main__": | |
data = ["This article describes machine learning"] | |
model = DocumentClassifier(n_classes=61).from_file() | |
model(data) | |