"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import IPython.display as ipd\n",
+ "import numpy as np\n",
+ "import random\n",
+ "\n",
+ "rand_int = random.randint(0, len(common_voice_train)-1)\n",
+ "\n",
+ "print(\"Target text:\", common_voice_train[rand_int][\"sentence\"])\n",
+ "print(\"Input array shape:\", common_voice_train[rand_int][\"audio\"][\"array\"].shape)\n",
+ "print(\"Sampling rate:\", common_voice_train[rand_int][\"audio\"][\"sampling_rate\"])\n",
+ "ipd.Audio(data=common_voice_train[rand_int][\"audio\"][\"array\"], autoplay=True, rate=16000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "54926718",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This does not prepare the input for the Transformer model.\n",
+ "# This will resample the data and convert the sentence into indices\n",
+ "# Batch here is just for one entry (row)\n",
+ "def prepare_dataset(batch):\n",
+ " audio = batch[\"audio\"]\n",
+ " \n",
+ " # batched output is \"un-batched\"\n",
+ " batch[\"input_values\"] = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n",
+ " batch[\"input_length\"] = len(batch[\"input_values\"])\n",
+ " \n",
+ " with processor.as_target_processor():\n",
+ " batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n",
+ " return batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "0a348aa0",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-e3ff506f96ec6817.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-00a0dacd1c387ee8.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-89839f1a29958c06.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-ea97d53e6e03248b.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-74c31e1ede89718b.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-b4485d5ec10af59a.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-87741a8a8705e488.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-2aa5c421e49dbb8a.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-6fa3756abc090cb1.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-7082faf01a7536d9.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-dbf56923bad5550e.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-cfa541d30ccf3270.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-9f28af78c8d178d8.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-4fc740b07e55a01b.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-ec4bd65c3d0c2b80.arrow\n",
+ "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/common_voice/tr/6.1.0/5693bfc0feeade582a78c2fb250bc88f52bd86f0a7f1bb22bfee67e715de30fd/cache-033c2e0fab0f0e8a.arrow\n"
+ ]
+ }
+ ],
+ "source": [
+ "common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, num_proc=16)\n",
+ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, num_proc=16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "142e5d79",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# In case the dataset is too long which can lead to OOM. We should filter them out.\n",
+ "# max_input_length_in_sec = 5.0\n",
+ "# common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=[\"input_length\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "310cdbb1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "from dataclasses import dataclass, field\n",
+ "from typing import Any, Dict, List, Optional, Union\n",
+ "\n",
+ "@dataclass\n",
+ "class DataCollatorCTCWithPadding:\n",
+ " \"\"\"\n",
+ " Data collator that will dynamically pad the inputs received.\n",
+ " Args:\n",
+ " processor (:class:`~transformers.Wav2Vec2Processor`)\n",
+ " The processor used for proccessing the data.\n",
+ " padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):\n",
+ " Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n",
+ " among:\n",
+ " * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single\n",
+ " sequence if provided).\n",
+ " * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n",
+ " maximum acceptable input length for the model if that argument is not provided.\n",
+ " * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n",
+ " different lengths).\n",
+ " \"\"\"\n",
+ "\n",
+ " processor: Wav2Vec2Processor\n",
+ " padding: Union[bool, str] = True\n",
+ "\n",
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
+ " # split inputs and labels since they have to be of different lenghts and need\n",
+ " # different padding methods\n",
+ " input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n",
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
+ "\n",
+ " batch = self.processor.pad(\n",
+ " input_features,\n",
+ " padding=self.padding,\n",
+ " return_tensors=\"pt\",\n",
+ " )\n",
+ "\n",
+ " with self.processor.as_target_processor():\n",
+ " labels_batch = self.processor.pad(\n",
+ " label_features,\n",
+ " padding=self.padding,\n",
+ " return_tensors=\"pt\",\n",
+ " )\n",
+ "\n",
+ " # replace padding with -100 to ignore loss correctly\n",
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
+ "\n",
+ " batch[\"labels\"] = labels\n",
+ "\n",
+ " return batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "6cff622b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "df12cc5b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wer_metric = load_metric(\"wer\")\n",
+ "# cer_metric = load_metric(\"cer\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "8b25005a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def compute_metrics(pred):\n",
+ " pred_logits = pred.predictions\n",
+ " pred_ids = np.argmax(pred_logits, axis=-1)\n",
+ "\n",
+ " pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id\n",
+ "\n",
+ " pred_str = tokenizer.batch_decode(pred_ids)\n",
+ " # we do not want to group tokens when computing the metrics\n",
+ " label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)\n",
+ "\n",
+ " wer = wer_metric.compute(predictions=pred_str, references=label_str)\n",
+ "# cer = cer_metric.compute(predictions=pred_str, references=label_str)\n",
+ "\n",
+ " return {\"wer\": wer}\n",
+ "# return {\"cer\": cer}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "a7ac7d14",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at facebook/wav2vec2-xls-r-300m were not used when initializing Wav2Vec2ForCTC: ['quantizer.weight_proj.weight', 'project_q.bias', 'project_hid.bias', 'project_q.weight', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_hid.weight']\n",
+ "- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.weight', 'lm_head.bias']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import Wav2Vec2ForCTC\n",
+ "\n",
+ "model = Wav2Vec2ForCTC.from_pretrained(\n",
+ " \"facebook/wav2vec2-xls-r-300m\", \n",
+ " attention_dropout=0.0,\n",
+ " hidden_dropout=0.0,\n",
+ " feat_proj_dropout=0.0,\n",
+ " mask_time_prob=0.05,\n",
+ " layerdrop=0.0,\n",
+ " ctc_loss_reduction=\"mean\", \n",
+ " pad_token_id=tokenizer.pad_token_id,\n",
+ " vocab_size=len(processor.tokenizer),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "352fb742",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.freeze_feature_encoder()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "ae38b1c1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import TrainingArguments\n",
+ "\n",
+ "training_args = TrainingArguments(\n",
+ " output_dir='.',\n",
+ " group_by_length=True,\n",
+ " per_device_train_batch_size=8,\n",
+ " gradient_accumulation_steps=2,\n",
+ " evaluation_strategy=\"steps\",\n",
+ " gradient_checkpointing=True,\n",
+ " fp16=True,\n",
+ " num_train_epochs=25,\n",
+ " save_steps=500,\n",
+ " eval_steps=500,\n",
+ " logging_steps=100,\n",
+ " learning_rate=5e-5,\n",
+ " warmup_steps=1000,\n",
+ " save_total_limit=3\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "d60948cc",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using amp half precision backend\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import Trainer\n",
+ "\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " data_collator=data_collator,\n",
+ " args=training_args,\n",
+ " compute_metrics=compute_metrics,\n",
+ " train_dataset=common_voice_train,\n",
+ " eval_dataset=common_voice_test,\n",
+ " tokenizer=processor.feature_extractor,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "6b20f77c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
+ " warnings.warn(\n",
+ "***** Running training *****\n",
+ " Num examples = 3478\n",
+ " Num Epochs = 25\n",
+ " Instantaneous batch size per device = 8\n",
+ " Total train batch size (w. parallel, distributed & accumulation) = 16\n",
+ " Gradient Accumulation steps = 2\n",
+ " Total optimization steps = 5425\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [5425/5425 1:31:08, Epoch 24/25]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ " Wer | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 500 | \n",
+ " 3.885900 | \n",
+ " 3.760785 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ " 1000 | \n",
+ " 1.819300 | \n",
+ " 1.530782 | \n",
+ " 1.000613 | \n",
+ "
\n",
+ " \n",
+ " 1500 | \n",
+ " 0.598600 | \n",
+ " 0.729536 | \n",
+ " 1.005616 | \n",
+ "
\n",
+ " \n",
+ " 2000 | \n",
+ " 0.399200 | \n",
+ " 0.618558 | \n",
+ " 1.013377 | \n",
+ "
\n",
+ " \n",
+ " 2500 | \n",
+ " 0.319900 | \n",
+ " 0.597245 | \n",
+ " 1.012254 | \n",
+ "
\n",
+ " \n",
+ " 3000 | \n",
+ " 0.238800 | \n",
+ " 0.555572 | \n",
+ " 1.010109 | \n",
+ "
\n",
+ " \n",
+ " 3500 | \n",
+ " 0.188200 | \n",
+ " 0.517281 | \n",
+ " 1.014092 | \n",
+ "
\n",
+ " \n",
+ " 4000 | \n",
+ " 0.160400 | \n",
+ " 0.517009 | \n",
+ " 1.018278 | \n",
+ "
\n",
+ " \n",
+ " 4500 | \n",
+ " 0.144300 | \n",
+ " 0.526738 | \n",
+ " 1.018380 | \n",
+ "
\n",
+ " \n",
+ " 5000 | \n",
+ " 0.140400 | \n",
+ " 0.536664 | \n",
+ " 1.016747 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-500\n",
+ "Configuration saved in ./checkpoint-500/config.json\n",
+ "Model weights saved in ./checkpoint-500/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-500/preprocessor_config.json\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-1000\n",
+ "Configuration saved in ./checkpoint-1000/config.json\n",
+ "Model weights saved in ./checkpoint-1000/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-1000/preprocessor_config.json\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-1500\n",
+ "Configuration saved in ./checkpoint-1500/config.json\n",
+ "Model weights saved in ./checkpoint-1500/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-1500/preprocessor_config.json\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-2000\n",
+ "Configuration saved in ./checkpoint-2000/config.json\n",
+ "Model weights saved in ./checkpoint-2000/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-2000/preprocessor_config.json\n",
+ "Deleting older checkpoint [checkpoint-500] due to args.save_total_limit\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-2500\n",
+ "Configuration saved in ./checkpoint-2500/config.json\n",
+ "Model weights saved in ./checkpoint-2500/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-2500/preprocessor_config.json\n",
+ "Deleting older checkpoint [checkpoint-1000] due to args.save_total_limit\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-3000\n",
+ "Configuration saved in ./checkpoint-3000/config.json\n",
+ "Model weights saved in ./checkpoint-3000/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-3000/preprocessor_config.json\n",
+ "Deleting older checkpoint [checkpoint-1500] due to args.save_total_limit\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-3500\n",
+ "Configuration saved in ./checkpoint-3500/config.json\n",
+ "Model weights saved in ./checkpoint-3500/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-3500/preprocessor_config.json\n",
+ "Deleting older checkpoint [checkpoint-2000] due to args.save_total_limit\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-4000\n",
+ "Configuration saved in ./checkpoint-4000/config.json\n",
+ "Model weights saved in ./checkpoint-4000/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-4000/preprocessor_config.json\n",
+ "Deleting older checkpoint [checkpoint-2500] due to args.save_total_limit\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-4500\n",
+ "Configuration saved in ./checkpoint-4500/config.json\n",
+ "Model weights saved in ./checkpoint-4500/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-4500/preprocessor_config.json\n",
+ "Deleting older checkpoint [checkpoint-3000] due to args.save_total_limit\n",
+ "The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
+ "***** Running Evaluation *****\n",
+ " Num examples = 1647\n",
+ " Batch size = 8\n",
+ "Saving model checkpoint to ./checkpoint-5000\n",
+ "Configuration saved in ./checkpoint-5000/config.json\n",
+ "Model weights saved in ./checkpoint-5000/pytorch_model.bin\n",
+ "Configuration saved in ./checkpoint-5000/preprocessor_config.json\n",
+ "Deleting older checkpoint [checkpoint-3500] due to args.save_total_limit\n",
+ "\n",
+ "\n",
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=5425, training_loss=1.241710463449153, metrics={'train_runtime': 5469.9405, 'train_samples_per_second': 15.896, 'train_steps_per_second': 0.992, 'total_flos': 1.0590512839529611e+19, 'train_loss': 1.241710463449153, 'epoch': 25.0})"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f580e49e",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}