|
## Usage |
|
|
|
```python |
|
from transformers import AutoModel, AutoTokenizer |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class Classifier(nn.Module): |
|
def __init__(self): |
|
super(Classifier, self).__init__() |
|
self.fc1 = nn.Linear(1024, 4096) |
|
self.fc2 = nn.Linear(4096, 512) |
|
self.fc3 = nn.Linear(512, 2) |
|
self.dropout = nn.Dropout(p=0.1) |
|
self.leaky_relu = nn.LeakyReLU(negative_slope=0.01) |
|
|
|
def forward(self, x): |
|
x = self.leaky_relu(self.fc1(x)) |
|
x = self.dropout(x) |
|
x = self.leaky_relu(self.fc2(x)) |
|
x = self.dropout(x) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
class CombinedModel(nn.Module): |
|
def __init__(self, transformer_model_name, classifier_checkpoint_path): |
|
super(CombinedModel, self).__init__() |
|
self.transformer = AutoModel.from_pretrained(transformer_model_name) |
|
self.tokenizer = AutoTokenizer.from_pretrained(transformer_model_name) |
|
self.classifier = Classifier() |
|
classifier_checkpoint = torch.load(classifier_checkpoint_path, map_location=torch.device('mps')) # could set cpu, cuda |
|
self.classifier.load_state_dict(classifier_checkpoint) |
|
|
|
def forward(self, text): |
|
outputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True) |
|
transformer_outputs = self.transformer(**outputs) |
|
pooled_output = transformer_outputs.pooler_output |
|
logits = self.classifier(pooled_output) |
|
return logits |
|
|
|
|
|
model = CombinedModel('intfloat/multilingual-e5-large', 'path/to/best_model.pt') |
|
model.eval() |
|
|
|
def get_label(query, doc): |
|
text = f"Запрос: {query} /Документ: {doc}" |
|
logits = model(text) |
|
return torch.softmax(logits, dim=1).detach().numpy() |
|
``` |