Cielciel/aift-model-review-multiple-label-classification
Browse files- .ipynb_checkpoints/Aift-review-multiple-label-classification-workflow-checkpoint.ipynb +6 -0
- Aift-review-multiple-label-classification-workflow.ipynb +1613 -0
- README.md +68 -0
- config.json +42 -0
- custom_container/Dockerfile +21 -0
- custom_container/README.md +66 -0
- custom_container/scripts/train-cloud.sh +80 -0
- model.safetensors +3 -0
- python_package/README.md +58 -0
- python_package/dist/trainer-0.1.tar.gz +3 -0
- python_package/scripts/train-cloud.sh +70 -0
- python_package/setup.py +24 -0
- python_package/trainer.egg-info/PKG-INFO +8 -0
- python_package/trainer.egg-info/SOURCES.txt +13 -0
- python_package/trainer.egg-info/dependency_links.txt +1 -0
- python_package/trainer.egg-info/requires.txt +4 -0
- python_package/trainer.egg-info/top_level.txt +1 -0
- python_package/trainer/__init__.py +0 -0
- python_package/trainer/experiment.py +137 -0
- python_package/trainer/metadata.py +31 -0
- python_package/trainer/model.py +31 -0
- python_package/trainer/task.py +104 -0
- python_package/trainer/utils.py +99 -0
- runs/Jan08_04-05-34_aift-review-classification-multiple-label/events.out.tfevents.1704686768.aift-review-classification-multiple-label +3 -0
- runs/Jan08_04-07-17_aift-review-classification-multiple-label/events.out.tfevents.1704686842.aift-review-classification-multiple-label +3 -0
- runs/Jan08_04-07-17_aift-review-classification-multiple-label/events.out.tfevents.1704687081.aift-review-classification-multiple-label +3 -0
- special_tokens_map.json +7 -0
- tokenizer.json +0 -0
- tokenizer_config.json +55 -0
- training_args.bin +3 -0
- vocab.txt +0 -0
.ipynb_checkpoints/Aift-review-multiple-label-classification-workflow-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
Aift-review-multiple-label-classification-workflow.ipynb
ADDED
@@ -0,0 +1,1613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "53a990e3-0d47-4e66-b928-f40d67f06584",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Setup"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "51fb0d43-c12b-4892-95d2-074bf5de0ce2",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"## Install addition packages"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": 1,
|
22 |
+
"id": "9cf48779-454b-4b1d-b78f-531a1b207276",
|
23 |
+
"metadata": {
|
24 |
+
"tags": []
|
25 |
+
},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"import os\n",
|
29 |
+
"\n",
|
30 |
+
"# The Google Cloud Notebook product has specific requirements\n",
|
31 |
+
"IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists(\"/opt/deeplearning/metadata/env_version\")\n",
|
32 |
+
"\n",
|
33 |
+
"# Google Cloud Notebook requires dependencies to be installed with '--user'\n",
|
34 |
+
"USER_FLAG = \"\"\n",
|
35 |
+
"if IS_GOOGLE_CLOUD_NOTEBOOK:\n",
|
36 |
+
" USER_FLAG = \"--user\""
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 2,
|
42 |
+
"id": "d2a3556a-ebf1-49c7-9d2c-63e30ca45f73",
|
43 |
+
"metadata": {
|
44 |
+
"tags": []
|
45 |
+
},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"%%capture\n",
|
49 |
+
"!pip -q install {USER_FLAG} --upgrade transformers\n",
|
50 |
+
"!pip -q install {USER_FLAG} --upgrade datasets\n",
|
51 |
+
"!pip -q install {USER_FLAG} --upgrade tqdm\n",
|
52 |
+
"!pip -q install {USER_FLAG} --upgrade cloudml-hypertune"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": 3,
|
58 |
+
"id": "fcc3f1f6-36d3-4056-ad29-b69c57bb0bac",
|
59 |
+
"metadata": {
|
60 |
+
"tags": []
|
61 |
+
},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"%%capture\n",
|
65 |
+
"!pip -q install {USER_FLAG} --upgrade google-cloud-aiplatform"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": 4,
|
71 |
+
"id": "2214d165-356d-47f1-a4ee-4f6c50027e96",
|
72 |
+
"metadata": {
|
73 |
+
"tags": []
|
74 |
+
},
|
75 |
+
"outputs": [],
|
76 |
+
"source": [
|
77 |
+
"# Automatically restart kernel after installs\n",
|
78 |
+
"import os\n",
|
79 |
+
"\n",
|
80 |
+
"if not os.getenv(\"IS_TESTING\"):\n",
|
81 |
+
" # Automatically restart kernel after installs\n",
|
82 |
+
" import IPython\n",
|
83 |
+
"\n",
|
84 |
+
" app = IPython.Application.instance()\n",
|
85 |
+
" app.kernel.do_shutdown(True)"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": 1,
|
91 |
+
"id": "e8817443-c80e-475b-b54e-dd834c040b12",
|
92 |
+
"metadata": {},
|
93 |
+
"outputs": [],
|
94 |
+
"source": [
|
95 |
+
"%%capture\n",
|
96 |
+
"!pip install git+https://github.com/huggingface/transformers.git datasets pandas torch\n",
|
97 |
+
"!pip install transformers[torch]\n",
|
98 |
+
"!pip install accelerate -U"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "markdown",
|
103 |
+
"id": "21cc7690-95bf-4452-abef-46cd318ccfb5",
|
104 |
+
"metadata": {},
|
105 |
+
"source": [
|
106 |
+
"## Set Project ID"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"execution_count": 2,
|
112 |
+
"id": "30b78533-ff39-4c92-a365-f2e05ddb642f",
|
113 |
+
"metadata": {
|
114 |
+
"tags": []
|
115 |
+
},
|
116 |
+
"outputs": [
|
117 |
+
{
|
118 |
+
"name": "stdout",
|
119 |
+
"output_type": "stream",
|
120 |
+
"text": [
|
121 |
+
"Project ID: ikame-gem-ai-research\n"
|
122 |
+
]
|
123 |
+
}
|
124 |
+
],
|
125 |
+
"source": [
|
126 |
+
"PROJECT_ID = \"iKame-gem-ai-research\" # <---CHANGE THIS TO YOUR PROJECT\n",
|
127 |
+
"\n",
|
128 |
+
"import os\n",
|
129 |
+
"\n",
|
130 |
+
"# Get your Google Cloud project ID using google.auth\n",
|
131 |
+
"if not os.getenv(\"IS_TESTING\"):\n",
|
132 |
+
" import google.auth\n",
|
133 |
+
"\n",
|
134 |
+
" _, PROJECT_ID = google.auth.default()\n",
|
135 |
+
" print(\"Project ID: \", PROJECT_ID)\n",
|
136 |
+
"\n",
|
137 |
+
"# validate PROJECT_ID\n",
|
138 |
+
"if PROJECT_ID == \"\" or PROJECT_ID is None or PROJECT_ID == \"iKame-gem-ai-research\":\n",
|
139 |
+
" print(\n",
|
140 |
+
" f\"Please set your project id before proceeding to next step. Currently it's set as {PROJECT_ID}\"\n",
|
141 |
+
" )"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": 3,
|
147 |
+
"id": "5c4631f5-c8ba-43e9-a623-08cb2cb3a51a",
|
148 |
+
"metadata": {
|
149 |
+
"tags": []
|
150 |
+
},
|
151 |
+
"outputs": [
|
152 |
+
{
|
153 |
+
"name": "stdout",
|
154 |
+
"output_type": "stream",
|
155 |
+
"text": [
|
156 |
+
"TIMESTAMP = 20240108040502\n"
|
157 |
+
]
|
158 |
+
}
|
159 |
+
],
|
160 |
+
"source": [
|
161 |
+
"from datetime import datetime\n",
|
162 |
+
"\n",
|
163 |
+
"\n",
|
164 |
+
"def get_timestamp():\n",
|
165 |
+
" return datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
|
166 |
+
"\n",
|
167 |
+
"\n",
|
168 |
+
"TIMESTAMP = get_timestamp()\n",
|
169 |
+
"print(f\"TIMESTAMP = {TIMESTAMP}\")"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "markdown",
|
174 |
+
"id": "494d8009-7f9a-45d8-ba7c-3e3205d1c96b",
|
175 |
+
"metadata": {},
|
176 |
+
"source": [
|
177 |
+
"## Create Cloud Storage bucket"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": 4,
|
183 |
+
"id": "303136a0-6334-4889-b43b-9f171a934311",
|
184 |
+
"metadata": {
|
185 |
+
"tags": []
|
186 |
+
},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"BUCKET_NAME = \"gs://iKame-gem-ai-research\" # <---CHANGE THIS TO YOUR BUCKET\n",
|
190 |
+
"REGION = \"us-central1\" # @param {type:\"string\"}"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": 5,
|
196 |
+
"id": "014c6208-0b1a-4da8-888b-19c02a112474",
|
197 |
+
"metadata": {
|
198 |
+
"tags": []
|
199 |
+
},
|
200 |
+
"outputs": [],
|
201 |
+
"source": [
|
202 |
+
"if BUCKET_NAME == \"\" or BUCKET_NAME is None or BUCKET_NAME == \"gs://iKame-gem-ai-research\":\n",
|
203 |
+
" BUCKET_NAME = f\"gs://{PROJECT_ID}-bucket-review\""
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"execution_count": 6,
|
209 |
+
"id": "a52a28fa-591e-487c-bd53-8f770441ba63",
|
210 |
+
"metadata": {
|
211 |
+
"tags": []
|
212 |
+
},
|
213 |
+
"outputs": [
|
214 |
+
{
|
215 |
+
"name": "stdout",
|
216 |
+
"output_type": "stream",
|
217 |
+
"text": [
|
218 |
+
"PROJECT_ID = ikame-gem-ai-research\n",
|
219 |
+
"BUCKET_NAME = gs://ikame-gem-ai-research-bucket-review\n",
|
220 |
+
"REGION = us-central1\n"
|
221 |
+
]
|
222 |
+
}
|
223 |
+
],
|
224 |
+
"source": [
|
225 |
+
"print(f\"PROJECT_ID = {PROJECT_ID}\")\n",
|
226 |
+
"print(f\"BUCKET_NAME = {BUCKET_NAME}\")\n",
|
227 |
+
"print(f\"REGION = {REGION}\")"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": 7,
|
233 |
+
"id": "24c35eb2-7619-4958-a04a-79b62788f257",
|
234 |
+
"metadata": {
|
235 |
+
"tags": []
|
236 |
+
},
|
237 |
+
"outputs": [],
|
238 |
+
"source": [
|
239 |
+
"# ! gsutil mb -l $REGION $BUCKET_NAME"
|
240 |
+
]
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"execution_count": 8,
|
245 |
+
"id": "6f2ee0a0-3cff-47cb-9379-6f6e75fef9d5",
|
246 |
+
"metadata": {
|
247 |
+
"tags": []
|
248 |
+
},
|
249 |
+
"outputs": [
|
250 |
+
{
|
251 |
+
"name": "stdout",
|
252 |
+
"output_type": "stream",
|
253 |
+
"text": [
|
254 |
+
" 3078 2024-01-05T01:42:25Z gs://ikame-gem-ai-research-bucket-review/batch_examples.csv#1704418945853255 metageneration=1\n",
|
255 |
+
" gs://ikame-gem-ai-research-bucket-review/pipeline_root/\n",
|
256 |
+
"TOTAL: 1 objects, 3078 bytes (3.01 KiB)\n"
|
257 |
+
]
|
258 |
+
}
|
259 |
+
],
|
260 |
+
"source": [
|
261 |
+
"! gsutil ls -al $BUCKET_NAME #validate access to your Cloud Storage bucket"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "markdown",
|
266 |
+
"id": "da865a4c-5e29-465e-abf2-e443dae1b573",
|
267 |
+
"metadata": {},
|
268 |
+
"source": [
|
269 |
+
"## Install libraries"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": 9,
|
275 |
+
"id": "fedbebaf-516e-4f7d-8a70-c7dc31de02df",
|
276 |
+
"metadata": {
|
277 |
+
"tags": []
|
278 |
+
},
|
279 |
+
"outputs": [],
|
280 |
+
"source": [
|
281 |
+
"import base64\n",
|
282 |
+
"import json\n",
|
283 |
+
"import os\n",
|
284 |
+
"import random\n",
|
285 |
+
"import sys\n",
|
286 |
+
"\n",
|
287 |
+
"import google.auth\n",
|
288 |
+
"from google.cloud import aiplatform\n",
|
289 |
+
"from google.cloud.aiplatform import gapic as aip\n",
|
290 |
+
"from google.cloud.aiplatform import hyperparameter_tuning as hpt\n",
|
291 |
+
"from google.protobuf.json_format import MessageToDict"
|
292 |
+
]
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "code",
|
296 |
+
"execution_count": 10,
|
297 |
+
"id": "0cc75279-b7a9-47cc-81a4-f8729c7d57f8",
|
298 |
+
"metadata": {
|
299 |
+
"tags": []
|
300 |
+
},
|
301 |
+
"outputs": [],
|
302 |
+
"source": [
|
303 |
+
"from IPython.display import HTML, display"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": 11,
|
309 |
+
"id": "8856c9f3-270f-4dca-8a10-6bdee1af8bc0",
|
310 |
+
"metadata": {
|
311 |
+
"tags": []
|
312 |
+
},
|
313 |
+
"outputs": [],
|
314 |
+
"source": [
|
315 |
+
"import datasets\n",
|
316 |
+
"from datasets import Dataset, DatasetDict\n",
|
317 |
+
"import numpy as np\n",
|
318 |
+
"import pandas as pd\n",
|
319 |
+
"import torch\n",
|
320 |
+
"import transformers\n",
|
321 |
+
"from datasets import ClassLabel, Sequence, load_dataset\n",
|
322 |
+
"from transformers import (AutoModelForSequenceClassification, AutoTokenizer,BertForSequenceClassification,\n",
|
323 |
+
" EvalPrediction, Trainer, TrainingArguments,PreTrainedModel,BertModel,\n",
|
324 |
+
" default_data_collator)"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "code",
|
329 |
+
"execution_count": 12,
|
330 |
+
"id": "bbecdaa8-3cd3-4e7b-939d-f959da9301d6",
|
331 |
+
"metadata": {
|
332 |
+
"tags": []
|
333 |
+
},
|
334 |
+
"outputs": [],
|
335 |
+
"source": [
|
336 |
+
"from google.cloud import bigquery\n",
|
337 |
+
"from google.cloud import storage\n",
|
338 |
+
"\n",
|
339 |
+
"client = bigquery.Client()\n",
|
340 |
+
"storage_client = storage.Client()"
|
341 |
+
]
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"cell_type": "code",
|
345 |
+
"execution_count": 13,
|
346 |
+
"id": "f693060f-c0ed-4ec3-bc66-17898f8ef854",
|
347 |
+
"metadata": {
|
348 |
+
"tags": []
|
349 |
+
},
|
350 |
+
"outputs": [
|
351 |
+
{
|
352 |
+
"name": "stdout",
|
353 |
+
"output_type": "stream",
|
354 |
+
"text": [
|
355 |
+
"Notebook runtime: GPU\n",
|
356 |
+
"PyTorch version : 2.0.0+cu118\n",
|
357 |
+
"Transformers version : 2.16.1\n",
|
358 |
+
"Datasets version : 4.37.0.dev0\n"
|
359 |
+
]
|
360 |
+
}
|
361 |
+
],
|
362 |
+
"source": [
|
363 |
+
"print(f\"Notebook runtime: {'GPU' if torch.cuda.is_available() else 'CPU'}\")\n",
|
364 |
+
"print(f\"PyTorch version : {torch.__version__}\")\n",
|
365 |
+
"print(f\"Transformers version : {datasets.__version__}\")\n",
|
366 |
+
"print(f\"Datasets version : {transformers.__version__}\")"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "code",
|
371 |
+
"execution_count": 15,
|
372 |
+
"id": "5637d9f0-d290-4107-974a-bfbda3b316b2",
|
373 |
+
"metadata": {
|
374 |
+
"tags": []
|
375 |
+
},
|
376 |
+
"outputs": [],
|
377 |
+
"source": [
|
378 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
|
379 |
+
]
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"cell_type": "code",
|
383 |
+
"execution_count": 14,
|
384 |
+
"id": "3d114e96-31c2-4ed9-82d1-f2fab38f0944",
|
385 |
+
"metadata": {
|
386 |
+
"tags": []
|
387 |
+
},
|
388 |
+
"outputs": [],
|
389 |
+
"source": [
|
390 |
+
"APP_NAME = \"aift-review-classificatio-multiple-label\""
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"cell_type": "code",
|
395 |
+
"execution_count": null,
|
396 |
+
"id": "173dcb77-9908-4af1-86bb-7811c9f580e9",
|
397 |
+
"metadata": {},
|
398 |
+
"outputs": [],
|
399 |
+
"source": [
|
400 |
+
"!cd aift-model-review-multiple-label-classification"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "markdown",
|
405 |
+
"id": "3f383051-501f-4f8c-8017-c989c5740041",
|
406 |
+
"metadata": {},
|
407 |
+
"source": [
|
408 |
+
"# Training"
|
409 |
+
]
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"cell_type": "markdown",
|
413 |
+
"id": "db9715cc-0779-47a4-a0ed-82714b6668f6",
|
414 |
+
"metadata": {},
|
415 |
+
"source": [
|
416 |
+
"## Preprocess data"
|
417 |
+
]
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"cell_type": "code",
|
421 |
+
"execution_count": 16,
|
422 |
+
"id": "052ecc7b-c015-49a0-a359-85afbac10bbf",
|
423 |
+
"metadata": {
|
424 |
+
"tags": []
|
425 |
+
},
|
426 |
+
"outputs": [],
|
427 |
+
"source": [
|
428 |
+
"model_ckpt = \"distilbert-base-uncased\"\n",
|
429 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
430 |
+
"\n",
|
431 |
+
"def tokenize_and_encode(examples):\n",
|
432 |
+
" return tokenizer(examples[\"review\"], truncation=True)"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": 17,
|
438 |
+
"id": "6f5faf02-ede8-4d48-b94a-1d4619c8e610",
|
439 |
+
"metadata": {
|
440 |
+
"tags": []
|
441 |
+
},
|
442 |
+
"outputs": [
|
443 |
+
{
|
444 |
+
"data": {
|
445 |
+
"application/vnd.jupyter.widget-view+json": {
|
446 |
+
"model_id": "7a2415bdfd4a40fe80afe71e70d97976",
|
447 |
+
"version_major": 2,
|
448 |
+
"version_minor": 0
|
449 |
+
},
|
450 |
+
"text/plain": [
|
451 |
+
"Map: 0%| | 0/556 [00:00<?, ? examples/s]"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
"metadata": {},
|
455 |
+
"output_type": "display_data"
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"data": {
|
459 |
+
"application/vnd.jupyter.widget-view+json": {
|
460 |
+
"model_id": "3b1c36309d4e4e108e79578edc45ed56",
|
461 |
+
"version_major": 2,
|
462 |
+
"version_minor": 0
|
463 |
+
},
|
464 |
+
"text/plain": [
|
465 |
+
"Map: 0%| | 0/140 [00:00<?, ? examples/s]"
|
466 |
+
]
|
467 |
+
},
|
468 |
+
"metadata": {},
|
469 |
+
"output_type": "display_data"
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"data": {
|
473 |
+
"application/vnd.jupyter.widget-view+json": {
|
474 |
+
"model_id": "2b79b69e8457427781c8e6fc8ad54d82",
|
475 |
+
"version_major": 2,
|
476 |
+
"version_minor": 0
|
477 |
+
},
|
478 |
+
"text/plain": [
|
479 |
+
"Map: 0%| | 0/556 [00:00<?, ? examples/s]"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
"metadata": {},
|
483 |
+
"output_type": "display_data"
|
484 |
+
},
|
485 |
+
{
|
486 |
+
"data": {
|
487 |
+
"application/vnd.jupyter.widget-view+json": {
|
488 |
+
"model_id": "e1e4981003d04646944fa0ce8ae0dc73",
|
489 |
+
"version_major": 2,
|
490 |
+
"version_minor": 0
|
491 |
+
},
|
492 |
+
"text/plain": [
|
493 |
+
"Map: 0%| | 0/140 [00:00<?, ? examples/s]"
|
494 |
+
]
|
495 |
+
},
|
496 |
+
"metadata": {},
|
497 |
+
"output_type": "display_data"
|
498 |
+
}
|
499 |
+
],
|
500 |
+
"source": [
|
501 |
+
"sql = f\"\"\"\n",
|
502 |
+
"SELECT * FROM `ikame-gem-ai-research.AIFT.reviews_multi_label_training`\n",
|
503 |
+
"\"\"\"\n",
|
504 |
+
"data = client.query(sql).to_dataframe()\n",
|
505 |
+
"data= data.fillna('0')\n",
|
506 |
+
"for i in data.columns:\n",
|
507 |
+
" if i != 'review':\n",
|
508 |
+
" data[i] = data[i].astype(int)\n",
|
509 |
+
"\n",
|
510 |
+
"data = Dataset.from_pandas(data).train_test_split(test_size=0.2,shuffle = True, seed=0)\n",
|
511 |
+
"cols = data[\"train\"].column_names\n",
|
512 |
+
"data = data.map(lambda x : {\"labels\": [x[c] for c in cols if c != \"review\"]})\n",
|
513 |
+
"\n",
|
514 |
+
"# Tokenize and encode\n",
|
515 |
+
"dataset = data.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
516 |
+
]
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"cell_type": "code",
|
520 |
+
"execution_count": 18,
|
521 |
+
"id": "f56a7de9-19a4-4cc8-996d-857c491cf633",
|
522 |
+
"metadata": {},
|
523 |
+
"outputs": [
|
524 |
+
{
|
525 |
+
"data": {
|
526 |
+
"text/plain": [
|
527 |
+
"['ads', 'bugs', 'positive', 'negative', 'graphic', 'gameplay', 'request']"
|
528 |
+
]
|
529 |
+
},
|
530 |
+
"execution_count": 18,
|
531 |
+
"metadata": {},
|
532 |
+
"output_type": "execute_result"
|
533 |
+
}
|
534 |
+
],
|
535 |
+
"source": [
|
536 |
+
"labels = [label for label in data['train'].features.keys() if label not in ['review','labels']]\n",
|
537 |
+
"id2label = {idx:label for idx, label in enumerate(labels)}\n",
|
538 |
+
"label2id = {label:idx for idx, label in enumerate(labels)}\n",
|
539 |
+
"labels"
|
540 |
+
]
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"cell_type": "code",
|
544 |
+
"execution_count": 19,
|
545 |
+
"id": "ad182dbc-c63d-49c9-b53c-9b63996d3746",
|
546 |
+
"metadata": {
|
547 |
+
"tags": []
|
548 |
+
},
|
549 |
+
"outputs": [
|
550 |
+
{
|
551 |
+
"data": {
|
552 |
+
"text/plain": [
|
553 |
+
"{'labels': [0, 1, 0, 0, 0, 1, 0],\n",
|
554 |
+
" 'input_ids': [101,\n",
|
555 |
+
" 8795,\n",
|
556 |
+
" 11100,\n",
|
557 |
+
" 2024,\n",
|
558 |
+
" 10599,\n",
|
559 |
+
" 2030,\n",
|
560 |
+
" 11829,\n",
|
561 |
+
" 5999,\n",
|
562 |
+
" 1010,\n",
|
563 |
+
" 2437,\n",
|
564 |
+
" 14967,\n",
|
565 |
+
" 25198,\n",
|
566 |
+
" 1012,\n",
|
567 |
+
" 102],\n",
|
568 |
+
" 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
|
569 |
+
]
|
570 |
+
},
|
571 |
+
"execution_count": 19,
|
572 |
+
"metadata": {},
|
573 |
+
"output_type": "execute_result"
|
574 |
+
}
|
575 |
+
],
|
576 |
+
"source": [
|
577 |
+
"dataset[\"train\"][0]"
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "markdown",
|
582 |
+
"id": "02c2a7b2-58f1-4eac-ac61-5d54dbdc1184",
|
583 |
+
"metadata": {},
|
584 |
+
"source": [
|
585 |
+
"## Fine-tuning"
|
586 |
+
]
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"cell_type": "code",
|
590 |
+
"execution_count": 20,
|
591 |
+
"id": "9452f6f3-2b4b-4ee7-8c9f-3c42e04e396f",
|
592 |
+
"metadata": {
|
593 |
+
"tags": []
|
594 |
+
},
|
595 |
+
"outputs": [],
|
596 |
+
"source": [
|
597 |
+
"class BertForMultilabelSequenceClassification(BertForSequenceClassification):\n",
|
598 |
+
" def __init__(self, config):\n",
|
599 |
+
" super().__init__(config)\n",
|
600 |
+
"\n",
|
601 |
+
" def forward(self,\n",
|
602 |
+
" input_ids=None,\n",
|
603 |
+
" attention_mask=None,\n",
|
604 |
+
" token_type_ids=None,\n",
|
605 |
+
" position_ids=None,\n",
|
606 |
+
" head_mask=None,\n",
|
607 |
+
" inputs_embeds=None,\n",
|
608 |
+
" labels=None,\n",
|
609 |
+
" output_attentions=None,\n",
|
610 |
+
" output_hidden_states=None,\n",
|
611 |
+
" return_dict=None):\n",
|
612 |
+
" return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
|
613 |
+
"\n",
|
614 |
+
" outputs = self.bert(input_ids,\n",
|
615 |
+
" attention_mask=attention_mask,\n",
|
616 |
+
" token_type_ids=token_type_ids,\n",
|
617 |
+
" position_ids=position_ids,\n",
|
618 |
+
" head_mask=head_mask,\n",
|
619 |
+
" inputs_embeds=inputs_embeds,\n",
|
620 |
+
" output_attentions=output_attentions,\n",
|
621 |
+
" output_hidden_states=output_hidden_states,\n",
|
622 |
+
" return_dict=return_dict)\n",
|
623 |
+
"\n",
|
624 |
+
" pooled_output = outputs[1]\n",
|
625 |
+
" pooled_output = self.dropout(pooled_output)\n",
|
626 |
+
" logits = self.classifier(pooled_output)\n",
|
627 |
+
"\n",
|
628 |
+
" loss = None\n",
|
629 |
+
" if labels is not None:\n",
|
630 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
631 |
+
" loss = loss_fct(logits.view(-1, self.num_labels),\n",
|
632 |
+
" labels.float().view(-1, self.num_labels))\n",
|
633 |
+
"\n",
|
634 |
+
" if not return_dict:\n",
|
635 |
+
" output = (logits,) + outputs[2:]\n",
|
636 |
+
" return ((loss,) + output) if loss is not None else output\n",
|
637 |
+
"\n",
|
638 |
+
" return SequenceClassifierOutput(loss=loss,\n",
|
639 |
+
" logits=logits,\n",
|
640 |
+
" hidden_states=outputs.hidden_states,\n",
|
641 |
+
" attentions=outputs.attentions)"
|
642 |
+
]
|
643 |
+
},
|
644 |
+
{
|
645 |
+
"cell_type": "code",
|
646 |
+
"execution_count": 21,
|
647 |
+
"id": "76035010-b10a-4398-8a85-feaa19414ca4",
|
648 |
+
"metadata": {
|
649 |
+
"tags": []
|
650 |
+
},
|
651 |
+
"outputs": [
|
652 |
+
{
|
653 |
+
"name": "stderr",
|
654 |
+
"output_type": "stream",
|
655 |
+
"text": [
|
656 |
+
"You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.\n",
|
657 |
+
"Some weights of BertForMultilabelSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.layer.11.attention.self.key.bias', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.4.output.LayerNorm.weight', 'classifier.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.token_type_embeddings.weight', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.9.attention.self.value.weight', 'embeddings.position_embeddings.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.5.output.dense.bias', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.attention.self.query.weight', 'classifier.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.6.output.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.0.output.dense.bias', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.4.attention.output.dense.bias', 'pooler.dense.bias', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.9.attention.self.query.bias', 'embeddings.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.1.output.dense.bias', 'encoder.layer.11.output.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.6.attention.self.value.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.10.output.dense.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.3.output.dense.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.7.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.10.output.dense.bias', 'pooler.dense.weight', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.8.attention.self.key.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.8.output.dense.weight', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight']\n",
|
658 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
659 |
+
]
|
660 |
+
}
|
661 |
+
],
|
662 |
+
"source": [
|
663 |
+
"num_labels=7\n",
|
664 |
+
"model = BertForMultilabelSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to('cuda')"
|
665 |
+
]
|
666 |
+
},
|
667 |
+
{
|
668 |
+
"cell_type": "code",
|
669 |
+
"execution_count": 22,
|
670 |
+
"id": "74af900d-0688-4f7b-b8f2-56f36f467a06",
|
671 |
+
"metadata": {
|
672 |
+
"tags": []
|
673 |
+
},
|
674 |
+
"outputs": [],
|
675 |
+
"source": [
|
676 |
+
"def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True):\n",
|
677 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
678 |
+
" y_true = torch.from_numpy(y_true)\n",
|
679 |
+
" if sigmoid:\n",
|
680 |
+
" y_pred = y_pred.sigmoid()\n",
|
681 |
+
" return ((y_pred>thresh)==y_true.bool()).float().mean().item()"
|
682 |
+
]
|
683 |
+
},
|
684 |
+
{
|
685 |
+
"cell_type": "code",
|
686 |
+
"execution_count": 23,
|
687 |
+
"id": "db202a97-61e1-4e43-bb93-20179c2c0aa2",
|
688 |
+
"metadata": {
|
689 |
+
"tags": []
|
690 |
+
},
|
691 |
+
"outputs": [],
|
692 |
+
"source": [
|
693 |
+
"def compute_metrics(eval_pred):\n",
|
694 |
+
" predictions, labels = eval_pred\n",
|
695 |
+
" return {'accuracy_thresh': accuracy_thresh(predictions, labels)}"
|
696 |
+
]
|
697 |
+
},
|
698 |
+
{
|
699 |
+
"cell_type": "code",
|
700 |
+
"execution_count": 24,
|
701 |
+
"id": "e0ab370a-fc4d-460b-9dab-dbde755dc3f4",
|
702 |
+
"metadata": {},
|
703 |
+
"outputs": [],
|
704 |
+
"source": [
|
705 |
+
"class MultilabelTrainer(Trainer):\n",
|
706 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
707 |
+
" labels = inputs.pop(\"labels\")\n",
|
708 |
+
" outputs = model(**inputs)\n",
|
709 |
+
" logits = outputs.logits\n",
|
710 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
711 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
712 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
713 |
+
" return (loss, outputs) if return_outputs else loss"
|
714 |
+
]
|
715 |
+
},
|
716 |
+
{
|
717 |
+
"cell_type": "code",
|
718 |
+
"execution_count": 32,
|
719 |
+
"id": "340ade6d-1eb1-47ec-b8e6-56371083e361",
|
720 |
+
"metadata": {},
|
721 |
+
"outputs": [],
|
722 |
+
"source": [
|
723 |
+
"batch_size = 8\n",
|
724 |
+
"\n",
|
725 |
+
"args = TrainingArguments(\n",
|
726 |
+
" output_dir=\"aift-model-review-multiple-label-classification\",\n",
|
727 |
+
" evaluation_strategy = \"epoch\",\n",
|
728 |
+
" learning_rate=2e-5,\n",
|
729 |
+
" per_device_train_batch_size=batch_size,\n",
|
730 |
+
" per_device_eval_batch_size=batch_size,\n",
|
731 |
+
" num_train_epochs=10,\n",
|
732 |
+
" weight_decay=0.01,\n",
|
733 |
+
" use_cpu = False\n",
|
734 |
+
")"
|
735 |
+
]
|
736 |
+
},
|
737 |
+
{
|
738 |
+
"cell_type": "code",
|
739 |
+
"execution_count": 33,
|
740 |
+
"id": "39d8e955-9ca8-463c-899a-bd3b1d5f2c0e",
|
741 |
+
"metadata": {},
|
742 |
+
"outputs": [
|
743 |
+
{
|
744 |
+
"name": "stderr",
|
745 |
+
"output_type": "stream",
|
746 |
+
"text": [
|
747 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n",
|
748 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
749 |
+
]
|
750 |
+
}
|
751 |
+
],
|
752 |
+
"source": [
|
753 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to('cuda')"
|
754 |
+
]
|
755 |
+
},
|
756 |
+
{
|
757 |
+
"cell_type": "code",
|
758 |
+
"execution_count": 34,
|
759 |
+
"id": "3cb96e02-f0f7-4a0a-9fe6-f88fe89826f8",
|
760 |
+
"metadata": {},
|
761 |
+
"outputs": [],
|
762 |
+
"source": [
|
763 |
+
"trainer = MultilabelTrainer(\n",
|
764 |
+
" model,\n",
|
765 |
+
" args,\n",
|
766 |
+
" train_dataset=dataset[\"train\"],\n",
|
767 |
+
" eval_dataset=dataset[\"test\"],\n",
|
768 |
+
" compute_metrics=compute_metrics,\n",
|
769 |
+
" tokenizer=tokenizer)"
|
770 |
+
]
|
771 |
+
},
|
772 |
+
{
|
773 |
+
"cell_type": "code",
|
774 |
+
"execution_count": 35,
|
775 |
+
"id": "da79a882-f1f1-41a5-b4dd-98b070012c4c",
|
776 |
+
"metadata": {},
|
777 |
+
"outputs": [
|
778 |
+
{
|
779 |
+
"data": {
|
780 |
+
"text/html": [
|
781 |
+
"\n",
|
782 |
+
" <div>\n",
|
783 |
+
" \n",
|
784 |
+
" <progress value='36' max='18' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
785 |
+
" [18/18 00:06]\n",
|
786 |
+
" </div>\n",
|
787 |
+
" "
|
788 |
+
],
|
789 |
+
"text/plain": [
|
790 |
+
"<IPython.core.display.HTML object>"
|
791 |
+
]
|
792 |
+
},
|
793 |
+
"metadata": {},
|
794 |
+
"output_type": "display_data"
|
795 |
+
},
|
796 |
+
{
|
797 |
+
"data": {
|
798 |
+
"text/plain": [
|
799 |
+
"{'eval_loss': 0.7062913179397583,\n",
|
800 |
+
" 'eval_accuracy_thresh': 0.4561224579811096,\n",
|
801 |
+
" 'eval_runtime': 0.2818,\n",
|
802 |
+
" 'eval_samples_per_second': 496.847,\n",
|
803 |
+
" 'eval_steps_per_second': 63.88}"
|
804 |
+
]
|
805 |
+
},
|
806 |
+
"execution_count": 35,
|
807 |
+
"metadata": {},
|
808 |
+
"output_type": "execute_result"
|
809 |
+
}
|
810 |
+
],
|
811 |
+
"source": [
|
812 |
+
"trainer.evaluate()"
|
813 |
+
]
|
814 |
+
},
|
815 |
+
{
|
816 |
+
"cell_type": "code",
|
817 |
+
"execution_count": 36,
|
818 |
+
"id": "eeefe348-a66f-4e14-9844-da6f3f3ebd80",
|
819 |
+
"metadata": {},
|
820 |
+
"outputs": [
|
821 |
+
{
|
822 |
+
"data": {
|
823 |
+
"text/html": [
|
824 |
+
"\n",
|
825 |
+
" <div>\n",
|
826 |
+
" \n",
|
827 |
+
" <progress value='700' max='700' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
828 |
+
" [700/700 00:47, Epoch 10/10]\n",
|
829 |
+
" </div>\n",
|
830 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
831 |
+
" <thead>\n",
|
832 |
+
" <tr style=\"text-align: left;\">\n",
|
833 |
+
" <th>Epoch</th>\n",
|
834 |
+
" <th>Training Loss</th>\n",
|
835 |
+
" <th>Validation Loss</th>\n",
|
836 |
+
" <th>Accuracy Thresh</th>\n",
|
837 |
+
" </tr>\n",
|
838 |
+
" </thead>\n",
|
839 |
+
" <tbody>\n",
|
840 |
+
" <tr>\n",
|
841 |
+
" <td>1</td>\n",
|
842 |
+
" <td>No log</td>\n",
|
843 |
+
" <td>0.415191</td>\n",
|
844 |
+
" <td>0.868367</td>\n",
|
845 |
+
" </tr>\n",
|
846 |
+
" <tr>\n",
|
847 |
+
" <td>2</td>\n",
|
848 |
+
" <td>No log</td>\n",
|
849 |
+
" <td>0.302631</td>\n",
|
850 |
+
" <td>0.901020</td>\n",
|
851 |
+
" </tr>\n",
|
852 |
+
" <tr>\n",
|
853 |
+
" <td>3</td>\n",
|
854 |
+
" <td>No log</td>\n",
|
855 |
+
" <td>0.240627</td>\n",
|
856 |
+
" <td>0.928571</td>\n",
|
857 |
+
" </tr>\n",
|
858 |
+
" <tr>\n",
|
859 |
+
" <td>4</td>\n",
|
860 |
+
" <td>No log</td>\n",
|
861 |
+
" <td>0.217601</td>\n",
|
862 |
+
" <td>0.931633</td>\n",
|
863 |
+
" </tr>\n",
|
864 |
+
" <tr>\n",
|
865 |
+
" <td>5</td>\n",
|
866 |
+
" <td>No log</td>\n",
|
867 |
+
" <td>0.203845</td>\n",
|
868 |
+
" <td>0.924490</td>\n",
|
869 |
+
" </tr>\n",
|
870 |
+
" <tr>\n",
|
871 |
+
" <td>6</td>\n",
|
872 |
+
" <td>No log</td>\n",
|
873 |
+
" <td>0.192444</td>\n",
|
874 |
+
" <td>0.929592</td>\n",
|
875 |
+
" </tr>\n",
|
876 |
+
" <tr>\n",
|
877 |
+
" <td>7</td>\n",
|
878 |
+
" <td>No log</td>\n",
|
879 |
+
" <td>0.190031</td>\n",
|
880 |
+
" <td>0.926531</td>\n",
|
881 |
+
" </tr>\n",
|
882 |
+
" <tr>\n",
|
883 |
+
" <td>8</td>\n",
|
884 |
+
" <td>0.265200</td>\n",
|
885 |
+
" <td>0.186760</td>\n",
|
886 |
+
" <td>0.928571</td>\n",
|
887 |
+
" </tr>\n",
|
888 |
+
" <tr>\n",
|
889 |
+
" <td>9</td>\n",
|
890 |
+
" <td>0.265200</td>\n",
|
891 |
+
" <td>0.180436</td>\n",
|
892 |
+
" <td>0.936735</td>\n",
|
893 |
+
" </tr>\n",
|
894 |
+
" <tr>\n",
|
895 |
+
" <td>10</td>\n",
|
896 |
+
" <td>0.265200</td>\n",
|
897 |
+
" <td>0.179821</td>\n",
|
898 |
+
" <td>0.934694</td>\n",
|
899 |
+
" </tr>\n",
|
900 |
+
" </tbody>\n",
|
901 |
+
"</table><p>"
|
902 |
+
],
|
903 |
+
"text/plain": [
|
904 |
+
"<IPython.core.display.HTML object>"
|
905 |
+
]
|
906 |
+
},
|
907 |
+
"metadata": {},
|
908 |
+
"output_type": "display_data"
|
909 |
+
},
|
910 |
+
{
|
911 |
+
"name": "stderr",
|
912 |
+
"output_type": "stream",
|
913 |
+
"text": [
|
914 |
+
"Checkpoint destination directory aift-model-review-multiple-label-classification/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n"
|
915 |
+
]
|
916 |
+
},
|
917 |
+
{
|
918 |
+
"data": {
|
919 |
+
"text/plain": [
|
920 |
+
"TrainOutput(global_step=700, training_loss=0.22303315843854632, metrics={'train_runtime': 47.1667, 'train_samples_per_second': 117.88, 'train_steps_per_second': 14.841, 'total_flos': 55632988457664.0, 'train_loss': 0.22303315843854632, 'epoch': 10.0})"
|
921 |
+
]
|
922 |
+
},
|
923 |
+
"execution_count": 36,
|
924 |
+
"metadata": {},
|
925 |
+
"output_type": "execute_result"
|
926 |
+
}
|
927 |
+
],
|
928 |
+
"source": [
|
929 |
+
"trainer.train()"
|
930 |
+
]
|
931 |
+
},
|
932 |
+
{
|
933 |
+
"cell_type": "code",
|
934 |
+
"execution_count": 104,
|
935 |
+
"id": "d9c2e1e1-c20e-48e5-8f6b-e4e3222899a5",
|
936 |
+
"metadata": {},
|
937 |
+
"outputs": [
|
938 |
+
{
|
939 |
+
"name": "stdout",
|
940 |
+
"output_type": "stream",
|
941 |
+
"text": [
|
942 |
+
"mkdir: cannot create directory ‘./models’: File exists\n"
|
943 |
+
]
|
944 |
+
}
|
945 |
+
],
|
946 |
+
"source": [
|
947 |
+
"saved_model_local_path = \"./models\"\n",
|
948 |
+
"# !mkdir ./aift-model-review-multiple-label-classification/models"
|
949 |
+
]
|
950 |
+
},
|
951 |
+
{
|
952 |
+
"cell_type": "code",
|
953 |
+
"execution_count": 39,
|
954 |
+
"id": "c6632c17-49e2-4823-abae-a286fa06f8c5",
|
955 |
+
"metadata": {},
|
956 |
+
"outputs": [],
|
957 |
+
"source": [
|
958 |
+
"trainer.save_model(saved_model_local_path)"
|
959 |
+
]
|
960 |
+
},
|
961 |
+
{
|
962 |
+
"cell_type": "code",
|
963 |
+
"execution_count": 69,
|
964 |
+
"id": "4af413bf-9c9d-46aa-b75b-f729c8aae546",
|
965 |
+
"metadata": {},
|
966 |
+
"outputs": [],
|
967 |
+
"source": [
|
968 |
+
"history = trainer.evaluate()"
|
969 |
+
]
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"cell_type": "code",
|
973 |
+
"execution_count": 70,
|
974 |
+
"id": "6ee5c718-6b27-4ed8-993b-dd41468cf16a",
|
975 |
+
"metadata": {
|
976 |
+
"tags": []
|
977 |
+
},
|
978 |
+
"outputs": [
|
979 |
+
{
|
980 |
+
"data": {
|
981 |
+
"text/plain": [
|
982 |
+
"{'eval_loss': 0.1798214465379715,\n",
|
983 |
+
" 'eval_accuracy_thresh': 0.9346938729286194,\n",
|
984 |
+
" 'eval_runtime': 0.2965,\n",
|
985 |
+
" 'eval_samples_per_second': 472.249,\n",
|
986 |
+
" 'eval_steps_per_second': 60.718,\n",
|
987 |
+
" 'epoch': 10.0}"
|
988 |
+
]
|
989 |
+
},
|
990 |
+
"execution_count": 70,
|
991 |
+
"metadata": {},
|
992 |
+
"output_type": "execute_result"
|
993 |
+
}
|
994 |
+
],
|
995 |
+
"source": [
|
996 |
+
"history"
|
997 |
+
]
|
998 |
+
},
|
999 |
+
{
|
1000 |
+
"cell_type": "code",
|
1001 |
+
"execution_count": 110,
|
1002 |
+
"id": "948a6110-48c3-42f5-8950-d4dc3cfc21a5",
|
1003 |
+
"metadata": {
|
1004 |
+
"tags": []
|
1005 |
+
},
|
1006 |
+
"outputs": [
|
1007 |
+
{
|
1008 |
+
"data": {
|
1009 |
+
"application/vnd.jupyter.widget-view+json": {
|
1010 |
+
"model_id": "c835ed1d2ac74d3995f59f351a5933bd",
|
1011 |
+
"version_major": 2,
|
1012 |
+
"version_minor": 0
|
1013 |
+
},
|
1014 |
+
"text/plain": [
|
1015 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
1016 |
+
]
|
1017 |
+
},
|
1018 |
+
"metadata": {},
|
1019 |
+
"output_type": "display_data"
|
1020 |
+
}
|
1021 |
+
],
|
1022 |
+
"source": [
|
1023 |
+
"from huggingface_hub import notebook_login\n",
|
1024 |
+
"\n",
|
1025 |
+
"notebook_login()"
|
1026 |
+
]
|
1027 |
+
},
|
1028 |
+
{
|
1029 |
+
"cell_type": "markdown",
|
1030 |
+
"id": "3efca8d4-a40f-40e0-b628-e9f1718b519d",
|
1031 |
+
"metadata": {},
|
1032 |
+
"source": [
|
1033 |
+
"## Predict"
|
1034 |
+
]
|
1035 |
+
},
|
1036 |
+
{
|
1037 |
+
"cell_type": "code",
|
1038 |
+
"execution_count": 41,
|
1039 |
+
"id": "08a4759d-ab64-4112-ae27-4f1c4998e269",
|
1040 |
+
"metadata": {},
|
1041 |
+
"outputs": [],
|
1042 |
+
"source": [
|
1043 |
+
"def predict(text,threshold):\n",
|
1044 |
+
" encoding = tokenizer(text, return_tensors=\"pt\")\n",
|
1045 |
+
" encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}\n",
|
1046 |
+
"\n",
|
1047 |
+
" outputs = trainer.model(**encoding)\n",
|
1048 |
+
" logits = outputs.logits\n",
|
1049 |
+
" logits.shape\n",
|
1050 |
+
" # apply sigmoid + threshold\n",
|
1051 |
+
" sigmoid = torch.nn.Sigmoid()\n",
|
1052 |
+
" probs = sigmoid(logits.squeeze().cpu())\n",
|
1053 |
+
" predictions = np.zeros(probs.shape)\n",
|
1054 |
+
" print(predictions)\n",
|
1055 |
+
" print(probs)\n",
|
1056 |
+
" predictions[np.where(probs >= threshold)] = 1\n",
|
1057 |
+
" # turn predicted id's into actual label names\n",
|
1058 |
+
" predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]\n",
|
1059 |
+
" print(predicted_labels)"
|
1060 |
+
]
|
1061 |
+
},
|
1062 |
+
{
|
1063 |
+
"cell_type": "code",
|
1064 |
+
"execution_count": 57,
|
1065 |
+
"id": "136f3624-d752-4e62-ae67-c52c8c7413b0",
|
1066 |
+
"metadata": {},
|
1067 |
+
"outputs": [
|
1068 |
+
{
|
1069 |
+
"name": "stdout",
|
1070 |
+
"output_type": "stream",
|
1071 |
+
"text": [
|
1072 |
+
"[0. 0. 0. 0. 0. 0. 0.]\n",
|
1073 |
+
"tensor([0.9740, 0.0251, 0.1409, 0.7609, 0.0359, 0.0374, 0.0321],\n",
|
1074 |
+
" grad_fn=<SigmoidBackward0>)\n",
|
1075 |
+
"['ads', 'negative']\n"
|
1076 |
+
]
|
1077 |
+
}
|
1078 |
+
],
|
1079 |
+
"source": [
|
1080 |
+
"text = \"a lot of ads\"\n",
|
1081 |
+
"predict(text,0.4)"
|
1082 |
+
]
|
1083 |
+
},
|
1084 |
+
{
|
1085 |
+
"cell_type": "code",
|
1086 |
+
"execution_count": 60,
|
1087 |
+
"id": "4bdd8052-5c6f-4148-a5cd-bbd5e42aa640",
|
1088 |
+
"metadata": {},
|
1089 |
+
"outputs": [],
|
1090 |
+
"source": [
|
1091 |
+
"label_text = id2label\n",
|
1092 |
+
"model_name_or_path=model_ckpt\n",
|
1093 |
+
"saved_model_path = saved_model_local_path\n",
|
1094 |
+
"\n",
|
1095 |
+
"\n",
|
1096 |
+
"def predict_(input_text, saved_model_path,threshold):\n",
|
1097 |
+
" # initialize tokenizer\n",
|
1098 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)\n",
|
1099 |
+
"\n",
|
1100 |
+
" # preprocess and encode input text\n",
|
1101 |
+
" tokenizer_args = (input_text,)\n",
|
1102 |
+
" predict_input = tokenizer(\n",
|
1103 |
+
" *tokenizer_args,\n",
|
1104 |
+
" padding=\"max_length\",\n",
|
1105 |
+
" max_length=128,\n",
|
1106 |
+
" truncation=True,\n",
|
1107 |
+
" return_tensors=\"pt\",\n",
|
1108 |
+
" )\n",
|
1109 |
+
"\n",
|
1110 |
+
" # load trained model\n",
|
1111 |
+
" loaded_model = AutoModelForSequenceClassification.from_pretrained(saved_model_path)\n",
|
1112 |
+
"\n",
|
1113 |
+
" # get predictions\n",
|
1114 |
+
" output = loaded_model(predict_input[\"input_ids\"])\n",
|
1115 |
+
"\n",
|
1116 |
+
" # return labels\n",
|
1117 |
+
" logits = output.logits\n",
|
1118 |
+
" logits.shape\n",
|
1119 |
+
" # apply sigmoid + threshold\n",
|
1120 |
+
" sigmoid = torch.nn.Sigmoid()\n",
|
1121 |
+
" probs = sigmoid(logits.squeeze().cpu())\n",
|
1122 |
+
" predictions = np.zeros(probs.shape)\n",
|
1123 |
+
" print(predictions)\n",
|
1124 |
+
" print(probs)\n",
|
1125 |
+
" predictions[np.where(probs >= threshold)] = 1\n",
|
1126 |
+
" # turn predicted id's into actual label names\n",
|
1127 |
+
" predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]\n",
|
1128 |
+
" print(predicted_labels)"
|
1129 |
+
]
|
1130 |
+
},
|
1131 |
+
{
|
1132 |
+
"cell_type": "code",
|
1133 |
+
"execution_count": 62,
|
1134 |
+
"id": "48e96b48-db19-4c25-89f1-eb640c955614",
|
1135 |
+
"metadata": {
|
1136 |
+
"tags": []
|
1137 |
+
},
|
1138 |
+
"outputs": [
|
1139 |
+
{
|
1140 |
+
"name": "stdout",
|
1141 |
+
"output_type": "stream",
|
1142 |
+
"text": [
|
1143 |
+
"[0. 0. 0. 0. 0. 0. 0.]\n",
|
1144 |
+
"tensor([0.5107, 0.1010, 0.5961, 0.2481, 0.2118, 0.1907, 0.1010],\n",
|
1145 |
+
" grad_fn=<SigmoidBackward0>)\n",
|
1146 |
+
"['ads', 'positive']\n"
|
1147 |
+
]
|
1148 |
+
}
|
1149 |
+
],
|
1150 |
+
"source": [
|
1151 |
+
"text='ew a lot of ads'\n",
|
1152 |
+
"predict_(text, saved_model_path,0.4)"
|
1153 |
+
]
|
1154 |
+
},
|
1155 |
+
{
|
1156 |
+
"cell_type": "markdown",
|
1157 |
+
"id": "2b8505cd-bc32-46e9-9387-a102830e62ef",
|
1158 |
+
"metadata": {
|
1159 |
+
"tags": []
|
1160 |
+
},
|
1161 |
+
"source": [
|
1162 |
+
"# Custom training"
|
1163 |
+
]
|
1164 |
+
},
|
1165 |
+
{
|
1166 |
+
"cell_type": "code",
|
1167 |
+
"execution_count": 99,
|
1168 |
+
"id": "bba84d7d-5971-4e44-a977-268bc2b97e77",
|
1169 |
+
"metadata": {
|
1170 |
+
"tags": []
|
1171 |
+
},
|
1172 |
+
"outputs": [],
|
1173 |
+
"source": [
|
1174 |
+
"PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI = (\n",
|
1175 |
+
" \"us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-7:latest\"\n",
|
1176 |
+
")\n",
|
1177 |
+
"\n",
|
1178 |
+
"PYTHON_PACKAGE_APPLICATION_DIR = \"python_package\"\n",
|
1179 |
+
"\n",
|
1180 |
+
"source_package_file_name = f\"pipeline/aift-model-review-multiple-label-classification/{PYTHON_PACKAGE_APPLICATION_DIR}/dist/trainer-0.1.tar.gz\"\n",
|
1181 |
+
"python_package_gcs_uri = (\n",
|
1182 |
+
" f\"{BUCKET_NAME}/pytorch-on-gcp/{APP_NAME}/train/python_package/trainer-0.1.tar.gz\"\n",
|
1183 |
+
")\n",
|
1184 |
+
"python_module_name = \"trainer.task\""
|
1185 |
+
]
|
1186 |
+
},
|
1187 |
+
{
|
1188 |
+
"cell_type": "code",
|
1189 |
+
"execution_count": 100,
|
1190 |
+
"id": "3610d07c-909a-470a-b3f7-2e68f3b8292e",
|
1191 |
+
"metadata": {
|
1192 |
+
"tags": []
|
1193 |
+
},
|
1194 |
+
"outputs": [],
|
1195 |
+
"source": [
|
1196 |
+
"# !mkdir ./python_package"
|
1197 |
+
]
|
1198 |
+
},
|
1199 |
+
{
|
1200 |
+
"cell_type": "code",
|
1201 |
+
"execution_count": 108,
|
1202 |
+
"id": "ecdc6201-d714-4cbe-9c1f-415857730700",
|
1203 |
+
"metadata": {
|
1204 |
+
"tags": []
|
1205 |
+
},
|
1206 |
+
"outputs": [
|
1207 |
+
{
|
1208 |
+
"name": "stdout",
|
1209 |
+
"output_type": "stream",
|
1210 |
+
"text": [
|
1211 |
+
"Overwriting ./aift-model-review-multiple-label-classification/python_package/setup.py\n"
|
1212 |
+
]
|
1213 |
+
}
|
1214 |
+
],
|
1215 |
+
"source": [
|
1216 |
+
"%%writefile ./aift-model-review-multiple-label-classification/{PYTHON_PACKAGE_APPLICATION_DIR}/setup.py\n",
|
1217 |
+
"\n",
|
1218 |
+
"from setuptools import find_packages\n",
|
1219 |
+
"from setuptools import setup\n",
|
1220 |
+
"import setuptools\n",
|
1221 |
+
"\n",
|
1222 |
+
"from distutils.command.build import build as _build\n",
|
1223 |
+
"import subprocess\n",
|
1224 |
+
"\n",
|
1225 |
+
"\n",
|
1226 |
+
"REQUIRED_PACKAGES = [\n",
|
1227 |
+
" 'transformers',\n",
|
1228 |
+
" 'datasets',\n",
|
1229 |
+
" 'tqdm',\n",
|
1230 |
+
" 'cloudml-hypertune'\n",
|
1231 |
+
"]\n",
|
1232 |
+
"\n",
|
1233 |
+
"setup(\n",
|
1234 |
+
" name='trainer',\n",
|
1235 |
+
" version='0.1',\n",
|
1236 |
+
" install_requires=REQUIRED_PACKAGES,\n",
|
1237 |
+
" packages=find_packages(),\n",
|
1238 |
+
" include_package_data=True,\n",
|
1239 |
+
" description='Vertex AI | Training | PyTorch | Text Classification | Python Package'\n",
|
1240 |
+
")"
|
1241 |
+
]
|
1242 |
+
},
|
1243 |
+
{
|
1244 |
+
"cell_type": "code",
|
1245 |
+
"execution_count": 109,
|
1246 |
+
"id": "d001cdca-a207-4f23-b6e5-33106c252004",
|
1247 |
+
"metadata": {
|
1248 |
+
"tags": []
|
1249 |
+
},
|
1250 |
+
"outputs": [
|
1251 |
+
{
|
1252 |
+
"name": "stdout",
|
1253 |
+
"output_type": "stream",
|
1254 |
+
"text": [
|
1255 |
+
"running sdist\n",
|
1256 |
+
"running egg_info\n",
|
1257 |
+
"creating trainer.egg-info\n",
|
1258 |
+
"writing trainer.egg-info/PKG-INFO\n",
|
1259 |
+
"writing dependency_links to trainer.egg-info/dependency_links.txt\n",
|
1260 |
+
"writing requirements to trainer.egg-info/requires.txt\n",
|
1261 |
+
"writing top-level names to trainer.egg-info/top_level.txt\n",
|
1262 |
+
"writing manifest file 'trainer.egg-info/SOURCES.txt'\n",
|
1263 |
+
"reading manifest file 'trainer.egg-info/SOURCES.txt'\n",
|
1264 |
+
"writing manifest file 'trainer.egg-info/SOURCES.txt'\n",
|
1265 |
+
"running check\n",
|
1266 |
+
"creating trainer-0.1\n",
|
1267 |
+
"creating trainer-0.1/trainer\n",
|
1268 |
+
"creating trainer-0.1/trainer.egg-info\n",
|
1269 |
+
"copying files to trainer-0.1...\n",
|
1270 |
+
"copying README.md -> trainer-0.1\n",
|
1271 |
+
"copying setup.py -> trainer-0.1\n",
|
1272 |
+
"copying trainer/__init__.py -> trainer-0.1/trainer\n",
|
1273 |
+
"copying trainer/experiment.py -> trainer-0.1/trainer\n",
|
1274 |
+
"copying trainer/metadata.py -> trainer-0.1/trainer\n",
|
1275 |
+
"copying trainer/model.py -> trainer-0.1/trainer\n",
|
1276 |
+
"copying trainer/task.py -> trainer-0.1/trainer\n",
|
1277 |
+
"copying trainer/utils.py -> trainer-0.1/trainer\n",
|
1278 |
+
"copying trainer.egg-info/PKG-INFO -> trainer-0.1/trainer.egg-info\n",
|
1279 |
+
"copying trainer.egg-info/SOURCES.txt -> trainer-0.1/trainer.egg-info\n",
|
1280 |
+
"copying trainer.egg-info/dependency_links.txt -> trainer-0.1/trainer.egg-info\n",
|
1281 |
+
"copying trainer.egg-info/requires.txt -> trainer-0.1/trainer.egg-info\n",
|
1282 |
+
"copying trainer.egg-info/top_level.txt -> trainer-0.1/trainer.egg-info\n",
|
1283 |
+
"Writing trainer-0.1/setup.cfg\n",
|
1284 |
+
"creating dist\n",
|
1285 |
+
"Creating tar archive\n",
|
1286 |
+
"removing 'trainer-0.1' (and everything under it)\n"
|
1287 |
+
]
|
1288 |
+
}
|
1289 |
+
],
|
1290 |
+
"source": [
|
1291 |
+
"!cd aift-model-review-multiple-label-classification/{PYTHON_PACKAGE_APPLICATION_DIR} && python3 setup.py sdist --formats=gztar"
|
1292 |
+
]
|
1293 |
+
},
|
1294 |
+
{
|
1295 |
+
"cell_type": "code",
|
1296 |
+
"execution_count": 82,
|
1297 |
+
"id": "7a296aa0-ead6-456f-a93a-657fed393bd2",
|
1298 |
+
"metadata": {
|
1299 |
+
"tags": []
|
1300 |
+
},
|
1301 |
+
"outputs": [
|
1302 |
+
{
|
1303 |
+
"name": "stdout",
|
1304 |
+
"output_type": "stream",
|
1305 |
+
"text": [
|
1306 |
+
"Copying file://python_package/dist/trainer-0.1.tar.gz [Content-Type=application/x-tar]...\n",
|
1307 |
+
"/ [1 files][ 916.0 B/ 916.0 B] \n",
|
1308 |
+
"Operation completed over 1 objects/916.0 B. \n"
|
1309 |
+
]
|
1310 |
+
}
|
1311 |
+
],
|
1312 |
+
"source": [
|
1313 |
+
"!gsutil cp {source_package_file_name} {python_package_gcs_uri}"
|
1314 |
+
]
|
1315 |
+
},
|
1316 |
+
{
|
1317 |
+
"cell_type": "code",
|
1318 |
+
"execution_count": 83,
|
1319 |
+
"id": "087fcdaa-0d99-4104-8e61-74455d4bf734",
|
1320 |
+
"metadata": {
|
1321 |
+
"tags": []
|
1322 |
+
},
|
1323 |
+
"outputs": [
|
1324 |
+
{
|
1325 |
+
"name": "stdout",
|
1326 |
+
"output_type": "stream",
|
1327 |
+
"text": [
|
1328 |
+
" 916 2024-01-08T07:48:19Z gs://ikame-gem-ai-research-bucket-review/pytorch-on-gcp/aift-review-classificatio-multiple-label/train/python_package/trainer-0.1.tar.gz\n",
|
1329 |
+
"TOTAL: 1 objects, 916 bytes (916 B)\n"
|
1330 |
+
]
|
1331 |
+
}
|
1332 |
+
],
|
1333 |
+
"source": [
|
1334 |
+
"!gsutil ls -l {python_package_gcs_uri}"
|
1335 |
+
]
|
1336 |
+
},
|
1337 |
+
{
|
1338 |
+
"cell_type": "code",
|
1339 |
+
"execution_count": 85,
|
1340 |
+
"id": "4dce414a-063a-4952-8197-75586909e098",
|
1341 |
+
"metadata": {
|
1342 |
+
"tags": []
|
1343 |
+
},
|
1344 |
+
"outputs": [],
|
1345 |
+
"source": [
|
1346 |
+
"# !cd {PYTHON_PACKAGE_APPLICATION_DIR} && python -m trainer.task"
|
1347 |
+
]
|
1348 |
+
},
|
1349 |
+
{
|
1350 |
+
"cell_type": "code",
|
1351 |
+
"execution_count": 86,
|
1352 |
+
"id": "a7698349-f5f4-4032-a9b2-1fc659f4b022",
|
1353 |
+
"metadata": {
|
1354 |
+
"tags": []
|
1355 |
+
},
|
1356 |
+
"outputs": [],
|
1357 |
+
"source": [
|
1358 |
+
"aiplatform.init(project=PROJECT_ID, staging_bucket=BUCKET_NAME)"
|
1359 |
+
]
|
1360 |
+
},
|
1361 |
+
{
|
1362 |
+
"cell_type": "code",
|
1363 |
+
"execution_count": 87,
|
1364 |
+
"id": "112e1b67-5bb0-444a-94c6-a2f010e24fe9",
|
1365 |
+
"metadata": {
|
1366 |
+
"tags": []
|
1367 |
+
},
|
1368 |
+
"outputs": [
|
1369 |
+
{
|
1370 |
+
"name": "stdout",
|
1371 |
+
"output_type": "stream",
|
1372 |
+
"text": [
|
1373 |
+
"APP_NAME=aift-review-classificatio-multiple-label\n",
|
1374 |
+
"PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI=us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-7:latest\n",
|
1375 |
+
"python_package_gcs_uri=gs://ikame-gem-ai-research-bucket-review/pytorch-on-gcp/aift-review-classificatio-multiple-label/train/python_package/trainer-0.1.tar.gz\n",
|
1376 |
+
"python_module_name=trainer.task\n"
|
1377 |
+
]
|
1378 |
+
}
|
1379 |
+
],
|
1380 |
+
"source": [
|
1381 |
+
"print(f\"APP_NAME={APP_NAME}\")\n",
|
1382 |
+
"print(\n",
|
1383 |
+
" f\"PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI={PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI}\"\n",
|
1384 |
+
")\n",
|
1385 |
+
"print(f\"python_package_gcs_uri={python_package_gcs_uri}\")\n",
|
1386 |
+
"print(f\"python_module_name={python_module_name}\")"
|
1387 |
+
]
|
1388 |
+
},
|
1389 |
+
{
|
1390 |
+
"cell_type": "code",
|
1391 |
+
"execution_count": 88,
|
1392 |
+
"id": "c0fa20a0-0831-49ab-9fce-423016e98db6",
|
1393 |
+
"metadata": {
|
1394 |
+
"tags": []
|
1395 |
+
},
|
1396 |
+
"outputs": [
|
1397 |
+
{
|
1398 |
+
"name": "stdout",
|
1399 |
+
"output_type": "stream",
|
1400 |
+
"text": [
|
1401 |
+
"JOB_NAME=aift-review-classificatio-multiple-label-pytorch-pkg-ar-20240108075109\n"
|
1402 |
+
]
|
1403 |
+
}
|
1404 |
+
],
|
1405 |
+
"source": [
|
1406 |
+
"JOB_NAME = f\"{APP_NAME}-pytorch-pkg-ar-{get_timestamp()}\"\n",
|
1407 |
+
"print(f\"JOB_NAME={JOB_NAME}\")"
|
1408 |
+
]
|
1409 |
+
},
|
1410 |
+
{
|
1411 |
+
"cell_type": "code",
|
1412 |
+
"execution_count": 89,
|
1413 |
+
"id": "86922169-8509-48ff-acc9-c06bc9a4ecd1",
|
1414 |
+
"metadata": {
|
1415 |
+
"tags": []
|
1416 |
+
},
|
1417 |
+
"outputs": [],
|
1418 |
+
"source": [
|
1419 |
+
"job = aiplatform.CustomPythonPackageTrainingJob(\n",
|
1420 |
+
" display_name=f\"{JOB_NAME}\",\n",
|
1421 |
+
" python_package_gcs_uri=python_package_gcs_uri,\n",
|
1422 |
+
" python_module_name=python_module_name,\n",
|
1423 |
+
" container_uri=PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI,\n",
|
1424 |
+
")"
|
1425 |
+
]
|
1426 |
+
},
|
1427 |
+
{
|
1428 |
+
"cell_type": "code",
|
1429 |
+
"execution_count": 90,
|
1430 |
+
"id": "a7909b64-fedb-4da8-bc61-80b4806117d3",
|
1431 |
+
"metadata": {
|
1432 |
+
"tags": []
|
1433 |
+
},
|
1434 |
+
"outputs": [
|
1435 |
+
{
|
1436 |
+
"name": "stdout",
|
1437 |
+
"output_type": "stream",
|
1438 |
+
"text": [
|
1439 |
+
"Training Output directory:\n",
|
1440 |
+
"gs://ikame-gem-ai-research-bucket-review/aiplatform-custom-training-2024-01-08-07:51:20.301 \n"
|
1441 |
+
]
|
1442 |
+
},
|
1443 |
+
{
|
1444 |
+
"name": "stderr",
|
1445 |
+
"output_type": "stream",
|
1446 |
+
"text": [
|
1447 |
+
"INFO:google.cloud.aiplatform.training_jobs:Training Output directory:\n",
|
1448 |
+
"gs://ikame-gem-ai-research-bucket-review/aiplatform-custom-training-2024-01-08-07:51:20.301 \n"
|
1449 |
+
]
|
1450 |
+
},
|
1451 |
+
{
|
1452 |
+
"name": "stdout",
|
1453 |
+
"output_type": "stream",
|
1454 |
+
"text": [
|
1455 |
+
"View Training:\n",
|
1456 |
+
"https://console.cloud.google.com/ai/platform/locations/us-central1/training/2282426366479564800?project=763889829809\n"
|
1457 |
+
]
|
1458 |
+
},
|
1459 |
+
{
|
1460 |
+
"name": "stderr",
|
1461 |
+
"output_type": "stream",
|
1462 |
+
"text": [
|
1463 |
+
"INFO:google.cloud.aiplatform.training_jobs:View Training:\n",
|
1464 |
+
"https://console.cloud.google.com/ai/platform/locations/us-central1/training/2282426366479564800?project=763889829809\n"
|
1465 |
+
]
|
1466 |
+
},
|
1467 |
+
{
|
1468 |
+
"name": "stdout",
|
1469 |
+
"output_type": "stream",
|
1470 |
+
"text": [
|
1471 |
+
"CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1472 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1473 |
+
]
|
1474 |
+
},
|
1475 |
+
{
|
1476 |
+
"name": "stderr",
|
1477 |
+
"output_type": "stream",
|
1478 |
+
"text": [
|
1479 |
+
"INFO:google.cloud.aiplatform.training_jobs:CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1480 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1481 |
+
]
|
1482 |
+
},
|
1483 |
+
{
|
1484 |
+
"name": "stdout",
|
1485 |
+
"output_type": "stream",
|
1486 |
+
"text": [
|
1487 |
+
"View backing custom job:\n",
|
1488 |
+
"https://console.cloud.google.com/ai/platform/locations/us-central1/training/7832101356516147200?project=763889829809\n"
|
1489 |
+
]
|
1490 |
+
},
|
1491 |
+
{
|
1492 |
+
"name": "stderr",
|
1493 |
+
"output_type": "stream",
|
1494 |
+
"text": [
|
1495 |
+
"INFO:google.cloud.aiplatform.training_jobs:View backing custom job:\n",
|
1496 |
+
"https://console.cloud.google.com/ai/platform/locations/us-central1/training/7832101356516147200?project=763889829809\n"
|
1497 |
+
]
|
1498 |
+
},
|
1499 |
+
{
|
1500 |
+
"name": "stdout",
|
1501 |
+
"output_type": "stream",
|
1502 |
+
"text": [
|
1503 |
+
"CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1504 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1505 |
+
]
|
1506 |
+
},
|
1507 |
+
{
|
1508 |
+
"name": "stderr",
|
1509 |
+
"output_type": "stream",
|
1510 |
+
"text": [
|
1511 |
+
"INFO:google.cloud.aiplatform.training_jobs:CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1512 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1513 |
+
]
|
1514 |
+
},
|
1515 |
+
{
|
1516 |
+
"name": "stdout",
|
1517 |
+
"output_type": "stream",
|
1518 |
+
"text": [
|
1519 |
+
"CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1520 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1521 |
+
]
|
1522 |
+
},
|
1523 |
+
{
|
1524 |
+
"name": "stderr",
|
1525 |
+
"output_type": "stream",
|
1526 |
+
"text": [
|
1527 |
+
"INFO:google.cloud.aiplatform.training_jobs:CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1528 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1529 |
+
]
|
1530 |
+
},
|
1531 |
+
{
|
1532 |
+
"name": "stdout",
|
1533 |
+
"output_type": "stream",
|
1534 |
+
"text": [
|
1535 |
+
"CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1536 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1537 |
+
]
|
1538 |
+
},
|
1539 |
+
{
|
1540 |
+
"name": "stderr",
|
1541 |
+
"output_type": "stream",
|
1542 |
+
"text": [
|
1543 |
+
"INFO:google.cloud.aiplatform.training_jobs:CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1544 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1545 |
+
]
|
1546 |
+
},
|
1547 |
+
{
|
1548 |
+
"name": "stdout",
|
1549 |
+
"output_type": "stream",
|
1550 |
+
"text": [
|
1551 |
+
"CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1552 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1553 |
+
]
|
1554 |
+
},
|
1555 |
+
{
|
1556 |
+
"name": "stderr",
|
1557 |
+
"output_type": "stream",
|
1558 |
+
"text": [
|
1559 |
+
"INFO:google.cloud.aiplatform.training_jobs:CustomPythonPackageTrainingJob projects/763889829809/locations/us-central1/trainingPipelines/2282426366479564800 current state:\n",
|
1560 |
+
"PipelineState.PIPELINE_STATE_RUNNING\n"
|
1561 |
+
]
|
1562 |
+
}
|
1563 |
+
],
|
1564 |
+
"source": [
|
1565 |
+
"training_args = [\"--num-epochs\", \"2\", \"--model-name\", \"finetuned-bert-classifier\"]\n",
|
1566 |
+
"\n",
|
1567 |
+
"model = job.run(\n",
|
1568 |
+
" replica_count=1,\n",
|
1569 |
+
" machine_type=\"n1-standard-8\",\n",
|
1570 |
+
" accelerator_type=\"NVIDIA_TESLA_V100\",\n",
|
1571 |
+
" accelerator_count=1,\n",
|
1572 |
+
" args=training_args,\n",
|
1573 |
+
" sync=False,\n",
|
1574 |
+
")"
|
1575 |
+
]
|
1576 |
+
},
|
1577 |
+
{
|
1578 |
+
"cell_type": "code",
|
1579 |
+
"execution_count": null,
|
1580 |
+
"id": "1e681913-680e-4664-9c6a-083f350915bc",
|
1581 |
+
"metadata": {},
|
1582 |
+
"outputs": [],
|
1583 |
+
"source": []
|
1584 |
+
}
|
1585 |
+
],
|
1586 |
+
"metadata": {
|
1587 |
+
"environment": {
|
1588 |
+
"kernel": "python3",
|
1589 |
+
"name": ".m114",
|
1590 |
+
"type": "gcloud",
|
1591 |
+
"uri": "gcr.io/deeplearning-platform-release/:m114"
|
1592 |
+
},
|
1593 |
+
"kernelspec": {
|
1594 |
+
"display_name": "Python 3 (ipykernel)",
|
1595 |
+
"language": "python",
|
1596 |
+
"name": "python3"
|
1597 |
+
},
|
1598 |
+
"language_info": {
|
1599 |
+
"codemirror_mode": {
|
1600 |
+
"name": "ipython",
|
1601 |
+
"version": 3
|
1602 |
+
},
|
1603 |
+
"file_extension": ".py",
|
1604 |
+
"mimetype": "text/x-python",
|
1605 |
+
"name": "python",
|
1606 |
+
"nbconvert_exporter": "python",
|
1607 |
+
"pygments_lexer": "ipython3",
|
1608 |
+
"version": "3.10.13"
|
1609 |
+
}
|
1610 |
+
},
|
1611 |
+
"nbformat": 4,
|
1612 |
+
"nbformat_minor": 5
|
1613 |
+
}
|
README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
base_model: distilbert-base-uncased
|
4 |
+
tags:
|
5 |
+
- generated_from_trainer
|
6 |
+
model-index:
|
7 |
+
- name: aift-model-review-multiple-label-classification
|
8 |
+
results: []
|
9 |
+
---
|
10 |
+
|
11 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
12 |
+
should probably proofread and complete it, then remove this comment. -->
|
13 |
+
|
14 |
+
# aift-model-review-multiple-label-classification
|
15 |
+
|
16 |
+
This model is a fine-tuned version of [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) on the None dataset.
|
17 |
+
It achieves the following results on the evaluation set:
|
18 |
+
- Loss: 0.1798
|
19 |
+
- Accuracy Thresh: 0.9347
|
20 |
+
|
21 |
+
## Model description
|
22 |
+
|
23 |
+
More information needed
|
24 |
+
|
25 |
+
## Intended uses & limitations
|
26 |
+
|
27 |
+
More information needed
|
28 |
+
|
29 |
+
## Training and evaluation data
|
30 |
+
|
31 |
+
More information needed
|
32 |
+
|
33 |
+
## Training procedure
|
34 |
+
|
35 |
+
### Training hyperparameters
|
36 |
+
|
37 |
+
The following hyperparameters were used during training:
|
38 |
+
- learning_rate: 2e-05
|
39 |
+
- train_batch_size: 8
|
40 |
+
- eval_batch_size: 8
|
41 |
+
- seed: 42
|
42 |
+
- distributed_type: tpu
|
43 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
44 |
+
- lr_scheduler_type: linear
|
45 |
+
- num_epochs: 10
|
46 |
+
|
47 |
+
### Training results
|
48 |
+
|
49 |
+
| Training Loss | Epoch | Step | Validation Loss | Accuracy Thresh |
|
50 |
+
|:-------------:|:-----:|:----:|:---------------:|:---------------:|
|
51 |
+
| No log | 1.0 | 70 | 0.4152 | 0.8684 |
|
52 |
+
| No log | 2.0 | 140 | 0.3026 | 0.9010 |
|
53 |
+
| No log | 3.0 | 210 | 0.2406 | 0.9286 |
|
54 |
+
| No log | 4.0 | 280 | 0.2176 | 0.9316 |
|
55 |
+
| No log | 5.0 | 350 | 0.2038 | 0.9245 |
|
56 |
+
| No log | 6.0 | 420 | 0.1924 | 0.9296 |
|
57 |
+
| No log | 7.0 | 490 | 0.1900 | 0.9265 |
|
58 |
+
| 0.2652 | 8.0 | 560 | 0.1868 | 0.9286 |
|
59 |
+
| 0.2652 | 9.0 | 630 | 0.1804 | 0.9367 |
|
60 |
+
| 0.2652 | 10.0 | 700 | 0.1798 | 0.9347 |
|
61 |
+
|
62 |
+
|
63 |
+
### Framework versions
|
64 |
+
|
65 |
+
- Transformers 4.37.0.dev0
|
66 |
+
- Pytorch 2.0.0+cu118
|
67 |
+
- Datasets 2.16.1
|
68 |
+
- Tokenizers 0.15.0
|
config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "distilbert-base-uncased",
|
3 |
+
"activation": "gelu",
|
4 |
+
"architectures": [
|
5 |
+
"DistilBertForSequenceClassification"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"dim": 768,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"hidden_dim": 3072,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0",
|
13 |
+
"1": "LABEL_1",
|
14 |
+
"2": "LABEL_2",
|
15 |
+
"3": "LABEL_3",
|
16 |
+
"4": "LABEL_4",
|
17 |
+
"5": "LABEL_5",
|
18 |
+
"6": "LABEL_6"
|
19 |
+
},
|
20 |
+
"initializer_range": 0.02,
|
21 |
+
"label2id": {
|
22 |
+
"LABEL_0": 0,
|
23 |
+
"LABEL_1": 1,
|
24 |
+
"LABEL_2": 2,
|
25 |
+
"LABEL_3": 3,
|
26 |
+
"LABEL_4": 4,
|
27 |
+
"LABEL_5": 5,
|
28 |
+
"LABEL_6": 6
|
29 |
+
},
|
30 |
+
"max_position_embeddings": 512,
|
31 |
+
"model_type": "distilbert",
|
32 |
+
"n_heads": 12,
|
33 |
+
"n_layers": 6,
|
34 |
+
"pad_token_id": 0,
|
35 |
+
"qa_dropout": 0.1,
|
36 |
+
"seq_classif_dropout": 0.2,
|
37 |
+
"sinusoidal_pos_embds": false,
|
38 |
+
"tie_weights_": true,
|
39 |
+
"torch_dtype": "float32",
|
40 |
+
"transformers_version": "4.37.0.dev0",
|
41 |
+
"vocab_size": 30522
|
42 |
+
}
|
custom_container/Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Use pytorch GPU base image
|
3 |
+
# FROM gcr.io/cloud-aiplatform/training/pytorch-gpu.1-7
|
4 |
+
FROM us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-10:latest
|
5 |
+
|
6 |
+
# set working directory
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
# Install required packages
|
10 |
+
RUN pip install google-cloud-storage transformers datasets tqdm cloudml-hypertune
|
11 |
+
|
12 |
+
# Copies the trainer code to the docker image.
|
13 |
+
COPY ./trainer/__init__.py /app/trainer/__init__.py
|
14 |
+
COPY ./trainer/experiment.py /app/trainer/experiment.py
|
15 |
+
COPY ./trainer/utils.py /app/trainer/utils.py
|
16 |
+
COPY ./trainer/metadata.py /app/trainer/metadata.py
|
17 |
+
COPY ./trainer/model.py /app/trainer/model.py
|
18 |
+
COPY ./trainer/task.py /app/trainer/task.py
|
19 |
+
|
20 |
+
# Set up the entry point to invoke the trainer.
|
21 |
+
ENTRYPOINT ["python", "-m", "trainer.task"]
|
custom_container/README.md
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch Custom Containers GPU Template
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
The directory provides code to fine tune a transformer model ([BERT-base](https://huggingface.co/bert-base-cased)) from Huggingface Transformers Library for sentiment analysis task. [BERT](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html) (Bidirectional Encoder Representations from Transformers) is a transformers model pre-trained on a large corpus of unlabeled text in a self-supervised fashion. In this sample, we use [IMDB sentiment classification dataset](https://huggingface.co/datasets/imdb) for the task. We show you packaging a PyTorch training model to submit it to Vertex AI using pre-built PyTorch containers and handling Python dependencies using [Vertex Training custom containers](https://cloud.google.com/vertex-ai/docs/training/create-custom-container?hl=hr).
|
6 |
+
|
7 |
+
## Prerequisites
|
8 |
+
|
9 |
+
* Setup your project by following the instructions from [documentation](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)
|
10 |
+
* [Setup docker with Cloud Container Registry](https://cloud.google.com/container-registry/docs/pushing-and-pulling)
|
11 |
+
* Change the directory to this sample and run
|
12 |
+
|
13 |
+
`Note:` These instructions are used for local testing. When you submit a training job, no code will be executed on your local machine.
|
14 |
+
|
15 |
+
|
16 |
+
## Directory Structure
|
17 |
+
|
18 |
+
* `trainer` directory: all Python modules to train the model.
|
19 |
+
* `scripts` directory: command-line scripts to train the model on Vertex AI.
|
20 |
+
* `setup.py`: `setup.py` scripts specifies Python dependencies required for the training job. Vertex Training uses pip to install the package on the training instances allocated for the job.
|
21 |
+
|
22 |
+
### Trainer Modules
|
23 |
+
| File Name | Purpose |
|
24 |
+
| :-------- | :------ |
|
25 |
+
| [metadata.py](trainer/metadata.py) | Defines: metadata for classification task such as predefined model dataset name, target labels. |
|
26 |
+
| [utils.py](trainer/utils.py) | Includes: utility functions such as data input functions to read data, save model to GCS bucket. |
|
27 |
+
| [model.py](trainer/model.py) | Includes: function to create model with a sequence classification head from a pretrained model. |
|
28 |
+
| [experiment.py](trainer/experiment.py) | Runs the model training and evaluation experiment, and exports the final model. |
|
29 |
+
| [task.py](trainer/task.py) | Includes: 1) Initialize and parse task arguments (hyper parameters), and 2) Entry point to the trainer. |
|
30 |
+
|
31 |
+
### Scripts
|
32 |
+
|
33 |
+
* [train-cloud.sh](scripts/train-cloud.sh) This script builds your Docker image locally, pushes the image to Container Registry and submits a custom container training job to Vertex AI.
|
34 |
+
|
35 |
+
Please read the [documentation](https://cloud.google.com/vertex-ai/docs/training/containers-overview?hl=hr) on Vertex Training with Custom Containers for more details.
|
36 |
+
|
37 |
+
## How to run
|
38 |
+
|
39 |
+
Once the prerequisites are satisfied, you may:
|
40 |
+
|
41 |
+
1. For local testing, run (refer [notebook](../pytorch-text-classification-vertex-ai-train-tune-deploy.ipynb) for instructions):
|
42 |
+
```
|
43 |
+
CUSTOM_TRAIN_IMAGE_URI='gcr.io/{PROJECT_ID}/pytorch_gpu_train_{APP_NAME}'
|
44 |
+
cd ./custom_container/ && docker build -f Dockerfile -t $CUSTOM_TRAIN_IMAGE_URI ../python_package
|
45 |
+
docker run --gpus all -it --rm $CUSTOM_TRAIN_IMAGE_URI
|
46 |
+
```
|
47 |
+
2. For cloud testing, run:
|
48 |
+
```
|
49 |
+
source ./scripts/train-cloud.sh
|
50 |
+
```
|
51 |
+
|
52 |
+
## Run on GPU
|
53 |
+
The provided trainer code runs on a GPU if one is available including data loading and model creation.
|
54 |
+
|
55 |
+
To run the trainer code on a different GPU configuration or latest PyTorch pre-built container image, make the following changes to the trainer script.
|
56 |
+
* Update the PyTorch image URI to one of [PyTorch pre-built containers](https://cloud.google.com/vertex-ai/docs/training/pre-built-containers#available_container_images)
|
57 |
+
* Update the [`worker-pool-spec`](https://cloud.google.com/vertex-ai/docs/training/configure-compute?hl=hr) in the gcloud command that includes a GPU
|
58 |
+
|
59 |
+
Then, run the script to submit a Custom Job on Vertex Training job:
|
60 |
+
```
|
61 |
+
source ./scripts/train-cloud.sh
|
62 |
+
```
|
63 |
+
|
64 |
+
### Versions
|
65 |
+
This script uses the pre-built PyTorch containers for PyTorch 1.7.
|
66 |
+
* `us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-7:latest`
|
custom_container/scripts/train-cloud.sh
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright 2019 Google LLC
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ==============================================================================
|
16 |
+
# This script performs cloud training for a PyTorch model.
|
17 |
+
|
18 |
+
echo "Submitting PyTorch model training job to Vertex AI"
|
19 |
+
|
20 |
+
# PROJECT_ID: Change to your project id
|
21 |
+
PROJECT_ID=$(gcloud config list --format 'value(core.project)')
|
22 |
+
|
23 |
+
# BUCKET_NAME: Change to your bucket name.
|
24 |
+
BUCKET_NAME="[your-bucket-name]" # <-- CHANGE TO YOUR BUCKET NAME
|
25 |
+
|
26 |
+
# validate bucket name
|
27 |
+
if [ "${BUCKET_NAME}" = "[your-bucket-name]" ]
|
28 |
+
then
|
29 |
+
echo "[ERROR] INVALID VALUE: Please update the variable BUCKET_NAME with valid Cloud Storage bucket name. Exiting the script..."
|
30 |
+
exit 1
|
31 |
+
fi
|
32 |
+
|
33 |
+
# JOB_NAME: the name of your job running on AI Platform.
|
34 |
+
JOB_PREFIX="finetuned-bert-classifier-pytorch-cstm-cntr"
|
35 |
+
JOB_NAME=${JOB_PREFIX}-$(date +%Y%m%d%H%M%S)-custom-job
|
36 |
+
|
37 |
+
# REGION: select a region from https://cloud.google.com/vertex-ai/docs/general/locations#available_regions
|
38 |
+
# or use the default '`us-central1`'. The region is where the job will be run.
|
39 |
+
REGION="us-central1"
|
40 |
+
|
41 |
+
# JOB_DIR: Where to store prepared package and upload output model.
|
42 |
+
JOB_DIR=gs://${BUCKET_NAME}/${JOB_PREFIX}/models/${JOB_NAME}
|
43 |
+
|
44 |
+
# IMAGE_REPO_NAME: set a local repo name to distinquish our image
|
45 |
+
IMAGE_REPO_NAME=pytorch_gpu_train_finetuned-bert-classifier
|
46 |
+
|
47 |
+
# IMAGE_URI: the complete URI location for Cloud Container Registry
|
48 |
+
CUSTOM_TRAIN_IMAGE_URI=gcr.io/${PROJECT_ID}/${IMAGE_REPO_NAME}
|
49 |
+
|
50 |
+
# Build the docker image
|
51 |
+
docker build --no-cache -f Dockerfile -t $CUSTOM_TRAIN_IMAGE_URI ../python_package
|
52 |
+
|
53 |
+
# Deploy the docker image to Cloud Container Registry
|
54 |
+
docker push ${CUSTOM_TRAIN_IMAGE_URI}
|
55 |
+
|
56 |
+
# worker pool spec
|
57 |
+
worker_pool_spec="\
|
58 |
+
replica-count=1,\
|
59 |
+
machine-type=n1-standard-8,\
|
60 |
+
accelerator-type=NVIDIA_TESLA_V100,\
|
61 |
+
accelerator-count=1,\
|
62 |
+
container-image-uri=${CUSTOM_TRAIN_IMAGE_URI}"
|
63 |
+
|
64 |
+
# Submit Custom Job to Vertex AI
|
65 |
+
gcloud beta ai custom-jobs create \
|
66 |
+
--display-name=${JOB_NAME} \
|
67 |
+
--region ${REGION} \
|
68 |
+
--worker-pool-spec="${worker_pool_spec}" \
|
69 |
+
--args="--model-name","finetuned-bert-classifier","--job-dir",$JOB_DIR
|
70 |
+
|
71 |
+
echo "After the job is completed successfully, model files will be saved at $JOB_DIR/"
|
72 |
+
|
73 |
+
# uncomment following lines to monitor the job progress by streaming logs
|
74 |
+
|
75 |
+
# Stream the logs from the job
|
76 |
+
# gcloud ai custom-jobs stream-logs $(gcloud ai custom-jobs list --region=$REGION --filter="displayName:"$JOB_NAME --format="get(name)")
|
77 |
+
|
78 |
+
# # Verify the model was exported
|
79 |
+
# echo "Verify the model was exported:"
|
80 |
+
# gsutil ls ${JOB_DIR}/
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e0601c465c2b89e2c72137736ee8a82835a3651c432dbf1e6017523f91d3b7f
|
3 |
+
size 267847948
|
python_package/README.md
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch - Python Package Training
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
The directory provides code to fine tune a transformer model ([BERT-base](https://huggingface.co/bert-base-cased)) from Huggingface Transformers Library for sentiment analysis task. [BERT](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html) (Bidirectional Encoder Representations from Transformers) is a transformers model pre-trained on a large corpus of unlabeled text in a self-supervised fashion. In this sample, we use [IMDB sentiment classification dataset](https://huggingface.co/datasets/imdb) for the task. We show you packaging a PyTorch training model to submit it to Vertex AI using pre-built PyTorch containers and handling Python dependencies through Python build scripts (`setup.py`).
|
6 |
+
|
7 |
+
## Prerequisites
|
8 |
+
* Setup your project by following the instructions from [documentation](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)
|
9 |
+
* Change directories to this sample.
|
10 |
+
|
11 |
+
## Directory Structure
|
12 |
+
|
13 |
+
* `trainer` directory: all Python modules to train the model.
|
14 |
+
* `scripts` directory: command-line scripts to train the model on Vertex AI.
|
15 |
+
* `setup.py`: `setup.py` scripts specifies Python dependencies required for the training job. Vertex Training uses pip to install the package on the training instances allocated for the job.
|
16 |
+
|
17 |
+
### Trainer Modules
|
18 |
+
| File Name | Purpose |
|
19 |
+
| :-------- | :------ |
|
20 |
+
| [metadata.py](trainer/metadata.py) | Defines: metadata for classification task such as predefined model dataset name, target labels. |
|
21 |
+
| [utils.py](trainer/utils.py) | Includes: utility functions such as data input functions to read data, save model to GCS bucket. |
|
22 |
+
| [model.py](trainer/model.py) | Includes: function to create model with a sequence classification head from a pretrained model. |
|
23 |
+
| [experiment.py](trainer/experiment.py) | Runs the model training and evaluation experiment, and exports the final model. |
|
24 |
+
| [task.py](trainer/task.py) | Includes: 1) Initialize and parse task arguments (hyper parameters), and 2) Entry point to the trainer. |
|
25 |
+
|
26 |
+
### Scripts
|
27 |
+
|
28 |
+
* [train-cloud.sh](scripts/train-cloud.sh) This script submits a training job to Vertex AI
|
29 |
+
|
30 |
+
## How to run
|
31 |
+
For local testing, run:
|
32 |
+
```
|
33 |
+
!cd python_package && python -m trainer.task
|
34 |
+
```
|
35 |
+
|
36 |
+
For cloud training, once the prerequisites are satisfied, update the
|
37 |
+
`BUCKET_NAME` environment variable in `scripts/train-cloud.sh`. You may then
|
38 |
+
run the following script to submit an AI Platform Training job:
|
39 |
+
```
|
40 |
+
source ./python_package/scripts/train-cloud.sh
|
41 |
+
```
|
42 |
+
|
43 |
+
## Run on GPU
|
44 |
+
The provided trainer code runs on a GPU if one is available including data loading and model creation.
|
45 |
+
|
46 |
+
To run the trainer code on a different GPU configuration or latest PyTorch pre-built container image, make the following changes to the trainer script.
|
47 |
+
* Update the PyTorch image URI to one of [PyTorch pre-built containers](https://cloud.google.com/vertex-ai/docs/training/pre-built-containers#available_container_images)
|
48 |
+
* Update the [`worker-pool-spec`](https://cloud.google.com/vertex-ai/docs/training/configure-compute?hl=hr) in the gcloud command that includes a GPU
|
49 |
+
|
50 |
+
Then, run the script to submit a Custom Job on Vertex Training job:
|
51 |
+
```
|
52 |
+
source ./scripts/train-cloud.sh
|
53 |
+
```
|
54 |
+
|
55 |
+
### Versions
|
56 |
+
This script uses the pre-built PyTorch containers for PyTorch 1.7.
|
57 |
+
* `us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-7:latest`
|
58 |
+
|
python_package/dist/trainer-0.1.tar.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de00247994c728d30322eec83cc5a3976137bc7394e4213c2775c517af5163e6
|
3 |
+
size 6337
|
python_package/scripts/train-cloud.sh
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Copyright 2019 Google LLC
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ==============================================================================
|
16 |
+
# This script performs cloud training for a PyTorch model.
|
17 |
+
|
18 |
+
echo "Submitting Custom Job to Vertex AI to train PyTorch model"
|
19 |
+
|
20 |
+
# BUCKET_NAME: Change to your bucket name
|
21 |
+
BUCKET_NAME="[your-bucket-name]" # <-- CHANGE TO YOUR BUCKET NAME
|
22 |
+
|
23 |
+
# validate bucket name
|
24 |
+
if [ "${BUCKET_NAME}" = "[your-bucket-name]" ]
|
25 |
+
then
|
26 |
+
echo "[ERROR] INVALID VALUE: Please update the variable BUCKET_NAME with valid Cloud Storage bucket name. Exiting the script..."
|
27 |
+
exit 1
|
28 |
+
fi
|
29 |
+
|
30 |
+
# The PyTorch image provided by Vertex AI Training.
|
31 |
+
IMAGE_URI="us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-7:latest"
|
32 |
+
|
33 |
+
# JOB_NAME: the name of your job running on Vertex AI.
|
34 |
+
JOB_PREFIX="finetuned-bert-classifier-pytorch-pkg-ar"
|
35 |
+
JOB_NAME=${JOB_PREFIX}-$(date +%Y%m%d%H%M%S)-custom-job
|
36 |
+
|
37 |
+
# REGION: select a region from https://cloud.google.com/vertex-ai/docs/general/locations#available_regions
|
38 |
+
# or use the default '`us-central1`'. The region is where the job will be run.
|
39 |
+
REGION="us-central1"
|
40 |
+
|
41 |
+
# JOB_DIR: Where to store prepared package and upload output model.
|
42 |
+
JOB_DIR=gs://${BUCKET_NAME}/${JOB_PREFIX}/model/${JOB_NAME}
|
43 |
+
|
44 |
+
# worker pool spec
|
45 |
+
worker_pool_spec="\
|
46 |
+
replica-count=1,\
|
47 |
+
machine-type=n1-standard-8,\
|
48 |
+
accelerator-type=NVIDIA_TESLA_V100,\
|
49 |
+
accelerator-count=1,\
|
50 |
+
executor-image-uri=${IMAGE_URI},\
|
51 |
+
python-module=trainer.task,\
|
52 |
+
local-package-path=../python_package/"
|
53 |
+
|
54 |
+
# Submit Custom Job to Vertex AI
|
55 |
+
gcloud beta ai custom-jobs create \
|
56 |
+
--display-name=${JOB_NAME} \
|
57 |
+
--region ${REGION} \
|
58 |
+
--worker-pool-spec="${worker_pool_spec}" \
|
59 |
+
--args="--model-name","finetuned-bert-classifier","--job-dir",$JOB_DIR
|
60 |
+
|
61 |
+
echo "After the job is completed successfully, model files will be saved at $JOB_DIR/"
|
62 |
+
|
63 |
+
# uncomment following lines to monitor the job progress by streaming logs
|
64 |
+
|
65 |
+
# Stream the logs from the job
|
66 |
+
# gcloud ai custom-jobs stream-logs $(gcloud ai custom-jobs list --region=$REGION --filter="displayName:"$JOB_NAME --format="get(name)")
|
67 |
+
|
68 |
+
# # Verify the model was exported
|
69 |
+
# echo "Verify the model was exported:"
|
70 |
+
# gsutil ls ${JOB_DIR}/
|
python_package/setup.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from setuptools import find_packages
|
3 |
+
from setuptools import setup
|
4 |
+
import setuptools
|
5 |
+
|
6 |
+
from distutils.command.build import build as _build
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
|
10 |
+
REQUIRED_PACKAGES = [
|
11 |
+
'transformers',
|
12 |
+
'datasets',
|
13 |
+
'tqdm',
|
14 |
+
'cloudml-hypertune'
|
15 |
+
]
|
16 |
+
|
17 |
+
setup(
|
18 |
+
name='trainer',
|
19 |
+
version='0.1',
|
20 |
+
install_requires=REQUIRED_PACKAGES,
|
21 |
+
packages=find_packages(),
|
22 |
+
include_package_data=True,
|
23 |
+
description='Vertex AI | Training | PyTorch | Text Classification | Python Package'
|
24 |
+
)
|
python_package/trainer.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: trainer
|
3 |
+
Version: 0.1
|
4 |
+
Summary: Vertex AI | Training | PyTorch | Text Classification | Python Package
|
5 |
+
Requires-Dist: transformers
|
6 |
+
Requires-Dist: datasets
|
7 |
+
Requires-Dist: tqdm
|
8 |
+
Requires-Dist: cloudml-hypertune
|
python_package/trainer.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.py
|
3 |
+
trainer/__init__.py
|
4 |
+
trainer/experiment.py
|
5 |
+
trainer/metadata.py
|
6 |
+
trainer/model.py
|
7 |
+
trainer/task.py
|
8 |
+
trainer/utils.py
|
9 |
+
trainer.egg-info/PKG-INFO
|
10 |
+
trainer.egg-info/SOURCES.txt
|
11 |
+
trainer.egg-info/dependency_links.txt
|
12 |
+
trainer.egg-info/requires.txt
|
13 |
+
trainer.egg-info/top_level.txt
|
python_package/trainer.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
python_package/trainer.egg-info/requires.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
datasets
|
3 |
+
tqdm
|
4 |
+
cloudml-hypertune
|
python_package/trainer.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
trainer
|
python_package/trainer/__init__.py
ADDED
File without changes
|
python_package/trainer/experiment.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the \"License\");
|
4 |
+
# you may not use this file except in compliance with the License.\n",
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an \"AS IS\" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import numpy as np
|
17 |
+
import hypertune
|
18 |
+
|
19 |
+
from transformers import (
|
20 |
+
AutoTokenizer,
|
21 |
+
EvalPrediction,
|
22 |
+
Trainer,
|
23 |
+
TrainingArguments,
|
24 |
+
default_data_collator,
|
25 |
+
TrainerCallback
|
26 |
+
)
|
27 |
+
|
28 |
+
from trainer import model, metadata, utils
|
29 |
+
|
30 |
+
|
31 |
+
class HPTuneCallback(TrainerCallback):
|
32 |
+
"""
|
33 |
+
A custom callback class that reports a metric to hypertuner
|
34 |
+
at the end of each epoch.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, metric_tag, metric_value):
|
38 |
+
super(HPTuneCallback, self).__init__()
|
39 |
+
self.metric_tag = metric_tag
|
40 |
+
self.metric_value = metric_value
|
41 |
+
self.hpt = hypertune.HyperTune()
|
42 |
+
|
43 |
+
def on_evaluate(self, args, state, control, **kwargs):
|
44 |
+
print(f"HP metric {self.metric_tag}={kwargs['metrics'][self.metric_value]}")
|
45 |
+
self.hpt.report_hyperparameter_tuning_metric(
|
46 |
+
hyperparameter_metric_tag=self.metric_tag,
|
47 |
+
metric_value=kwargs['metrics'][self.metric_value],
|
48 |
+
global_step=state.epoch)
|
49 |
+
|
50 |
+
|
51 |
+
def compute_metrics(p: EvalPrediction):
|
52 |
+
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
|
53 |
+
preds = np.argmax(preds, axis=1)
|
54 |
+
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
|
55 |
+
|
56 |
+
|
57 |
+
def train(args, model, train_dataset, test_dataset):
|
58 |
+
"""Create the training loop to load pretrained model and tokenizer and
|
59 |
+
start the training process
|
60 |
+
|
61 |
+
Args:
|
62 |
+
args: read arguments from the runner to set training hyperparameters
|
63 |
+
model: The neural network that you are training
|
64 |
+
train_dataset: The training dataset
|
65 |
+
test_dataset: The test dataset for evaluation
|
66 |
+
"""
|
67 |
+
|
68 |
+
# initialize the tokenizer
|
69 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
70 |
+
metadata.PRETRAINED_MODEL_NAME,
|
71 |
+
use_fast=True,
|
72 |
+
)
|
73 |
+
|
74 |
+
# set training arguments
|
75 |
+
training_args = TrainingArguments(
|
76 |
+
evaluation_strategy="epoch",
|
77 |
+
learning_rate=args.learning_rate,
|
78 |
+
per_device_train_batch_size=args.batch_size,
|
79 |
+
per_device_eval_batch_size=args.batch_size,
|
80 |
+
num_train_epochs=args.num_epochs,
|
81 |
+
weight_decay=args.weight_decay,
|
82 |
+
output_dir=os.path.join("/tmp", args.model_name)
|
83 |
+
)
|
84 |
+
|
85 |
+
# initialize our Trainer
|
86 |
+
trainer = Trainer(
|
87 |
+
model,
|
88 |
+
training_args,
|
89 |
+
train_dataset=train_dataset,
|
90 |
+
eval_dataset=test_dataset,
|
91 |
+
data_collator=default_data_collator,
|
92 |
+
tokenizer=tokenizer,
|
93 |
+
compute_metrics=compute_metrics
|
94 |
+
)
|
95 |
+
|
96 |
+
# add hyperparameter tuning callback to report metrics when enabled
|
97 |
+
if args.hp_tune == "y":
|
98 |
+
trainer.add_callback(HPTuneCallback("accuracy", "eval_accuracy"))
|
99 |
+
|
100 |
+
# training
|
101 |
+
trainer.train()
|
102 |
+
|
103 |
+
return trainer
|
104 |
+
|
105 |
+
|
106 |
+
def run(args):
|
107 |
+
"""Load the data, train, evaluate, and export the model for serving and
|
108 |
+
evaluating.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
args: experiment parameters.
|
112 |
+
"""
|
113 |
+
# Open our dataset
|
114 |
+
train_dataset, test_dataset = utils.load_data(args)
|
115 |
+
|
116 |
+
label_list = train_dataset.unique("label")
|
117 |
+
num_labels = len(label_list)
|
118 |
+
|
119 |
+
# Create the model, loss function, and optimizer
|
120 |
+
text_classifier = model.create(num_labels=num_labels)
|
121 |
+
|
122 |
+
# Train / Test the model
|
123 |
+
trainer = train(args, text_classifier, train_dataset, test_dataset)
|
124 |
+
|
125 |
+
metrics = trainer.evaluate(eval_dataset=test_dataset)
|
126 |
+
trainer.save_metrics("all", metrics)
|
127 |
+
|
128 |
+
# Export the trained model
|
129 |
+
trainer.save_model(os.path.join("/tmp", args.model_name))
|
130 |
+
|
131 |
+
# Save the model to GCS
|
132 |
+
if args.job_dir:
|
133 |
+
utils.save_model(args)
|
134 |
+
else:
|
135 |
+
print(f"Saved model files at {os.path.join('/tmp', args.model_name)}")
|
136 |
+
print(f"To save model files in GCS bucket, please specify job_dir starting with gs://")
|
137 |
+
|
python_package/trainer/metadata.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2019 Google LLC
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Task type can be either 'classification', 'regression', or 'custom'.
|
17 |
+
# This is based on the target feature in the dataset.
|
18 |
+
TASK_TYPE = 'classification'
|
19 |
+
|
20 |
+
# Dataset name
|
21 |
+
DATASET_NAME = 'imdb'
|
22 |
+
|
23 |
+
# pre-trained model name
|
24 |
+
PRETRAINED_MODEL_NAME = 'bert-base-cased'
|
25 |
+
|
26 |
+
# List of the class values (labels) in a classification dataset.
|
27 |
+
TARGET_LABELS = {1:1, 0:0, -1:0}
|
28 |
+
|
29 |
+
|
30 |
+
# maximum sequence length
|
31 |
+
MAX_SEQ_LENGTH = 128
|
python_package/trainer/model.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the \"License\");
|
4 |
+
# you may not use this file except in compliance with the License.\n",
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an \"AS IS\" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from transformers import AutoModelForSequenceClassification
|
16 |
+
from trainer import metadata
|
17 |
+
|
18 |
+
def create(num_labels):
|
19 |
+
"""create the model by loading a pretrained model or define your
|
20 |
+
own
|
21 |
+
|
22 |
+
Args:
|
23 |
+
num_labels: number of target labels
|
24 |
+
"""
|
25 |
+
# Create the model, loss function, and optimizer
|
26 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
27 |
+
metadata.PRETRAINED_MODEL_NAME,
|
28 |
+
num_labels=num_labels
|
29 |
+
)
|
30 |
+
|
31 |
+
return model
|
python_package/trainer/task.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the \"License\");
|
4 |
+
# you may not use this file except in compliance with the License.\n",
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an \"AS IS\" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
|
18 |
+
from trainer import experiment
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
"""Define the task arguments with the default values.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
experiment parameters
|
26 |
+
"""
|
27 |
+
args_parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
|
30 |
+
# Experiment arguments
|
31 |
+
args_parser.add_argument(
|
32 |
+
'--batch-size',
|
33 |
+
help='Batch size for each training and evaluation step.',
|
34 |
+
type=int,
|
35 |
+
default=16)
|
36 |
+
args_parser.add_argument(
|
37 |
+
'--num-epochs',
|
38 |
+
help="""\
|
39 |
+
Maximum number of training data epochs on which to train.
|
40 |
+
If both --train-size and --num-epochs are specified,
|
41 |
+
--train-steps will be: (train-size/train-batch-size) * num-epochs.\
|
42 |
+
""",
|
43 |
+
default=1,
|
44 |
+
type=int,
|
45 |
+
)
|
46 |
+
args_parser.add_argument(
|
47 |
+
'--seed',
|
48 |
+
help='Random seed (default: 42)',
|
49 |
+
type=int,
|
50 |
+
default=42,
|
51 |
+
)
|
52 |
+
|
53 |
+
# Estimator arguments
|
54 |
+
args_parser.add_argument(
|
55 |
+
'--learning-rate',
|
56 |
+
help='Learning rate value for the optimizers.',
|
57 |
+
default=2e-5,
|
58 |
+
type=float)
|
59 |
+
args_parser.add_argument(
|
60 |
+
'--weight-decay',
|
61 |
+
help="""
|
62 |
+
The factor by which the learning rate should decay by the end of the
|
63 |
+
training.
|
64 |
+
|
65 |
+
decayed_learning_rate =
|
66 |
+
learning_rate * decay_rate ^ (global_step / decay_steps)
|
67 |
+
|
68 |
+
If set to 0 (default), then no decay will occur.
|
69 |
+
If set to 0.5, then the learning rate should reach 0.5 of its original
|
70 |
+
value at the end of the training.
|
71 |
+
Note that decay_steps is set to train_steps.
|
72 |
+
""",
|
73 |
+
default=0.01,
|
74 |
+
type=float)
|
75 |
+
|
76 |
+
# Enable hyperparameter
|
77 |
+
args_parser.add_argument(
|
78 |
+
'--hp-tune',
|
79 |
+
default="n",
|
80 |
+
help='Enable hyperparameter tuning. Valida values are: "y" - enable, "n" - disable')
|
81 |
+
|
82 |
+
# Saved model arguments
|
83 |
+
args_parser.add_argument(
|
84 |
+
'--job-dir',
|
85 |
+
default=os.getenv('AIP_MODEL_DIR'),
|
86 |
+
help='GCS location to export models')
|
87 |
+
args_parser.add_argument(
|
88 |
+
'--model-name',
|
89 |
+
default="finetuned-bert-classifier",
|
90 |
+
help='The name of your saved model')
|
91 |
+
|
92 |
+
return args_parser.parse_args()
|
93 |
+
|
94 |
+
|
95 |
+
def main():
|
96 |
+
"""Setup / Start the experiment
|
97 |
+
"""
|
98 |
+
args = get_args()
|
99 |
+
print(args)
|
100 |
+
experiment.run(args)
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
main()
|
python_package/trainer/utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the \"License\");
|
4 |
+
# you may not use this file except in compliance with the License.\n",
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an \"AS IS\" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import datetime
|
17 |
+
|
18 |
+
from google.cloud import storage
|
19 |
+
|
20 |
+
from transformers import AutoTokenizer
|
21 |
+
from datasets import load_dataset, load_metric, ReadInstruction
|
22 |
+
from trainer import metadata
|
23 |
+
|
24 |
+
|
25 |
+
def preprocess_function(examples):
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
27 |
+
metadata.PRETRAINED_MODEL_NAME,
|
28 |
+
use_fast=True,
|
29 |
+
)
|
30 |
+
|
31 |
+
# Tokenize the texts
|
32 |
+
tokenizer_args = (
|
33 |
+
(examples['text'],)
|
34 |
+
)
|
35 |
+
result = tokenizer(*tokenizer_args,
|
36 |
+
padding='max_length',
|
37 |
+
max_length=metadata.MAX_SEQ_LENGTH,
|
38 |
+
truncation=True)
|
39 |
+
|
40 |
+
# TEMP: We can extract this automatically but Unique method of the dataset
|
41 |
+
# is not reporting the label -1 which shows up in the pre-processing
|
42 |
+
# Hence the additional -1 term in the dictionary
|
43 |
+
label_to_id = metadata.TARGET_LABELS
|
44 |
+
|
45 |
+
# Map labels to IDs (not necessary for GLUE tasks)
|
46 |
+
if label_to_id is not None and "label" in examples:
|
47 |
+
result["label"] = [label_to_id[l] for l in examples["label"]]
|
48 |
+
|
49 |
+
return result
|
50 |
+
|
51 |
+
|
52 |
+
def load_data(args):
|
53 |
+
"""Loads the data into two different data loaders. (Train, Test)
|
54 |
+
|
55 |
+
Args:
|
56 |
+
args: arguments passed to the python script
|
57 |
+
"""
|
58 |
+
# Dataset loading repeated here to make this cell idempotent
|
59 |
+
# Since we are over-writing datasets variable
|
60 |
+
dataset = load_dataset(metadata.DATASET_NAME)
|
61 |
+
|
62 |
+
dataset = dataset.map(preprocess_function,
|
63 |
+
batched=True,
|
64 |
+
load_from_cache_file=True)
|
65 |
+
|
66 |
+
train_dataset, test_dataset = dataset["train"], dataset["test"]
|
67 |
+
|
68 |
+
return train_dataset, test_dataset
|
69 |
+
|
70 |
+
|
71 |
+
def save_model(args):
|
72 |
+
"""Saves the model to Google Cloud Storage or local file system
|
73 |
+
|
74 |
+
Args:
|
75 |
+
args: contains name for saved model.
|
76 |
+
"""
|
77 |
+
scheme = 'gs://'
|
78 |
+
if args.job_dir.startswith(scheme):
|
79 |
+
job_dir = args.job_dir.split("/")
|
80 |
+
bucket_name = job_dir[2]
|
81 |
+
object_prefix = "/".join(job_dir[3:]).rstrip("/")
|
82 |
+
|
83 |
+
if object_prefix:
|
84 |
+
model_path = '{}/{}'.format(object_prefix, args.model_name)
|
85 |
+
else:
|
86 |
+
model_path = '{}'.format(args.model_name)
|
87 |
+
|
88 |
+
bucket = storage.Client().bucket(bucket_name)
|
89 |
+
local_path = os.path.join("/tmp", args.model_name)
|
90 |
+
files = [f for f in os.listdir(local_path) if os.path.isfile(os.path.join(local_path, f))]
|
91 |
+
for file in files:
|
92 |
+
local_file = os.path.join(local_path, file)
|
93 |
+
blob = bucket.blob("/".join([model_path, file]))
|
94 |
+
blob.upload_from_filename(local_file)
|
95 |
+
print(f"Saved model files in gs://{bucket_name}/{model_path}")
|
96 |
+
else:
|
97 |
+
print(f"Saved model files at {os.path.join('/tmp', args.model_name)}")
|
98 |
+
print(f"To save model files in GCS bucket, please specify job_dir starting with gs://")
|
99 |
+
|
runs/Jan08_04-05-34_aift-review-classification-multiple-label/events.out.tfevents.1704686768.aift-review-classification-multiple-label
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11bbe9bc3fd1a688e8527b9b8f2d874ffb7e68110df30c66f62ac9ad81315a29
|
3 |
+
size 8687
|
runs/Jan08_04-07-17_aift-review-classification-multiple-label/events.out.tfevents.1704686842.aift-review-classification-multiple-label
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fc40e3f3954e05c2bc3f1492277fb35266ad0c86df183b923fb5d4b8e3340da
|
3 |
+
size 8687
|
runs/Jan08_04-07-17_aift-review-classification-multiple-label/events.out.tfevents.1704687081.aift-review-classification-multiple-label
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f63a8177014d20b0ecfd65c4aa43af47460832c3c1acfa40cf70b5ba7475951
|
3 |
+
size 700
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": true,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_lower_case": true,
|
47 |
+
"mask_token": "[MASK]",
|
48 |
+
"model_max_length": 512,
|
49 |
+
"pad_token": "[PAD]",
|
50 |
+
"sep_token": "[SEP]",
|
51 |
+
"strip_accents": null,
|
52 |
+
"tokenize_chinese_chars": true,
|
53 |
+
"tokenizer_class": "DistilBertTokenizer",
|
54 |
+
"unk_token": "[UNK]"
|
55 |
+
}
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6f3045d2e9afe29c006bc0a52f01855c1a37d7525699533b28f06d6396c303f
|
3 |
+
size 4347
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|