sappho192 commited on
Commit
adf30ec
·
verified ·
1 Parent(s): 8499b8b

Update training.ipynb

Browse files
Files changed (1) hide show
  1. training.ipynb +477 -233
training.ipynb CHANGED
@@ -1,278 +1,522 @@
1
  {
2
- "cells": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  {
4
- "attachments": {},
5
- "cell_type": "markdown",
6
- "metadata": {},
7
- "source": [
8
- "The primary codes below are based on [akpe12/JP-KR-ocr-translator-for-travel](https://github.com/akpe12/JP-KR-ocr-translator-for-travel)."
9
  ]
10
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  {
12
- "cell_type": "markdown",
13
- "metadata": {
14
- "id": "TrHlPFqwFAgj"
15
- },
16
- "source": [
17
- "## Import"
18
  ]
19
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  {
21
- "cell_type": "code",
22
- "execution_count": null,
23
- "metadata": {
24
- "id": "t-jXeSJKE1WM"
25
- },
26
- "outputs": [],
27
- "source": [
28
- "\n",
29
- "from typing import Dict, List\n",
30
- "import csv\n",
31
- "import torch\n",
32
- "from transformers import (\n",
33
- " EncoderDecoderModel,\n",
34
- " GPT2Tokenizer as BaseGPT2Tokenizer,\n",
35
- " PreTrainedTokenizer, BertTokenizerFast,\n",
36
- " PreTrainedTokenizerFast,\n",
37
- " DataCollatorForSeq2Seq,\n",
38
- " Seq2SeqTrainingArguments,\n",
39
- " AutoTokenizer,\n",
40
- " XLMRobertaTokenizerFast,\n",
41
- " BertJapaneseTokenizer,\n",
42
- " Trainer\n",
43
- ")\n",
44
- "from torch.utils.data import DataLoader\n",
45
- "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n",
46
- "\n",
47
- "# encoder_model_name = \"xlm-roberta-base\"\n",
48
- "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n",
49
- "decoder_model_name = \"skt/kogpt2-base-v2\""
50
- ]
51
- },
52
  {
53
- "cell_type": "code",
54
- "execution_count": null,
55
- "metadata": {
56
- "id": "nEW5trBtbykK"
57
- },
58
- "outputs": [],
59
- "source": [
60
- "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
61
- "# device = torch.device(\"cpu\")\n",
62
- "device, torch.cuda.device_count()"
63
  ]
 
 
 
64
  },
65
  {
66
- "cell_type": "code",
67
- "execution_count": null,
68
- "metadata": {
69
- "id": "5ic7pUUBFU_v"
 
70
  },
71
- "outputs": [],
72
- "source": [
73
- "class GPT2Tokenizer(PreTrainedTokenizerFast):\n",
74
- " def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n",
75
- " return token_ids + [self.eos_token_id] \n",
76
- "\n",
77
- "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n",
78
- "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='</s>', eos_token='</s>', unk_token='<unk>',\n",
79
- " pad_token='<pad>', mask_token='<mask>')"
80
  ]
 
 
 
81
  },
82
  {
83
- "cell_type": "markdown",
84
- "metadata": {
85
- "id": "DTf4U1fmFQFh"
86
- },
87
- "source": [
88
- "## Data"
89
  ]
 
 
 
90
  },
91
  {
92
- "cell_type": "code",
93
- "execution_count": null,
94
- "metadata": {
95
- "id": "65L4O1c5FLKt"
96
- },
97
- "outputs": [],
98
- "source": [
99
- "class PairedDataset:\n",
100
- " def __init__(self, \n",
101
- " src_tokenizer: PreTrainedTokenizerFast, tgt_tokenizer: PreTrainedTokenizerFast,\n",
102
- " file_path: str\n",
103
- " ):\n",
104
- " self.src_tokenizer = src_tokenizer\n",
105
- " self.trg_tokenizer = tgt_tokenizer\n",
106
- " with open(file_path, 'r') as fd:\n",
107
- " reader = csv.reader(fd)\n",
108
- " next(reader)\n",
109
- " self.data = [row for row in reader]\n",
110
- "\n",
111
- " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n",
112
- "# with open('train_log.txt', 'a+') as log_file:\n",
113
- "# log_file.write(f'reading data[{index}] {self.data[index]}\\n')\n",
114
- " src, trg = self.data[index]\n",
115
- " embeddings = self.src_tokenizer(src, return_attention_mask=False, return_token_type_ids=False)\n",
116
- " embeddings['labels'] = self.trg_tokenizer.build_inputs_with_special_tokens(self.trg_tokenizer(trg, return_attention_mask=False)['input_ids'])\n",
117
- "\n",
118
- " return embeddings\n",
119
- "\n",
120
- " def __len__(self):\n",
121
- " return len(self.data)\n",
122
- " \n",
123
- "DATA_ROOT = './output'\n",
124
- "FILE_FFAC_FULL = 'ffac_full.csv'\n",
125
- "FILE_FFAC_TEST = 'ffac_test.csv'\n",
126
- "FILE_JA_KO_TRAIN = 'ja_ko_train.csv'\n",
127
- "FILE_JA_KO_TEST = 'ja_ko_test.csv'\n",
128
- "\n",
129
- "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n",
130
- "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n",
131
- "train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n",
132
- "eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TEST}') "
133
  ]
 
 
 
134
  },
135
  {
136
- "cell_type": "code",
137
- "execution_count": null,
138
- "metadata": {},
139
- "outputs": [],
140
- "source": [
141
- "# be sure to check the column count of each dataset if you encounter \"ValueError: too many values to unpack (expected 2)\"\n",
142
- "# at the `src, trg = self.data[index]`\n",
143
- "# The `cat ffac_full.csv tteb_train.csv > ja_ko_train.csv` command may be the reason.\n",
144
- "# the last row of first csv and first row of second csv is merged and that's why 3rd column is created (which arouse ValueError)\n",
145
- "# debug_data = train_dataset.data\n"
146
  ]
 
 
 
147
  },
148
  {
149
- "cell_type": "markdown",
150
- "metadata": {
151
- "id": "uCBiLouSFiZY"
 
 
152
  },
153
- "source": [
154
- "## Model"
155
  ]
 
 
 
156
  },
157
  {
158
- "cell_type": "code",
159
- "execution_count": null,
160
- "metadata": {
161
- "id": "I7uFbFYJFje8"
162
- },
163
- "outputs": [],
164
- "source": [
165
- "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n",
166
- " encoder_model_name,\n",
167
- " decoder_model_name,\n",
168
- " pad_token_id=trg_tokenizer.bos_token_id,\n",
169
- ")\n",
170
- "model.config.decoder_start_token_id = trg_tokenizer.bos_token_id"
171
  ]
 
 
 
172
  },
173
  {
174
- "cell_type": "code",
175
- "execution_count": null,
176
- "metadata": {
177
- "id": "YFq2GyOAUV0W"
178
- },
179
- "outputs": [],
180
- "source": [
181
- "# for Trainer\n",
182
- "import wandb\n",
183
- "\n",
184
- "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n",
185
- "wandb.init(project=\"fftr-poc1\", name='jbert+kogpt2')\n",
186
- "\n",
187
- "arguments = Seq2SeqTrainingArguments(\n",
188
- " output_dir='dump',\n",
189
- " do_train=True,\n",
190
- " do_eval=True,\n",
191
- " evaluation_strategy=\"epoch\",\n",
192
- " save_strategy=\"epoch\",\n",
193
- " num_train_epochs=3,\n",
194
- " # num_train_epochs=25,\n",
195
- " per_device_train_batch_size=30,\n",
196
- " # per_device_train_batch_size=64,\n",
197
- " per_device_eval_batch_size=30,\n",
198
- " # per_device_eval_batch_size=64,\n",
199
- " warmup_ratio=0.1,\n",
200
- " gradient_accumulation_steps=4,\n",
201
- " save_total_limit=5,\n",
202
- " dataloader_num_workers=1,\n",
203
- " fp16=True,\n",
204
- " load_best_model_at_end=True,\n",
205
- " report_to='wandb'\n",
206
- ")\n",
207
- "\n",
208
- "trainer = Trainer(\n",
209
- " model,\n",
210
- " arguments,\n",
211
- " data_collator=collate_fn,\n",
212
- " train_dataset=train_dataset,\n",
213
- " eval_dataset=eval_dataset\n",
214
- ")"
215
  ]
 
 
 
216
  },
217
  {
218
- "cell_type": "markdown",
219
- "metadata": {
220
- "id": "pPsjDHO5Vc3y"
221
- },
222
- "source": [
223
- "## Training"
224
  ]
 
 
 
225
  },
226
  {
227
- "cell_type": "code",
228
- "execution_count": null,
229
- "metadata": {
230
- "id": "_T4P4XunmK-C"
231
- },
232
- "outputs": [],
233
- "source": [
234
- "# model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"xlm-roberta-base\", \"skt/kogpt2-base-v2\")"
235
  ]
 
 
 
236
  },
237
  {
238
- "cell_type": "code",
239
- "execution_count": null,
240
- "metadata": {
241
- "id": "7vTqAgW6Ve3J"
242
- },
243
- "outputs": [],
244
- "source": [
245
- "trainer.train()\n",
246
- "\n",
247
- "model.save_pretrained(\"dump/best_model\")\n",
248
- "src_tokenizer.save_pretrained(\"dump/best_model/src_tokenizer\")\n",
249
- "trg_tokenizer.save_pretrained(\"dump/best_model/trg_tokenizer\")"
250
  ]
 
 
 
251
  }
252
- ],
253
- "metadata": {
254
- "colab": {
255
- "machine_shape": "hm",
256
- "provenance": []
257
- },
258
- "gpuClass": "premium",
259
- "kernelspec": {
260
- "display_name": "Python 3",
261
- "name": "python3"
262
- },
263
- "language_info": {
264
- "codemirror_mode": {
265
- "name": "ipython",
266
- "version": 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  },
268
- "file_extension": ".py",
269
- "mimetype": "text/x-python",
270
- "name": "python",
271
- "nbconvert_exporter": "python",
272
- "pygments_lexer": "ipython3",
273
- "version": "3.8.10"
274
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  },
276
- "nbformat": 4,
277
- "nbformat_minor": 0
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "The primary codes below are based on [akpe12/JP-KR-ocr-translator-for-travel](https://github.com/akpe12/JP-KR-ocr-translator-for-travel)."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "metadata": {
14
+ "id": "TrHlPFqwFAgj"
15
+ },
16
+ "source": [
17
+ "## Import"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 1,
23
+ "metadata": {
24
+ "id": "t-jXeSJKE1WM"
25
+ },
26
+ "outputs": [],
27
+ "source": [
28
+ "from typing import Dict, List\n",
29
+ "import csv\n",
30
+ "\n",
31
+ "import datasets\n",
32
+ "import torch\n",
33
+ "from transformers import (\n",
34
+ " PreTrainedTokenizerFast,\n",
35
+ " DataCollatorForSeq2Seq,\n",
36
+ " Seq2SeqTrainingArguments,\n",
37
+ " BertJapaneseTokenizer,\n",
38
+ " Trainer\n",
39
+ ")\n",
40
+ "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n",
41
+ "\n",
42
+ "from datasets import load_dataset\n",
43
+ "\n",
44
+ "# encoder_model_name = \"xlm-roberta-base\"\n",
45
+ "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n",
46
+ "decoder_model_name = \"skt/kogpt2-base-v2\""
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 2,
52
+ "metadata": {
53
+ "id": "nEW5trBtbykK"
54
+ },
55
+ "outputs": [
56
  {
57
+ "data": {
58
+ "text/plain": [
59
+ "(device(type='cpu'), 0)"
 
 
60
  ]
61
+ },
62
+ "execution_count": 2,
63
+ "metadata": {},
64
+ "output_type": "execute_result"
65
+ }
66
+ ],
67
+ "source": [
68
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
69
+ "# device = torch.device(\"cpu\")\n",
70
+ "device, torch.cuda.device_count()"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 3,
76
+ "metadata": {
77
+ "id": "5ic7pUUBFU_v"
78
+ },
79
+ "outputs": [],
80
+ "source": [
81
+ "class GPT2Tokenizer(PreTrainedTokenizerFast):\n",
82
+ " def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n",
83
+ " return token_ids + [self.eos_token_id] \n",
84
+ "\n",
85
+ "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n",
86
+ "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='</s>', eos_token='</s>', unk_token='<unk>',\n",
87
+ " pad_token='<pad>', mask_token='<mask>')"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "metadata": {
93
+ "id": "DTf4U1fmFQFh"
94
+ },
95
+ "source": [
96
+ "## Data"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 4,
102
+ "metadata": {
103
+ "collapsed": false
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "dataset = load_dataset(\"sappho192/Tatoeba-Challenge-jpn-kor\")\n",
108
+ "# dataset = load_dataset(\"D:\\\\REPO\\\\Tatoeba-Challenge-jpn-kor\")\n",
109
+ "\n",
110
+ "train_dataset = dataset['train']\n",
111
+ "test_dataset = dataset['test']\n",
112
+ "\n",
113
+ "train_first_row = train_dataset[0]\n",
114
+ "test_first_row = test_dataset[0]"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 5,
120
+ "metadata": {
121
+ "id": "65L4O1c5FLKt"
122
+ },
123
+ "outputs": [],
124
+ "source": [
125
+ "class PairedDataset:\n",
126
+ " def __init__(self, \n",
127
+ " source_tokenizer: PreTrainedTokenizerFast, target_tokenizer: PreTrainedTokenizerFast,\n",
128
+ " file_path: str = None,\n",
129
+ " dataset_raw: datasets.Dataset = None\n",
130
+ " ):\n",
131
+ " self.src_tokenizer = source_tokenizer\n",
132
+ " self.trg_tokenizer = target_tokenizer\n",
133
+ " \n",
134
+ " if file_path is not None:\n",
135
+ " with open(file_path, 'r') as fd:\n",
136
+ " reader = csv.reader(fd)\n",
137
+ " next(reader)\n",
138
+ " self.data = [row for row in reader]\n",
139
+ " elif dataset_raw is not None:\n",
140
+ " self.data = dataset_raw\n",
141
+ " else:\n",
142
+ " raise ValueError('file_path or dataset_raw must be specified')\n",
143
+ "\n",
144
+ " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n",
145
+ "# with open('train_log.txt', 'a+') as log_file:\n",
146
+ "# log_file.write(f'reading data[{index}] {self.data[index]}\\n')\n",
147
+ " if isinstance(self.data, datasets.Dataset):\n",
148
+ " src, trg = self.data[index]['sourceString'], self.data[index]['targetString']\n",
149
+ " else:\n",
150
+ " src, trg = self.data[index]\n",
151
+ " embeddings = self.src_tokenizer(src, return_attention_mask=False, return_token_type_ids=False)\n",
152
+ " embeddings['labels'] = self.trg_tokenizer.build_inputs_with_special_tokens(self.trg_tokenizer(trg, return_attention_mask=False)['input_ids'])\n",
153
+ "\n",
154
+ " return embeddings\n",
155
+ "\n",
156
+ " def __len__(self):\n",
157
+ " return len(self.data)"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 6,
163
+ "metadata": {
164
+ "collapsed": false
165
+ },
166
+ "outputs": [],
167
+ "source": [
168
+ "DATA_ROOT = './output'\n",
169
+ "FILE_FFAC_FULL = 'ffac_full.csv'\n",
170
+ "FILE_FFAC_TEST = 'ffac_test.csv'\n",
171
+ "FILE_JA_KO_TRAIN = 'ja_ko_train.csv'\n",
172
+ "FILE_JA_KO_TEST = 'ja_ko_test.csv'\n",
173
+ "\n",
174
+ "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n",
175
+ "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n",
176
+ "\n",
177
+ "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n",
178
+ "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, file_path=f'{DATA_ROOT}/{FILE_JA_KO_TEST}')"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 7,
184
+ "metadata": {
185
+ "collapsed": false
186
+ },
187
+ "outputs": [
188
  {
189
+ "data": {
190
+ "text/plain": [
191
+ "{'input_ids': [2, 33, 2181, 1402, 893, 15200, 893, 13507, 881, 933, 882, 829, 3], 'labels': [9085, 10936, 10993, 23363, 9134, 18368, 8006, 389, 1]}"
 
 
 
192
  ]
193
+ },
194
+ "execution_count": 7,
195
+ "metadata": {},
196
+ "output_type": "execute_result"
197
+ }
198
+ ],
199
+ "source": [
200
+ "train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=train_dataset)\n",
201
+ "eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, dataset_raw=test_dataset)\n",
202
+ "eval_dataset[0]"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 8,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "# be sure to check the column count of each dataset if you encounter \"ValueError: too many values to unpack (expected 2)\"\n",
212
+ "# at the `src, trg = self.data[index]`\n",
213
+ "# The `cat ffac_full.csv tteb_train.csv > ja_ko_train.csv` command may be the reason.\n",
214
+ "# the last row of first csv and first row of second csv is merged and that's why 3rd column is created (which arouse ValueError)\n",
215
+ "# debug_data = train_dataset.data\n"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {
221
+ "id": "uCBiLouSFiZY"
222
+ },
223
+ "source": [
224
+ "## Model"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": 9,
230
+ "metadata": {
231
+ "id": "I7uFbFYJFje8"
232
+ },
233
+ "outputs": [
234
  {
235
+ "name": "stderr",
236
+ "output_type": "stream",
237
+ "text": [
238
+ "Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at skt/kogpt2-base-v2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.10.crossattention.c_attn.bias', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.10.crossattention.c_proj.bias', 'transformer.h.10.crossattention.c_proj.weight', 'transformer.h.10.crossattention.q_attn.bias', 'transformer.h.10.crossattention.q_attn.weight', 'transformer.h.10.ln_cross_attn.bias', 'transformer.h.10.ln_cross_attn.weight', 'transformer.h.11.crossattention.c_attn.bias', 'transformer.h.11.crossattention.c_attn.weight', 'transformer.h.11.crossattention.c_proj.bias', 'transformer.h.11.crossattention.c_proj.weight', 'transformer.h.11.crossattention.q_attn.bias', 'transformer.h.11.crossattention.q_attn.weight', 'transformer.h.11.ln_cross_attn.bias', 'transformer.h.11.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_attn.bias', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.2.crossattention.q_attn.bias', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.2.ln_cross_attn.bias', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_attn.bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.3.crossattention.c_proj.bias', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.3.crossattention.q_attn.bias', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.3.ln_cross_attn.bias', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_attn.bias', 'transformer.h.4.crossattention.c_attn.weight', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.4.crossattention.q_attn.bias', 'transformer.h.4.crossattention.q_attn.weight', 'transformer.h.4.ln_cross_attn.bias', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_attn.bias', 'transformer.h.5.crossattention.c_attn.weight', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.5.crossattention.q_attn.bias', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.5.ln_cross_attn.bias', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.6.crossattention.c_attn.bias', 'transformer.h.6.crossattention.c_attn.weight', 'transformer.h.6.crossattention.c_proj.bias', 'transformer.h.6.crossattention.c_proj.weight', 'transformer.h.6.crossattention.q_attn.bias', 'transformer.h.6.crossattention.q_attn.weight', 'transformer.h.6.ln_cross_attn.bias', 'transformer.h.6.ln_cross_attn.weight', 'transformer.h.7.crossattention.c_attn.bias', 'transformer.h.7.crossattention.c_attn.weight', 'transformer.h.7.crossattention.c_proj.bias', 'transformer.h.7.crossattention.c_proj.weight', 'transformer.h.7.crossattention.q_attn.bias', 'transformer.h.7.crossattention.q_attn.weight', 'transformer.h.7.ln_cross_attn.bias', 'transformer.h.7.ln_cross_attn.weight', 'transformer.h.8.crossattention.c_attn.bias', 'transformer.h.8.crossattention.c_attn.weight', 'transformer.h.8.crossattention.c_proj.bias', 'transformer.h.8.crossattention.c_proj.weight', 'transformer.h.8.crossattention.q_attn.bias', 'transformer.h.8.crossattention.q_attn.weight', 'transformer.h.8.ln_cross_attn.bias', 'transformer.h.8.ln_cross_attn.weight', 'transformer.h.9.crossattention.c_attn.bias', 'transformer.h.9.crossattention.c_attn.weight', 'transformer.h.9.crossattention.c_proj.bias', 'transformer.h.9.crossattention.c_proj.weight', 'transformer.h.9.crossattention.q_attn.bias', 'transformer.h.9.crossattention.q_attn.weight', 'transformer.h.9.ln_cross_attn.bias', 'transformer.h.9.ln_cross_attn.weight']\n",
239
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n",
245
+ " encoder_model_name,\n",
246
+ " decoder_model_name,\n",
247
+ " pad_token_id=trg_tokenizer.bos_token_id,\n",
248
+ ")\n",
249
+ "model.config.decoder_start_token_id = trg_tokenizer.bos_token_id"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 11,
255
+ "metadata": {
256
+ "id": "YFq2GyOAUV0W"
257
+ },
258
+ "outputs": [
 
 
 
 
 
 
 
259
  {
260
+ "data": {
261
+ "text/html": [
262
+ "Finishing last run (ID:1vwqqxps) before initializing another..."
263
+ ],
264
+ "text/plain": [
265
+ "<IPython.core.display.HTML object>"
 
 
 
 
266
  ]
267
+ },
268
+ "metadata": {},
269
+ "output_type": "display_data"
270
  },
271
  {
272
+ "data": {
273
+ "application/vnd.jupyter.widget-view+json": {
274
+ "model_id": "a82aa19a250b43f28d7ecc72eeebc88d",
275
+ "version_major": 2,
276
+ "version_minor": 0
277
  },
278
+ "text/plain": [
279
+ "VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded\\r'), FloatProgress(value=0.10972568578553615, max=1.…"
 
 
 
 
 
 
 
280
  ]
281
+ },
282
+ "metadata": {},
283
+ "output_type": "display_data"
284
  },
285
  {
286
+ "data": {
287
+ "text/html": [
288
+ " View run <strong style=\"color:#cdcd00\">jbert+kogpt2</strong> at: <a href='https://wandb.ai/sappho192/fftr-poc1/runs/1vwqqxps' target=\"_blank\">https://wandb.ai/sappho192/fftr-poc1/runs/1vwqqxps</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
289
+ ],
290
+ "text/plain": [
291
+ "<IPython.core.display.HTML object>"
292
  ]
293
+ },
294
+ "metadata": {},
295
+ "output_type": "display_data"
296
  },
297
  {
298
+ "data": {
299
+ "text/html": [
300
+ "Find logs at: <code>.\\wandb\\run-20240131_135356-1vwqqxps\\logs</code>"
301
+ ],
302
+ "text/plain": [
303
+ "<IPython.core.display.HTML object>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  ]
305
+ },
306
+ "metadata": {},
307
+ "output_type": "display_data"
308
  },
309
  {
310
+ "data": {
311
+ "text/html": [
312
+ "Successfully finished last run (ID:1vwqqxps). Initializing new run:<br/>"
313
+ ],
314
+ "text/plain": [
315
+ "<IPython.core.display.HTML object>"
 
 
 
 
316
  ]
317
+ },
318
+ "metadata": {},
319
+ "output_type": "display_data"
320
  },
321
  {
322
+ "data": {
323
+ "application/vnd.jupyter.widget-view+json": {
324
+ "model_id": "c2cd7f6fb5b1428b98b80a3cc82ec303",
325
+ "version_major": 2,
326
+ "version_minor": 0
327
  },
328
+ "text/plain": [
329
+ "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011288888888884685, max=1.0…"
330
  ]
331
+ },
332
+ "metadata": {},
333
+ "output_type": "display_data"
334
  },
335
  {
336
+ "data": {
337
+ "text/html": [
338
+ "Tracking run with wandb version 0.16.2"
339
+ ],
340
+ "text/plain": [
341
+ "<IPython.core.display.HTML object>"
 
 
 
 
 
 
 
342
  ]
343
+ },
344
+ "metadata": {},
345
+ "output_type": "display_data"
346
  },
347
  {
348
+ "data": {
349
+ "text/html": [
350
+ "Run data is saved locally in <code>d:\\REPO\\ffxiv-ja-ko-translator\\wandb\\run-20240131_135421-etxsdxw2</code>"
351
+ ],
352
+ "text/plain": [
353
+ "<IPython.core.display.HTML object>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  ]
355
+ },
356
+ "metadata": {},
357
+ "output_type": "display_data"
358
  },
359
  {
360
+ "data": {
361
+ "text/html": [
362
+ "Syncing run <strong><a href='https://wandb.ai/sappho192/fftr-poc1/runs/etxsdxw2' target=\"_blank\">jbert+kogpt2</a></strong> to <a href='https://wandb.ai/sappho192/fftr-poc1' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
363
+ ],
364
+ "text/plain": [
365
+ "<IPython.core.display.HTML object>"
366
  ]
367
+ },
368
+ "metadata": {},
369
+ "output_type": "display_data"
370
  },
371
  {
372
+ "data": {
373
+ "text/html": [
374
+ " View project at <a href='https://wandb.ai/sappho192/fftr-poc1' target=\"_blank\">https://wandb.ai/sappho192/fftr-poc1</a>"
375
+ ],
376
+ "text/plain": [
377
+ "<IPython.core.display.HTML object>"
 
 
378
  ]
379
+ },
380
+ "metadata": {},
381
+ "output_type": "display_data"
382
  },
383
  {
384
+ "data": {
385
+ "text/html": [
386
+ " View run at <a href='https://wandb.ai/sappho192/fftr-poc1/runs/etxsdxw2' target=\"_blank\">https://wandb.ai/sappho192/fftr-poc1/runs/etxsdxw2</a>"
387
+ ],
388
+ "text/plain": [
389
+ "<IPython.core.display.HTML object>"
 
 
 
 
 
 
390
  ]
391
+ },
392
+ "metadata": {},
393
+ "output_type": "display_data"
394
  }
395
+ ],
396
+ "source": [
397
+ "# for Trainer\n",
398
+ "import wandb\n",
399
+ "\n",
400
+ "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n",
401
+ "wandb.init(project=\"fftr-poc1\", name='jbert+kogpt2')\n",
402
+ "\n",
403
+ "arguments = Seq2SeqTrainingArguments(\n",
404
+ " output_dir='dump',\n",
405
+ " do_train=True,\n",
406
+ " do_eval=True,\n",
407
+ " evaluation_strategy=\"epoch\",\n",
408
+ " save_strategy=\"epoch\",\n",
409
+ " num_train_epochs=3,\n",
410
+ " # num_train_epochs=25,\n",
411
+ " per_device_train_batch_size=1,\n",
412
+ " # per_device_train_batch_size=30, # takes 40GB\n",
413
+ " # per_device_train_batch_size=64,\n",
414
+ " per_device_eval_batch_size=1,\n",
415
+ " # per_device_eval_batch_size=30,\n",
416
+ " # per_device_eval_batch_size=64,\n",
417
+ " warmup_ratio=0.1,\n",
418
+ " gradient_accumulation_steps=4,\n",
419
+ " save_total_limit=5,\n",
420
+ " dataloader_num_workers=1,\n",
421
+ " # fp16=True, # ENABLE if CUDA is enabled\n",
422
+ " load_best_model_at_end=True,\n",
423
+ " report_to='wandb'\n",
424
+ ")\n",
425
+ "\n",
426
+ "trainer = Trainer(\n",
427
+ " model,\n",
428
+ " arguments,\n",
429
+ " data_collator=collate_fn,\n",
430
+ " train_dataset=train_dataset,\n",
431
+ " eval_dataset=eval_dataset\n",
432
+ ")"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "markdown",
437
+ "metadata": {
438
+ "id": "pPsjDHO5Vc3y"
439
+ },
440
+ "source": [
441
+ "## Training"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "metadata": {
448
+ "id": "_T4P4XunmK-C"
449
+ },
450
+ "outputs": [],
451
+ "source": [
452
+ "# model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"xlm-roberta-base\", \"skt/kogpt2-base-v2\")"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "code",
457
+ "execution_count": 12,
458
+ "metadata": {
459
+ "id": "7vTqAgW6Ve3J"
460
+ },
461
+ "outputs": [
462
+ {
463
+ "data": {
464
+ "application/vnd.jupyter.widget-view+json": {
465
+ "model_id": "0afe460e9f614d9a90379cf99fcf8af3",
466
+ "version_major": 2,
467
+ "version_minor": 0
468
  },
469
+ "text/plain": [
470
+ " 0%| | 0/9671328 [00:00<?, ?it/s]"
471
+ ]
472
+ },
473
+ "metadata": {},
474
+ "output_type": "display_data"
475
  }
476
+ ],
477
+ "source": [
478
+ "trainer.train()\n",
479
+ "\n",
480
+ "model.save_pretrained(\"dump/best_model\")\n",
481
+ "src_tokenizer.save_pretrained(\"dump/best_model/src_tokenizer\")\n",
482
+ "trg_tokenizer.save_pretrained(\"dump/best_model/trg_tokenizer\")"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "code",
487
+ "execution_count": 2,
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": [
491
+ "# import wandb\n",
492
+ "# wandb.finish()"
493
+ ]
494
+ }
495
+ ],
496
+ "metadata": {
497
+ "colab": {
498
+ "machine_shape": "hm",
499
+ "provenance": []
500
+ },
501
+ "gpuClass": "premium",
502
+ "kernelspec": {
503
+ "display_name": "Python 3 (ipykernel)",
504
+ "language": "python",
505
+ "name": "python3"
506
  },
507
+ "language_info": {
508
+ "codemirror_mode": {
509
+ "name": "ipython",
510
+ "version": 3
511
+ },
512
+ "file_extension": ".py",
513
+ "mimetype": "text/x-python",
514
+ "name": "python",
515
+ "nbconvert_exporter": "python",
516
+ "pygments_lexer": "ipython3",
517
+ "version": "3.10.13"
518
+ }
519
+ },
520
+ "nbformat": 4,
521
+ "nbformat_minor": 0
522
  }