{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ZF2ruR7koa27" }, "source": [ "# Compare the eval scores" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "ktOHD34Yoa25" }, "outputs": [], "source": [ "import os\n", "import json\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib import rc\n", "%config InlineBackend.figure_formats = ['svg']\n", "plt.style.use('seaborn-v0_8-darkgrid')\n", "rc('font',**{'family':'serif','size':10})\n", "rc('text', usetex=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Scores" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "ALL_SCORES = {\n", " 'Llama-58M': {\n", " 'qa_congruence_easy': 0.5,\n", " 'island_effects': 0.5728699551569507,\n", " 'irregular_forms': 0.7414758269720102,\n", " 'npi_licensing': 0.5109322805952019,\n", " 'turn_taking': 0.6392857142857142,\n", " 'argument_structure': 0.7233268671193016,\n", " 'hypernym': 0.4872093023255814,\n", " 'qa_congruence_tricky': 0.32727272727272727,\n", " 'determiner_noun_agreement': 0.8780164412622646,\n", " 'binding': 0.711635500148412,\n", " 'control_raising': 0.6752098983650022,\n", " 'subject_verb_agreement': 0.7300813008130081,\n", " 'subject_aux_inversion': 0.7740912417662844,\n", " 'anaphor_agreement': 0.8701431492842536,\n", " 'filler_gap': 0.7093059446000622,\n", " 'quantifiers': 0.6421947449768161,\n", " 'ellipsis': 0.6732101616628176,\n", " 'zeroshot_average': 0.6568388856827299,\n", " },\n", " 'Llama-360M': {\n", " 'argument_structure': 0.73545101842871,\n", " 'turn_taking': 0.6857142857142857,\n", " 'determiner_noun_agreement': 0.8959162025987801,\n", " 'hypernym': 0.4941860465116279,\n", " 'control_raising': 0.6743261157755193,\n", " 'anaphor_agreement': 0.876278118609407,\n", " 'qa_congruence_easy': 0.53125,\n", " 'subject_verb_agreement': 0.6971996386630533,\n", " 'island_effects': 0.5044843049327354,\n", " 'npi_licensing': 0.5734892195566352,\n", " 'quantifiers': 0.5904173106646059,\n", " 'filler_gap': 0.7060379707438531,\n", " 'irregular_forms': 0.689058524173028,\n", " 'ellipsis': 0.684757505773672,\n", " 'subject_aux_inversion': 0.8433764332764089,\n", " 'qa_congruence_tricky': 0.41818181818181815,\n", " 'binding': 0.7209854556248145,\n", " 'zeroshot_average': 0.665947645248762,\n", " },\n", " 'GPT2-705M': {\n", " 'subject_aux_inversion': 0.8172725054891437,\n", " 'island_effects': 0.515695067264574,\n", " 'turn_taking': 0.6571428571428571,\n", " 'argument_structure': 0.7348448108632396,\n", " 'hypernym': 0.49186046511627907,\n", " 'subject_verb_agreement': 0.6751580849141825,\n", " 'irregular_forms': 0.8305343511450382,\n", " 'qa_congruence_easy': 0.5625,\n", " 'determiner_noun_agreement': 0.8741713073455317,\n", " 'quantifiers': 0.6983513652756311,\n", " 'npi_licensing': 0.5050106286061342,\n", " 'binding': 0.7151973879489463,\n", " 'ellipsis': 0.6991916859122402,\n", " 'qa_congruence_tricky': 0.45454545454545453,\n", " 'filler_gap': 0.7023031434796141,\n", " 'control_raising': 0.6840477242598321,\n", " 'anaphor_agreement': 0.8962167689161554,\n", " 'zeroshot_average': 0.6772966828367561,\n", " },\n", " 'Baby-Llama-58M-distilled': {\n", " 'irregular_forms': 0.9307888040712469,\n", " 'quantifiers': 0.732612055641422,\n", " 'binding': 0.7273671712674384,\n", " 'hypernym': 0.4930232558139535,\n", " 'qa_congruence_tricky': 0.41818181818181815,\n", " 'island_effects': 0.5123318385650224,\n", " 'filler_gap': 0.7183317771553066,\n", " 'determiner_noun_agreement': 0.9081145584725537,\n", " 'npi_licensing': 0.5646826601882782,\n", " 'anaphor_agreement': 0.8982617586912065,\n", " 'turn_taking': 0.6607142857142857,\n", " 'ellipsis': 0.733256351039261,\n", " 'argument_structure': 0.7314500484966052,\n", " 'qa_congruence_easy': 0.515625,\n", " 'subject_verb_agreement': 0.7542908762420958,\n", " 'subject_aux_inversion': 0.8850939253476457,\n", " 'control_raising': 0.6749889527176315,\n", " 'zeroshot_average': 0.6975950080944572,\n", " 'cola': 0.7046123743057251,\n", " 'sst2': 0.8720472455024719,\n", " 'mrpc': 0.8200000000000001,\n", " 'qqp': 0.8295147660946076,\n", " 'mnli': 0.7285888195037842,\n", " 'mnli-mm': 0.7369509935379028,\n", " 'qnli': 0.8114610910415649,\n", " 'rte': 0.6161616444587708,\n", " 'boolq': 0.6721991896629333,\n", " 'multirc': 0.5892661809921265,\n", " 'wsc': 0.6144578456878662,\n", " 'main_verb_control': 0.9990825653076172,\n", " 'control_raising_control': 0.9362576603889465,\n", " 'syntactic_category_control': 0.9431824684143066,\n", " 'lexical_content_the_control': 0.9777472615242004,\n", " 'relative_position_control': 0.9976232647895813,\n", " 'main_verb_lexical_content_the': 0.6731526255607605,\n", " 'main_verb_relative_token_position': 0.6802076697349548,\n", " 'syntactic_category_lexical_content_the': 0.9469174146652222,\n", " 'syntactic_category_relative_position': 0.7421950697898865,\n", " 'control_raising_lexical_content_the': 0.7405232191085815,\n", " 'control_raising_relative_token_position': 0.671150803565979,\n", " },\n", " 'Ensemble-of-teachers': {\n", " 'npi_licensing': 0.5508654722137868,\n", " 'determiner_noun_agreement': 0.8976398833200743,\n", " 'control_raising': 0.6767565178965974,\n", " 'turn_taking': 0.6642857142857143,\n", " 'subject_aux_inversion': 0.8453281288119053,\n", " 'quantifiers': 0.6460587326120556,\n", " 'island_effects': 0.507473841554559,\n", " 'irregular_forms': 0.833587786259542,\n", " 'qa_congruence_tricky': 0.4303030303030303,\n", " 'binding': 0.7221727515583259,\n", " 'argument_structure': 0.7526673132880698,\n", " 'qa_congruence_easy': 0.5,\n", " 'subject_verb_agreement': 0.7029810298102981,\n", " 'hypernym': 0.5011627906976744,\n", " 'anaphor_agreement': 0.8957055214723927,\n", " 'ellipsis': 0.7165127020785219,\n", " 'filler_gap': 0.7108621226268285,\n", " 'zeroshot_average': 0.6796684316934927,\n", " }, \n", " 'OPT-125M': {\n", " 'anaphor_agreement': 0.638,\n", " 'argument_structure': 0.706,\n", " 'binding': 0.671,\n", " 'control_raising': 0.665,\n", " 'determiner_noun_agreement': 0.785,\n", " 'ellipsis': 0.620,\n", " 'filler_gap': 0.638,\n", " 'irregular_forms': 0.675,\n", " 'island_effects': 0.486,\n", " 'npi_licensing': 0.467,\n", " 'quantifiers': 0.596,\n", " 'subject_verb_agreement': 0.569,\n", " 'hypernym': 0.500,\n", " 'qa_congruence_easy': 0.547,\n", " 'qa_congruence_tricky': 0.315,\n", " 'subject_aux_inversion': 0.803,\n", " 'turn_taking': 0.571,\n", " 'zeroshot_average': 0.6030588235294118,\n", " 'cola': 0.646,\n", " 'sst2': 0.819,\n", " 'mrpc': 0.725,\n", " 'qqp': 0.604,\n", " 'mnli': 0.576,\n", " 'mnli-mm': 0.600,\n", " 'qnli': 0.615,\n", " 'rte': 0.600,\n", " 'boolq': 0.633,\n", " 'multirc': 0.552,\n", " 'wsc': 0.602,\n", " 'control_raising_control': 0.864,\n", " 'lexical_content_the_control': 0.861,\n", " 'main_verb_control': 0.998,\n", " 'relative_position_control': 1.000,\n", " 'syntactic_category_control': 0.943,\n", " 'control_raising_lexical_content_the': 0.665,\n", " 'control_raising_relative_token_position': 0.670,\n", " 'main_verb_lexical_content_the': 0.665,\n", " 'main_verb_relative_token_position': 0.676,\n", " 'syntactic_category_lexical_content_the': 0.802,\n", " 'syntactic_category_relative_position': 0.675,\n", " },\n", " 'RoBERTa-base': {\n", " 'anaphor_agreement': 0.815,\n", " 'argument_structure': 0.671,\n", " 'binding': 0.673,\n", " 'control_raising': 0.679,\n", " 'determiner_noun_agreement': 0.908,\n", " 'ellipsis': 0.764,\n", " 'filler_gap': 0.635,\n", " 'irregular_forms': 0.874,\n", " 'island_effects': 0.399,\n", " 'npi_licensing': 0.559,\n", " 'quantifiers': 0.705,\n", " 'subject_verb_agreement': 0.654,\n", " 'hypernym': 0.494,\n", " 'qa_congruence_easy': 0.313,\n", " 'qa_congruence_tricky': 0.321,\n", " 'subject_aux_inversion': 0.717,\n", " 'turn_taking': 0.532,\n", " 'zeroshot_average': 0.6301764705882353,\n", " 'cola': 0.708,\n", " 'sst2': 0.870,\n", " 'mrpc': 0.792,\n", " 'qqp': 0.737,\n", " 'mnli': 0.732,\n", " 'mnli-mm': 0.740,\n", " 'qnli': 0.770,\n", " 'rte': 0.616,\n", " 'boolq': 0.663,\n", " 'multirc': 0.614,\n", " 'wsc': 0.614,\n", " 'control_raising_control': 0.841,\n", " 'lexical_content_the_control': 1.000,\n", " 'main_verb_control': 0.994,\n", " 'relative_position_control': 0.935,\n", " 'syntactic_category_control': 0.964,\n", " 'control_raising_lexical_content_the': 0.677,\n", " 'control_raising_relative_token_position': 0.686,\n", " 'main_verb_lexical_content_the': 0.667,\n", " 'main_verb_relative_token_position': 0.686,\n", " 'syntactic_category_lexical_content_the': 0.842,\n", " 'syntactic_category_relative_position': 0.657,\n", " },\n", " 'T5-base': {\n", " 'anaphor_agreement': 0.689,\n", " 'argument_structure': 0.638,\n", " 'binding': 0.604,\n", " 'control_raising': 0.609,\n", " 'determiner_noun_agreement': 0.722,\n", " 'ellipsis': 0.344,\n", " 'filler_gap': 0.482,\n", " 'irregular_forms': 0.776,\n", " 'island_effects': 0.456,\n", " 'npi_licensing': 0.478,\n", " 'quantifiers': 0.612,\n", " 'subject_verb_agreement': 0.650,\n", " 'hypernym': 0.480,\n", " 'qa_congruence_easy': 0.406,\n", " 'qa_congruence_tricky': 0.212,\n", " 'subject_aux_inversion': 0.649,\n", " 'turn_taking': 0.450,\n", " 'zeroshot_average': 0.5445294117647057,\n", " 'cola': 0.612,\n", " 'sst2': 0.781,\n", " 'mrpc': 0.805,\n", " 'qqp': 0.662,\n", " 'mnli': 0.480,\n", " 'mnli-mm': 0.503,\n", " 'qnli': 0.620,\n", " 'rte': 0.494,\n", " 'boolq': 0.660,\n", " 'multirc': 0.471,\n", " 'wsc': 0.614,\n", " 'control_raising_control': 0.784,\n", " 'lexical_content_the_control': 1.000,\n", " 'main_verb_control': 0.727,\n", " 'relative_position_control': 0.955,\n", " 'syntactic_category_control': 0.944,\n", " 'control_raising_lexical_content_the': 0.667,\n", " 'control_raising_relative_token_position': 0.697,\n", " 'main_verb_lexical_content_the': 0.666,\n", " 'main_verb_relative_token_position': 0.669,\n", " 'syntactic_category_lexical_content_the': 0.736,\n", " 'syntactic_category_relative_position': 0.678,\n", " },\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plots" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "BLIMP_EVAL_NAMES = {\n", " 'anaphor_agreement': 'Anaphor Agr.',\n", " 'argument_structure': 'Arg. Structure',\n", " 'binding': 'Binding',\n", " 'control_raising': 'Control/Raising',\n", " 'determiner_noun_agreement': 'Det.-Noun Agr.',\n", " 'ellipsis': 'Ellipsis',\n", " 'filler_gap': 'Filler-Gap',\n", " 'irregular_forms': 'Irregular Forms',\n", " 'island_effects': 'Island Effects',\n", " 'npi_licensing': 'NPI Licensing',\n", " 'quantifiers': 'Quantifiers',\n", " 'subject_verb_agreement': 'Subj.-Verb Agr.',\n", "}\n", "\n", "BLIMP_SUPPLEMENTARY_EVAL_NAMES = {\n", " 'hypernym': 'Hypernym',\n", " 'qa_congruence_easy': 'QA Congruence (easy)',\n", " 'qa_congruence_tricky': 'QA Congruence (tricky)',\n", " 'subject_aux_inversion': 'Subj.-Aux. Inversion',\n", " 'turn_taking': 'Turn Taking',\n", "}\n", "\n", "SUPERGLUE_EVAL_NAMES = {\n", " 'cola': 'CoLA',\n", " 'sst2': 'SST-2',\n", " 'mrpc': r'MRPC ($F_1$)',\n", " 'qqp': r'QQP ($F_1$)',\n", " 'mnli': 'MNLI',\n", " 'mnli-mm': 'MNLI-mm',\n", " 'qnli': 'QNLI',\n", " 'rte': 'RTE',\n", " 'boolq': 'BoolQ',\n", " 'multirc': 'MultiRC',\n", " 'wsc': 'WSC',\n", "}\n", "\n", "MSGS_EVAL_NAMES = {\n", " 'control_raising_control': 'CR (Control)',\n", " 'lexical_content_the_control': 'LC (Control)',\n", " 'main_verb_control': 'MV (Control)',\n", " 'relative_position_control': 'RP (Control)',\n", " 'syntactic_category_control': 'SC (Control)',\n", " 'control_raising_lexical_content_the': r'CR\\_LC',\n", " 'control_raising_relative_token_position': r'CR\\_RTP',\n", " 'main_verb_lexical_content_the': r'MV\\_LC',\n", " 'main_verb_relative_token_position': r'MV\\_RTP',\n", " 'syntactic_category_lexical_content_the': r'SC\\_LC',\n", " 'syntactic_category_relative_position': r'SC\\_RP',\n", "}\n", "\n", "MODEL_STYLES_DISTIL = {\n", " 'Llama-58M': {\n", " 'label': 'LLaMA (58M)',\n", " 'sty': {\n", " 'color': 'tab:cyan',\n", " 'linestyle': ':',\n", " },\n", " },\n", " 'Llama-360M': {\n", " 'label': 'LLaMA (360M)',\n", " 'sty': {\n", " 'color': 'tab:orange',\n", " 'linestyle': '--',\n", " },\n", " },\n", " 'GPT2-705M': {\n", " 'label': 'GPT-2 (705M)',\n", " 'sty': {\n", " 'color': 'tab:purple',\n", " 'linestyle': '-.',\n", " },\n", " },\n", " 'Ensemble-of-teachers': {\n", " 'label': 'Ensemble of teachers',\n", " 'sty': {\n", " 'color': 'tab:brown',\n", " 'linestyle': (0, (3, 1, 1, 1, 1, 1)),\n", " },\n", " },\n", " 'Baby-Llama-58M-distilled': {\n", " 'label': 'Baby Llama',\n", " 'sty': {\n", " 'color': 'black',\n", " 'linestyle': '-',\n", " },\n", " },\n", "}\n", "\n", "MODEL_STYLES_BASELINE = {\n", " 'OPT-125M': {\n", " 'label': 'OPT (125M)',\n", " 'sty': {\n", " 'color': 'tab:blue',\n", " 'linestyle': ':',\n", " },\n", " },\n", " 'RoBERTa-base': {\n", " 'label': 'RoBERTa (base)',\n", " 'sty': {\n", " 'color': 'tab:red',\n", " 'linestyle': '--',\n", " },\n", " },\n", " 'T5-base': {\n", " 'label': 'T5 (base)',\n", " 'sty': {\n", " 'color': 'tab:green',\n", " 'linestyle': '-.',\n", " },\n", " },\n", " 'Baby-Llama-58M-distilled': {\n", " 'label': 'Baby Llama',\n", " 'sty': {\n", " 'color': 'black',\n", " 'linestyle': '-',\n", " },\n", " },\n", "}" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "JQXvlgHKoa27" }, "outputs": [], "source": [ "def plot_scores(ax, all_scores, eval_names, model_styles, ylim, title, ylab=None, legend=False, legend_kw=dict()):\n", "\n", " model_names = list(all_scores.keys())\n", " x = range(len(eval_names))\n", "\n", " best_scores = dict()\n", " for eval in eval_names.keys():\n", " best_scores[eval] = float('-inf')\n", " for model in model_styles.keys():\n", " best_scores[eval] = max(best_scores[eval], all_scores[model].get(eval, float('-inf')))\n", "\n", " sorted_evals = [first for (first, second) in sorted(best_scores.items(), key=lambda x: x[1], reverse=True)]\n", " sorted_eval_names = [eval_names[eval] for eval in sorted_evals]\n", "\n", " for model_name in model_styles.keys():\n", " scores = [all_scores[model_name].get(test, None) for test in sorted_evals]\n", " props = model_styles[model_name]\n", " ax.plot(x, scores, marker='_', label=props['label'],\n", " markersize=9, markeredgewidth=1.5, linewidth=1, clip_on=False, **props['sty'])\n", " \n", " ax.set_xticks(x, sorted_eval_names, rotation=60, ha='right', rotation_mode='anchor')\n", " ax.set_xlim((-0.5, len(x)-0.5))\n", " ax.set_ylim(ylim)\n", " ax.grid(True, which='both', axis='y') # Adds major and minor gridlines\n", " if ylab is not None:\n", " ax.set_ylabel(ylab)\n", " if legend:\n", " ax.legend(**legend_kw)\n", " if title is not None:\n", " ax.set_title(title)\n", "\n", "def plot_average_score(ax, all_scores, model_styles, ylim, which_average='zeroshot_average', label='Zero-shot average', title=None):\n", " \n", " for (i, model_name) in enumerate(model_styles.keys()):\n", " props = model_styles[model_name]\n", " offset = i / (len(all_scores) - 1) - 0.5\n", " ax.plot([offset*0.5], [all_scores[model_name].get(which_average, None)], marker='_',\n", " markersize=9, markeredgewidth=1.5, linewidth=1, clip_on=False, **props['sty'])\n", " \n", " ax.set_ylim(ylim)\n", " ax.grid(True, which='both', axis='y') # Adds major and minor gridlines\n", " ax.set_xlim((-0.5, +0.5))\n", " ax.set_xticks([0], [label], rotation=60, ha='right', rotation_mode='anchor')\n", " if title is not None:\n", " ax.set_title(title)\n", "\n", "def plot_zeroshot_scores(all_scores, model_styles, figsize=(8,3), ylim=(None,None), title=None, legend_kw=dict()):\n", "\n", " fig, axs = plt.subplots(1, 3, width_ratios=[len(BLIMP_EVAL_NAMES), len(BLIMP_SUPPLEMENTARY_EVAL_NAMES), 1],\n", " sharey=True, figsize=figsize)\n", " fig.subplots_adjust(wspace=0.05)\n", "\n", " plot_scores(axs[0], all_scores, BLIMP_EVAL_NAMES, model_styles, ylim=ylim, title='BLiMP', legend=True, ylab='Accuracy', legend_kw=legend_kw)\n", " plot_scores(axs[1], all_scores, BLIMP_SUPPLEMENTARY_EVAL_NAMES, model_styles, ylim=ylim, title='BLiMP Suppl.')\n", " plot_average_score(axs[2], all_scores, model_styles, ylim=ylim, title='Avg.')\n", "\n", " if title is not None:\n", " fig.suptitle(title, y=1.05)\n", "\n", " return fig, axs\n", "\n", "def plot_fine_tuning_scores(all_scores, model_styles, figsize=(8,3), ylim=(None,None), title=None, legend_kw=dict()):\n", "\n", " fig, axs = plt.subplots(1, 2, width_ratios=[len(SUPERGLUE_EVAL_NAMES), len(MSGS_EVAL_NAMES)],\n", " sharey=True, figsize=figsize)\n", " fig.subplots_adjust(wspace=0.05)\n", "\n", " plot_scores(axs[0], all_scores, SUPERGLUE_EVAL_NAMES, model_styles, ylim=ylim, title='(Super)GLUE', legend=True, ylab=r'Accuracy (or $F_1$)', legend_kw=legend_kw)\n", " plot_scores(axs[1], all_scores, MSGS_EVAL_NAMES, model_styles, ylim=ylim, title='MSGS')\n", "\n", " if title is not None:\n", " fig.suptitle(title, y=1.05)\n", "\n", " return fig, axs\n", "\n", "def plot_SuperGLUE_scores(all_scores, model_styles, figsize=(3.4,3), ylim=(None,None), legend_kw=dict()):\n", " fig, ax = plt.subplots(figsize=figsize)\n", " plot_scores(ax, all_scores, SUPERGLUE_EVAL_NAMES, model_styles, ylim=ylim, title='(Super)GLUE', legend=True, ylab=r'Accuracy (or $F_1$)', legend_kw=legend_kw)\n", " return fig, ax\n", "\n", "def plot_MSGS_scores(all_scores, model_styles, figsize=(3.4,3), ylim=(None,None), legend_kw=dict()):\n", " fig, ax = plt.subplots(figsize=figsize)\n", " plot_scores(ax, all_scores, MSGS_EVAL_NAMES, model_styles, ylim=ylim, title='MSGS', legend=True, ylab='Accuracy', legend_kw=legend_kw)\n", " return fig, ax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Zero-shot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Comparison with non-distilled and teacher models" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-02T19:06:39.482553\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axs = plot_zeroshot_scores(ALL_SCORES, MODEL_STYLES_DISTIL, figsize=(7,2.6), ylim=(0.3, 1),\n", " title=r'Zero-shot performance vs.\\ non-distilled and teacher models', legend_kw={'loc': 'lower left', 'ncols': 2})\n", "h, l = axs[0].get_legend_handles_labels()\n", "h.insert(3, plt.plot([], [], color=(0, 0, 0, 0), label=\" \")[0])\n", "l.insert(3, '')\n", "axs[0].legend(h, l, **{'loc': 'lower left', 'ncols': 2})" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "fig.savefig('../plots/zeroshot-distil.pdf', bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Comparison with baselines" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-02T19:06:39.990176\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, _ = plot_zeroshot_scores(ALL_SCORES, MODEL_STYLES_BASELINE, figsize=(7,3.3), ylim=(0.15, 1),\n", " title=r'Zero-shot performance vs.\\ baselines', legend_kw={'loc': 'lower center', 'ncols': 2})" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "fig.savefig('../plots/zeroshot-vs-baselines.pdf', bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fine-tuning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### All scores" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-02T19:06:40.351876\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, _ = plot_fine_tuning_scores(ALL_SCORES, MODEL_STYLES_BASELINE, figsize=(7,3), ylim=(0.4, 1.),\n", " title=r'Fine-tuning performance vs.\\ baselines')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "fig.savefig('../plots/finetuning-all.pdf', bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### (Super)GLUE" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-02T19:06:40.650507\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, _ = plot_SuperGLUE_scores(ALL_SCORES, MODEL_STYLES_BASELINE, ylim=(0.4, 1))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "fig.savefig('../plots/superglue.pdf', bbox_inches='tight')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### MSGS" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-02T19:06:40.895636\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, _ = plot_MSGS_scores(ALL_SCORES, MODEL_STYLES_BASELINE, ylim=(0.45, 1), figsize=(3.4,2.75))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "fig.savefig('../plots/msgs.pdf', bbox_inches='tight')" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 4 }