Build error
Build error
File size: 5,337 Bytes
7f7285f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/env python
import argparse
import itertools
import operator
import sys
from collections import OrderedDict
from run_eval import datetime_now, run_generate
from seq2seq_utils import ROUGE_KEYS
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
"translation": ["bleu"],
"summarization": ROUGE_KEYS,
def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names
def run_search():
Run parametric search over the desired hparam space with help of ````.
All the arguments except ``--search`` are passed to ```` as is. The values inside of "--search" are parsed, reformatted and fed to ```` as additional args.
The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ```` repeatedly with:
--num_beams 5 --length_penalty 0.8 --early_stopping true
--num_beams 5 --length_penalty 0.8 --early_stopping false
--num_beams 10 --length_penalty 1.2 --early_stopping false
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
prog = sys.argv[0]
parser = argparse.ArgumentParser(
usage="\n\nImportant: this script accepts all arguments `` accepts and then a few extra, therefore refer to ` -h` for the complete list."
help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
args, args_main = parser.parse_known_args()
# we share some of the args
args_main.extend(["--task", args.task])
args_normal = [prog] + args_main
# to support variations like translation_en_to_de"
task = "translation" if "translation" in args.task else "summarization"
matrix, col_names = parse_search_arg(
col_names[0:0] = task_score_names[task] # score cols first
col_widths = {col: len(str(col)) for col in col_names}
results = []
for r in matrix:
hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
args_exp = " ".join(r).split()
args_exp.extend(["--bs", str(]) # in case we need to reduce its size due to CUDA OOM
sys.argv = args_normal + args_exp
# XXX: need to trap CUDA OOM and lower if that happens and retry
scores = run_generate(verbose=False)
# make sure scores are first in the table
result = OrderedDict()
for score in task_score_names[task]:
result[score] = scores[score]
# find widest entries
for k, v in result.items():
l = len(str(v))
if l > col_widths[k]:
col_widths[k] = l
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
for row in results_sorted:
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
best = results_sorted[0]
for score in task_score_names[task]:
del best[score]
best_args = [f"--{k} {v}" for k, v in best.items()]
dyn_args = ["--bs", str(]
print(f"\nInfo: {}")
print("\nBest score args:")
print(" ".join(args_main + best_args + dyn_args))
return results_sorted
if __name__ == "__main__":
# Usage:
# [ cmd plus] \
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
# Example:
# PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/ $MODEL_NAME \
# $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/ \
# --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"