|
from argparse import ArgumentParser, Namespace |
|
from typing import List, Optional |
|
|
|
from model_api import SionicEmbeddingModel |
|
from mteb import MTEB |
|
|
|
RETRIEVAL_TASKS: List[str] = [ |
|
'ArguAna', |
|
'ClimateFEVER', |
|
'DBPedia', |
|
'FEVER', |
|
'FiQA2018', |
|
'HotpotQA', |
|
'MSMARCO', |
|
'NFCorpus', |
|
'NQ', |
|
'QuoraRetrieval', |
|
'SCIDOCS', |
|
'SciFact', |
|
'Touche2020', |
|
'TRECCOVID', |
|
] |
|
|
|
|
|
def get_arguments() -> Namespace: |
|
parser = ArgumentParser() |
|
parser.add_argument('--url', type=str, default='https://api.sionic.ai/v2/embedding', help='api server url') |
|
parser.add_argument('--instruction', type=str, default='query: ', help='query instruction') |
|
parser.add_argument('--batch_size', type=int, default=128) |
|
parser.add_argument('--dimension', type=int, default=3072) |
|
parser.add_argument('--output_dir', type=str, default='./result/v2') |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
args = get_arguments() |
|
|
|
model = SionicEmbeddingModel(url=args.url, instruction=args.instruction, batch_size=args.batch_size, dimension=args.dimension) |
|
|
|
task_names: List[str] = [t.description['name'] for t in MTEB(task_types=None, task_langs=['en']).tasks] |
|
|
|
for task in task_names: |
|
if task in ['MSMARCOv2']: |
|
continue |
|
|
|
instruction: Optional[str] = args.instruction if ('CQADupstack' in task) or (task in RETRIEVAL_TASKS) else None |
|
model.instruction = instruction |
|
|
|
evaluation = MTEB( |
|
tasks=[task], |
|
task_langs=['en'], |
|
eval_splits=['test' if task not in ['MSMARCO'] else 'dev'], |
|
) |
|
evaluation.run(model, output_folder=args.output_dir) |
|
|