|
import torch |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
import numpy as np |
|
|
|
from pyserini.encode import QueryEncoder |
|
|
|
|
|
class SpladeQueryEncoder(QueryEncoder): |
|
def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'): |
|
self.device = device |
|
self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path) |
|
self.model.to(self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) |
|
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} |
|
|
|
def encode(self, text, max_length=256, **kwargs): |
|
inputs = self.tokenizer([text], max_length=max_length, padding='longest', |
|
truncation=True, add_special_tokens=True, |
|
return_tensors='pt').to(self.device) |
|
input_ids = inputs['input_ids'] |
|
input_attention = inputs['attention_mask'] |
|
batch_logits = self.model(input_ids)['logits'] |
|
batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits)) |
|
* input_attention.unsqueeze(-1), dim=1) |
|
batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy() |
|
return self._output_to_weight_dicts(batch_aggregated_logits)[0] |
|
|
|
def _output_to_weight_dicts(self, batch_aggregated_logits): |
|
to_return = [] |
|
for aggregated_logits in batch_aggregated_logits: |
|
col = np.nonzero(aggregated_logits)[0] |
|
weights = aggregated_logits[col] |
|
d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))} |
|
to_return.append(d) |
|
return to_return |
|
|