cross-encoder-ru / README.md
Winst's picture
Upload 6 files
53197c7 verified
|
raw
history blame
1.74 kB

Usage

from transformers import AutoModel, AutoTokenizer
import torch
from torch import nn


class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(1024, 4096)
        self.fc2 = nn.Linear(4096, 512)
        self.fc3 = nn.Linear(512, 2)
        self.dropout = nn.Dropout(p=0.1)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)

    def forward(self, x):
        x = self.leaky_relu(self.fc1(x))
        x = self.dropout(x)
        x = self.leaky_relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x


class CombinedModel(nn.Module):
    def __init__(self, transformer_model_name, classifier_checkpoint_path):
        super(CombinedModel, self).__init__()
        self.transformer = AutoModel.from_pretrained(transformer_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)
        self.classifier = Classifier()
        classifier_checkpoint = torch.load(classifier_checkpoint_path, map_location=torch.device('mps')) # could set cpu, cuda
        self.classifier.load_state_dict(classifier_checkpoint)

    def forward(self, text):
        outputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
        transformer_outputs = self.transformer(**outputs)
        pooled_output = transformer_outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits


model = CombinedModel('intfloat/multilingual-e5-large', 'path/to/best_model.pt')
model.eval()

def get_label(query, doc):
  text = f"Запрос: {query} /Документ: {doc}"
  logits = model(text)
  return torch.softmax(logits, dim=1).detach().numpy()