Spaces:
Runtime error
Runtime error
File size: 2,378 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from colbert.infra.launcher import Launcher
from colbert.infra import Run, RunConfig
from colbert.modeling.reranker.electra import ElectraReranker
from colbert.utils.utils import flatten
DEFAULT_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
class Scorer:
def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256):
self.queries = queries
self.collection = collection
self.model = model
self.maxlen = maxlen
self.bsize = bsize
def launch(self, qids, pids):
launcher = Launcher(self._score_pairs_process, return_all=True)
outputs = launcher.launch(Run().config, qids, pids)
return flatten(outputs)
def _score_pairs_process(self, config, qids, pids):
assert len(qids) == len(pids), (len(qids), len(pids))
share = 1 + len(qids) // config.nranks
offset = config.rank * share
endpos = (1 + config.rank) * share
return self._score_pairs(qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1))
def _score_pairs(self, qids, pids, show_progress=False):
tokenizer = AutoTokenizer.from_pretrained(self.model)
model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda()
assert len(qids) == len(pids), (len(qids), len(pids))
scores = []
model.eval()
with torch.inference_mode():
with torch.cuda.amp.autocast():
for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)):
endpos = offset + self.bsize
queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
passages_ = [self.collection[pid] for pid in pids[offset:endpos]]
features = tokenizer(queries_, passages_, padding='longest', truncation=True,
return_tensors='pt', max_length=self.maxlen).to(model.device)
scores.append(model(**features).logits.flatten())
scores = torch.cat(scores)
scores = scores.tolist()
Run().print(f'Returning with {len(scores)} scores')
return scores
# LONG-TERM TODO: This can be sped up by sorting by length in advance.
|