cross-encoder-ru / README.md
Winst's picture
Upload 6 files
53197c7 verified
|
raw
history blame
1.74 kB
## Usage
```python
from transformers import AutoModel, AutoTokenizer
import torch
from torch import nn
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.fc1 = nn.Linear(1024, 4096)
self.fc2 = nn.Linear(4096, 512)
self.fc3 = nn.Linear(512, 2)
self.dropout = nn.Dropout(p=0.1)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)
def forward(self, x):
x = self.leaky_relu(self.fc1(x))
x = self.dropout(x)
x = self.leaky_relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
class CombinedModel(nn.Module):
def __init__(self, transformer_model_name, classifier_checkpoint_path):
super(CombinedModel, self).__init__()
self.transformer = AutoModel.from_pretrained(transformer_model_name)
self.tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)
self.classifier = Classifier()
classifier_checkpoint = torch.load(classifier_checkpoint_path, map_location=torch.device('mps')) # could set cpu, cuda
self.classifier.load_state_dict(classifier_checkpoint)
def forward(self, text):
outputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
transformer_outputs = self.transformer(**outputs)
pooled_output = transformer_outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
model = CombinedModel('intfloat/multilingual-e5-large', 'path/to/best_model.pt')
model.eval()
def get_label(query, doc):
text = f"Запрос: {query} /Документ: {doc}"
logits = model(text)
return torch.softmax(logits, dim=1).detach().numpy()
```