Add script
Browse files- run_flax_glue.py +526 -0
run_flax_glue.py
ADDED
|
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
|
| 17 |
+
import argparse
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
import time
|
| 22 |
+
from itertools import chain
|
| 23 |
+
from typing import Any, Callable, Dict, Tuple
|
| 24 |
+
|
| 25 |
+
import datasets
|
| 26 |
+
from datasets import load_dataset, load_metric
|
| 27 |
+
|
| 28 |
+
import jax
|
| 29 |
+
import jax.numpy as jnp
|
| 30 |
+
import optax
|
| 31 |
+
import transformers
|
| 32 |
+
from flax import struct, traverse_util
|
| 33 |
+
from flax.jax_utils import replicate, unreplicate
|
| 34 |
+
from flax.metrics import tensorboard
|
| 35 |
+
from flax.training import train_state
|
| 36 |
+
from flax.training.common_utils import get_metrics, onehot, shard
|
| 37 |
+
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
Array = Any
|
| 43 |
+
Dataset = datasets.arrow_dataset.Dataset
|
| 44 |
+
PRNGKey = Any
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
task_to_keys = {
|
| 48 |
+
"cola": ("sentence", None),
|
| 49 |
+
"mnli": ("premise", "hypothesis"),
|
| 50 |
+
"mrpc": ("sentence1", "sentence2"),
|
| 51 |
+
"qnli": ("question", "sentence"),
|
| 52 |
+
"qqp": ("question1", "question2"),
|
| 53 |
+
"rte": ("sentence1", "sentence2"),
|
| 54 |
+
"sst2": ("sentence", None),
|
| 55 |
+
"swahili_news": ("text", None),
|
| 56 |
+
"stsb": ("sentence1", "sentence2"),
|
| 57 |
+
"wnli": ("sentence1", "sentence2"),
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def parse_args():
|
| 62 |
+
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--task_name",
|
| 65 |
+
type=str,
|
| 66 |
+
default=None,
|
| 67 |
+
help="The name of the glue task to train on.",
|
| 68 |
+
choices=list(task_to_keys.keys()),
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--max_length",
|
| 78 |
+
type=int,
|
| 79 |
+
default=128,
|
| 80 |
+
help=(
|
| 81 |
+
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
| 82 |
+
" sequences shorter will be padded."
|
| 83 |
+
),
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--model_name_or_path",
|
| 87 |
+
type=str,
|
| 88 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 89 |
+
required=True,
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--use_slow_tokenizer",
|
| 93 |
+
action="store_true",
|
| 94 |
+
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--per_device_train_batch_size",
|
| 98 |
+
type=int,
|
| 99 |
+
default=8,
|
| 100 |
+
help="Batch size (per device) for the training dataloader.",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--per_device_eval_batch_size",
|
| 104 |
+
type=int,
|
| 105 |
+
default=8,
|
| 106 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--learning_rate",
|
| 110 |
+
type=float,
|
| 111 |
+
default=5e-5,
|
| 112 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
| 115 |
+
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--max_train_steps",
|
| 118 |
+
type=int,
|
| 119 |
+
default=None,
|
| 120 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
| 126 |
+
parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--push_to_hub",
|
| 129 |
+
action="store_true",
|
| 130 |
+
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
|
| 131 |
+
)
|
| 132 |
+
args = parser.parse_args()
|
| 133 |
+
|
| 134 |
+
# Sanity checks
|
| 135 |
+
if args.task_name is None and args.train_file is None and args.validation_file is None:
|
| 136 |
+
raise ValueError("Need either a task name or a training/validation file.")
|
| 137 |
+
else:
|
| 138 |
+
if args.train_file is not None:
|
| 139 |
+
extension = args.train_file.split(".")[-1]
|
| 140 |
+
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
| 141 |
+
if args.validation_file is not None:
|
| 142 |
+
extension = args.validation_file.split(".")[-1]
|
| 143 |
+
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
| 144 |
+
|
| 145 |
+
if args.output_dir is not None:
|
| 146 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
return args
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def create_train_state(
|
| 152 |
+
model: FlaxAutoModelForSequenceClassification,
|
| 153 |
+
learning_rate_fn: Callable[[int], float],
|
| 154 |
+
is_regression: bool,
|
| 155 |
+
num_labels: int,
|
| 156 |
+
weight_decay: float,
|
| 157 |
+
) -> train_state.TrainState:
|
| 158 |
+
"""Create initial training state."""
|
| 159 |
+
|
| 160 |
+
class TrainState(train_state.TrainState):
|
| 161 |
+
"""Train state with an Optax optimizer.
|
| 162 |
+
|
| 163 |
+
The two functions below differ depending on whether the task is classification
|
| 164 |
+
or regression.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
logits_fn: Applied to last layer to obtain the logits.
|
| 168 |
+
loss_fn: Function to compute the loss.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
logits_fn: Callable = struct.field(pytree_node=False)
|
| 172 |
+
loss_fn: Callable = struct.field(pytree_node=False)
|
| 173 |
+
|
| 174 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
| 175 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
| 176 |
+
# mask boolean with the same structure as the parameters.
|
| 177 |
+
# The mask is True for parameters that should be decayed.
|
| 178 |
+
def decay_mask_fn(params):
|
| 179 |
+
flat_params = traverse_util.flatten_dict(params)
|
| 180 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
| 181 |
+
return traverse_util.unflatten_dict(flat_mask)
|
| 182 |
+
|
| 183 |
+
tx = optax.adamw(
|
| 184 |
+
learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if is_regression:
|
| 188 |
+
|
| 189 |
+
def mse_loss(logits, labels):
|
| 190 |
+
return jnp.mean((logits[..., 0] - labels) ** 2)
|
| 191 |
+
|
| 192 |
+
return TrainState.create(
|
| 193 |
+
apply_fn=model.__call__,
|
| 194 |
+
params=model.params,
|
| 195 |
+
tx=tx,
|
| 196 |
+
logits_fn=lambda logits: logits[..., 0],
|
| 197 |
+
loss_fn=mse_loss,
|
| 198 |
+
)
|
| 199 |
+
else: # Classification.
|
| 200 |
+
|
| 201 |
+
def cross_entropy_loss(logits, labels):
|
| 202 |
+
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
|
| 203 |
+
return jnp.mean(xentropy)
|
| 204 |
+
|
| 205 |
+
return TrainState.create(
|
| 206 |
+
apply_fn=model.__call__,
|
| 207 |
+
params=model.params,
|
| 208 |
+
tx=tx,
|
| 209 |
+
logits_fn=lambda logits: logits.argmax(-1),
|
| 210 |
+
loss_fn=cross_entropy_loss,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def create_learning_rate_fn(
|
| 215 |
+
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
| 216 |
+
) -> Callable[[int], jnp.array]:
|
| 217 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 218 |
+
steps_per_epoch = train_ds_size // train_batch_size
|
| 219 |
+
num_train_steps = steps_per_epoch * num_train_epochs
|
| 220 |
+
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
| 221 |
+
decay_fn = optax.linear_schedule(
|
| 222 |
+
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
| 223 |
+
)
|
| 224 |
+
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
| 225 |
+
return schedule_fn
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
| 229 |
+
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
|
| 230 |
+
steps_per_epoch = len(dataset) // batch_size
|
| 231 |
+
perms = jax.random.permutation(rng, len(dataset))
|
| 232 |
+
perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
| 233 |
+
perms = perms.reshape((steps_per_epoch, batch_size))
|
| 234 |
+
|
| 235 |
+
for perm in perms:
|
| 236 |
+
batch = dataset[perm]
|
| 237 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 238 |
+
batch = shard(batch)
|
| 239 |
+
|
| 240 |
+
yield batch
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def glue_eval_data_collator(dataset: Dataset, batch_size: int):
|
| 244 |
+
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
|
| 245 |
+
for i in range(len(dataset) // batch_size):
|
| 246 |
+
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
| 247 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 248 |
+
batch = shard(batch)
|
| 249 |
+
|
| 250 |
+
yield batch
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def main():
|
| 254 |
+
args = parse_args()
|
| 255 |
+
|
| 256 |
+
# Make one log on every process with the configuration for debugging.
|
| 257 |
+
logging.basicConfig(
|
| 258 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 259 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 260 |
+
level=logging.INFO,
|
| 261 |
+
)
|
| 262 |
+
# Setup logging, we only want one process per machine to log things on the screen.
|
| 263 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
| 264 |
+
if jax.process_index() == 0:
|
| 265 |
+
datasets.utils.logging.set_verbosity_warning()
|
| 266 |
+
transformers.utils.logging.set_verbosity_info()
|
| 267 |
+
else:
|
| 268 |
+
datasets.utils.logging.set_verbosity_error()
|
| 269 |
+
transformers.utils.logging.set_verbosity_error()
|
| 270 |
+
|
| 271 |
+
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
| 272 |
+
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
|
| 273 |
+
|
| 274 |
+
# For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
|
| 275 |
+
# sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
|
| 276 |
+
# label if at least two columns are provided.
|
| 277 |
+
|
| 278 |
+
# If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
|
| 279 |
+
# single column. You can easily tweak this behavior (see below)
|
| 280 |
+
|
| 281 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
| 282 |
+
# download the dataset.
|
| 283 |
+
if args.task_name == "swahili_news":
|
| 284 |
+
raw_datasets = load_dataset("swahili_news")
|
| 285 |
+
valid_test_split = 10
|
| 286 |
+
raw_datasets["validation"] = load_dataset(
|
| 287 |
+
"swahili_news",
|
| 288 |
+
split=f"train[:{valid_test_split}%]"
|
| 289 |
+
)
|
| 290 |
+
raw_datasets["train"] = load_dataset(
|
| 291 |
+
"swahili_news",
|
| 292 |
+
split=f"train[{valid_test_split}%:]"
|
| 293 |
+
)
|
| 294 |
+
print(f"train: {len(raw_datasets['train'])}, validation: {len(raw_datasets['validation'])},")
|
| 295 |
+
elif args.task_name is not None:
|
| 296 |
+
# Downloading and loading a dataset from the hub.
|
| 297 |
+
raw_datasets = load_dataset("glue", args.task_name)
|
| 298 |
+
else:
|
| 299 |
+
# Loading the dataset from local csv or json file.
|
| 300 |
+
data_files = {}
|
| 301 |
+
if args.train_file is not None:
|
| 302 |
+
data_files["train"] = args.train_file
|
| 303 |
+
if args.validation_file is not None:
|
| 304 |
+
data_files["validation"] = args.validation_file
|
| 305 |
+
extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
|
| 306 |
+
raw_datasets = load_dataset(extension, data_files=data_files)
|
| 307 |
+
# See more about loading any type of standard or custom dataset at
|
| 308 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
| 309 |
+
|
| 310 |
+
# Labels
|
| 311 |
+
if args.task_name is not None:
|
| 312 |
+
is_regression = args.task_name == "stsb"
|
| 313 |
+
if not is_regression:
|
| 314 |
+
label_list = raw_datasets["train"].features["label"].names
|
| 315 |
+
num_labels = len(label_list)
|
| 316 |
+
else:
|
| 317 |
+
num_labels = 1
|
| 318 |
+
else:
|
| 319 |
+
# Trying to have good defaults here, don't hesitate to tweak to your needs.
|
| 320 |
+
is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
|
| 321 |
+
if is_regression:
|
| 322 |
+
num_labels = 1
|
| 323 |
+
else:
|
| 324 |
+
# A useful fast method:
|
| 325 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
|
| 326 |
+
label_list = raw_datasets["train"].unique("label")
|
| 327 |
+
label_list.sort() # Let's sort it for determinism
|
| 328 |
+
num_labels = len(label_list)
|
| 329 |
+
|
| 330 |
+
# Load pretrained model and tokenizer
|
| 331 |
+
config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
|
| 332 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
| 333 |
+
model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
|
| 334 |
+
|
| 335 |
+
# Preprocessing the datasets
|
| 336 |
+
if args.task_name is not None:
|
| 337 |
+
sentence1_key, sentence2_key = task_to_keys[args.task_name]
|
| 338 |
+
else:
|
| 339 |
+
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
|
| 340 |
+
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
|
| 341 |
+
if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
|
| 342 |
+
sentence1_key, sentence2_key = "sentence1", "sentence2"
|
| 343 |
+
else:
|
| 344 |
+
if len(non_label_column_names) >= 2:
|
| 345 |
+
sentence1_key, sentence2_key = non_label_column_names[:2]
|
| 346 |
+
else:
|
| 347 |
+
sentence1_key, sentence2_key = non_label_column_names[0], None
|
| 348 |
+
|
| 349 |
+
# Some models have set the order of the labels to use, so let's make sure we do use it.
|
| 350 |
+
label_to_id = None
|
| 351 |
+
if (
|
| 352 |
+
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
|
| 353 |
+
and args.task_name is not None
|
| 354 |
+
and not is_regression
|
| 355 |
+
):
|
| 356 |
+
# Some have all caps in their config, some don't.
|
| 357 |
+
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
|
| 358 |
+
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
| 359 |
+
logger.info(
|
| 360 |
+
f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
|
| 361 |
+
"Using it!"
|
| 362 |
+
)
|
| 363 |
+
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
|
| 364 |
+
else:
|
| 365 |
+
logger.warning(
|
| 366 |
+
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
| 367 |
+
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
|
| 368 |
+
"\nIgnoring the model labels as a result.",
|
| 369 |
+
)
|
| 370 |
+
elif args.task_name is None:
|
| 371 |
+
label_to_id = {v: i for i, v in enumerate(label_list)}
|
| 372 |
+
|
| 373 |
+
def preprocess_function(examples):
|
| 374 |
+
# Tokenize the texts
|
| 375 |
+
texts = (
|
| 376 |
+
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
| 377 |
+
)
|
| 378 |
+
result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)
|
| 379 |
+
|
| 380 |
+
if "label" in examples:
|
| 381 |
+
if label_to_id is not None:
|
| 382 |
+
# Map labels to IDs (not necessary for GLUE tasks)
|
| 383 |
+
result["labels"] = [label_to_id[l] for l in examples["label"]]
|
| 384 |
+
else:
|
| 385 |
+
# In all cases, rename the column to labels because the model will expect that.
|
| 386 |
+
result["labels"] = examples["label"]
|
| 387 |
+
return result
|
| 388 |
+
|
| 389 |
+
processed_datasets = raw_datasets.map(
|
| 390 |
+
preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
train_dataset = processed_datasets["train"]
|
| 394 |
+
eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
|
| 395 |
+
|
| 396 |
+
# Log a few random samples from the training set:
|
| 397 |
+
for index in random.sample(range(len(train_dataset)), 3):
|
| 398 |
+
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
| 399 |
+
|
| 400 |
+
# Define a summary writer
|
| 401 |
+
summary_writer = tensorboard.SummaryWriter(args.output_dir)
|
| 402 |
+
summary_writer.hparams(vars(args))
|
| 403 |
+
|
| 404 |
+
def write_metric(train_metrics, eval_metrics, train_time, step):
|
| 405 |
+
summary_writer.scalar("train_time", train_time, step)
|
| 406 |
+
|
| 407 |
+
train_metrics = get_metrics(train_metrics)
|
| 408 |
+
for key, vals in train_metrics.items():
|
| 409 |
+
tag = f"train_{key}"
|
| 410 |
+
for i, val in enumerate(vals):
|
| 411 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 412 |
+
|
| 413 |
+
for metric_name, value in eval_metrics.items():
|
| 414 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
| 415 |
+
|
| 416 |
+
num_epochs = int(args.num_train_epochs)
|
| 417 |
+
rng = jax.random.PRNGKey(args.seed)
|
| 418 |
+
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
| 419 |
+
|
| 420 |
+
train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
|
| 421 |
+
eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
|
| 422 |
+
|
| 423 |
+
learning_rate_fn = create_learning_rate_fn(
|
| 424 |
+
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
state = create_train_state(
|
| 428 |
+
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# define step functions
|
| 432 |
+
def train_step(
|
| 433 |
+
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
|
| 434 |
+
) -> Tuple[train_state.TrainState, float]:
|
| 435 |
+
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
|
| 436 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
| 437 |
+
targets = batch.pop("labels")
|
| 438 |
+
|
| 439 |
+
def loss_fn(params):
|
| 440 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
| 441 |
+
loss = state.loss_fn(logits, targets)
|
| 442 |
+
return loss
|
| 443 |
+
|
| 444 |
+
grad_fn = jax.value_and_grad(loss_fn)
|
| 445 |
+
loss, grad = grad_fn(state.params)
|
| 446 |
+
grad = jax.lax.pmean(grad, "batch")
|
| 447 |
+
new_state = state.apply_gradients(grads=grad)
|
| 448 |
+
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
|
| 449 |
+
return new_state, metrics, new_dropout_rng
|
| 450 |
+
|
| 451 |
+
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
|
| 452 |
+
|
| 453 |
+
def eval_step(state, batch):
|
| 454 |
+
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
|
| 455 |
+
return state.logits_fn(logits)
|
| 456 |
+
|
| 457 |
+
p_eval_step = jax.pmap(eval_step, axis_name="batch")
|
| 458 |
+
|
| 459 |
+
if args.task_name == "swahili_news":
|
| 460 |
+
metric = load_metric("glue", "sst2")
|
| 461 |
+
elif args.task_name is not None:
|
| 462 |
+
metric = load_metric("glue", args.task_name)
|
| 463 |
+
else:
|
| 464 |
+
metric = load_metric("accuracy")
|
| 465 |
+
|
| 466 |
+
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
|
| 467 |
+
train_time = 0
|
| 468 |
+
|
| 469 |
+
# make sure weights are replicated on each device
|
| 470 |
+
state = replicate(state)
|
| 471 |
+
|
| 472 |
+
for epoch in range(1, num_epochs + 1):
|
| 473 |
+
logger.info(f"Epoch {epoch}")
|
| 474 |
+
logger.info(" Training...")
|
| 475 |
+
|
| 476 |
+
train_start = time.time()
|
| 477 |
+
train_metrics = []
|
| 478 |
+
rng, input_rng = jax.random.split(rng)
|
| 479 |
+
|
| 480 |
+
# train
|
| 481 |
+
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
|
| 482 |
+
state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
|
| 483 |
+
train_metrics.append(metrics)
|
| 484 |
+
train_time += time.time() - train_start
|
| 485 |
+
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
| 486 |
+
|
| 487 |
+
logger.info(" Evaluating...")
|
| 488 |
+
|
| 489 |
+
# evaluate
|
| 490 |
+
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
|
| 491 |
+
labels = batch.pop("labels")
|
| 492 |
+
predictions = p_eval_step(state, batch)
|
| 493 |
+
metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
|
| 494 |
+
|
| 495 |
+
# evaluate also on leftover examples (not divisible by batch_size)
|
| 496 |
+
num_leftover_samples = len(eval_dataset) % eval_batch_size
|
| 497 |
+
|
| 498 |
+
# make sure leftover batch is evaluated on one device
|
| 499 |
+
if num_leftover_samples > 0 and jax.process_index() == 0:
|
| 500 |
+
# take leftover samples
|
| 501 |
+
batch = eval_dataset[-num_leftover_samples:]
|
| 502 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 503 |
+
|
| 504 |
+
labels = batch.pop("labels")
|
| 505 |
+
predictions = eval_step(unreplicate(state), batch)
|
| 506 |
+
metric.add_batch(predictions=predictions, references=labels)
|
| 507 |
+
|
| 508 |
+
eval_metric = metric.compute()
|
| 509 |
+
logger.info(f" Done! Eval metrics: {eval_metric}")
|
| 510 |
+
|
| 511 |
+
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 512 |
+
write_metric(train_metrics, eval_metric, train_time, cur_step)
|
| 513 |
+
|
| 514 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
| 515 |
+
if jax.process_index() == 0:
|
| 516 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 517 |
+
model.save_pretrained(
|
| 518 |
+
args.output_dir,
|
| 519 |
+
params=params,
|
| 520 |
+
push_to_hub=args.push_to_hub,
|
| 521 |
+
commit_message=f"Saving weights and logs of epoch {epoch}",
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
if __name__ == "__main__":
|
| 526 |
+
main()
|