|
|
|
|
|
|
|
|
|
import dataclasses |
|
import pprint |
|
from functools import partial |
|
import os |
|
from tqdm import tqdm, trange |
|
import numpy as np |
|
import mlxu |
|
|
|
from flax.traverse_util import flatten_dict |
|
from lm_eval import evaluator, tasks |
|
from lm_eval.base import LM |
|
|
|
from EasyLM.serving import LMClient |
|
|
|
|
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
|
tasks='wsc,piqa,winogrande,openbookqa,logiqa', |
|
shots=0, |
|
lm_client=LMClient.get_default_config(), |
|
logger=mlxu.WandBLogger.get_default_config(), |
|
) |
|
|
|
|
|
class LMEvalHarnessInterface(LM): |
|
|
|
def __init__(self, lm_client): |
|
self.lm_client = lm_client |
|
|
|
def greedy_until(self, inputs): |
|
prefix, until = zip(*inputs) |
|
return self.lm_client.greedy_until(prefix, until) |
|
|
|
def loglikelihood_rolling(self, inputs): |
|
loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs) |
|
return list(zip(loglikelihood, is_greedy)) |
|
|
|
def loglikelihood(self, inputs): |
|
prefix, text = zip(*inputs) |
|
loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text) |
|
return list(zip(loglikelihood, is_greedy)) |
|
|
|
|
|
def main(argv): |
|
logger = mlxu.WandBLogger( |
|
config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF) |
|
) |
|
model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client)) |
|
task_list = FLAGS.tasks.split(',') |
|
results = evaluator.evaluate( |
|
model, tasks.get_task_dict(task_list), False, FLAGS.shots, None |
|
) |
|
logger.log(flatten_dict(results['results'], sep='/')) |
|
pprint.pprint(results) |
|
|
|
|
|
if __name__ == "__main__": |
|
mlxu.run(main) |
|
|