ElisonSherton
commited on
Commit
·
af50acb
1
Parent(s):
91bf3a6
Initial Commit
Browse files- .gitignore +164 -0
- README.md +6 -0
- chapter1/1-transformers-what-can-they-do.ipynb +819 -0
- chapter2/2-behind-the-pipeline.ipynb +340 -0
- chapter2/2-handling-multiple-sequences.ipynb +384 -0
- chapter2/2-tokenizers.ipynb +231 -0
- chapter3/3-a-full-training.ipynb +393 -0
- chapter3/3-fine-tuning-a-model-with-the-Trainer-API.ipynb +685 -0
- chapter3/3-processing-the-data.ipynb +1719 -0
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
artifacts/
|
6 |
+
test-trainer/
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
.pybuilder/
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
# For a library or package, you might want to ignore these files since the code is
|
89 |
+
# intended to run in multiple environments; otherwise, check them in:
|
90 |
+
# .python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# poetry
|
100 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
101 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102 |
+
# commonly ignored for libraries.
|
103 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
104 |
+
#poetry.lock
|
105 |
+
|
106 |
+
# pdm
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
108 |
+
#pdm.lock
|
109 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
110 |
+
# in version control.
|
111 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
112 |
+
.pdm.toml
|
113 |
+
.pdm-python
|
114 |
+
.pdm-build/
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
README.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
4 |
+
|
5 |
+
# Huggingface Course
|
6 |
+
|
7 |
+
This is my first model repo.
|
8 |
+
|
9 |
+
This aims to document my progress in going through the huggingface course and my understanding of different libraries provided by huggingface.
|
chapter1/1-transformers-what-can-they-do.ipynb
ADDED
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Pipeline\n",
|
8 |
+
"\n",
|
9 |
+
"This is the most basic object in huggingface transformers libray. It is a one-stop object for doing everything under the hood and abstracting away a lot of the complexity away from the task at hand like `tokenization`, `preprocessing`, `postprocessing` etc."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [
|
17 |
+
{
|
18 |
+
"name": "stderr",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"/home/huggingface/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
22 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
23 |
+
"No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).\n",
|
24 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
25 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
26 |
+
" warnings.warn(\n"
|
27 |
+
]
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"source": [
|
31 |
+
"from transformers import pipeline\n",
|
32 |
+
"classifier = pipeline(task = \"sentiment-analysis\")"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": 2,
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"sentences = [\n",
|
42 |
+
" \"I have been sleeping a lot lately. Wish I could do more and procrastinate less\",\n",
|
43 |
+
" \"It is a wonderful day today\",\n",
|
44 |
+
" \"What the heck, this software sucks!!\"\n",
|
45 |
+
"]"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 3,
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [
|
53 |
+
{
|
54 |
+
"data": {
|
55 |
+
"text/plain": [
|
56 |
+
"[{'label': 'NEGATIVE', 'score': 0.9991617202758789},\n",
|
57 |
+
" {'label': 'POSITIVE', 'score': 0.999890923500061},\n",
|
58 |
+
" {'label': 'NEGATIVE', 'score': 0.9995805621147156}]"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
"execution_count": 3,
|
62 |
+
"metadata": {},
|
63 |
+
"output_type": "execute_result"
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"source": [
|
67 |
+
"classifier(sentences)"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "markdown",
|
72 |
+
"metadata": {},
|
73 |
+
"source": [
|
74 |
+
"## Zero Shot Classification"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "code",
|
79 |
+
"execution_count": 4,
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"sentences = [\n",
|
84 |
+
" \"Rahul Dravid was a great coach and led India to win the world cup in 2024\",\n",
|
85 |
+
" \"What is a transformer? It is a black box neural network model which can be used to do stuff with sequences\",\n",
|
86 |
+
" \"How can one understand the meaning of life? It is not so simple\",\n",
|
87 |
+
" \"Shaun had a great insight right in the middle of a surgery\"\n",
|
88 |
+
"]\n",
|
89 |
+
"\n",
|
90 |
+
"labels = [\"Sports\", \"Education\", \"Other\"]"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": 5,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [
|
98 |
+
{
|
99 |
+
"name": "stderr",
|
100 |
+
"output_type": "stream",
|
101 |
+
"text": [
|
102 |
+
"No model was supplied, defaulted to facebook/bart-large-mnli and revision c626438 (https://huggingface.co/facebook/bart-large-mnli).\n",
|
103 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
104 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
105 |
+
" warnings.warn(\n"
|
106 |
+
]
|
107 |
+
}
|
108 |
+
],
|
109 |
+
"source": [
|
110 |
+
"classifier = pipeline(\"zero-shot-classification\")"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": 7,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [
|
118 |
+
{
|
119 |
+
"data": {
|
120 |
+
"text/plain": [
|
121 |
+
"[{'sequence': 'Rahul Dravid was a great coach and led India to win the world cup in 2024',\n",
|
122 |
+
" 'labels': ['Sports', 'Other', 'Education'],\n",
|
123 |
+
" 'scores': [0.967433512210846, 0.025695420801639557, 0.006871006917208433]},\n",
|
124 |
+
" {'sequence': 'What is a transformer? It is a black box neural network model which can be used to do stuff with sequences',\n",
|
125 |
+
" 'labels': ['Other', 'Education', 'Sports'],\n",
|
126 |
+
" 'scores': [0.776347279548645, 0.11728236079216003, 0.10637037456035614]},\n",
|
127 |
+
" {'sequence': 'How can one understand the meaning of life? It is not so simple',\n",
|
128 |
+
" 'labels': ['Other', 'Education', 'Sports'],\n",
|
129 |
+
" 'scores': [0.8647233247756958, 0.08910410851240158, 0.046172577887773514]},\n",
|
130 |
+
" {'sequence': 'Shaun had a great insight right in the middle of a surgery',\n",
|
131 |
+
" 'labels': ['Other', 'Sports', 'Education'],\n",
|
132 |
+
" 'scores': [0.7419394850730896, 0.18247079849243164, 0.07558975368738174]}]"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
"execution_count": 7,
|
136 |
+
"metadata": {},
|
137 |
+
"output_type": "execute_result"
|
138 |
+
}
|
139 |
+
],
|
140 |
+
"source": [
|
141 |
+
"classifier(sequences = sentences, candidate_labels = labels)"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "markdown",
|
146 |
+
"metadata": {},
|
147 |
+
"source": [
|
148 |
+
"## Text Generation"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "markdown",
|
153 |
+
"metadata": {},
|
154 |
+
"source": [
|
155 |
+
"### Using default model"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": 8,
|
161 |
+
"metadata": {},
|
162 |
+
"outputs": [
|
163 |
+
{
|
164 |
+
"name": "stderr",
|
165 |
+
"output_type": "stream",
|
166 |
+
"text": [
|
167 |
+
"No model was supplied, defaulted to gpt2 and revision 6c0e608 (https://huggingface.co/gpt2).\n",
|
168 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
169 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
170 |
+
" warnings.warn(\n"
|
171 |
+
]
|
172 |
+
}
|
173 |
+
],
|
174 |
+
"source": [
|
175 |
+
"generator = pipeline(task = \"text-generation\")"
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"cell_type": "code",
|
180 |
+
"execution_count": 9,
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [],
|
183 |
+
"source": [
|
184 |
+
"seed_text = \"Dhoni finishes off in style and the entire Indian team\""
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"execution_count": 11,
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [
|
192 |
+
{
|
193 |
+
"name": "stderr",
|
194 |
+
"output_type": "stream",
|
195 |
+
"text": [
|
196 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/generation/utils.py:1201: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
|
197 |
+
" warnings.warn(\n",
|
198 |
+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
|
199 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/generation/utils.py:1288: UserWarning: Using `max_length`'s default (50) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n",
|
200 |
+
" warnings.warn(\n"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"data": {
|
205 |
+
"text/plain": [
|
206 |
+
"[{'generated_text': 'Dhoni finishes off in style and the entire Indian team look forward to meeting him at home to continue their efforts towards an unbeaten run in this World Cup.'}]"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
"execution_count": 11,
|
210 |
+
"metadata": {},
|
211 |
+
"output_type": "execute_result"
|
212 |
+
}
|
213 |
+
],
|
214 |
+
"source": [
|
215 |
+
"generator(text_inputs = seed_text)"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"execution_count": 12,
|
221 |
+
"metadata": {},
|
222 |
+
"outputs": [
|
223 |
+
{
|
224 |
+
"name": "stderr",
|
225 |
+
"output_type": "stream",
|
226 |
+
"text": [
|
227 |
+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"data": {
|
232 |
+
"text/plain": [
|
233 |
+
"[{'generated_text': \"Dhoni finishes off in style and the entire Indian team is delighted with his victory\\n\\nIndia have failed to impress Pakistan's Ranji Trophy winner\"},\n",
|
234 |
+
" {'generated_text': \"Dhoni finishes off in style and the entire Indian team goes to great lengths to make him comfortable. It's a very important decision for the first\"},\n",
|
235 |
+
" {'generated_text': 'Dhoni finishes off in style and the entire Indian team is immediately in a good position to secure victory.\\n\\nA few weeks from now,'}]"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
"execution_count": 12,
|
239 |
+
"metadata": {},
|
240 |
+
"output_type": "execute_result"
|
241 |
+
}
|
242 |
+
],
|
243 |
+
"source": [
|
244 |
+
"generator(text_inputs = seed_text, num_return_sequences = 3, max_length = 30)"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "markdown",
|
249 |
+
"metadata": {},
|
250 |
+
"source": [
|
251 |
+
"### Using specific model from huggingface hub"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "code",
|
256 |
+
"execution_count": 13,
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [
|
259 |
+
{
|
260 |
+
"name": "stderr",
|
261 |
+
"output_type": "stream",
|
262 |
+
"text": [
|
263 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
264 |
+
" warnings.warn(\n",
|
265 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/generation/utils.py:1201: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
|
266 |
+
" warnings.warn(\n",
|
267 |
+
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"data": {
|
272 |
+
"text/plain": [
|
273 |
+
"[{'generated_text': 'Dhoni finishes off in style and the entire Indian team has their legs.\\n\\n\\nThe match between the West Indian and the Americans was the'},\n",
|
274 |
+
" {'generated_text': 'Dhoni finishes off in style and the entire Indian team is preparing to compete on October 31st.\\n\\nThe squad of India is made up'},\n",
|
275 |
+
" {'generated_text': 'Dhoni finishes off in style and the entire Indian team looks happy to be back as usual this term,\" he added.'}]"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
"execution_count": 13,
|
279 |
+
"metadata": {},
|
280 |
+
"output_type": "execute_result"
|
281 |
+
}
|
282 |
+
],
|
283 |
+
"source": [
|
284 |
+
"generator = pipeline(\"text-generation\", model = \"distilgpt2\")\n",
|
285 |
+
"\n",
|
286 |
+
"generator(text_inputs= seed_text, num_return_sequences = 3, max_length = 30)"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "markdown",
|
291 |
+
"metadata": {},
|
292 |
+
"source": [
|
293 |
+
"## Mask Filling"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"cell_type": "code",
|
298 |
+
"execution_count": 14,
|
299 |
+
"metadata": {},
|
300 |
+
"outputs": [
|
301 |
+
{
|
302 |
+
"name": "stderr",
|
303 |
+
"output_type": "stream",
|
304 |
+
"text": [
|
305 |
+
"No model was supplied, defaulted to distilroberta-base and revision ec58a5b (https://huggingface.co/distilroberta-base).\n",
|
306 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
307 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
308 |
+
" warnings.warn(\n"
|
309 |
+
]
|
310 |
+
}
|
311 |
+
],
|
312 |
+
"source": [
|
313 |
+
"filler = pipeline(\"fill-mask\")"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "code",
|
318 |
+
"execution_count": 16,
|
319 |
+
"metadata": {},
|
320 |
+
"outputs": [
|
321 |
+
{
|
322 |
+
"data": {
|
323 |
+
"text/plain": [
|
324 |
+
"[{'score': 0.07598453760147095,\n",
|
325 |
+
" 'token': 6943,\n",
|
326 |
+
" 'token_str': ' depression',\n",
|
327 |
+
" 'sequence': 'How deep is your depression?'},\n",
|
328 |
+
" {'score': 0.035246096551418304,\n",
|
329 |
+
" 'token': 12172,\n",
|
330 |
+
" 'token_str': ' bubble',\n",
|
331 |
+
" 'sequence': 'How deep is your bubble?'},\n",
|
332 |
+
" {'score': 0.027820784598588943,\n",
|
333 |
+
" 'token': 7530,\n",
|
334 |
+
" 'token_str': ' addiction',\n",
|
335 |
+
" 'sequence': 'How deep is your addiction?'},\n",
|
336 |
+
" {'score': 0.014877567999064922,\n",
|
337 |
+
" 'token': 4683,\n",
|
338 |
+
" 'token_str': ' hole',\n",
|
339 |
+
" 'sequence': 'How deep is your hole?'},\n",
|
340 |
+
" {'score': 0.013593271374702454,\n",
|
341 |
+
" 'token': 1144,\n",
|
342 |
+
" 'token_str': ' heart',\n",
|
343 |
+
" 'sequence': 'How deep is your heart?'}]"
|
344 |
+
]
|
345 |
+
},
|
346 |
+
"execution_count": 16,
|
347 |
+
"metadata": {},
|
348 |
+
"output_type": "execute_result"
|
349 |
+
}
|
350 |
+
],
|
351 |
+
"source": [
|
352 |
+
"filler(\"How deep is your <mask>?\", top_k = 5)"
|
353 |
+
]
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"cell_type": "code",
|
357 |
+
"execution_count": 17,
|
358 |
+
"metadata": {},
|
359 |
+
"outputs": [
|
360 |
+
{
|
361 |
+
"name": "stderr",
|
362 |
+
"output_type": "stream",
|
363 |
+
"text": [
|
364 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
365 |
+
" warnings.warn(\n",
|
366 |
+
"Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
|
367 |
+
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
368 |
+
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
369 |
+
]
|
370 |
+
}
|
371 |
+
],
|
372 |
+
"source": [
|
373 |
+
"filler = pipeline(\"fill-mask\", model = \"bert-base-cased\")"
|
374 |
+
]
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"cell_type": "code",
|
378 |
+
"execution_count": 18,
|
379 |
+
"metadata": {},
|
380 |
+
"outputs": [
|
381 |
+
{
|
382 |
+
"data": {
|
383 |
+
"text/plain": [
|
384 |
+
"[{'score': 0.0551474466919899,\n",
|
385 |
+
" 'token': 1762,\n",
|
386 |
+
" 'token_str': 'heart',\n",
|
387 |
+
" 'sequence': 'How deep is your heart?'},\n",
|
388 |
+
" {'score': 0.04252220690250397,\n",
|
389 |
+
" 'token': 5785,\n",
|
390 |
+
" 'token_str': 'wound',\n",
|
391 |
+
" 'sequence': 'How deep is your wound?'},\n",
|
392 |
+
" {'score': 0.038988541811704636,\n",
|
393 |
+
" 'token': 3960,\n",
|
394 |
+
" 'token_str': 'soul',\n",
|
395 |
+
" 'sequence': 'How deep is your soul?'},\n",
|
396 |
+
" {'score': 0.03589598089456558,\n",
|
397 |
+
" 'token': 2922,\n",
|
398 |
+
" 'token_str': 'throat',\n",
|
399 |
+
" 'sequence': 'How deep is your throat?'},\n",
|
400 |
+
" {'score': 0.0302369873970747,\n",
|
401 |
+
" 'token': 1567,\n",
|
402 |
+
" 'token_str': 'love',\n",
|
403 |
+
" 'sequence': 'How deep is your love?'}]"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
"execution_count": 18,
|
407 |
+
"metadata": {},
|
408 |
+
"output_type": "execute_result"
|
409 |
+
}
|
410 |
+
],
|
411 |
+
"source": [
|
412 |
+
"filler(\"How deep is your [MASK]?\", top_k = 5)"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "markdown",
|
417 |
+
"metadata": {},
|
418 |
+
"source": [
|
419 |
+
"## Named Entity Recognition (NER)"
|
420 |
+
]
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"cell_type": "code",
|
424 |
+
"execution_count": 19,
|
425 |
+
"metadata": {},
|
426 |
+
"outputs": [
|
427 |
+
{
|
428 |
+
"name": "stderr",
|
429 |
+
"output_type": "stream",
|
430 |
+
"text": [
|
431 |
+
"No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).\n",
|
432 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
433 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
434 |
+
" warnings.warn(\n",
|
435 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/pipelines/token_classification.py:157: UserWarning: `grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy=\"simple\"` instead.\n",
|
436 |
+
" warnings.warn(\n"
|
437 |
+
]
|
438 |
+
}
|
439 |
+
],
|
440 |
+
"source": [
|
441 |
+
"ner = pipeline(task = \"ner\", grouped_entities = True)"
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"cell_type": "code",
|
446 |
+
"execution_count": 20,
|
447 |
+
"metadata": {},
|
448 |
+
"outputs": [
|
449 |
+
{
|
450 |
+
"data": {
|
451 |
+
"text/plain": [
|
452 |
+
"[{'entity_group': 'PER',\n",
|
453 |
+
" 'score': 0.9884488,\n",
|
454 |
+
" 'word': 'Sachin Tendulkar',\n",
|
455 |
+
" 'start': 63,\n",
|
456 |
+
" 'end': 79},\n",
|
457 |
+
" {'entity_group': 'ORG',\n",
|
458 |
+
" 'score': 0.9564063,\n",
|
459 |
+
" 'word': 'Indian Cricket Team',\n",
|
460 |
+
" 'start': 89,\n",
|
461 |
+
" 'end': 108}]"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
"execution_count": 20,
|
465 |
+
"metadata": {},
|
466 |
+
"output_type": "execute_result"
|
467 |
+
}
|
468 |
+
],
|
469 |
+
"source": [
|
470 |
+
"ner(\"Hey everyone, please welcome, the chief guest for tonight: Mr. Sachin Tendulkar from the Indian Cricket Team\")"
|
471 |
+
]
|
472 |
+
},
|
473 |
+
{
|
474 |
+
"cell_type": "code",
|
475 |
+
"execution_count": 21,
|
476 |
+
"metadata": {},
|
477 |
+
"outputs": [
|
478 |
+
{
|
479 |
+
"name": "stderr",
|
480 |
+
"output_type": "stream",
|
481 |
+
"text": [
|
482 |
+
"No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).\n",
|
483 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
484 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
485 |
+
" warnings.warn(\n",
|
486 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/pipelines/token_classification.py:157: UserWarning: `grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy=\"none\"` instead.\n",
|
487 |
+
" warnings.warn(\n"
|
488 |
+
]
|
489 |
+
},
|
490 |
+
{
|
491 |
+
"data": {
|
492 |
+
"text/plain": [
|
493 |
+
"[{'entity': 'I-PER',\n",
|
494 |
+
" 'score': 0.9995166,\n",
|
495 |
+
" 'index': 15,\n",
|
496 |
+
" 'word': 'Sa',\n",
|
497 |
+
" 'start': 63,\n",
|
498 |
+
" 'end': 65},\n",
|
499 |
+
" {'entity': 'I-PER',\n",
|
500 |
+
" 'score': 0.9992397,\n",
|
501 |
+
" 'index': 16,\n",
|
502 |
+
" 'word': '##chin',\n",
|
503 |
+
" 'start': 65,\n",
|
504 |
+
" 'end': 69},\n",
|
505 |
+
" {'entity': 'I-PER',\n",
|
506 |
+
" 'score': 0.99916065,\n",
|
507 |
+
" 'index': 17,\n",
|
508 |
+
" 'word': 'Ten',\n",
|
509 |
+
" 'start': 70,\n",
|
510 |
+
" 'end': 73},\n",
|
511 |
+
" {'entity': 'I-PER',\n",
|
512 |
+
" 'score': 0.9957129,\n",
|
513 |
+
" 'index': 18,\n",
|
514 |
+
" 'word': '##du',\n",
|
515 |
+
" 'start': 73,\n",
|
516 |
+
" 'end': 75},\n",
|
517 |
+
" {'entity': 'I-PER',\n",
|
518 |
+
" 'score': 0.9410511,\n",
|
519 |
+
" 'index': 19,\n",
|
520 |
+
" 'word': '##lk',\n",
|
521 |
+
" 'start': 75,\n",
|
522 |
+
" 'end': 77},\n",
|
523 |
+
" {'entity': 'I-PER',\n",
|
524 |
+
" 'score': 0.99601185,\n",
|
525 |
+
" 'index': 20,\n",
|
526 |
+
" 'word': '##ar',\n",
|
527 |
+
" 'start': 77,\n",
|
528 |
+
" 'end': 79},\n",
|
529 |
+
" {'entity': 'I-ORG',\n",
|
530 |
+
" 'score': 0.9637556,\n",
|
531 |
+
" 'index': 23,\n",
|
532 |
+
" 'word': 'Indian',\n",
|
533 |
+
" 'start': 89,\n",
|
534 |
+
" 'end': 95},\n",
|
535 |
+
" {'entity': 'I-ORG',\n",
|
536 |
+
" 'score': 0.9248884,\n",
|
537 |
+
" 'index': 24,\n",
|
538 |
+
" 'word': 'Cricket',\n",
|
539 |
+
" 'start': 96,\n",
|
540 |
+
" 'end': 103},\n",
|
541 |
+
" {'entity': 'I-ORG',\n",
|
542 |
+
" 'score': 0.98057497,\n",
|
543 |
+
" 'index': 25,\n",
|
544 |
+
" 'word': 'Team',\n",
|
545 |
+
" 'start': 104,\n",
|
546 |
+
" 'end': 108}]"
|
547 |
+
]
|
548 |
+
},
|
549 |
+
"execution_count": 21,
|
550 |
+
"metadata": {},
|
551 |
+
"output_type": "execute_result"
|
552 |
+
}
|
553 |
+
],
|
554 |
+
"source": [
|
555 |
+
"ner = pipeline(task = \"ner\", grouped_entities = False)\n",
|
556 |
+
"ner(\"Hey everyone, please welcome, the chief guest for tonight: Mr. Sachin Tendulkar from the Indian Cricket Team\")"
|
557 |
+
]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
"cell_type": "code",
|
561 |
+
"execution_count": 25,
|
562 |
+
"metadata": {},
|
563 |
+
"outputs": [
|
564 |
+
{
|
565 |
+
"name": "stderr",
|
566 |
+
"output_type": "stream",
|
567 |
+
"text": [
|
568 |
+
"No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).\n",
|
569 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
570 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
571 |
+
" warnings.warn(\n"
|
572 |
+
]
|
573 |
+
}
|
574 |
+
],
|
575 |
+
"source": [
|
576 |
+
"pos = pipeline(task = \"token-classification\")"
|
577 |
+
]
|
578 |
+
},
|
579 |
+
{
|
580 |
+
"cell_type": "code",
|
581 |
+
"execution_count": 28,
|
582 |
+
"metadata": {},
|
583 |
+
"outputs": [
|
584 |
+
{
|
585 |
+
"data": {
|
586 |
+
"text/plain": [
|
587 |
+
"[{'entity': 'I-PER',\n",
|
588 |
+
" 'score': 0.99938285,\n",
|
589 |
+
" 'index': 4,\n",
|
590 |
+
" 'word': 'S',\n",
|
591 |
+
" 'start': 11,\n",
|
592 |
+
" 'end': 12},\n",
|
593 |
+
" {'entity': 'I-PER',\n",
|
594 |
+
" 'score': 0.99815494,\n",
|
595 |
+
" 'index': 5,\n",
|
596 |
+
" 'word': '##yl',\n",
|
597 |
+
" 'start': 12,\n",
|
598 |
+
" 'end': 14},\n",
|
599 |
+
" {'entity': 'I-PER',\n",
|
600 |
+
" 'score': 0.9959072,\n",
|
601 |
+
" 'index': 6,\n",
|
602 |
+
" 'word': '##va',\n",
|
603 |
+
" 'start': 14,\n",
|
604 |
+
" 'end': 16},\n",
|
605 |
+
" {'entity': 'I-PER',\n",
|
606 |
+
" 'score': 0.99923277,\n",
|
607 |
+
" 'index': 7,\n",
|
608 |
+
" 'word': '##in',\n",
|
609 |
+
" 'start': 16,\n",
|
610 |
+
" 'end': 18},\n",
|
611 |
+
" {'entity': 'I-ORG',\n",
|
612 |
+
" 'score': 0.9738931,\n",
|
613 |
+
" 'index': 12,\n",
|
614 |
+
" 'word': 'Hu',\n",
|
615 |
+
" 'start': 33,\n",
|
616 |
+
" 'end': 35},\n",
|
617 |
+
" {'entity': 'I-ORG',\n",
|
618 |
+
" 'score': 0.97611505,\n",
|
619 |
+
" 'index': 13,\n",
|
620 |
+
" 'word': '##gging',\n",
|
621 |
+
" 'start': 35,\n",
|
622 |
+
" 'end': 40},\n",
|
623 |
+
" {'entity': 'I-ORG',\n",
|
624 |
+
" 'score': 0.9887976,\n",
|
625 |
+
" 'index': 14,\n",
|
626 |
+
" 'word': 'Face',\n",
|
627 |
+
" 'start': 41,\n",
|
628 |
+
" 'end': 45},\n",
|
629 |
+
" {'entity': 'I-LOC',\n",
|
630 |
+
" 'score': 0.9932106,\n",
|
631 |
+
" 'index': 16,\n",
|
632 |
+
" 'word': 'Brooklyn',\n",
|
633 |
+
" 'start': 49,\n",
|
634 |
+
" 'end': 57}]"
|
635 |
+
]
|
636 |
+
},
|
637 |
+
"execution_count": 28,
|
638 |
+
"metadata": {},
|
639 |
+
"output_type": "execute_result"
|
640 |
+
}
|
641 |
+
],
|
642 |
+
"source": [
|
643 |
+
"pos(\"My name is Sylvain and I work at Hugging Face in Brooklyn.\")"
|
644 |
+
]
|
645 |
+
},
|
646 |
+
{
|
647 |
+
"cell_type": "markdown",
|
648 |
+
"metadata": {},
|
649 |
+
"source": [
|
650 |
+
"## Question Answering"
|
651 |
+
]
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"cell_type": "code",
|
655 |
+
"execution_count": 30,
|
656 |
+
"metadata": {},
|
657 |
+
"outputs": [
|
658 |
+
{
|
659 |
+
"name": "stderr",
|
660 |
+
"output_type": "stream",
|
661 |
+
"text": [
|
662 |
+
"No model was supplied, defaulted to distilbert-base-cased-distilled-squad and revision 626af31 (https://huggingface.co/distilbert-base-cased-distilled-squad).\n",
|
663 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
664 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
665 |
+
" warnings.warn(\n"
|
666 |
+
]
|
667 |
+
},
|
668 |
+
{
|
669 |
+
"data": {
|
670 |
+
"text/plain": [
|
671 |
+
"{'score': 0.21678458154201508,\n",
|
672 |
+
" 'start': 48,\n",
|
673 |
+
" 'end': 76,\n",
|
674 |
+
" 'answer': 'I wish I could get some rest'}"
|
675 |
+
]
|
676 |
+
},
|
677 |
+
"execution_count": 30,
|
678 |
+
"metadata": {},
|
679 |
+
"output_type": "execute_result"
|
680 |
+
}
|
681 |
+
],
|
682 |
+
"source": [
|
683 |
+
"bot = pipeline(\"question-answering\")\n",
|
684 |
+
"bot(\n",
|
685 |
+
" question = \"How am I doing?\",\n",
|
686 |
+
" context = \"I have just came back from a very busy trip and I wish I could get some rest.\"\n",
|
687 |
+
")"
|
688 |
+
]
|
689 |
+
},
|
690 |
+
{
|
691 |
+
"cell_type": "markdown",
|
692 |
+
"metadata": {},
|
693 |
+
"source": [
|
694 |
+
"This is a model which is meant to extract the phrases from the given text which could be the answer and does not generate the answer."
|
695 |
+
]
|
696 |
+
},
|
697 |
+
{
|
698 |
+
"cell_type": "markdown",
|
699 |
+
"metadata": {},
|
700 |
+
"source": [
|
701 |
+
"## Summarization"
|
702 |
+
]
|
703 |
+
},
|
704 |
+
{
|
705 |
+
"cell_type": "code",
|
706 |
+
"execution_count": 31,
|
707 |
+
"metadata": {},
|
708 |
+
"outputs": [
|
709 |
+
{
|
710 |
+
"name": "stderr",
|
711 |
+
"output_type": "stream",
|
712 |
+
"text": [
|
713 |
+
"No model was supplied, defaulted to sshleifer/distilbart-cnn-12-6 and revision a4f8f3e (https://huggingface.co/sshleifer/distilbart-cnn-12-6).\n",
|
714 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
715 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
716 |
+
" warnings.warn(\n"
|
717 |
+
]
|
718 |
+
},
|
719 |
+
{
|
720 |
+
"data": {
|
721 |
+
"text/plain": [
|
722 |
+
"[{'summary_text': ' America has changed dramatically during recent years . The number of engineering graduates in the U.S. has declined in traditional engineering disciplines such as mechanical, civil, electrical, chemical, and aeronautical engineering . Rapidly developing economies such as China and India continue to encourage and advance the teaching of engineering .'}]"
|
723 |
+
]
|
724 |
+
},
|
725 |
+
"execution_count": 31,
|
726 |
+
"metadata": {},
|
727 |
+
"output_type": "execute_result"
|
728 |
+
}
|
729 |
+
],
|
730 |
+
"source": [
|
731 |
+
"summary = pipeline(\"summarization\")\n",
|
732 |
+
"\n",
|
733 |
+
"summary(\n",
|
734 |
+
"\"\"\"\n",
|
735 |
+
" America has changed dramatically during recent years. Not only has the number of \n",
|
736 |
+
" graduates in traditional engineering disciplines such as mechanical, civil, \n",
|
737 |
+
" electrical, chemical, and aeronautical engineering declined, but in most of \n",
|
738 |
+
" the premier American universities engineering curricula now concentrate on \n",
|
739 |
+
" and encourage largely the study of engineering science. As a result, there \n",
|
740 |
+
" are declining offerings in engineering subjects dealing with infrastructure, \n",
|
741 |
+
" the environment, and related issues, and greater concentration on high \n",
|
742 |
+
" technology subjects, largely supporting increasingly complex scientific \n",
|
743 |
+
" developments. While the latter is important, it should not be at the expense \n",
|
744 |
+
" of more traditional engineering.\n",
|
745 |
+
"\n",
|
746 |
+
" Rapidly developing economies such as China and India, as well as other \n",
|
747 |
+
" industrial countries in Europe and Asia, continue to encourage and advance \n",
|
748 |
+
" the teaching of engineering. Both China and India, respectively, graduate \n",
|
749 |
+
" six and eight times as many traditional engineers as does the United States. \n",
|
750 |
+
" Other industrial countries at minimum maintain their output, while America \n",
|
751 |
+
" suffers an increasingly serious decline in the number of engineering graduates \n",
|
752 |
+
" and a lack of well-educated engineers.\n",
|
753 |
+
"\"\"\"\n",
|
754 |
+
")"
|
755 |
+
]
|
756 |
+
},
|
757 |
+
{
|
758 |
+
"cell_type": "markdown",
|
759 |
+
"metadata": {},
|
760 |
+
"source": [
|
761 |
+
"## Translation"
|
762 |
+
]
|
763 |
+
},
|
764 |
+
{
|
765 |
+
"cell_type": "code",
|
766 |
+
"execution_count": 34,
|
767 |
+
"metadata": {},
|
768 |
+
"outputs": [
|
769 |
+
{
|
770 |
+
"ename": "KeyError",
|
771 |
+
"evalue": "'translation'",
|
772 |
+
"output_type": "error",
|
773 |
+
"traceback": [
|
774 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
775 |
+
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
776 |
+
"Cell \u001b[0;32mIn[34], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m translator \u001b[38;5;241m=\u001b[39m \u001b[43mpipeline\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtranslation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mHariSekhar/Eng_Marathi_translation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
777 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/transformers/pipelines/__init__.py:692\u001b[0m, in \u001b[0;36mpipeline\u001b[0;34m(task, model, config, tokenizer, feature_extractor, image_processor, framework, revision, use_fast, use_auth_token, device, device_map, torch_dtype, trust_remote_code, model_kwargs, pipeline_class, **kwargs)\u001b[0m\n\u001b[1;32m 690\u001b[0m hub_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39m_commit_hash\n\u001b[1;32m 691\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 692\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[43mAutoConfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 693\u001b[0m hub_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39m_commit_hash\n\u001b[1;32m 695\u001b[0m custom_tasks \u001b[38;5;241m=\u001b[39m {}\n",
|
778 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py:917\u001b[0m, in \u001b[0;36mAutoConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m config_class\u001b[38;5;241m.\u001b[39mfrom_pretrained(pretrained_model_name_or_path, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 916\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict:\n\u001b[0;32m--> 917\u001b[0m config_class \u001b[38;5;241m=\u001b[39m \u001b[43mCONFIG_MAPPING\u001b[49m\u001b[43m[\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel_type\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 918\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m config_class\u001b[38;5;241m.\u001b[39mfrom_dict(config_dict, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39munused_kwargs)\n\u001b[1;32m 919\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 920\u001b[0m \u001b[38;5;66;03m# Fallback: use pattern matching on the string.\u001b[39;00m\n\u001b[1;32m 921\u001b[0m \u001b[38;5;66;03m# We go from longer names to shorter names to catch roberta before bert (for instance)\u001b[39;00m\n",
|
779 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py:623\u001b[0m, in \u001b[0;36m_LazyConfigMapping.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 621\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_extra_content[key]\n\u001b[1;32m 622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mapping:\n\u001b[0;32m--> 623\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key)\n\u001b[1;32m 624\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mapping[key]\n\u001b[1;32m 625\u001b[0m module_name \u001b[38;5;241m=\u001b[39m model_type_to_module_name(key)\n",
|
780 |
+
"\u001b[0;31mKeyError\u001b[0m: 'translation'"
|
781 |
+
]
|
782 |
+
}
|
783 |
+
],
|
784 |
+
"source": [
|
785 |
+
"translator = pipeline(\"translation\", model = \"HariSekhar/Eng_Marathi_translation\")"
|
786 |
+
]
|
787 |
+
},
|
788 |
+
{
|
789 |
+
"cell_type": "code",
|
790 |
+
"execution_count": null,
|
791 |
+
"metadata": {},
|
792 |
+
"outputs": [],
|
793 |
+
"source": [
|
794 |
+
"translator(\"\")"
|
795 |
+
]
|
796 |
+
}
|
797 |
+
],
|
798 |
+
"metadata": {
|
799 |
+
"kernelspec": {
|
800 |
+
"display_name": "Python 3",
|
801 |
+
"language": "python",
|
802 |
+
"name": "python3"
|
803 |
+
},
|
804 |
+
"language_info": {
|
805 |
+
"codemirror_mode": {
|
806 |
+
"name": "ipython",
|
807 |
+
"version": 3
|
808 |
+
},
|
809 |
+
"file_extension": ".py",
|
810 |
+
"mimetype": "text/x-python",
|
811 |
+
"name": "python",
|
812 |
+
"nbconvert_exporter": "python",
|
813 |
+
"pygments_lexer": "ipython3",
|
814 |
+
"version": "3.10.14"
|
815 |
+
}
|
816 |
+
},
|
817 |
+
"nbformat": 4,
|
818 |
+
"nbformat_minor": 2
|
819 |
+
}
|
chapter2/2-behind-the-pipeline.ipynb
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Using the pipeline function"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 4,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"name": "stderr",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).\n",
|
20 |
+
"Using a pipeline without specifying a model name and revision in production is not recommended.\n",
|
21 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
22 |
+
" warnings.warn(\n"
|
23 |
+
]
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"source": [
|
27 |
+
"from transformers import pipeline\n",
|
28 |
+
"\n",
|
29 |
+
"classifier = pipeline(task=\"sentiment-analysis\")\n",
|
30 |
+
"\n",
|
31 |
+
"inputs = [\"This was so bad I couldn´t finish it. The actresses are so bad at acting it feels like a bad comedy from minute one. The high rated reviews is obviously from friend/family and is pure BS.\",\n",
|
32 |
+
" \"I thought the cast was great. Brianna and Emma were exceptionaly talented in thier characters. Fun film.\"]\n",
|
33 |
+
"\n",
|
34 |
+
"outputs = classifier(inputs)"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 2,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [
|
42 |
+
{
|
43 |
+
"data": {
|
44 |
+
"text/plain": [
|
45 |
+
"[{'label': 'NEGATIVE', 'score': 0.9995231628417969},\n",
|
46 |
+
" {'label': 'POSITIVE', 'score': 0.9998352527618408}]"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
"execution_count": 2,
|
50 |
+
"metadata": {},
|
51 |
+
"output_type": "execute_result"
|
52 |
+
}
|
53 |
+
],
|
54 |
+
"source": [
|
55 |
+
"outputs"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"metadata": {},
|
61 |
+
"source": [
|
62 |
+
"# Defining tokenizer and model manually"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "markdown",
|
67 |
+
"metadata": {},
|
68 |
+
"source": [
|
69 |
+
"## Tokenizer"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": 3,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [
|
77 |
+
{
|
78 |
+
"name": "stderr",
|
79 |
+
"output_type": "stream",
|
80 |
+
"text": [
|
81 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
82 |
+
" warnings.warn(\n"
|
83 |
+
]
|
84 |
+
}
|
85 |
+
],
|
86 |
+
"source": [
|
87 |
+
"from transformers import AutoTokenizer\n",
|
88 |
+
"\n",
|
89 |
+
"checkpoint = \"distilbert/distilbert-base-uncased-finetuned-sst-2-english\"\n",
|
90 |
+
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": 23,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"from pprint import pprint\n",
|
100 |
+
"tokenized_inputs = tokenizer(\n",
|
101 |
+
" inputs, padding=True, truncation=True, return_tensors=\"pt\")"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 24,
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [
|
109 |
+
{
|
110 |
+
"name": "stdout",
|
111 |
+
"output_type": "stream",
|
112 |
+
"text": [
|
113 |
+
"tensor([ 101, 2023, 2001, 2061, 2919, 1045, 2481, 29658, 2102, 3926,\n",
|
114 |
+
" 2009, 1012, 1996, 19910, 2024, 2061, 2919, 2012, 3772, 2009,\n",
|
115 |
+
" 5683, 2066, 1037, 2919, 4038, 2013, 3371, 2028, 1012, 1996,\n",
|
116 |
+
" 2152, 6758, 4391, 2003, 5525, 2013, 2767, 1013, 2155, 1998,\n",
|
117 |
+
" 2003, 5760, 18667, 1012, 102])\n",
|
118 |
+
"tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
119 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n"
|
120 |
+
]
|
121 |
+
}
|
122 |
+
],
|
123 |
+
"source": [
|
124 |
+
"print(tokenized_inputs[\"input_ids\"][0], tokenized_inputs[\"attention_mask\"][0], sep = \"\\n\")"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"execution_count": 25,
|
130 |
+
"metadata": {},
|
131 |
+
"outputs": [
|
132 |
+
{
|
133 |
+
"name": "stdout",
|
134 |
+
"output_type": "stream",
|
135 |
+
"text": [
|
136 |
+
"tensor([ 101, 1045, 2245, 1996, 3459, 2001, 2307, 1012, 25558, 1998,\n",
|
137 |
+
" 5616, 2020, 11813, 2100, 10904, 1999, 16215, 3771, 3494, 1012,\n",
|
138 |
+
" 4569, 2143, 1012, 102, 0, 0, 0, 0, 0, 0,\n",
|
139 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
140 |
+
" 0, 0, 0, 0, 0])\n",
|
141 |
+
"tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
142 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n"
|
143 |
+
]
|
144 |
+
}
|
145 |
+
],
|
146 |
+
"source": [
|
147 |
+
"print(tokenized_inputs[\"input_ids\"][1], tokenized_inputs[\"attention_mask\"][1], sep = \"\\n\")"
|
148 |
+
]
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"execution_count": 26,
|
153 |
+
"metadata": {},
|
154 |
+
"outputs": [
|
155 |
+
{
|
156 |
+
"data": {
|
157 |
+
"text/plain": [
|
158 |
+
"(45, 45)"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
"execution_count": 26,
|
162 |
+
"metadata": {},
|
163 |
+
"output_type": "execute_result"
|
164 |
+
}
|
165 |
+
],
|
166 |
+
"source": [
|
167 |
+
"len(tokenized_inputs[\"input_ids\"][0]), len(tokenized_inputs[\"input_ids\"][1])"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "markdown",
|
172 |
+
"metadata": {},
|
173 |
+
"source": [
|
174 |
+
"## Model"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": 56,
|
180 |
+
"metadata": {},
|
181 |
+
"outputs": [
|
182 |
+
{
|
183 |
+
"name": "stderr",
|
184 |
+
"output_type": "stream",
|
185 |
+
"text": [
|
186 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
187 |
+
" warnings.warn(\n"
|
188 |
+
]
|
189 |
+
}
|
190 |
+
],
|
191 |
+
"source": [
|
192 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
193 |
+
"import torch\n",
|
194 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
|
195 |
+
"model.eval();"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": 57,
|
201 |
+
"metadata": {},
|
202 |
+
"outputs": [],
|
203 |
+
"source": [
|
204 |
+
"with torch.no_grad():\n",
|
205 |
+
" outputs = model(**tokenized_inputs)"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"execution_count": 58,
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [
|
213 |
+
{
|
214 |
+
"name": "stdout",
|
215 |
+
"output_type": "stream",
|
216 |
+
"text": [
|
217 |
+
"['__annotations__', '__class__', '__class_getitem__', '__contains__', '__dataclass_fields__', '__dataclass_params__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__match_args__', '__module__', '__ne__', '__new__', '__or__', '__post_init__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', 'attentions', 'clear', 'copy', 'fromkeys', 'get', 'hidden_states', 'items', 'keys', 'logits', 'loss', 'move_to_end', 'pop', 'popitem', 'setdefault', 'to_tuple', 'update', 'values']\n"
|
218 |
+
]
|
219 |
+
}
|
220 |
+
],
|
221 |
+
"source": [
|
222 |
+
"print(dir(outputs))"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "code",
|
227 |
+
"execution_count": 59,
|
228 |
+
"metadata": {},
|
229 |
+
"outputs": [
|
230 |
+
{
|
231 |
+
"data": {
|
232 |
+
"text/plain": [
|
233 |
+
"tensor([[ 4.2415, -3.4063],\n",
|
234 |
+
" [-4.1783, 4.5328]])"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
"execution_count": 59,
|
238 |
+
"metadata": {},
|
239 |
+
"output_type": "execute_result"
|
240 |
+
}
|
241 |
+
],
|
242 |
+
"source": [
|
243 |
+
"outputs.logits"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": 60,
|
249 |
+
"metadata": {},
|
250 |
+
"outputs": [
|
251 |
+
{
|
252 |
+
"data": {
|
253 |
+
"text/plain": [
|
254 |
+
"tensor([[9.9952e-01, 4.7686e-04],\n",
|
255 |
+
" [1.6471e-04, 9.9984e-01]])"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
"execution_count": 60,
|
259 |
+
"metadata": {},
|
260 |
+
"output_type": "execute_result"
|
261 |
+
}
|
262 |
+
],
|
263 |
+
"source": [
|
264 |
+
"import torch.nn.functional as F\n",
|
265 |
+
"F.softmax(outputs.logits, dim = -1)"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": 66,
|
271 |
+
"metadata": {},
|
272 |
+
"outputs": [],
|
273 |
+
"source": [
|
274 |
+
"predictions = outputs.logits.argmax(dim = -1)\n",
|
275 |
+
"pred_probas = F.softmax(outputs.logits, dim = -1).max(dim = -1).values\n",
|
276 |
+
"\n",
|
277 |
+
"preds = []\n",
|
278 |
+
"for p, pp in zip(predictions, pred_probas):\n",
|
279 |
+
" preds.append({'label': model.config.id2label[p.item()], 'score': pp.item()})"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"cell_type": "code",
|
284 |
+
"execution_count": 67,
|
285 |
+
"metadata": {},
|
286 |
+
"outputs": [
|
287 |
+
{
|
288 |
+
"data": {
|
289 |
+
"text/plain": [
|
290 |
+
"[{'label': 'NEGATIVE', 'score': 0.9995231628417969},\n",
|
291 |
+
" {'label': 'POSITIVE', 'score': 0.9998352527618408}]"
|
292 |
+
]
|
293 |
+
},
|
294 |
+
"execution_count": 67,
|
295 |
+
"metadata": {},
|
296 |
+
"output_type": "execute_result"
|
297 |
+
}
|
298 |
+
],
|
299 |
+
"source": [
|
300 |
+
"preds"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "markdown",
|
305 |
+
"metadata": {},
|
306 |
+
"source": [
|
307 |
+
"```\n",
|
308 |
+
"\n",
|
309 |
+
"Reference Output\n",
|
310 |
+
"\n",
|
311 |
+
"---\n",
|
312 |
+
"\n",
|
313 |
+
"[{'label': 'NEGATIVE', 'score': 0.9995231628417969},\n",
|
314 |
+
" {'label': 'POSITIVE', 'score': 0.9998352527618408}]\n",
|
315 |
+
"````"
|
316 |
+
]
|
317 |
+
}
|
318 |
+
],
|
319 |
+
"metadata": {
|
320 |
+
"kernelspec": {
|
321 |
+
"display_name": "Python 3",
|
322 |
+
"language": "python",
|
323 |
+
"name": "python3"
|
324 |
+
},
|
325 |
+
"language_info": {
|
326 |
+
"codemirror_mode": {
|
327 |
+
"name": "ipython",
|
328 |
+
"version": 3
|
329 |
+
},
|
330 |
+
"file_extension": ".py",
|
331 |
+
"mimetype": "text/x-python",
|
332 |
+
"name": "python",
|
333 |
+
"nbconvert_exporter": "python",
|
334 |
+
"pygments_lexer": "ipython3",
|
335 |
+
"version": "3.10.14"
|
336 |
+
}
|
337 |
+
},
|
338 |
+
"nbformat": 4,
|
339 |
+
"nbformat_minor": 2
|
340 |
+
}
|
chapter2/2-handling-multiple-sequences.ipynb
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/home/huggingface/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
14 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
15 |
+
" warnings.warn(\n"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"import torch\n",
|
21 |
+
"from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
|
22 |
+
"\n",
|
23 |
+
"checkpoint = \"distilbert-base-uncased-finetuned-sst-2-english\"\n",
|
24 |
+
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
25 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
|
26 |
+
"model.eval();"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "markdown",
|
31 |
+
"metadata": {},
|
32 |
+
"source": [
|
33 |
+
"# Try forward pass on single Example"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 2,
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [
|
41 |
+
{
|
42 |
+
"data": {
|
43 |
+
"text/plain": [
|
44 |
+
"tensor([2057, 2342, 2062, 3737, 7435, 1010, 6145, 1998, 9559, 1999, 2256, 3842,\n",
|
45 |
+
" 1012])"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
"execution_count": 2,
|
49 |
+
"metadata": {},
|
50 |
+
"output_type": "execute_result"
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"sequence = \"We need more quality doctors, engineers and lawyers in our nation.\"\n",
|
55 |
+
"token_ids = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sequence)))\n",
|
56 |
+
"token_ids"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 3,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [
|
64 |
+
{
|
65 |
+
"ename": "RuntimeError",
|
66 |
+
"evalue": "The size of tensor a (13) must match the size of tensor b (512) at non-singleton dimension 1",
|
67 |
+
"output_type": "error",
|
68 |
+
"traceback": [
|
69 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
70 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
71 |
+
"Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 2\u001b[0m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtoken_ids\u001b[49m\u001b[43m)\u001b[49m\n",
|
72 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
73 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
74 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/transformers/models/distilbert/modeling_distilbert.py:763\u001b[0m, in \u001b[0;36mDistilBertForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 756\u001b[0m \u001b[38;5;124;03mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m 757\u001b[0m \u001b[38;5;124;03m Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m 758\u001b[0m \u001b[38;5;124;03m config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m 759\u001b[0m \u001b[38;5;124;03m `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m 760\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 761\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m--> 763\u001b[0m distilbert_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdistilbert\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 764\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 765\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 766\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 767\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 768\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 769\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 771\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 772\u001b[0m hidden_state \u001b[38;5;241m=\u001b[39m distilbert_output[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;66;03m# (bs, seq_len, dim)\u001b[39;00m\n\u001b[1;32m 773\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m hidden_state[:, \u001b[38;5;241m0\u001b[39m] \u001b[38;5;66;03m# (bs, dim)\u001b[39;00m\n",
|
75 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
76 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
77 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/transformers/models/distilbert/modeling_distilbert.py:581\u001b[0m, in \u001b[0;36mDistilBertModel.forward\u001b[0;34m(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 578\u001b[0m \u001b[38;5;66;03m# Prepare head mask if needed\u001b[39;00m\n\u001b[1;32m 579\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[0;32m--> 581\u001b[0m embeddings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membeddings\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (bs, seq_length, dim)\u001b[39;00m\n\u001b[1;32m 583\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransformer(\n\u001b[1;32m 584\u001b[0m x\u001b[38;5;241m=\u001b[39membeddings,\n\u001b[1;32m 585\u001b[0m attn_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 589\u001b[0m return_dict\u001b[38;5;241m=\u001b[39mreturn_dict,\n\u001b[1;32m 590\u001b[0m )\n",
|
78 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
79 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
80 |
+
"File \u001b[0;32m~/huggingface/lib/python3.10/site-packages/transformers/models/distilbert/modeling_distilbert.py:135\u001b[0m, in \u001b[0;36mEmbeddings.forward\u001b[0;34m(self, input_ids, input_embeds)\u001b[0m\n\u001b[1;32m 131\u001b[0m position_ids \u001b[38;5;241m=\u001b[39m position_ids\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mexpand_as(input_ids) \u001b[38;5;66;03m# (bs, max_seq_length)\u001b[39;00m\n\u001b[1;32m 133\u001b[0m position_embeddings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mposition_embeddings(position_ids) \u001b[38;5;66;03m# (bs, max_seq_length, dim)\u001b[39;00m\n\u001b[0;32m--> 135\u001b[0m embeddings \u001b[38;5;241m=\u001b[39m \u001b[43minput_embeds\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m \u001b[38;5;66;03m# (bs, max_seq_length, dim)\u001b[39;00m\n\u001b[1;32m 136\u001b[0m embeddings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLayerNorm(embeddings) \u001b[38;5;66;03m# (bs, max_seq_length, dim)\u001b[39;00m\n\u001b[1;32m 137\u001b[0m embeddings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout(embeddings) \u001b[38;5;66;03m# (bs, max_seq_length, dim)\u001b[39;00m\n",
|
81 |
+
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (13) must match the size of tensor b (512) at non-singleton dimension 1"
|
82 |
+
]
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"source": [
|
86 |
+
"with torch.no_grad():\n",
|
87 |
+
" model(token_ids)"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [
|
95 |
+
{
|
96 |
+
"data": {
|
97 |
+
"text/plain": [
|
98 |
+
"tensor([2057, 2342, 2062, 3737, 7435, 1010, 6145, 1998, 9559, 1999, 2256, 3842,\n",
|
99 |
+
" 1012])"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
"execution_count": 4,
|
103 |
+
"metadata": {},
|
104 |
+
"output_type": "execute_result"
|
105 |
+
}
|
106 |
+
],
|
107 |
+
"source": [
|
108 |
+
"token_ids"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "markdown",
|
113 |
+
"metadata": {},
|
114 |
+
"source": [
|
115 |
+
"As seen above our model does not have a batch dimension because of which we are seeing this issue. Let's add a batch dimension and then pass our sequence through the model"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": null,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [
|
123 |
+
{
|
124 |
+
"data": {
|
125 |
+
"text/plain": [
|
126 |
+
"SequenceClassifierOutput(loss=None, logits=tensor([[ 1.2781, -1.0656]]), hidden_states=None, attentions=None)"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
"execution_count": 5,
|
130 |
+
"metadata": {},
|
131 |
+
"output_type": "execute_result"
|
132 |
+
}
|
133 |
+
],
|
134 |
+
"source": [
|
135 |
+
"with torch.no_grad():\n",
|
136 |
+
" out = model(token_ids.unsqueeze(0))\n",
|
137 |
+
"out"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "markdown",
|
142 |
+
"metadata": {},
|
143 |
+
"source": [
|
144 |
+
"Let's try by duplicating the input if we get the same logits"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": null,
|
150 |
+
"metadata": {},
|
151 |
+
"outputs": [
|
152 |
+
{
|
153 |
+
"data": {
|
154 |
+
"text/plain": [
|
155 |
+
"tensor([[ 1.2781, -1.0656],\n",
|
156 |
+
" [ 1.2781, -1.0656]])"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
"execution_count": 6,
|
160 |
+
"metadata": {},
|
161 |
+
"output_type": "execute_result"
|
162 |
+
}
|
163 |
+
],
|
164 |
+
"source": [
|
165 |
+
"with torch.no_grad():\n",
|
166 |
+
" inp = torch.cat([token_ids.unsqueeze(0), token_ids.unsqueeze(0)], dim = 0)\n",
|
167 |
+
" out = model(inp)\n",
|
168 |
+
"out.logits"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "markdown",
|
173 |
+
"metadata": {},
|
174 |
+
"source": [
|
175 |
+
"# Input padding"
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"cell_type": "code",
|
180 |
+
"execution_count": null,
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [
|
183 |
+
{
|
184 |
+
"name": "stdout",
|
185 |
+
"output_type": "stream",
|
186 |
+
"text": [
|
187 |
+
"tensor([[ 1.5694, -1.3895]], grad_fn=<AddmmBackward0>)\n",
|
188 |
+
"tensor([[ 0.5803, -0.4125]], grad_fn=<AddmmBackward0>)\n",
|
189 |
+
"tensor([[ 1.5694, -1.3895],\n",
|
190 |
+
" [ 0.9907, -0.9139]], grad_fn=<AddmmBackward0>)\n"
|
191 |
+
]
|
192 |
+
}
|
193 |
+
],
|
194 |
+
"source": [
|
195 |
+
"padding_id = 100\n",
|
196 |
+
"\n",
|
197 |
+
"batched_ids = [\n",
|
198 |
+
" [200, 200, 200],\n",
|
199 |
+
" [200, 200, padding_id],\n",
|
200 |
+
"]\n",
|
201 |
+
"\n",
|
202 |
+
"print(model(torch.tensor([batched_ids[0]])).logits)\n",
|
203 |
+
"print(model(torch.tensor([batched_ids[1][:2]])).logits)\n",
|
204 |
+
"print(model(torch.tensor(batched_ids)).logits)"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "markdown",
|
209 |
+
"metadata": {},
|
210 |
+
"source": [
|
211 |
+
"There’s something wrong with the logits in our batched predictions: the second row should be the same as the logits for the second sentence, but we’ve got completely different values!\n",
|
212 |
+
"\n",
|
213 |
+
"This is because when we add padding, we need to make sure we nullify it's impact during the attention matrix computation step. This is why we need a mask so that we can explicily shut these tokens from the attention calculation."
|
214 |
+
]
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "markdown",
|
218 |
+
"metadata": {},
|
219 |
+
"source": [
|
220 |
+
"# Cross checking the working of attention masks"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "code",
|
225 |
+
"execution_count": null,
|
226 |
+
"metadata": {},
|
227 |
+
"outputs": [
|
228 |
+
{
|
229 |
+
"data": {
|
230 |
+
"text/plain": [
|
231 |
+
"{'input_ids': tensor([[ 101, 1045, 1521, 2310, 2042, 3403, 2005, 1037, 17662, 12172,\n",
|
232 |
+
" 2607, 2026, 2878, 2166, 1012, 102],\n",
|
233 |
+
" [ 101, 1045, 5223, 2023, 2061, 2172, 999, 102, 0, 0,\n",
|
234 |
+
" 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
235 |
+
" [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])}"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
"execution_count": 11,
|
239 |
+
"metadata": {},
|
240 |
+
"output_type": "execute_result"
|
241 |
+
}
|
242 |
+
],
|
243 |
+
"source": [
|
244 |
+
"tokens"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": null,
|
250 |
+
"metadata": {},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"sentences = [\"I’ve been waiting for a HuggingFace course my whole life.\",\n",
|
254 |
+
" \"I hate this so much!\"]\n",
|
255 |
+
"tokens = tokenizer(sentences, padding=True, return_tensors=\"pt\")\n",
|
256 |
+
"with torch.no_grad():\n",
|
257 |
+
" out = model(**tokens)"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": null,
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [
|
265 |
+
{
|
266 |
+
"name": "stdout",
|
267 |
+
"output_type": "stream",
|
268 |
+
"text": [
|
269 |
+
"{'input_ids': tensor([[ 101, 1045, 1521, 2310, 2042, 3403, 2005, 1037, 17662, 12172,\n",
|
270 |
+
" 2607, 2026, 2878, 2166, 1012, 102],\n",
|
271 |
+
" [ 101, 1045, 5223, 2023, 2061, 2172, 999, 102, 0, 0,\n",
|
272 |
+
" 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
273 |
+
" [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
|
274 |
+
]
|
275 |
+
}
|
276 |
+
],
|
277 |
+
"source": [
|
278 |
+
"print(tokens)"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": null,
|
284 |
+
"metadata": {},
|
285 |
+
"outputs": [
|
286 |
+
{
|
287 |
+
"data": {
|
288 |
+
"text/plain": [
|
289 |
+
"tensor([[-1.5979, 1.6390],\n",
|
290 |
+
" [ 4.1692, -3.3464]])"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
"execution_count": 32,
|
294 |
+
"metadata": {},
|
295 |
+
"output_type": "execute_result"
|
296 |
+
}
|
297 |
+
],
|
298 |
+
"source": [
|
299 |
+
"out.logits"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "code",
|
304 |
+
"execution_count": null,
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [
|
307 |
+
{
|
308 |
+
"name": "stdout",
|
309 |
+
"output_type": "stream",
|
310 |
+
"text": [
|
311 |
+
"tensor([[-1.5979, 1.6390]])\n"
|
312 |
+
]
|
313 |
+
}
|
314 |
+
],
|
315 |
+
"source": [
|
316 |
+
"# Do the entire forward pass manually for sentence 1\n",
|
317 |
+
"\n",
|
318 |
+
"# Tokenize the sentence and get the tokenids\n",
|
319 |
+
"token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentences[0]))\n",
|
320 |
+
"\n",
|
321 |
+
"# Add the special token CLS and SEP at the start and end of the token rspectively\n",
|
322 |
+
"token_ids = [101] + token_ids + [102]\n",
|
323 |
+
"\n",
|
324 |
+
"# Perform the forward pass and print the logits\n",
|
325 |
+
"with torch.no_grad():\n",
|
326 |
+
" print(model(torch.tensor([token_ids])).logits)"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "code",
|
331 |
+
"execution_count": null,
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [
|
334 |
+
{
|
335 |
+
"name": "stdout",
|
336 |
+
"output_type": "stream",
|
337 |
+
"text": [
|
338 |
+
"tensor([[ 4.1692, -3.3464]])\n"
|
339 |
+
]
|
340 |
+
}
|
341 |
+
],
|
342 |
+
"source": [
|
343 |
+
"# Do the entire forward pass manually for sentence 2\n",
|
344 |
+
"\n",
|
345 |
+
"# Tokenize the sentence and get the tokenids\n",
|
346 |
+
"s0_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentences[0]))\n",
|
347 |
+
"token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentences[1]))\n",
|
348 |
+
"s1_tokens = len(token_ids)\n",
|
349 |
+
"additional_ids = len(s0_ids) - len(token_ids)\n",
|
350 |
+
"\n",
|
351 |
+
"# Add the special token CLS and SEP at the start and end of the token repectively\n",
|
352 |
+
"# Also create an attention mask here to stop the attention from considering additional padding tokens\n",
|
353 |
+
"token_ids = [101] + token_ids + [102] + [0 for _ in range(additional_ids)]\n",
|
354 |
+
"attention_mask = [1 for _ in range(s1_tokens + 2)] + [0 for _ in range(additional_ids)]\n",
|
355 |
+
"\n",
|
356 |
+
"# Perform the forward pass and print the logits\n",
|
357 |
+
"with torch.no_grad():\n",
|
358 |
+
" print(model(input_ids = torch.tensor([token_ids]),\n",
|
359 |
+
" attention_mask = torch.tensor([attention_mask])).logits)"
|
360 |
+
]
|
361 |
+
}
|
362 |
+
],
|
363 |
+
"metadata": {
|
364 |
+
"kernelspec": {
|
365 |
+
"display_name": "Python 3",
|
366 |
+
"language": "python",
|
367 |
+
"name": "python3"
|
368 |
+
},
|
369 |
+
"language_info": {
|
370 |
+
"codemirror_mode": {
|
371 |
+
"name": "ipython",
|
372 |
+
"version": 3
|
373 |
+
},
|
374 |
+
"file_extension": ".py",
|
375 |
+
"mimetype": "text/x-python",
|
376 |
+
"name": "python",
|
377 |
+
"nbconvert_exporter": "python",
|
378 |
+
"pygments_lexer": "ipython3",
|
379 |
+
"version": "3.10.14"
|
380 |
+
}
|
381 |
+
},
|
382 |
+
"nbformat": 4,
|
383 |
+
"nbformat_minor": 2
|
384 |
+
}
|
chapter2/2-tokenizers.ipynb
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
13 |
+
" warnings.warn(\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from transformers import BertTokenizer\n",
|
19 |
+
"from pprint import pprint\n",
|
20 |
+
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": 10,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [
|
28 |
+
{
|
29 |
+
"name": "stdout",
|
30 |
+
"output_type": "stream",
|
31 |
+
"text": [
|
32 |
+
"{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
33 |
+
" 'input_ids': [101,\n",
|
34 |
+
" 1109,\n",
|
35 |
+
" 20164,\n",
|
36 |
+
" 10932,\n",
|
37 |
+
" 2271,\n",
|
38 |
+
" 7954,\n",
|
39 |
+
" 10176,\n",
|
40 |
+
" 1110,\n",
|
41 |
+
" 2385,\n",
|
42 |
+
" 1107,\n",
|
43 |
+
" 7926,\n",
|
44 |
+
" 8588,\n",
|
45 |
+
" 102],\n",
|
46 |
+
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n"
|
47 |
+
]
|
48 |
+
}
|
49 |
+
],
|
50 |
+
"source": [
|
51 |
+
"pprint(tokenizer(\"The HuggingFace Course is quite intuitive\"))"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": 11,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [
|
59 |
+
{
|
60 |
+
"data": {
|
61 |
+
"text/plain": [
|
62 |
+
"('./artifacts/tokenizer_config.json',\n",
|
63 |
+
" './artifacts/special_tokens_map.json',\n",
|
64 |
+
" './artifacts/vocab.txt',\n",
|
65 |
+
" './artifacts/added_tokens.json')"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
"execution_count": 11,
|
69 |
+
"metadata": {},
|
70 |
+
"output_type": "execute_result"
|
71 |
+
}
|
72 |
+
],
|
73 |
+
"source": [
|
74 |
+
"tokenizer.save_pretrained(save_directory=\"./artifacts/\")"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "markdown",
|
79 |
+
"metadata": {},
|
80 |
+
"source": [
|
81 |
+
"# Breaking it down"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": 12,
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"sequence = \"The HuggingFace Course is quite intuitive\""
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": 14,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [
|
98 |
+
{
|
99 |
+
"name": "stdout",
|
100 |
+
"output_type": "stream",
|
101 |
+
"text": [
|
102 |
+
"['The', 'Hu', '##gging', '##F', '##ace', 'Course', 'is', 'quite', 'in', '##tu', '##itive']\n"
|
103 |
+
]
|
104 |
+
}
|
105 |
+
],
|
106 |
+
"source": [
|
107 |
+
"tokens = tokenizer.tokenize(sequence)\n",
|
108 |
+
"print(tokens)"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 15,
|
114 |
+
"metadata": {},
|
115 |
+
"outputs": [
|
116 |
+
{
|
117 |
+
"data": {
|
118 |
+
"text/plain": [
|
119 |
+
"[1109, 20164, 10932, 2271, 7954, 10176, 1110, 2385, 1107, 7926, 8588]"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
"execution_count": 15,
|
123 |
+
"metadata": {},
|
124 |
+
"output_type": "execute_result"
|
125 |
+
}
|
126 |
+
],
|
127 |
+
"source": [
|
128 |
+
"token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
129 |
+
"token_ids"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "markdown",
|
134 |
+
"metadata": {},
|
135 |
+
"source": [
|
136 |
+
"Try tokenization using tokenize method and the __call__ method of the tokenizer object and confirm the outputs"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": 24,
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [
|
144 |
+
{
|
145 |
+
"name": "stdout",
|
146 |
+
"output_type": "stream",
|
147 |
+
"text": [
|
148 |
+
"[101, 146, 787, 1396, 1151, 2613, 1111, 170, 20164, 10932, 2271, 7954, 1736, 1139, 2006, 1297, 119, 102]\n",
|
149 |
+
"[CLS] I ’ ve been waiting for a HuggingFace course my whole life. [SEP]\n",
|
150 |
+
"\n",
|
151 |
+
"[146, 787, 1396, 1151, 2613, 1111, 170, 20164, 10932, 2271, 7954, 1736, 1139, 2006, 1297, 119]\n",
|
152 |
+
"I ’ ve been waiting for a HuggingFace course my whole life.\n",
|
153 |
+
"====================================================================================================\n",
|
154 |
+
"[101, 146, 4819, 1142, 1177, 1277, 106, 102]\n",
|
155 |
+
"[CLS] I hate this so much! [SEP]\n",
|
156 |
+
"\n",
|
157 |
+
"[146, 4819, 1142, 1177, 1277, 106]\n",
|
158 |
+
"I hate this so much!\n",
|
159 |
+
"====================================================================================================\n"
|
160 |
+
]
|
161 |
+
}
|
162 |
+
],
|
163 |
+
"source": [
|
164 |
+
"sentences = [\"I’ve been waiting for a HuggingFace course my whole life.\", \"I hate this so much!\"]\n",
|
165 |
+
"\n",
|
166 |
+
"for sentence in sentences:\n",
|
167 |
+
" # 1: Perform tokenization using the default call method\n",
|
168 |
+
" token_ids = tokenizer(sentence)[\"input_ids\"]\n",
|
169 |
+
" print(token_ids)\n",
|
170 |
+
" print(tokenizer.decode(token_ids))\n",
|
171 |
+
" print()\n",
|
172 |
+
"\n",
|
173 |
+
" # 2: First tokenize and then convert to ids\n",
|
174 |
+
" tokens = tokenizer.tokenize(sentence)\n",
|
175 |
+
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
176 |
+
" print(token_ids)\n",
|
177 |
+
" print(tokenizer.decode(token_ids))\n",
|
178 |
+
"\n",
|
179 |
+
" print(\"=\"*100)"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": 22,
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [
|
187 |
+
{
|
188 |
+
"data": {
|
189 |
+
"text/plain": [
|
190 |
+
"'[CLS] [SEP]'"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
"execution_count": 22,
|
194 |
+
"metadata": {},
|
195 |
+
"output_type": "execute_result"
|
196 |
+
}
|
197 |
+
],
|
198 |
+
"source": [
|
199 |
+
"tokenizer.decode([101, 102])"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "markdown",
|
204 |
+
"metadata": {},
|
205 |
+
"source": [
|
206 |
+
"The difference in the first and last token values is because of the introduction of special tokens which is proposed in the BERT paper otherwise all the tokens are exactly the same."
|
207 |
+
]
|
208 |
+
}
|
209 |
+
],
|
210 |
+
"metadata": {
|
211 |
+
"kernelspec": {
|
212 |
+
"display_name": "Python 3",
|
213 |
+
"language": "python",
|
214 |
+
"name": "python3"
|
215 |
+
},
|
216 |
+
"language_info": {
|
217 |
+
"codemirror_mode": {
|
218 |
+
"name": "ipython",
|
219 |
+
"version": 3
|
220 |
+
},
|
221 |
+
"file_extension": ".py",
|
222 |
+
"mimetype": "text/x-python",
|
223 |
+
"name": "python",
|
224 |
+
"nbconvert_exporter": "python",
|
225 |
+
"pygments_lexer": "ipython3",
|
226 |
+
"version": "3.10.14"
|
227 |
+
}
|
228 |
+
},
|
229 |
+
"nbformat": 4,
|
230 |
+
"nbformat_minor": 2
|
231 |
+
}
|
chapter3/3-a-full-training.ipynb
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Train with Pytorch"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"name": "stderr",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"/home/huggingface/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
20 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
21 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
22 |
+
" warnings.warn(\n",
|
23 |
+
"Map: 100%|██████████| 872/872 [00:00<00:00, 15492.15 examples/s]\n"
|
24 |
+
]
|
25 |
+
}
|
26 |
+
],
|
27 |
+
"source": [
|
28 |
+
"from datasets import load_dataset\n",
|
29 |
+
"from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification\n",
|
30 |
+
"\n",
|
31 |
+
"raw_dataset = load_dataset(\"glue\", \"sst2\")\n",
|
32 |
+
"checkpoint = \"bert-base-uncased\"\n",
|
33 |
+
"\n",
|
34 |
+
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
35 |
+
"\n",
|
36 |
+
"# # For MRPC\n",
|
37 |
+
"# def tokenize_function(sample):\n",
|
38 |
+
"# return tokenizer(sample[\"sentence1\"], sample[\"sentence2\"], truncation = True)\n",
|
39 |
+
"\n",
|
40 |
+
"# For SST2\n",
|
41 |
+
"def tokenize_function(sample):\n",
|
42 |
+
" return tokenizer(sample[\"sentence\"], truncation = True)\n",
|
43 |
+
"\n",
|
44 |
+
"\n",
|
45 |
+
"tokenized_dataset = raw_dataset.map(tokenize_function, batched = True)\n",
|
46 |
+
"data_collator = DataCollatorWithPadding(tokenizer = tokenizer)"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "markdown",
|
51 |
+
"metadata": {},
|
52 |
+
"source": [
|
53 |
+
"# Preprocess the dataset "
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 2,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [
|
61 |
+
{
|
62 |
+
"data": {
|
63 |
+
"text/plain": [
|
64 |
+
"{'train': ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
65 |
+
" 'validation': ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
66 |
+
" 'test': ['labels', 'input_ids', 'token_type_ids', 'attention_mask']}"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
"execution_count": 2,
|
70 |
+
"metadata": {},
|
71 |
+
"output_type": "execute_result"
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"source": [
|
75 |
+
"# Remove unwanted columns which are not to be uitilized during pytorch dataloading\n",
|
76 |
+
"# # For MRPC\n",
|
77 |
+
"# tokenized_dataset = tokenized_dataset.remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n",
|
78 |
+
"\n",
|
79 |
+
"# For SST2\n",
|
80 |
+
"tokenized_dataset = tokenized_dataset.remove_columns([\"sentence\", \"idx\"])\n",
|
81 |
+
"\n",
|
82 |
+
"# Rename the target column appropriately\n",
|
83 |
+
"tokenized_dataset = tokenized_dataset.rename_column(\"label\", \"labels\")\n",
|
84 |
+
"\n",
|
85 |
+
"# Set the format to return tensors instead of lists\n",
|
86 |
+
"tokenized_dataset.set_format(\"torch\")\n",
|
87 |
+
"\n",
|
88 |
+
"tokenized_dataset.column_names"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 3,
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"from torch.utils.data import DataLoader\n",
|
98 |
+
"\n",
|
99 |
+
"train_dataloader = DataLoader(tokenized_dataset[\"train\"], shuffle = True, batch_size = 64, collate_fn = data_collator)\n",
|
100 |
+
"eval_dataloader = DataLoader(tokenized_dataset[\"validation\"], batch_size = 64, collate_fn= data_collator)"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": 4,
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [
|
108 |
+
{
|
109 |
+
"name": "stderr",
|
110 |
+
"output_type": "stream",
|
111 |
+
"text": [
|
112 |
+
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"data": {
|
117 |
+
"text/plain": [
|
118 |
+
"{'labels': torch.Size([64]),\n",
|
119 |
+
" 'input_ids': torch.Size([64, 41]),\n",
|
120 |
+
" 'token_type_ids': torch.Size([64, 41]),\n",
|
121 |
+
" 'attention_mask': torch.Size([64, 41])}"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
"execution_count": 4,
|
125 |
+
"metadata": {},
|
126 |
+
"output_type": "execute_result"
|
127 |
+
}
|
128 |
+
],
|
129 |
+
"source": [
|
130 |
+
"one_batch = next(iter(train_dataloader))\n",
|
131 |
+
"{k: v.shape for k, v in one_batch.items()}"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "markdown",
|
136 |
+
"metadata": {},
|
137 |
+
"source": [
|
138 |
+
"# Define the model and start training"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "code",
|
143 |
+
"execution_count": 5,
|
144 |
+
"metadata": {},
|
145 |
+
"outputs": [
|
146 |
+
{
|
147 |
+
"name": "stderr",
|
148 |
+
"output_type": "stream",
|
149 |
+
"text": [
|
150 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']\n",
|
151 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
152 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
153 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
|
154 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
155 |
+
]
|
156 |
+
}
|
157 |
+
],
|
158 |
+
"source": [
|
159 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels = 2)"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": 6,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [
|
167 |
+
{
|
168 |
+
"name": "stdout",
|
169 |
+
"output_type": "stream",
|
170 |
+
"text": [
|
171 |
+
"SequenceClassifierOutput(loss=tensor(0.7528), logits=tensor([[-0.4735, 0.2345],\n",
|
172 |
+
" [-0.5462, 0.2849],\n",
|
173 |
+
" [-0.8623, 0.6073],\n",
|
174 |
+
" [-0.6334, 0.3747],\n",
|
175 |
+
" [-0.5882, 0.4656],\n",
|
176 |
+
" [-0.1711, 0.1957],\n",
|
177 |
+
" [-0.4656, 0.2387],\n",
|
178 |
+
" [-0.8434, 0.6939],\n",
|
179 |
+
" [-0.4384, 0.2810],\n",
|
180 |
+
" [-0.5239, 0.2832],\n",
|
181 |
+
" [-0.4431, 0.2877],\n",
|
182 |
+
" [-0.5974, 0.2958],\n",
|
183 |
+
" [-0.7655, 0.6273],\n",
|
184 |
+
" [-0.7656, 0.6703],\n",
|
185 |
+
" [-0.7001, 0.4183],\n",
|
186 |
+
" [-0.3617, 0.2145],\n",
|
187 |
+
" [-0.6250, 0.3684],\n",
|
188 |
+
" [-0.5722, 0.4677],\n",
|
189 |
+
" [-0.1536, 0.1978],\n",
|
190 |
+
" [-0.5606, 0.3755],\n",
|
191 |
+
" [-0.6292, 0.3662],\n",
|
192 |
+
" [-0.7420, 0.3527],\n",
|
193 |
+
" [-0.4581, 0.2733],\n",
|
194 |
+
" [-0.6560, 0.4098],\n",
|
195 |
+
" [-0.2436, 0.1589],\n",
|
196 |
+
" [-0.5316, 0.2916],\n",
|
197 |
+
" [-0.6136, 0.3340],\n",
|
198 |
+
" [-0.6650, 0.3447],\n",
|
199 |
+
" [-0.6319, 0.4982],\n",
|
200 |
+
" [-0.7093, 0.4292],\n",
|
201 |
+
" [-0.3495, 0.2136],\n",
|
202 |
+
" [-0.5344, 0.2056],\n",
|
203 |
+
" [-0.2243, 0.2376],\n",
|
204 |
+
" [-0.2150, 0.2638],\n",
|
205 |
+
" [-0.6236, 0.4449],\n",
|
206 |
+
" [-0.3363, 0.2330],\n",
|
207 |
+
" [-0.7103, 0.5592],\n",
|
208 |
+
" [-0.6709, 0.4674],\n",
|
209 |
+
" [-0.6250, 0.4823],\n",
|
210 |
+
" [-0.8934, 0.8637],\n",
|
211 |
+
" [-0.7147, 0.4695],\n",
|
212 |
+
" [-0.4029, 0.2238],\n",
|
213 |
+
" [-0.6455, 0.4327],\n",
|
214 |
+
" [-0.2547, 0.2432],\n",
|
215 |
+
" [-0.3518, 0.3581],\n",
|
216 |
+
" [-0.1312, 0.1507],\n",
|
217 |
+
" [-0.5558, 0.4219],\n",
|
218 |
+
" [-0.4881, 0.3416],\n",
|
219 |
+
" [-0.6623, 0.4497],\n",
|
220 |
+
" [-0.5963, 0.4848],\n",
|
221 |
+
" [-0.5053, 0.3500],\n",
|
222 |
+
" [-0.1152, 0.1482],\n",
|
223 |
+
" [-0.6302, 0.3531],\n",
|
224 |
+
" [-0.6268, 0.4978],\n",
|
225 |
+
" [-0.4811, 0.2927],\n",
|
226 |
+
" [ 0.0057, 0.1694],\n",
|
227 |
+
" [-0.6268, 0.3306],\n",
|
228 |
+
" [-0.5859, 0.4029],\n",
|
229 |
+
" [-0.3552, 0.2425],\n",
|
230 |
+
" [-0.5622, 0.4161],\n",
|
231 |
+
" [-0.7670, 0.5203],\n",
|
232 |
+
" [-0.6624, 0.5146],\n",
|
233 |
+
" [-0.6089, 0.4091],\n",
|
234 |
+
" [-0.4992, 0.2702]]), hidden_states=None, attentions=None)\n"
|
235 |
+
]
|
236 |
+
}
|
237 |
+
],
|
238 |
+
"source": [
|
239 |
+
"import torch\n",
|
240 |
+
"model.eval()\n",
|
241 |
+
"with torch.no_grad():\n",
|
242 |
+
" print(model(**one_batch))"
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"cell_type": "code",
|
247 |
+
"execution_count": 7,
|
248 |
+
"metadata": {},
|
249 |
+
"outputs": [
|
250 |
+
{
|
251 |
+
"name": "stderr",
|
252 |
+
"output_type": "stream",
|
253 |
+
"text": [
|
254 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
255 |
+
" warnings.warn(\n"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"name": "stdout",
|
260 |
+
"output_type": "stream",
|
261 |
+
"text": [
|
262 |
+
"2106\n"
|
263 |
+
]
|
264 |
+
}
|
265 |
+
],
|
266 |
+
"source": [
|
267 |
+
"from transformers import AdamW\n",
|
268 |
+
"from transformers import get_scheduler\n",
|
269 |
+
"\n",
|
270 |
+
"# Define the optimizer here\n",
|
271 |
+
"optimizer = AdamW(model.parameters(), lr = 5e-5)\n",
|
272 |
+
"\n",
|
273 |
+
"# Define the learning rate scheduler here\n",
|
274 |
+
"num_epochs = 2\n",
|
275 |
+
"num_training_steps = num_epochs * len(train_dataloader)\n",
|
276 |
+
"lr_scheduler = get_scheduler(\n",
|
277 |
+
" \"linear\",\n",
|
278 |
+
" optimizer=optimizer,\n",
|
279 |
+
" num_warmup_steps=0,\n",
|
280 |
+
" num_training_steps=num_training_steps,\n",
|
281 |
+
")\n",
|
282 |
+
"print(num_training_steps)\n"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "code",
|
287 |
+
"execution_count": 8,
|
288 |
+
"metadata": {},
|
289 |
+
"outputs": [],
|
290 |
+
"source": [
|
291 |
+
"# Use GPU if available\n",
|
292 |
+
"device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
293 |
+
"model.to(device);"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"cell_type": "code",
|
298 |
+
"execution_count": 9,
|
299 |
+
"metadata": {},
|
300 |
+
"outputs": [
|
301 |
+
{
|
302 |
+
"name": "stderr",
|
303 |
+
"output_type": "stream",
|
304 |
+
"text": [
|
305 |
+
" 50%|█████ | 1054/2106 [03:48<13:25, 1.31it/s]"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"name": "stdout",
|
310 |
+
"output_type": "stream",
|
311 |
+
"text": [
|
312 |
+
"Metrics at end of epoch 0:\n",
|
313 |
+
"{'accuracy': 0.9288990825688074}\n"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"name": "stderr",
|
318 |
+
"output_type": "stream",
|
319 |
+
"text": [
|
320 |
+
"100%|█████████▉| 2105/2106 [07:35<00:00, 4.98it/s]"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"name": "stdout",
|
325 |
+
"output_type": "stream",
|
326 |
+
"text": [
|
327 |
+
"Metrics at end of epoch 1:\n",
|
328 |
+
"{'accuracy': 0.926605504587156}\n"
|
329 |
+
]
|
330 |
+
}
|
331 |
+
],
|
332 |
+
"source": [
|
333 |
+
"from tqdm.auto import tqdm\n",
|
334 |
+
"import evaluate\n",
|
335 |
+
"progress_bar = tqdm(range(num_training_steps))\n",
|
336 |
+
"\n",
|
337 |
+
"for epoch_id in range(num_epochs):\n",
|
338 |
+
"\n",
|
339 |
+
" # Train for one epoch\n",
|
340 |
+
" model.train()\n",
|
341 |
+
" for batch in train_dataloader:\n",
|
342 |
+
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
343 |
+
" outputs = model(**batch)\n",
|
344 |
+
" outputs.loss.backward()\n",
|
345 |
+
"\n",
|
346 |
+
" optimizer.step()\n",
|
347 |
+
" lr_scheduler.step()\n",
|
348 |
+
" optimizer.zero_grad()\n",
|
349 |
+
" progress_bar.update(1)\n",
|
350 |
+
"\n",
|
351 |
+
" # Evaluate at the end of epoch\n",
|
352 |
+
" model.eval()\n",
|
353 |
+
" # # For MRPC\n",
|
354 |
+
" # metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
355 |
+
"\n",
|
356 |
+
" # For SST2\n",
|
357 |
+
" metric = evaluate.load(\"glue\", \"sst2\")\n",
|
358 |
+
"\n",
|
359 |
+
" with torch.no_grad():\n",
|
360 |
+
" for batch in eval_dataloader:\n",
|
361 |
+
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
362 |
+
" outputs = model(**batch)\n",
|
363 |
+
" logits = outputs.logits\n",
|
364 |
+
" predictions = logits.argmax(dim = -1)\n",
|
365 |
+
" metric.add_batch(predictions = predictions, references = batch[\"labels\"])\n",
|
366 |
+
" m = metric.compute()\n",
|
367 |
+
"\n",
|
368 |
+
" print(f\"Metrics at end of epoch {epoch_id}:\\n{m}\")\n"
|
369 |
+
]
|
370 |
+
}
|
371 |
+
],
|
372 |
+
"metadata": {
|
373 |
+
"kernelspec": {
|
374 |
+
"display_name": "Python 3",
|
375 |
+
"language": "python",
|
376 |
+
"name": "python3"
|
377 |
+
},
|
378 |
+
"language_info": {
|
379 |
+
"codemirror_mode": {
|
380 |
+
"name": "ipython",
|
381 |
+
"version": 3
|
382 |
+
},
|
383 |
+
"file_extension": ".py",
|
384 |
+
"mimetype": "text/x-python",
|
385 |
+
"name": "python",
|
386 |
+
"nbconvert_exporter": "python",
|
387 |
+
"pygments_lexer": "ipython3",
|
388 |
+
"version": "3.10.14"
|
389 |
+
}
|
390 |
+
},
|
391 |
+
"nbformat": 4,
|
392 |
+
"nbformat_minor": 2
|
393 |
+
}
|
chapter3/3-fine-tuning-a-model-with-the-Trainer-API.ipynb
ADDED
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Fine tuning bert base uncased for paraphrasing identification task"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"name": "stderr",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"/home/huggingface/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
20 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
21 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
22 |
+
" warnings.warn(\n",
|
23 |
+
"Map: 100%|██████████| 408/408 [00:00<00:00, 8416.87 examples/s]\n"
|
24 |
+
]
|
25 |
+
}
|
26 |
+
],
|
27 |
+
"source": [
|
28 |
+
"from datasets import load_dataset\n",
|
29 |
+
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
|
30 |
+
"\n",
|
31 |
+
"raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
|
32 |
+
"checkpoint = \"bert-base-uncased\"\n",
|
33 |
+
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
34 |
+
"\n",
|
35 |
+
"\n",
|
36 |
+
"def tokenize_function(example):\n",
|
37 |
+
" return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
|
38 |
+
"\n",
|
39 |
+
"\n",
|
40 |
+
"tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
|
41 |
+
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 11,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [
|
49 |
+
{
|
50 |
+
"data": {
|
51 |
+
"text/plain": [
|
52 |
+
"3.0"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
"execution_count": 11,
|
56 |
+
"metadata": {},
|
57 |
+
"output_type": "execute_result"
|
58 |
+
}
|
59 |
+
],
|
60 |
+
"source": [
|
61 |
+
"training_args.num_train_epochs = 1"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 24,
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [
|
69 |
+
{
|
70 |
+
"data": {
|
71 |
+
"text/plain": [
|
72 |
+
"TrainingArguments(\n",
|
73 |
+
"_n_gpu=1,\n",
|
74 |
+
"adafactor=False,\n",
|
75 |
+
"adam_beta1=0.9,\n",
|
76 |
+
"adam_beta2=0.999,\n",
|
77 |
+
"adam_epsilon=1e-08,\n",
|
78 |
+
"auto_find_batch_size=False,\n",
|
79 |
+
"bf16=False,\n",
|
80 |
+
"bf16_full_eval=False,\n",
|
81 |
+
"data_seed=None,\n",
|
82 |
+
"dataloader_drop_last=False,\n",
|
83 |
+
"dataloader_num_workers=0,\n",
|
84 |
+
"dataloader_pin_memory=True,\n",
|
85 |
+
"ddp_bucket_cap_mb=None,\n",
|
86 |
+
"ddp_find_unused_parameters=None,\n",
|
87 |
+
"ddp_timeout=1800,\n",
|
88 |
+
"debug=[],\n",
|
89 |
+
"deepspeed=None,\n",
|
90 |
+
"disable_tqdm=False,\n",
|
91 |
+
"do_eval=False,\n",
|
92 |
+
"do_predict=False,\n",
|
93 |
+
"do_train=False,\n",
|
94 |
+
"eval_accumulation_steps=None,\n",
|
95 |
+
"eval_delay=0,\n",
|
96 |
+
"eval_steps=None,\n",
|
97 |
+
"evaluation_strategy=no,\n",
|
98 |
+
"fp16=False,\n",
|
99 |
+
"fp16_backend=auto,\n",
|
100 |
+
"fp16_full_eval=False,\n",
|
101 |
+
"fp16_opt_level=O1,\n",
|
102 |
+
"fsdp=[],\n",
|
103 |
+
"fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},\n",
|
104 |
+
"fsdp_min_num_params=0,\n",
|
105 |
+
"fsdp_transformer_layer_cls_to_wrap=None,\n",
|
106 |
+
"full_determinism=False,\n",
|
107 |
+
"gradient_accumulation_steps=1,\n",
|
108 |
+
"gradient_checkpointing=False,\n",
|
109 |
+
"greater_is_better=None,\n",
|
110 |
+
"group_by_length=False,\n",
|
111 |
+
"half_precision_backend=auto,\n",
|
112 |
+
"hub_model_id=None,\n",
|
113 |
+
"hub_private_repo=False,\n",
|
114 |
+
"hub_strategy=every_save,\n",
|
115 |
+
"hub_token=<HUB_TOKEN>,\n",
|
116 |
+
"ignore_data_skip=False,\n",
|
117 |
+
"include_inputs_for_metrics=False,\n",
|
118 |
+
"jit_mode_eval=False,\n",
|
119 |
+
"label_names=None,\n",
|
120 |
+
"label_smoothing_factor=0.0,\n",
|
121 |
+
"learning_rate=5e-05,\n",
|
122 |
+
"length_column_name=length,\n",
|
123 |
+
"load_best_model_at_end=False,\n",
|
124 |
+
"local_rank=-1,\n",
|
125 |
+
"log_level=passive,\n",
|
126 |
+
"log_level_replica=warning,\n",
|
127 |
+
"log_on_each_node=True,\n",
|
128 |
+
"logging_dir=test-trainer/runs/Jul21_12-03-08_602d65b93b25,\n",
|
129 |
+
"logging_first_step=False,\n",
|
130 |
+
"logging_nan_inf_filter=True,\n",
|
131 |
+
"logging_steps=500,\n",
|
132 |
+
"logging_strategy=steps,\n",
|
133 |
+
"lr_scheduler_type=linear,\n",
|
134 |
+
"max_grad_norm=1.0,\n",
|
135 |
+
"max_steps=-1,\n",
|
136 |
+
"metric_for_best_model=None,\n",
|
137 |
+
"mp_parameters=,\n",
|
138 |
+
"no_cuda=False,\n",
|
139 |
+
"num_train_epochs=2,\n",
|
140 |
+
"optim=adamw_hf,\n",
|
141 |
+
"optim_args=None,\n",
|
142 |
+
"output_dir=test-trainer,\n",
|
143 |
+
"overwrite_output_dir=False,\n",
|
144 |
+
"past_index=-1,\n",
|
145 |
+
"per_device_eval_batch_size=8,\n",
|
146 |
+
"per_device_train_batch_size=8,\n",
|
147 |
+
"prediction_loss_only=False,\n",
|
148 |
+
"push_to_hub=False,\n",
|
149 |
+
"push_to_hub_model_id=None,\n",
|
150 |
+
"push_to_hub_organization=None,\n",
|
151 |
+
"push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
|
152 |
+
"ray_scope=last,\n",
|
153 |
+
"remove_unused_columns=True,\n",
|
154 |
+
"report_to=[],\n",
|
155 |
+
"resume_from_checkpoint=None,\n",
|
156 |
+
"run_name=test-trainer,\n",
|
157 |
+
"save_on_each_node=False,\n",
|
158 |
+
"save_steps=500,\n",
|
159 |
+
"save_strategy=steps,\n",
|
160 |
+
"save_total_limit=None,\n",
|
161 |
+
"seed=42,\n",
|
162 |
+
"sharded_ddp=[],\n",
|
163 |
+
"skip_memory_metrics=True,\n",
|
164 |
+
"tf32=None,\n",
|
165 |
+
"torch_compile=False,\n",
|
166 |
+
"torch_compile_backend=None,\n",
|
167 |
+
"torch_compile_mode=None,\n",
|
168 |
+
"torchdynamo=None,\n",
|
169 |
+
"tpu_metrics_debug=False,\n",
|
170 |
+
"tpu_num_cores=None,\n",
|
171 |
+
"use_ipex=False,\n",
|
172 |
+
"use_legacy_prediction_loop=False,\n",
|
173 |
+
"use_mps_device=False,\n",
|
174 |
+
"warmup_ratio=0.0,\n",
|
175 |
+
"warmup_steps=0,\n",
|
176 |
+
"weight_decay=0.0,\n",
|
177 |
+
"xpu_backend=None,\n",
|
178 |
+
")"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
"execution_count": 24,
|
182 |
+
"metadata": {},
|
183 |
+
"output_type": "execute_result"
|
184 |
+
}
|
185 |
+
],
|
186 |
+
"source": [
|
187 |
+
"from transformers import TrainingArguments\n",
|
188 |
+
"training_args = TrainingArguments(\"test-trainer\")\n",
|
189 |
+
"training_args.num_train_epochs = 2\n",
|
190 |
+
"training_args"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": 13,
|
196 |
+
"metadata": {},
|
197 |
+
"outputs": [
|
198 |
+
{
|
199 |
+
"name": "stderr",
|
200 |
+
"output_type": "stream",
|
201 |
+
"text": [
|
202 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
203 |
+
" warnings.warn(\n",
|
204 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
|
205 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
206 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
207 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
208 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
209 |
+
]
|
210 |
+
}
|
211 |
+
],
|
212 |
+
"source": [
|
213 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
214 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": 14,
|
220 |
+
"metadata": {},
|
221 |
+
"outputs": [
|
222 |
+
{
|
223 |
+
"data": {
|
224 |
+
"text/plain": [
|
225 |
+
"Linear(in_features=768, out_features=2, bias=True)"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
"execution_count": 14,
|
229 |
+
"metadata": {},
|
230 |
+
"output_type": "execute_result"
|
231 |
+
}
|
232 |
+
],
|
233 |
+
"source": [
|
234 |
+
"model.classifier"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"execution_count": 25,
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [],
|
242 |
+
"source": [
|
243 |
+
"from transformers import Trainer\n",
|
244 |
+
"trainer = Trainer(\n",
|
245 |
+
" model,\n",
|
246 |
+
" training_args,\n",
|
247 |
+
" train_dataset=tokenized_datasets[\"train\"],\n",
|
248 |
+
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
249 |
+
" data_collator=data_collator,\n",
|
250 |
+
" tokenizer=tokenizer,\n",
|
251 |
+
")"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "code",
|
256 |
+
"execution_count": 26,
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [
|
259 |
+
{
|
260 |
+
"name": "stderr",
|
261 |
+
"output_type": "stream",
|
262 |
+
"text": [
|
263 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
264 |
+
" warnings.warn(\n"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"data": {
|
269 |
+
"text/html": [
|
270 |
+
"\n",
|
271 |
+
" <div>\n",
|
272 |
+
" \n",
|
273 |
+
" <progress value='918' max='918' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
274 |
+
" [918/918 01:11, Epoch 2/2]\n",
|
275 |
+
" </div>\n",
|
276 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
277 |
+
" <thead>\n",
|
278 |
+
" <tr style=\"text-align: left;\">\n",
|
279 |
+
" <th>Step</th>\n",
|
280 |
+
" <th>Training Loss</th>\n",
|
281 |
+
" </tr>\n",
|
282 |
+
" </thead>\n",
|
283 |
+
" <tbody>\n",
|
284 |
+
" <tr>\n",
|
285 |
+
" <td>500</td>\n",
|
286 |
+
" <td>0.323900</td>\n",
|
287 |
+
" </tr>\n",
|
288 |
+
" </tbody>\n",
|
289 |
+
"</table><p>"
|
290 |
+
],
|
291 |
+
"text/plain": [
|
292 |
+
"<IPython.core.display.HTML object>"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
"metadata": {},
|
296 |
+
"output_type": "display_data"
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"data": {
|
300 |
+
"text/plain": [
|
301 |
+
"TrainOutput(global_step=918, training_loss=0.26239446669102756, metrics={'train_runtime': 72.0933, 'train_samples_per_second': 101.757, 'train_steps_per_second': 12.733, 'total_flos': 270693998197680.0, 'train_loss': 0.26239446669102756, 'epoch': 2.0})"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
"execution_count": 26,
|
305 |
+
"metadata": {},
|
306 |
+
"output_type": "execute_result"
|
307 |
+
}
|
308 |
+
],
|
309 |
+
"source": [
|
310 |
+
"# Plain training\n",
|
311 |
+
"trainer.train()"
|
312 |
+
]
|
313 |
+
},
|
314 |
+
{
|
315 |
+
"cell_type": "code",
|
316 |
+
"execution_count": 27,
|
317 |
+
"metadata": {},
|
318 |
+
"outputs": [
|
319 |
+
{
|
320 |
+
"data": {
|
321 |
+
"text/html": [
|
322 |
+
"\n",
|
323 |
+
" <div>\n",
|
324 |
+
" \n",
|
325 |
+
" <progress value='6' max='51' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
326 |
+
" [ 6/51 00:00 < 00:00, 50.51 it/s]\n",
|
327 |
+
" </div>\n",
|
328 |
+
" "
|
329 |
+
],
|
330 |
+
"text/plain": [
|
331 |
+
"<IPython.core.display.HTML object>"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
"metadata": {},
|
335 |
+
"output_type": "display_data"
|
336 |
+
},
|
337 |
+
{
|
338 |
+
"name": "stdout",
|
339 |
+
"output_type": "stream",
|
340 |
+
"text": [
|
341 |
+
"(408, 2) (408,)\n"
|
342 |
+
]
|
343 |
+
}
|
344 |
+
],
|
345 |
+
"source": [
|
346 |
+
"predictions = trainer.predict(tokenized_datasets[\"validation\"])\n",
|
347 |
+
"print(predictions.predictions.shape, predictions.label_ids.shape)"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "code",
|
352 |
+
"execution_count": 28,
|
353 |
+
"metadata": {},
|
354 |
+
"outputs": [],
|
355 |
+
"source": [
|
356 |
+
"import numpy as np\n",
|
357 |
+
"preds = np.argmax(predictions.predictions, axis=-1)"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "code",
|
362 |
+
"execution_count": 29,
|
363 |
+
"metadata": {},
|
364 |
+
"outputs": [
|
365 |
+
{
|
366 |
+
"data": {
|
367 |
+
"text/plain": [
|
368 |
+
"{'accuracy': 0.8553921568627451, 'f1': 0.8963093145869947}"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
"execution_count": 29,
|
372 |
+
"metadata": {},
|
373 |
+
"output_type": "execute_result"
|
374 |
+
}
|
375 |
+
],
|
376 |
+
"source": [
|
377 |
+
"import evaluate\n",
|
378 |
+
"metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
379 |
+
"metric.compute(predictions=preds, references=predictions.label_ids)"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "code",
|
384 |
+
"execution_count": 30,
|
385 |
+
"metadata": {},
|
386 |
+
"outputs": [],
|
387 |
+
"source": [
|
388 |
+
"def compute_metrics(eval_preds):\n",
|
389 |
+
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
390 |
+
" logits, labels = eval_preds\n",
|
391 |
+
" predictions = np.argmax(logits, axis=-1)\n",
|
392 |
+
" return metric.compute(predictions=predictions, references=labels)"
|
393 |
+
]
|
394 |
+
},
|
395 |
+
{
|
396 |
+
"cell_type": "code",
|
397 |
+
"execution_count": 32,
|
398 |
+
"metadata": {},
|
399 |
+
"outputs": [
|
400 |
+
{
|
401 |
+
"name": "stderr",
|
402 |
+
"output_type": "stream",
|
403 |
+
"text": [
|
404 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
|
405 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
406 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
407 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
408 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
409 |
+
]
|
410 |
+
}
|
411 |
+
],
|
412 |
+
"source": [
|
413 |
+
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\", num_train_epochs = 2)\n",
|
414 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
415 |
+
"\n",
|
416 |
+
"trainer = Trainer(\n",
|
417 |
+
" model,\n",
|
418 |
+
" training_args,\n",
|
419 |
+
" train_dataset=tokenized_datasets[\"train\"],\n",
|
420 |
+
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
421 |
+
" data_collator=data_collator,\n",
|
422 |
+
" tokenizer=tokenizer,\n",
|
423 |
+
" compute_metrics=compute_metrics,\n",
|
424 |
+
")"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"cell_type": "code",
|
429 |
+
"execution_count": 33,
|
430 |
+
"metadata": {},
|
431 |
+
"outputs": [
|
432 |
+
{
|
433 |
+
"name": "stderr",
|
434 |
+
"output_type": "stream",
|
435 |
+
"text": [
|
436 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
437 |
+
" warnings.warn(\n"
|
438 |
+
]
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"data": {
|
442 |
+
"text/html": [
|
443 |
+
"\n",
|
444 |
+
" <div>\n",
|
445 |
+
" \n",
|
446 |
+
" <progress value='918' max='918' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
447 |
+
" [918/918 01:21, Epoch 2/2]\n",
|
448 |
+
" </div>\n",
|
449 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
450 |
+
" <thead>\n",
|
451 |
+
" <tr style=\"text-align: left;\">\n",
|
452 |
+
" <th>Epoch</th>\n",
|
453 |
+
" <th>Training Loss</th>\n",
|
454 |
+
" <th>Validation Loss</th>\n",
|
455 |
+
" <th>Accuracy</th>\n",
|
456 |
+
" <th>F1</th>\n",
|
457 |
+
" </tr>\n",
|
458 |
+
" </thead>\n",
|
459 |
+
" <tbody>\n",
|
460 |
+
" <tr>\n",
|
461 |
+
" <td>1</td>\n",
|
462 |
+
" <td>No log</td>\n",
|
463 |
+
" <td>0.418620</td>\n",
|
464 |
+
" <td>0.830882</td>\n",
|
465 |
+
" <td>0.883249</td>\n",
|
466 |
+
" </tr>\n",
|
467 |
+
" <tr>\n",
|
468 |
+
" <td>2</td>\n",
|
469 |
+
" <td>0.498100</td>\n",
|
470 |
+
" <td>0.485925</td>\n",
|
471 |
+
" <td>0.860294</td>\n",
|
472 |
+
" <td>0.903226</td>\n",
|
473 |
+
" </tr>\n",
|
474 |
+
" </tbody>\n",
|
475 |
+
"</table><p>"
|
476 |
+
],
|
477 |
+
"text/plain": [
|
478 |
+
"<IPython.core.display.HTML object>"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
"metadata": {},
|
482 |
+
"output_type": "display_data"
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"data": {
|
486 |
+
"text/plain": [
|
487 |
+
"TrainOutput(global_step=918, training_loss=0.39945665579735584, metrics={'train_runtime': 82.0502, 'train_samples_per_second': 89.409, 'train_steps_per_second': 11.188, 'total_flos': 270693998197680.0, 'train_loss': 0.39945665579735584, 'epoch': 2.0})"
|
488 |
+
]
|
489 |
+
},
|
490 |
+
"execution_count": 33,
|
491 |
+
"metadata": {},
|
492 |
+
"output_type": "execute_result"
|
493 |
+
}
|
494 |
+
],
|
495 |
+
"source": [
|
496 |
+
"trainer.train()"
|
497 |
+
]
|
498 |
+
},
|
499 |
+
{
|
500 |
+
"cell_type": "markdown",
|
501 |
+
"metadata": {},
|
502 |
+
"source": [
|
503 |
+
"# Finetuning on GLUE-SST-2"
|
504 |
+
]
|
505 |
+
},
|
506 |
+
{
|
507 |
+
"cell_type": "code",
|
508 |
+
"execution_count": 37,
|
509 |
+
"metadata": {},
|
510 |
+
"outputs": [
|
511 |
+
{
|
512 |
+
"name": "stderr",
|
513 |
+
"output_type": "stream",
|
514 |
+
"text": [
|
515 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
516 |
+
" warnings.warn(\n",
|
517 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
|
518 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
519 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
520 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
521 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
522 |
+
"Map: 100%|██████████| 1821/1821 [00:00<00:00, 9257.22 examples/s]\n"
|
523 |
+
]
|
524 |
+
}
|
525 |
+
],
|
526 |
+
"source": [
|
527 |
+
"from datasets import load_dataset\n",
|
528 |
+
"raw_dataset = load_dataset(\"glue\", \"sst2\")\n",
|
529 |
+
"\n",
|
530 |
+
"from transformers import AutoTokenizer\n",
|
531 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
532 |
+
"\n",
|
533 |
+
"checkpoint = \"bert-base-uncased\"\n",
|
534 |
+
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
535 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
|
536 |
+
"\n",
|
537 |
+
"def tokenize_function(sequence):\n",
|
538 |
+
" return tokenizer(sequence[\"sentence\"], padding = True, truncation = True, return_tensors=\"pt\")\n",
|
539 |
+
"\n",
|
540 |
+
"tokenized_dataset = raw_dataset.map(tokenize_function, batched = True)\n",
|
541 |
+
"\n",
|
542 |
+
"from transformers import DataCollatorWithPadding\n",
|
543 |
+
"dc = DataCollatorWithPadding(tokenizer = tokenizer, padding = True)"
|
544 |
+
]
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"cell_type": "code",
|
548 |
+
"execution_count": 38,
|
549 |
+
"metadata": {},
|
550 |
+
"outputs": [],
|
551 |
+
"source": [
|
552 |
+
"def compute_metrics(eval_preds):\n",
|
553 |
+
" metric = evaluate.load(\"glue\", \"sst2\")\n",
|
554 |
+
" logits, labels = eval_preds\n",
|
555 |
+
" predictions = np.argmax(logits, axis=-1)\n",
|
556 |
+
" return metric.compute(predictions=predictions, references=labels)"
|
557 |
+
]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
"cell_type": "code",
|
561 |
+
"execution_count": 52,
|
562 |
+
"metadata": {},
|
563 |
+
"outputs": [
|
564 |
+
{
|
565 |
+
"name": "stderr",
|
566 |
+
"output_type": "stream",
|
567 |
+
"text": [
|
568 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
569 |
+
" warnings.warn(\n",
|
570 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
|
571 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
572 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
573 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
574 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
575 |
+
]
|
576 |
+
}
|
577 |
+
],
|
578 |
+
"source": [
|
579 |
+
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\", num_train_epochs = 2,\n",
|
580 |
+
" per_device_eval_batch_size = 32, per_device_train_batch_size = 64)\n",
|
581 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
582 |
+
"\n",
|
583 |
+
"trainer = Trainer(\n",
|
584 |
+
" model,\n",
|
585 |
+
" training_args,\n",
|
586 |
+
" train_dataset=tokenized_dataset[\"train\"],\n",
|
587 |
+
" eval_dataset=tokenized_dataset[\"validation\"],\n",
|
588 |
+
" data_collator=data_collator,\n",
|
589 |
+
" tokenizer=tokenizer,\n",
|
590 |
+
" compute_metrics=compute_metrics,\n",
|
591 |
+
")"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
{
|
595 |
+
"cell_type": "code",
|
596 |
+
"execution_count": 53,
|
597 |
+
"metadata": {},
|
598 |
+
"outputs": [
|
599 |
+
{
|
600 |
+
"name": "stderr",
|
601 |
+
"output_type": "stream",
|
602 |
+
"text": [
|
603 |
+
"/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
604 |
+
" warnings.warn(\n"
|
605 |
+
]
|
606 |
+
},
|
607 |
+
{
|
608 |
+
"data": {
|
609 |
+
"text/html": [
|
610 |
+
"\n",
|
611 |
+
" <div>\n",
|
612 |
+
" \n",
|
613 |
+
" <progress value='2106' max='2106' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
614 |
+
" [2106/2106 11:38, Epoch 2/2]\n",
|
615 |
+
" </div>\n",
|
616 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
617 |
+
" <thead>\n",
|
618 |
+
" <tr style=\"text-align: left;\">\n",
|
619 |
+
" <th>Epoch</th>\n",
|
620 |
+
" <th>Training Loss</th>\n",
|
621 |
+
" <th>Validation Loss</th>\n",
|
622 |
+
" <th>Accuracy</th>\n",
|
623 |
+
" </tr>\n",
|
624 |
+
" </thead>\n",
|
625 |
+
" <tbody>\n",
|
626 |
+
" <tr>\n",
|
627 |
+
" <td>1</td>\n",
|
628 |
+
" <td>0.166400</td>\n",
|
629 |
+
" <td>0.236536</td>\n",
|
630 |
+
" <td>0.918578</td>\n",
|
631 |
+
" </tr>\n",
|
632 |
+
" <tr>\n",
|
633 |
+
" <td>2</td>\n",
|
634 |
+
" <td>0.089700</td>\n",
|
635 |
+
" <td>0.238962</td>\n",
|
636 |
+
" <td>0.925459</td>\n",
|
637 |
+
" </tr>\n",
|
638 |
+
" </tbody>\n",
|
639 |
+
"</table><p>"
|
640 |
+
],
|
641 |
+
"text/plain": [
|
642 |
+
"<IPython.core.display.HTML object>"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
"metadata": {},
|
646 |
+
"output_type": "display_data"
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"data": {
|
650 |
+
"text/plain": [
|
651 |
+
"TrainOutput(global_step=2106, training_loss=0.1493011535289507, metrics={'train_runtime': 698.8876, 'train_samples_per_second': 192.732, 'train_steps_per_second': 3.013, 'total_flos': 4556217062352120.0, 'train_loss': 0.1493011535289507, 'epoch': 2.0})"
|
652 |
+
]
|
653 |
+
},
|
654 |
+
"execution_count": 53,
|
655 |
+
"metadata": {},
|
656 |
+
"output_type": "execute_result"
|
657 |
+
}
|
658 |
+
],
|
659 |
+
"source": [
|
660 |
+
"trainer.train()"
|
661 |
+
]
|
662 |
+
}
|
663 |
+
],
|
664 |
+
"metadata": {
|
665 |
+
"kernelspec": {
|
666 |
+
"display_name": "Python 3",
|
667 |
+
"language": "python",
|
668 |
+
"name": "python3"
|
669 |
+
},
|
670 |
+
"language_info": {
|
671 |
+
"codemirror_mode": {
|
672 |
+
"name": "ipython",
|
673 |
+
"version": 3
|
674 |
+
},
|
675 |
+
"file_extension": ".py",
|
676 |
+
"mimetype": "text/x-python",
|
677 |
+
"name": "python",
|
678 |
+
"nbconvert_exporter": "python",
|
679 |
+
"pygments_lexer": "ipython3",
|
680 |
+
"version": "3.10.14"
|
681 |
+
}
|
682 |
+
},
|
683 |
+
"nbformat": 4,
|
684 |
+
"nbformat_minor": 2
|
685 |
+
}
|
chapter3/3-processing-the-data.ipynb
ADDED
@@ -0,0 +1,1719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Exploring the GLUE - MRPC dataset"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 12,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"from datasets import load_dataset\n",
|
17 |
+
"from pprint import pprint\n",
|
18 |
+
"\n",
|
19 |
+
"raw_dataset = load_dataset(path = \"glue\", name = \"mrpc\")"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 7,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"_home_.cache_huggingface_datasets_glue_mrpc_0.0.0_bcdcba79d07bc864c1c254ccfcedcce55bcc9a8c.lock\n",
|
32 |
+
"downloads\n",
|
33 |
+
"glue\n"
|
34 |
+
]
|
35 |
+
}
|
36 |
+
],
|
37 |
+
"source": [
|
38 |
+
"!ls ~/.cache/huggingface/datasets/"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": 8,
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [
|
46 |
+
{
|
47 |
+
"data": {
|
48 |
+
"text/plain": [
|
49 |
+
"DatasetDict({\n",
|
50 |
+
" train: Dataset({\n",
|
51 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
|
52 |
+
" num_rows: 3668\n",
|
53 |
+
" })\n",
|
54 |
+
" validation: Dataset({\n",
|
55 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
|
56 |
+
" num_rows: 408\n",
|
57 |
+
" })\n",
|
58 |
+
" test: Dataset({\n",
|
59 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
|
60 |
+
" num_rows: 1725\n",
|
61 |
+
" })\n",
|
62 |
+
"})"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
"execution_count": 8,
|
66 |
+
"metadata": {},
|
67 |
+
"output_type": "execute_result"
|
68 |
+
}
|
69 |
+
],
|
70 |
+
"source": [
|
71 |
+
"raw_dataset"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 9,
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [
|
79 |
+
{
|
80 |
+
"data": {
|
81 |
+
"text/plain": [
|
82 |
+
"{'sentence1': 'Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .',\n",
|
83 |
+
" 'sentence2': 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .',\n",
|
84 |
+
" 'label': 1,\n",
|
85 |
+
" 'idx': 0}"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
"execution_count": 9,
|
89 |
+
"metadata": {},
|
90 |
+
"output_type": "execute_result"
|
91 |
+
}
|
92 |
+
],
|
93 |
+
"source": [
|
94 |
+
"raw_dataset[\"train\"][0]"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": 10,
|
100 |
+
"metadata": {},
|
101 |
+
"outputs": [
|
102 |
+
{
|
103 |
+
"data": {
|
104 |
+
"text/plain": [
|
105 |
+
"{'sentence1': Value(dtype='string', id=None),\n",
|
106 |
+
" 'sentence2': Value(dtype='string', id=None),\n",
|
107 |
+
" 'label': ClassLabel(names=['not_equivalent', 'equivalent'], id=None),\n",
|
108 |
+
" 'idx': Value(dtype='int32', id=None)}"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
"execution_count": 10,
|
112 |
+
"metadata": {},
|
113 |
+
"output_type": "execute_result"
|
114 |
+
}
|
115 |
+
],
|
116 |
+
"source": [
|
117 |
+
"raw_dataset[\"train\"].features"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": 15,
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [
|
125 |
+
{
|
126 |
+
"name": "stdout",
|
127 |
+
"output_type": "stream",
|
128 |
+
"text": [
|
129 |
+
"{'idx': 16,\n",
|
130 |
+
" 'label': 0,\n",
|
131 |
+
" 'sentence1': 'Rudder was most recently senior vice president for the '\n",
|
132 |
+
" 'Developer & Platform Evangelism Business .',\n",
|
133 |
+
" 'sentence2': 'Senior Vice President Eric Rudder , formerly head of the '\n",
|
134 |
+
" 'Developer and Platform Evangelism unit , will lead the new '\n",
|
135 |
+
" 'entity .'}\n",
|
136 |
+
"{'idx': 812,\n",
|
137 |
+
" 'label': 0,\n",
|
138 |
+
" 'sentence1': 'However , EPA officials would not confirm the 20 percent figure '\n",
|
139 |
+
" '.',\n",
|
140 |
+
" 'sentence2': 'Only in the past few weeks have officials settled on the 20 '\n",
|
141 |
+
" 'percent figure .'}\n"
|
142 |
+
]
|
143 |
+
}
|
144 |
+
],
|
145 |
+
"source": [
|
146 |
+
"# Look at the 15th and 87th element of the train and validation datasets respectively\n",
|
147 |
+
"pprint(raw_dataset[\"train\"][15])\n",
|
148 |
+
"pprint(raw_dataset[\"validation\"][87])"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "markdown",
|
153 |
+
"metadata": {},
|
154 |
+
"source": [
|
155 |
+
"# Tokenizer for pair processing"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": 16,
|
161 |
+
"metadata": {},
|
162 |
+
"outputs": [
|
163 |
+
{
|
164 |
+
"name": "stderr",
|
165 |
+
"output_type": "stream",
|
166 |
+
"text": [
|
167 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
168 |
+
" warnings.warn(\n"
|
169 |
+
]
|
170 |
+
}
|
171 |
+
],
|
172 |
+
"source": [
|
173 |
+
"from transformers import AutoTokenizer\n",
|
174 |
+
"\n",
|
175 |
+
"checkpoint = \"bert-base-uncased\"\n",
|
176 |
+
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 18,
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [
|
184 |
+
{
|
185 |
+
"name": "stdout",
|
186 |
+
"output_type": "stream",
|
187 |
+
"text": [
|
188 |
+
"{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
|
189 |
+
" 'input_ids': [101,\n",
|
190 |
+
" 2023,\n",
|
191 |
+
" 2003,\n",
|
192 |
+
" 1996,\n",
|
193 |
+
" 2034,\n",
|
194 |
+
" 6251,\n",
|
195 |
+
" 1012,\n",
|
196 |
+
" 102,\n",
|
197 |
+
" 2023,\n",
|
198 |
+
" 2003,\n",
|
199 |
+
" 1996,\n",
|
200 |
+
" 2117,\n",
|
201 |
+
" 2028,\n",
|
202 |
+
" 1012,\n",
|
203 |
+
" 102],\n",
|
204 |
+
" 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]}\n"
|
205 |
+
]
|
206 |
+
}
|
207 |
+
],
|
208 |
+
"source": [
|
209 |
+
"inputs = tokenizer(\"This is the first sentence.\", \"This is the second one.\")\n",
|
210 |
+
"pprint(inputs)"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"cell_type": "code",
|
215 |
+
"execution_count": 19,
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [
|
218 |
+
{
|
219 |
+
"data": {
|
220 |
+
"text/plain": [
|
221 |
+
"'[CLS] this is the first sentence. [SEP] this is the second one. [SEP]'"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
"execution_count": 19,
|
225 |
+
"metadata": {},
|
226 |
+
"output_type": "execute_result"
|
227 |
+
}
|
228 |
+
],
|
229 |
+
"source": [
|
230 |
+
"tokenizer.decode(inputs['input_ids'])"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "markdown",
|
235 |
+
"metadata": {},
|
236 |
+
"source": [
|
237 |
+
"Here we can see that the tokenizer has appended the two sentences together and introduced `[CLS]` and `[SEP]` tokens specially because that's how bert was trained for next sentence prediction task."
|
238 |
+
]
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "code",
|
242 |
+
"execution_count": 25,
|
243 |
+
"metadata": {},
|
244 |
+
"outputs": [
|
245 |
+
{
|
246 |
+
"name": "stdout",
|
247 |
+
"output_type": "stream",
|
248 |
+
"text": [
|
249 |
+
"{'input_ids': [101, 24049, 2001, 2087, 3728, 3026, 3580, 2343, 2005, 1996, 9722, 1004, 4132, 9340, 12439, 2964, 2449, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
|
250 |
+
]
|
251 |
+
}
|
252 |
+
],
|
253 |
+
"source": [
|
254 |
+
"print(tokenizer(raw_dataset[\"train\"][15][\"sentence1\"]))"
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": 24,
|
260 |
+
"metadata": {},
|
261 |
+
"outputs": [
|
262 |
+
{
|
263 |
+
"name": "stdout",
|
264 |
+
"output_type": "stream",
|
265 |
+
"text": [
|
266 |
+
"{'input_ids': [101, 3026, 3580, 2343, 4388, 24049, 1010, 3839, 2132, 1997, 1996, 9722, 1998, 4132, 9340, 12439, 2964, 3131, 1010, 2097, 2599, 1996, 2047, 9178, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
|
267 |
+
]
|
268 |
+
}
|
269 |
+
],
|
270 |
+
"source": [
|
271 |
+
"print(tokenizer(raw_dataset[\"train\"][15][\"sentence2\"]))"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "code",
|
276 |
+
"execution_count": 26,
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [
|
279 |
+
{
|
280 |
+
"name": "stdout",
|
281 |
+
"output_type": "stream",
|
282 |
+
"text": [
|
283 |
+
"{'input_ids': [101, 24049, 2001, 2087, 3728, 3026, 3580, 2343, 2005, 1996, 9722, 1004, 4132, 9340, 12439, 2964, 2449, 1012, 102, 3026, 3580, 2343, 4388, 24049, 1010, 3839, 2132, 1997, 1996, 9722, 1998, 4132, 9340, 12439, 2964, 3131, 1010, 2097, 2599, 1996, 2047, 9178, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
|
284 |
+
]
|
285 |
+
}
|
286 |
+
],
|
287 |
+
"source": [
|
288 |
+
"print(tokenizer(raw_dataset[\"train\"][15][\"sentence1\"], raw_dataset[\"train\"][15][\"sentence2\"]))"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "markdown",
|
293 |
+
"metadata": {},
|
294 |
+
"source": [
|
295 |
+
"Here we need to observe the `token_type_ids` field. It is different if we encode the two sentences at the same time vs if we do them independently. Also the `[CLS]` and `[SEP]` tokens are added differently in the two cases."
|
296 |
+
]
|
297 |
+
},
|
298 |
+
{
|
299 |
+
"cell_type": "markdown",
|
300 |
+
"metadata": {},
|
301 |
+
"source": [
|
302 |
+
"# Dataset Map to create new datasets"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "code",
|
307 |
+
"execution_count": 28,
|
308 |
+
"metadata": {},
|
309 |
+
"outputs": [],
|
310 |
+
"source": [
|
311 |
+
"def tokenize_function(example):\n",
|
312 |
+
" return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"cell_type": "code",
|
317 |
+
"execution_count": 29,
|
318 |
+
"metadata": {},
|
319 |
+
"outputs": [
|
320 |
+
{
|
321 |
+
"name": "stderr",
|
322 |
+
"output_type": "stream",
|
323 |
+
"text": [
|
324 |
+
"Map: 100%|██████████| 3668/3668 [00:00<00:00, 9953.91 examples/s] \n",
|
325 |
+
"Map: 100%|██████████| 408/408 [00:00<00:00, 9044.46 examples/s]\n",
|
326 |
+
"Map: 100%|██████████| 1725/1725 [00:00<00:00, 9891.51 examples/s] \n"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"data": {
|
331 |
+
"text/plain": [
|
332 |
+
"DatasetDict({\n",
|
333 |
+
" train: Dataset({\n",
|
334 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
335 |
+
" num_rows: 3668\n",
|
336 |
+
" })\n",
|
337 |
+
" validation: Dataset({\n",
|
338 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
339 |
+
" num_rows: 408\n",
|
340 |
+
" })\n",
|
341 |
+
" test: Dataset({\n",
|
342 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
343 |
+
" num_rows: 1725\n",
|
344 |
+
" })\n",
|
345 |
+
"})"
|
346 |
+
]
|
347 |
+
},
|
348 |
+
"execution_count": 29,
|
349 |
+
"metadata": {},
|
350 |
+
"output_type": "execute_result"
|
351 |
+
}
|
352 |
+
],
|
353 |
+
"source": [
|
354 |
+
"tokenized_datasets = raw_dataset.map(tokenize_function, batched=True)\n",
|
355 |
+
"tokenized_datasets"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": 30,
|
361 |
+
"metadata": {},
|
362 |
+
"outputs": [
|
363 |
+
{
|
364 |
+
"data": {
|
365 |
+
"text/plain": [
|
366 |
+
"DatasetDict({\n",
|
367 |
+
" train: Dataset({\n",
|
368 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
|
369 |
+
" num_rows: 3668\n",
|
370 |
+
" })\n",
|
371 |
+
" validation: Dataset({\n",
|
372 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
|
373 |
+
" num_rows: 408\n",
|
374 |
+
" })\n",
|
375 |
+
" test: Dataset({\n",
|
376 |
+
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
|
377 |
+
" num_rows: 1725\n",
|
378 |
+
" })\n",
|
379 |
+
"})"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
"execution_count": 30,
|
383 |
+
"metadata": {},
|
384 |
+
"output_type": "execute_result"
|
385 |
+
}
|
386 |
+
],
|
387 |
+
"source": [
|
388 |
+
"raw_dataset"
|
389 |
+
]
|
390 |
+
},
|
391 |
+
{
|
392 |
+
"cell_type": "markdown",
|
393 |
+
"metadata": {},
|
394 |
+
"source": [
|
395 |
+
"Here we see that as out tokenize functions returns new keys of `'input_ids', 'token_type_ids', 'attention_mask'`, those simply get added to the new tokenized_dataset Dataset and rest remains the same."
|
396 |
+
]
|
397 |
+
},
|
398 |
+
{
|
399 |
+
"cell_type": "code",
|
400 |
+
"execution_count": 38,
|
401 |
+
"metadata": {},
|
402 |
+
"outputs": [],
|
403 |
+
"source": [
|
404 |
+
"from transformers import DataCollatorWithPadding\n",
|
405 |
+
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "markdown",
|
410 |
+
"metadata": {},
|
411 |
+
"source": [
|
412 |
+
"This `DataCollatorWithPadding` is meant to do dynamic collation of batches in the dataset based on the max length from among all the sequences in the batch."
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": 42,
|
418 |
+
"metadata": {},
|
419 |
+
"outputs": [
|
420 |
+
{
|
421 |
+
"data": {
|
422 |
+
"text/plain": [
|
423 |
+
"[50, 59, 47]"
|
424 |
+
]
|
425 |
+
},
|
426 |
+
"execution_count": 42,
|
427 |
+
"metadata": {},
|
428 |
+
"output_type": "execute_result"
|
429 |
+
}
|
430 |
+
],
|
431 |
+
"source": [
|
432 |
+
"samples = tokenized_datasets[\"train\"][:3]\n",
|
433 |
+
"samples = {k: v for k, v in samples.items()}\n",
|
434 |
+
"[len(x) for x in samples[\"input_ids\"]]"
|
435 |
+
]
|
436 |
+
},
|
437 |
+
{
|
438 |
+
"cell_type": "code",
|
439 |
+
"execution_count": 43,
|
440 |
+
"metadata": {},
|
441 |
+
"outputs": [
|
442 |
+
{
|
443 |
+
"data": {
|
444 |
+
"text/plain": [
|
445 |
+
"{'sentence1': ['Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .',\n",
|
446 |
+
" \"Yucaipa owned Dominick 's before selling the chain to Safeway in 1998 for $ 2.5 billion .\",\n",
|
447 |
+
" 'They had published an advertisement on the Internet on June 10 , offering the cargo for sale , he added .'],\n",
|
448 |
+
" 'sentence2': ['Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .',\n",
|
449 |
+
" \"Yucaipa bought Dominick 's in 1995 for $ 693 million and sold it to Safeway for $ 1.8 billion in 1998 .\",\n",
|
450 |
+
" \"On June 10 , the ship 's owners had published an advertisement on the Internet , offering the explosives for sale .\"],\n",
|
451 |
+
" 'label': [1, 0, 1],\n",
|
452 |
+
" 'idx': [0, 1, 2],\n",
|
453 |
+
" 'input_ids': [[101,\n",
|
454 |
+
" 2572,\n",
|
455 |
+
" 3217,\n",
|
456 |
+
" 5831,\n",
|
457 |
+
" 5496,\n",
|
458 |
+
" 2010,\n",
|
459 |
+
" 2567,\n",
|
460 |
+
" 1010,\n",
|
461 |
+
" 3183,\n",
|
462 |
+
" 2002,\n",
|
463 |
+
" 2170,\n",
|
464 |
+
" 1000,\n",
|
465 |
+
" 1996,\n",
|
466 |
+
" 7409,\n",
|
467 |
+
" 1000,\n",
|
468 |
+
" 1010,\n",
|
469 |
+
" 1997,\n",
|
470 |
+
" 9969,\n",
|
471 |
+
" 4487,\n",
|
472 |
+
" 23809,\n",
|
473 |
+
" 3436,\n",
|
474 |
+
" 2010,\n",
|
475 |
+
" 3350,\n",
|
476 |
+
" 1012,\n",
|
477 |
+
" 102,\n",
|
478 |
+
" 7727,\n",
|
479 |
+
" 2000,\n",
|
480 |
+
" 2032,\n",
|
481 |
+
" 2004,\n",
|
482 |
+
" 2069,\n",
|
483 |
+
" 1000,\n",
|
484 |
+
" 1996,\n",
|
485 |
+
" 7409,\n",
|
486 |
+
" 1000,\n",
|
487 |
+
" 1010,\n",
|
488 |
+
" 2572,\n",
|
489 |
+
" 3217,\n",
|
490 |
+
" 5831,\n",
|
491 |
+
" 5496,\n",
|
492 |
+
" 2010,\n",
|
493 |
+
" 2567,\n",
|
494 |
+
" 1997,\n",
|
495 |
+
" 9969,\n",
|
496 |
+
" 4487,\n",
|
497 |
+
" 23809,\n",
|
498 |
+
" 3436,\n",
|
499 |
+
" 2010,\n",
|
500 |
+
" 3350,\n",
|
501 |
+
" 1012,\n",
|
502 |
+
" 102],\n",
|
503 |
+
" [101,\n",
|
504 |
+
" 9805,\n",
|
505 |
+
" 3540,\n",
|
506 |
+
" 11514,\n",
|
507 |
+
" 2050,\n",
|
508 |
+
" 3079,\n",
|
509 |
+
" 11282,\n",
|
510 |
+
" 2243,\n",
|
511 |
+
" 1005,\n",
|
512 |
+
" 1055,\n",
|
513 |
+
" 2077,\n",
|
514 |
+
" 4855,\n",
|
515 |
+
" 1996,\n",
|
516 |
+
" 4677,\n",
|
517 |
+
" 2000,\n",
|
518 |
+
" 3647,\n",
|
519 |
+
" 4576,\n",
|
520 |
+
" 1999,\n",
|
521 |
+
" 2687,\n",
|
522 |
+
" 2005,\n",
|
523 |
+
" 1002,\n",
|
524 |
+
" 1016,\n",
|
525 |
+
" 1012,\n",
|
526 |
+
" 1019,\n",
|
527 |
+
" 4551,\n",
|
528 |
+
" 1012,\n",
|
529 |
+
" 102,\n",
|
530 |
+
" 9805,\n",
|
531 |
+
" 3540,\n",
|
532 |
+
" 11514,\n",
|
533 |
+
" 2050,\n",
|
534 |
+
" 4149,\n",
|
535 |
+
" 11282,\n",
|
536 |
+
" 2243,\n",
|
537 |
+
" 1005,\n",
|
538 |
+
" 1055,\n",
|
539 |
+
" 1999,\n",
|
540 |
+
" 2786,\n",
|
541 |
+
" 2005,\n",
|
542 |
+
" 1002,\n",
|
543 |
+
" 6353,\n",
|
544 |
+
" 2509,\n",
|
545 |
+
" 2454,\n",
|
546 |
+
" 1998,\n",
|
547 |
+
" 2853,\n",
|
548 |
+
" 2009,\n",
|
549 |
+
" 2000,\n",
|
550 |
+
" 3647,\n",
|
551 |
+
" 4576,\n",
|
552 |
+
" 2005,\n",
|
553 |
+
" 1002,\n",
|
554 |
+
" 1015,\n",
|
555 |
+
" 1012,\n",
|
556 |
+
" 1022,\n",
|
557 |
+
" 4551,\n",
|
558 |
+
" 1999,\n",
|
559 |
+
" 2687,\n",
|
560 |
+
" 1012,\n",
|
561 |
+
" 102],\n",
|
562 |
+
" [101,\n",
|
563 |
+
" 2027,\n",
|
564 |
+
" 2018,\n",
|
565 |
+
" 2405,\n",
|
566 |
+
" 2019,\n",
|
567 |
+
" 15147,\n",
|
568 |
+
" 2006,\n",
|
569 |
+
" 1996,\n",
|
570 |
+
" 4274,\n",
|
571 |
+
" 2006,\n",
|
572 |
+
" 2238,\n",
|
573 |
+
" 2184,\n",
|
574 |
+
" 1010,\n",
|
575 |
+
" 5378,\n",
|
576 |
+
" 1996,\n",
|
577 |
+
" 6636,\n",
|
578 |
+
" 2005,\n",
|
579 |
+
" 5096,\n",
|
580 |
+
" 1010,\n",
|
581 |
+
" 2002,\n",
|
582 |
+
" 2794,\n",
|
583 |
+
" 1012,\n",
|
584 |
+
" 102,\n",
|
585 |
+
" 2006,\n",
|
586 |
+
" 2238,\n",
|
587 |
+
" 2184,\n",
|
588 |
+
" 1010,\n",
|
589 |
+
" 1996,\n",
|
590 |
+
" 2911,\n",
|
591 |
+
" 1005,\n",
|
592 |
+
" 1055,\n",
|
593 |
+
" 5608,\n",
|
594 |
+
" 2018,\n",
|
595 |
+
" 2405,\n",
|
596 |
+
" 2019,\n",
|
597 |
+
" 15147,\n",
|
598 |
+
" 2006,\n",
|
599 |
+
" 1996,\n",
|
600 |
+
" 4274,\n",
|
601 |
+
" 1010,\n",
|
602 |
+
" 5378,\n",
|
603 |
+
" 1996,\n",
|
604 |
+
" 14792,\n",
|
605 |
+
" 2005,\n",
|
606 |
+
" 5096,\n",
|
607 |
+
" 1012,\n",
|
608 |
+
" 102]],\n",
|
609 |
+
" 'token_type_ids': [[0,\n",
|
610 |
+
" 0,\n",
|
611 |
+
" 0,\n",
|
612 |
+
" 0,\n",
|
613 |
+
" 0,\n",
|
614 |
+
" 0,\n",
|
615 |
+
" 0,\n",
|
616 |
+
" 0,\n",
|
617 |
+
" 0,\n",
|
618 |
+
" 0,\n",
|
619 |
+
" 0,\n",
|
620 |
+
" 0,\n",
|
621 |
+
" 0,\n",
|
622 |
+
" 0,\n",
|
623 |
+
" 0,\n",
|
624 |
+
" 0,\n",
|
625 |
+
" 0,\n",
|
626 |
+
" 0,\n",
|
627 |
+
" 0,\n",
|
628 |
+
" 0,\n",
|
629 |
+
" 0,\n",
|
630 |
+
" 0,\n",
|
631 |
+
" 0,\n",
|
632 |
+
" 0,\n",
|
633 |
+
" 0,\n",
|
634 |
+
" 1,\n",
|
635 |
+
" 1,\n",
|
636 |
+
" 1,\n",
|
637 |
+
" 1,\n",
|
638 |
+
" 1,\n",
|
639 |
+
" 1,\n",
|
640 |
+
" 1,\n",
|
641 |
+
" 1,\n",
|
642 |
+
" 1,\n",
|
643 |
+
" 1,\n",
|
644 |
+
" 1,\n",
|
645 |
+
" 1,\n",
|
646 |
+
" 1,\n",
|
647 |
+
" 1,\n",
|
648 |
+
" 1,\n",
|
649 |
+
" 1,\n",
|
650 |
+
" 1,\n",
|
651 |
+
" 1,\n",
|
652 |
+
" 1,\n",
|
653 |
+
" 1,\n",
|
654 |
+
" 1,\n",
|
655 |
+
" 1,\n",
|
656 |
+
" 1,\n",
|
657 |
+
" 1,\n",
|
658 |
+
" 1],\n",
|
659 |
+
" [0,\n",
|
660 |
+
" 0,\n",
|
661 |
+
" 0,\n",
|
662 |
+
" 0,\n",
|
663 |
+
" 0,\n",
|
664 |
+
" 0,\n",
|
665 |
+
" 0,\n",
|
666 |
+
" 0,\n",
|
667 |
+
" 0,\n",
|
668 |
+
" 0,\n",
|
669 |
+
" 0,\n",
|
670 |
+
" 0,\n",
|
671 |
+
" 0,\n",
|
672 |
+
" 0,\n",
|
673 |
+
" 0,\n",
|
674 |
+
" 0,\n",
|
675 |
+
" 0,\n",
|
676 |
+
" 0,\n",
|
677 |
+
" 0,\n",
|
678 |
+
" 0,\n",
|
679 |
+
" 0,\n",
|
680 |
+
" 0,\n",
|
681 |
+
" 0,\n",
|
682 |
+
" 0,\n",
|
683 |
+
" 0,\n",
|
684 |
+
" 0,\n",
|
685 |
+
" 0,\n",
|
686 |
+
" 1,\n",
|
687 |
+
" 1,\n",
|
688 |
+
" 1,\n",
|
689 |
+
" 1,\n",
|
690 |
+
" 1,\n",
|
691 |
+
" 1,\n",
|
692 |
+
" 1,\n",
|
693 |
+
" 1,\n",
|
694 |
+
" 1,\n",
|
695 |
+
" 1,\n",
|
696 |
+
" 1,\n",
|
697 |
+
" 1,\n",
|
698 |
+
" 1,\n",
|
699 |
+
" 1,\n",
|
700 |
+
" 1,\n",
|
701 |
+
" 1,\n",
|
702 |
+
" 1,\n",
|
703 |
+
" 1,\n",
|
704 |
+
" 1,\n",
|
705 |
+
" 1,\n",
|
706 |
+
" 1,\n",
|
707 |
+
" 1,\n",
|
708 |
+
" 1,\n",
|
709 |
+
" 1,\n",
|
710 |
+
" 1,\n",
|
711 |
+
" 1,\n",
|
712 |
+
" 1,\n",
|
713 |
+
" 1,\n",
|
714 |
+
" 1,\n",
|
715 |
+
" 1,\n",
|
716 |
+
" 1,\n",
|
717 |
+
" 1],\n",
|
718 |
+
" [0,\n",
|
719 |
+
" 0,\n",
|
720 |
+
" 0,\n",
|
721 |
+
" 0,\n",
|
722 |
+
" 0,\n",
|
723 |
+
" 0,\n",
|
724 |
+
" 0,\n",
|
725 |
+
" 0,\n",
|
726 |
+
" 0,\n",
|
727 |
+
" 0,\n",
|
728 |
+
" 0,\n",
|
729 |
+
" 0,\n",
|
730 |
+
" 0,\n",
|
731 |
+
" 0,\n",
|
732 |
+
" 0,\n",
|
733 |
+
" 0,\n",
|
734 |
+
" 0,\n",
|
735 |
+
" 0,\n",
|
736 |
+
" 0,\n",
|
737 |
+
" 0,\n",
|
738 |
+
" 0,\n",
|
739 |
+
" 0,\n",
|
740 |
+
" 0,\n",
|
741 |
+
" 1,\n",
|
742 |
+
" 1,\n",
|
743 |
+
" 1,\n",
|
744 |
+
" 1,\n",
|
745 |
+
" 1,\n",
|
746 |
+
" 1,\n",
|
747 |
+
" 1,\n",
|
748 |
+
" 1,\n",
|
749 |
+
" 1,\n",
|
750 |
+
" 1,\n",
|
751 |
+
" 1,\n",
|
752 |
+
" 1,\n",
|
753 |
+
" 1,\n",
|
754 |
+
" 1,\n",
|
755 |
+
" 1,\n",
|
756 |
+
" 1,\n",
|
757 |
+
" 1,\n",
|
758 |
+
" 1,\n",
|
759 |
+
" 1,\n",
|
760 |
+
" 1,\n",
|
761 |
+
" 1,\n",
|
762 |
+
" 1,\n",
|
763 |
+
" 1,\n",
|
764 |
+
" 1]],\n",
|
765 |
+
" 'attention_mask': [[1,\n",
|
766 |
+
" 1,\n",
|
767 |
+
" 1,\n",
|
768 |
+
" 1,\n",
|
769 |
+
" 1,\n",
|
770 |
+
" 1,\n",
|
771 |
+
" 1,\n",
|
772 |
+
" 1,\n",
|
773 |
+
" 1,\n",
|
774 |
+
" 1,\n",
|
775 |
+
" 1,\n",
|
776 |
+
" 1,\n",
|
777 |
+
" 1,\n",
|
778 |
+
" 1,\n",
|
779 |
+
" 1,\n",
|
780 |
+
" 1,\n",
|
781 |
+
" 1,\n",
|
782 |
+
" 1,\n",
|
783 |
+
" 1,\n",
|
784 |
+
" 1,\n",
|
785 |
+
" 1,\n",
|
786 |
+
" 1,\n",
|
787 |
+
" 1,\n",
|
788 |
+
" 1,\n",
|
789 |
+
" 1,\n",
|
790 |
+
" 1,\n",
|
791 |
+
" 1,\n",
|
792 |
+
" 1,\n",
|
793 |
+
" 1,\n",
|
794 |
+
" 1,\n",
|
795 |
+
" 1,\n",
|
796 |
+
" 1,\n",
|
797 |
+
" 1,\n",
|
798 |
+
" 1,\n",
|
799 |
+
" 1,\n",
|
800 |
+
" 1,\n",
|
801 |
+
" 1,\n",
|
802 |
+
" 1,\n",
|
803 |
+
" 1,\n",
|
804 |
+
" 1,\n",
|
805 |
+
" 1,\n",
|
806 |
+
" 1,\n",
|
807 |
+
" 1,\n",
|
808 |
+
" 1,\n",
|
809 |
+
" 1,\n",
|
810 |
+
" 1,\n",
|
811 |
+
" 1,\n",
|
812 |
+
" 1,\n",
|
813 |
+
" 1,\n",
|
814 |
+
" 1],\n",
|
815 |
+
" [1,\n",
|
816 |
+
" 1,\n",
|
817 |
+
" 1,\n",
|
818 |
+
" 1,\n",
|
819 |
+
" 1,\n",
|
820 |
+
" 1,\n",
|
821 |
+
" 1,\n",
|
822 |
+
" 1,\n",
|
823 |
+
" 1,\n",
|
824 |
+
" 1,\n",
|
825 |
+
" 1,\n",
|
826 |
+
" 1,\n",
|
827 |
+
" 1,\n",
|
828 |
+
" 1,\n",
|
829 |
+
" 1,\n",
|
830 |
+
" 1,\n",
|
831 |
+
" 1,\n",
|
832 |
+
" 1,\n",
|
833 |
+
" 1,\n",
|
834 |
+
" 1,\n",
|
835 |
+
" 1,\n",
|
836 |
+
" 1,\n",
|
837 |
+
" 1,\n",
|
838 |
+
" 1,\n",
|
839 |
+
" 1,\n",
|
840 |
+
" 1,\n",
|
841 |
+
" 1,\n",
|
842 |
+
" 1,\n",
|
843 |
+
" 1,\n",
|
844 |
+
" 1,\n",
|
845 |
+
" 1,\n",
|
846 |
+
" 1,\n",
|
847 |
+
" 1,\n",
|
848 |
+
" 1,\n",
|
849 |
+
" 1,\n",
|
850 |
+
" 1,\n",
|
851 |
+
" 1,\n",
|
852 |
+
" 1,\n",
|
853 |
+
" 1,\n",
|
854 |
+
" 1,\n",
|
855 |
+
" 1,\n",
|
856 |
+
" 1,\n",
|
857 |
+
" 1,\n",
|
858 |
+
" 1,\n",
|
859 |
+
" 1,\n",
|
860 |
+
" 1,\n",
|
861 |
+
" 1,\n",
|
862 |
+
" 1,\n",
|
863 |
+
" 1,\n",
|
864 |
+
" 1,\n",
|
865 |
+
" 1,\n",
|
866 |
+
" 1,\n",
|
867 |
+
" 1,\n",
|
868 |
+
" 1,\n",
|
869 |
+
" 1,\n",
|
870 |
+
" 1,\n",
|
871 |
+
" 1,\n",
|
872 |
+
" 1,\n",
|
873 |
+
" 1],\n",
|
874 |
+
" [1,\n",
|
875 |
+
" 1,\n",
|
876 |
+
" 1,\n",
|
877 |
+
" 1,\n",
|
878 |
+
" 1,\n",
|
879 |
+
" 1,\n",
|
880 |
+
" 1,\n",
|
881 |
+
" 1,\n",
|
882 |
+
" 1,\n",
|
883 |
+
" 1,\n",
|
884 |
+
" 1,\n",
|
885 |
+
" 1,\n",
|
886 |
+
" 1,\n",
|
887 |
+
" 1,\n",
|
888 |
+
" 1,\n",
|
889 |
+
" 1,\n",
|
890 |
+
" 1,\n",
|
891 |
+
" 1,\n",
|
892 |
+
" 1,\n",
|
893 |
+
" 1,\n",
|
894 |
+
" 1,\n",
|
895 |
+
" 1,\n",
|
896 |
+
" 1,\n",
|
897 |
+
" 1,\n",
|
898 |
+
" 1,\n",
|
899 |
+
" 1,\n",
|
900 |
+
" 1,\n",
|
901 |
+
" 1,\n",
|
902 |
+
" 1,\n",
|
903 |
+
" 1,\n",
|
904 |
+
" 1,\n",
|
905 |
+
" 1,\n",
|
906 |
+
" 1,\n",
|
907 |
+
" 1,\n",
|
908 |
+
" 1,\n",
|
909 |
+
" 1,\n",
|
910 |
+
" 1,\n",
|
911 |
+
" 1,\n",
|
912 |
+
" 1,\n",
|
913 |
+
" 1,\n",
|
914 |
+
" 1,\n",
|
915 |
+
" 1,\n",
|
916 |
+
" 1,\n",
|
917 |
+
" 1,\n",
|
918 |
+
" 1,\n",
|
919 |
+
" 1,\n",
|
920 |
+
" 1]]}"
|
921 |
+
]
|
922 |
+
},
|
923 |
+
"execution_count": 43,
|
924 |
+
"metadata": {},
|
925 |
+
"output_type": "execute_result"
|
926 |
+
}
|
927 |
+
],
|
928 |
+
"source": [
|
929 |
+
"samples"
|
930 |
+
]
|
931 |
+
},
|
932 |
+
{
|
933 |
+
"cell_type": "code",
|
934 |
+
"execution_count": 40,
|
935 |
+
"metadata": {},
|
936 |
+
"outputs": [],
|
937 |
+
"source": [
|
938 |
+
"samples_to_collate = tokenized_datasets[\"train\"][:3]\n",
|
939 |
+
"samples_to_collate.pop(\"sentence1\"); samples_to_collate.pop(\"sentence2\"); samples_to_collate.pop(\"idx\");"
|
940 |
+
]
|
941 |
+
},
|
942 |
+
{
|
943 |
+
"cell_type": "code",
|
944 |
+
"execution_count": 41,
|
945 |
+
"metadata": {},
|
946 |
+
"outputs": [
|
947 |
+
{
|
948 |
+
"name": "stderr",
|
949 |
+
"output_type": "stream",
|
950 |
+
"text": [
|
951 |
+
"You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
952 |
+
]
|
953 |
+
},
|
954 |
+
{
|
955 |
+
"data": {
|
956 |
+
"text/plain": [
|
957 |
+
"{'input_ids': torch.Size([3, 59]),\n",
|
958 |
+
" 'token_type_ids': torch.Size([3, 59]),\n",
|
959 |
+
" 'attention_mask': torch.Size([3, 59]),\n",
|
960 |
+
" 'labels': torch.Size([3])}"
|
961 |
+
]
|
962 |
+
},
|
963 |
+
"execution_count": 41,
|
964 |
+
"metadata": {},
|
965 |
+
"output_type": "execute_result"
|
966 |
+
}
|
967 |
+
],
|
968 |
+
"source": [
|
969 |
+
"batch = data_collator(samples_to_collate)\n",
|
970 |
+
"{k: v.shape for k, v in batch.items()}"
|
971 |
+
]
|
972 |
+
},
|
973 |
+
{
|
974 |
+
"cell_type": "markdown",
|
975 |
+
"metadata": {},
|
976 |
+
"source": [
|
977 |
+
"# Replication of the above preprocessing on GLUE-SST2 dataset"
|
978 |
+
]
|
979 |
+
},
|
980 |
+
{
|
981 |
+
"cell_type": "code",
|
982 |
+
"execution_count": 45,
|
983 |
+
"metadata": {},
|
984 |
+
"outputs": [
|
985 |
+
{
|
986 |
+
"name": "stderr",
|
987 |
+
"output_type": "stream",
|
988 |
+
"text": [
|
989 |
+
"Downloading data: 100%|██████████| 3.11M/3.11M [00:00<00:00, 4.89MB/s]\n",
|
990 |
+
"Downloading data: 100%|██████████| 72.8k/72.8k [00:00<00:00, 128kB/s]\n",
|
991 |
+
"Downloading data: 100%|██████████| 148k/148k [00:00<00:00, 260kB/s]\n",
|
992 |
+
"Generating train split: 100%|██████████| 67349/67349 [00:00<00:00, 467302.76 examples/s]\n",
|
993 |
+
"Generating validation split: 100%|████���█████| 872/872 [00:00<00:00, 137580.24 examples/s]\n",
|
994 |
+
"Generating test split: 100%|██████████| 1821/1821 [00:00<00:00, 205588.75 examples/s]\n"
|
995 |
+
]
|
996 |
+
}
|
997 |
+
],
|
998 |
+
"source": [
|
999 |
+
"from datasets import load_dataset\n",
|
1000 |
+
"\n",
|
1001 |
+
"raw_dataset = load_dataset(\"glue\", \"sst2\")"
|
1002 |
+
]
|
1003 |
+
},
|
1004 |
+
{
|
1005 |
+
"cell_type": "code",
|
1006 |
+
"execution_count": 50,
|
1007 |
+
"metadata": {},
|
1008 |
+
"outputs": [
|
1009 |
+
{
|
1010 |
+
"data": {
|
1011 |
+
"text/plain": [
|
1012 |
+
"{'sentence': 'hide new secretions from the parental units ',\n",
|
1013 |
+
" 'label': 0,\n",
|
1014 |
+
" 'idx': 0}"
|
1015 |
+
]
|
1016 |
+
},
|
1017 |
+
"execution_count": 50,
|
1018 |
+
"metadata": {},
|
1019 |
+
"output_type": "execute_result"
|
1020 |
+
}
|
1021 |
+
],
|
1022 |
+
"source": [
|
1023 |
+
"raw_dataset[\"train\"][0]"
|
1024 |
+
]
|
1025 |
+
},
|
1026 |
+
{
|
1027 |
+
"cell_type": "code",
|
1028 |
+
"execution_count": 46,
|
1029 |
+
"metadata": {},
|
1030 |
+
"outputs": [
|
1031 |
+
{
|
1032 |
+
"name": "stderr",
|
1033 |
+
"output_type": "stream",
|
1034 |
+
"text": [
|
1035 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
1036 |
+
" warnings.warn(\n",
|
1037 |
+
"/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
1038 |
+
" warnings.warn(\n",
|
1039 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']\n",
|
1040 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
1041 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
1042 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
1043 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
1044 |
+
]
|
1045 |
+
}
|
1046 |
+
],
|
1047 |
+
"source": [
|
1048 |
+
"from transformers import AutoTokenizer\n",
|
1049 |
+
"from transformers import AutoModelForSequenceClassification\n",
|
1050 |
+
"\n",
|
1051 |
+
"checkpoint = \"bert-base-uncased\"\n",
|
1052 |
+
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
1053 |
+
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)"
|
1054 |
+
]
|
1055 |
+
},
|
1056 |
+
{
|
1057 |
+
"cell_type": "code",
|
1058 |
+
"execution_count": 53,
|
1059 |
+
"metadata": {},
|
1060 |
+
"outputs": [],
|
1061 |
+
"source": [
|
1062 |
+
"def tokenize_function(sequence):\n",
|
1063 |
+
" return tokenizer(sequence[\"sentence\"], padding = True, truncation = True, return_tensors=\"pt\")"
|
1064 |
+
]
|
1065 |
+
},
|
1066 |
+
{
|
1067 |
+
"cell_type": "code",
|
1068 |
+
"execution_count": 54,
|
1069 |
+
"metadata": {},
|
1070 |
+
"outputs": [
|
1071 |
+
{
|
1072 |
+
"name": "stderr",
|
1073 |
+
"output_type": "stream",
|
1074 |
+
"text": [
|
1075 |
+
"Map: 0%| | 0/67349 [00:00<?, ? examples/s]"
|
1076 |
+
]
|
1077 |
+
},
|
1078 |
+
{
|
1079 |
+
"name": "stderr",
|
1080 |
+
"output_type": "stream",
|
1081 |
+
"text": [
|
1082 |
+
"Map: 100%|██████████| 67349/67349 [00:06<00:00, 11164.26 examples/s]\n",
|
1083 |
+
"Map: 100%|██████████| 872/872 [00:00<00:00, 10952.43 examples/s]\n",
|
1084 |
+
"Map: 100%|██████████| 1821/1821 [00:00<00:00, 11315.74 examples/s]\n"
|
1085 |
+
]
|
1086 |
+
}
|
1087 |
+
],
|
1088 |
+
"source": [
|
1089 |
+
"tokenized_dataset = raw_dataset.map(tokenize_function, batched = True)"
|
1090 |
+
]
|
1091 |
+
},
|
1092 |
+
{
|
1093 |
+
"cell_type": "code",
|
1094 |
+
"execution_count": 55,
|
1095 |
+
"metadata": {},
|
1096 |
+
"outputs": [
|
1097 |
+
{
|
1098 |
+
"data": {
|
1099 |
+
"text/plain": [
|
1100 |
+
"DatasetDict({\n",
|
1101 |
+
" train: Dataset({\n",
|
1102 |
+
" features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
1103 |
+
" num_rows: 67349\n",
|
1104 |
+
" })\n",
|
1105 |
+
" validation: Dataset({\n",
|
1106 |
+
" features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
1107 |
+
" num_rows: 872\n",
|
1108 |
+
" })\n",
|
1109 |
+
" test: Dataset({\n",
|
1110 |
+
" features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
|
1111 |
+
" num_rows: 1821\n",
|
1112 |
+
" })\n",
|
1113 |
+
"})"
|
1114 |
+
]
|
1115 |
+
},
|
1116 |
+
"execution_count": 55,
|
1117 |
+
"metadata": {},
|
1118 |
+
"output_type": "execute_result"
|
1119 |
+
}
|
1120 |
+
],
|
1121 |
+
"source": [
|
1122 |
+
"tokenized_dataset"
|
1123 |
+
]
|
1124 |
+
},
|
1125 |
+
{
|
1126 |
+
"cell_type": "code",
|
1127 |
+
"execution_count": 56,
|
1128 |
+
"metadata": {},
|
1129 |
+
"outputs": [],
|
1130 |
+
"source": [
|
1131 |
+
"from transformers import DataCollatorWithPadding\n",
|
1132 |
+
"dc = DataCollatorWithPadding(tokenizer = tokenizer, padding = True)"
|
1133 |
+
]
|
1134 |
+
},
|
1135 |
+
{
|
1136 |
+
"cell_type": "code",
|
1137 |
+
"execution_count": 59,
|
1138 |
+
"metadata": {},
|
1139 |
+
"outputs": [],
|
1140 |
+
"source": [
|
1141 |
+
"samples = tokenized_dataset[\"train\"][:3]\n",
|
1142 |
+
"samples = {k: v for k,v in samples.items() if k not in [\"sentence\", \"ids\"]}"
|
1143 |
+
]
|
1144 |
+
},
|
1145 |
+
{
|
1146 |
+
"cell_type": "code",
|
1147 |
+
"execution_count": 61,
|
1148 |
+
"metadata": {},
|
1149 |
+
"outputs": [
|
1150 |
+
{
|
1151 |
+
"data": {
|
1152 |
+
"text/plain": [
|
1153 |
+
"dict_keys(['label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'])"
|
1154 |
+
]
|
1155 |
+
},
|
1156 |
+
"execution_count": 61,
|
1157 |
+
"metadata": {},
|
1158 |
+
"output_type": "execute_result"
|
1159 |
+
}
|
1160 |
+
],
|
1161 |
+
"source": [
|
1162 |
+
"samples.keys()"
|
1163 |
+
]
|
1164 |
+
},
|
1165 |
+
{
|
1166 |
+
"cell_type": "code",
|
1167 |
+
"execution_count": 62,
|
1168 |
+
"metadata": {},
|
1169 |
+
"outputs": [
|
1170 |
+
{
|
1171 |
+
"data": {
|
1172 |
+
"text/plain": [
|
1173 |
+
"{'label': [0, 0, 1],\n",
|
1174 |
+
" 'idx': [0, 1, 2],\n",
|
1175 |
+
" 'input_ids': [[101,\n",
|
1176 |
+
" 5342,\n",
|
1177 |
+
" 2047,\n",
|
1178 |
+
" 3595,\n",
|
1179 |
+
" 8496,\n",
|
1180 |
+
" 2013,\n",
|
1181 |
+
" 1996,\n",
|
1182 |
+
" 18643,\n",
|
1183 |
+
" 3197,\n",
|
1184 |
+
" 102,\n",
|
1185 |
+
" 0,\n",
|
1186 |
+
" 0,\n",
|
1187 |
+
" 0,\n",
|
1188 |
+
" 0,\n",
|
1189 |
+
" 0,\n",
|
1190 |
+
" 0,\n",
|
1191 |
+
" 0,\n",
|
1192 |
+
" 0,\n",
|
1193 |
+
" 0,\n",
|
1194 |
+
" 0,\n",
|
1195 |
+
" 0,\n",
|
1196 |
+
" 0,\n",
|
1197 |
+
" 0,\n",
|
1198 |
+
" 0,\n",
|
1199 |
+
" 0,\n",
|
1200 |
+
" 0,\n",
|
1201 |
+
" 0,\n",
|
1202 |
+
" 0,\n",
|
1203 |
+
" 0,\n",
|
1204 |
+
" 0,\n",
|
1205 |
+
" 0,\n",
|
1206 |
+
" 0,\n",
|
1207 |
+
" 0,\n",
|
1208 |
+
" 0,\n",
|
1209 |
+
" 0,\n",
|
1210 |
+
" 0,\n",
|
1211 |
+
" 0,\n",
|
1212 |
+
" 0,\n",
|
1213 |
+
" 0,\n",
|
1214 |
+
" 0,\n",
|
1215 |
+
" 0,\n",
|
1216 |
+
" 0,\n",
|
1217 |
+
" 0,\n",
|
1218 |
+
" 0,\n",
|
1219 |
+
" 0,\n",
|
1220 |
+
" 0,\n",
|
1221 |
+
" 0,\n",
|
1222 |
+
" 0,\n",
|
1223 |
+
" 0,\n",
|
1224 |
+
" 0,\n",
|
1225 |
+
" 0,\n",
|
1226 |
+
" 0,\n",
|
1227 |
+
" 0,\n",
|
1228 |
+
" 0,\n",
|
1229 |
+
" 0,\n",
|
1230 |
+
" 0],\n",
|
1231 |
+
" [101,\n",
|
1232 |
+
" 3397,\n",
|
1233 |
+
" 2053,\n",
|
1234 |
+
" 15966,\n",
|
1235 |
+
" 1010,\n",
|
1236 |
+
" 2069,\n",
|
1237 |
+
" 4450,\n",
|
1238 |
+
" 2098,\n",
|
1239 |
+
" 18201,\n",
|
1240 |
+
" 2015,\n",
|
1241 |
+
" 102,\n",
|
1242 |
+
" 0,\n",
|
1243 |
+
" 0,\n",
|
1244 |
+
" 0,\n",
|
1245 |
+
" 0,\n",
|
1246 |
+
" 0,\n",
|
1247 |
+
" 0,\n",
|
1248 |
+
" 0,\n",
|
1249 |
+
" 0,\n",
|
1250 |
+
" 0,\n",
|
1251 |
+
" 0,\n",
|
1252 |
+
" 0,\n",
|
1253 |
+
" 0,\n",
|
1254 |
+
" 0,\n",
|
1255 |
+
" 0,\n",
|
1256 |
+
" 0,\n",
|
1257 |
+
" 0,\n",
|
1258 |
+
" 0,\n",
|
1259 |
+
" 0,\n",
|
1260 |
+
" 0,\n",
|
1261 |
+
" 0,\n",
|
1262 |
+
" 0,\n",
|
1263 |
+
" 0,\n",
|
1264 |
+
" 0,\n",
|
1265 |
+
" 0,\n",
|
1266 |
+
" 0,\n",
|
1267 |
+
" 0,\n",
|
1268 |
+
" 0,\n",
|
1269 |
+
" 0,\n",
|
1270 |
+
" 0,\n",
|
1271 |
+
" 0,\n",
|
1272 |
+
" 0,\n",
|
1273 |
+
" 0,\n",
|
1274 |
+
" 0,\n",
|
1275 |
+
" 0,\n",
|
1276 |
+
" 0,\n",
|
1277 |
+
" 0,\n",
|
1278 |
+
" 0,\n",
|
1279 |
+
" 0,\n",
|
1280 |
+
" 0,\n",
|
1281 |
+
" 0,\n",
|
1282 |
+
" 0,\n",
|
1283 |
+
" 0,\n",
|
1284 |
+
" 0,\n",
|
1285 |
+
" 0,\n",
|
1286 |
+
" 0],\n",
|
1287 |
+
" [101,\n",
|
1288 |
+
" 2008,\n",
|
1289 |
+
" 7459,\n",
|
1290 |
+
" 2049,\n",
|
1291 |
+
" 3494,\n",
|
1292 |
+
" 1998,\n",
|
1293 |
+
" 10639,\n",
|
1294 |
+
" 2015,\n",
|
1295 |
+
" 2242,\n",
|
1296 |
+
" 2738,\n",
|
1297 |
+
" 3376,\n",
|
1298 |
+
" 2055,\n",
|
1299 |
+
" 2529,\n",
|
1300 |
+
" 3267,\n",
|
1301 |
+
" 102,\n",
|
1302 |
+
" 0,\n",
|
1303 |
+
" 0,\n",
|
1304 |
+
" 0,\n",
|
1305 |
+
" 0,\n",
|
1306 |
+
" 0,\n",
|
1307 |
+
" 0,\n",
|
1308 |
+
" 0,\n",
|
1309 |
+
" 0,\n",
|
1310 |
+
" 0,\n",
|
1311 |
+
" 0,\n",
|
1312 |
+
" 0,\n",
|
1313 |
+
" 0,\n",
|
1314 |
+
" 0,\n",
|
1315 |
+
" 0,\n",
|
1316 |
+
" 0,\n",
|
1317 |
+
" 0,\n",
|
1318 |
+
" 0,\n",
|
1319 |
+
" 0,\n",
|
1320 |
+
" 0,\n",
|
1321 |
+
" 0,\n",
|
1322 |
+
" 0,\n",
|
1323 |
+
" 0,\n",
|
1324 |
+
" 0,\n",
|
1325 |
+
" 0,\n",
|
1326 |
+
" 0,\n",
|
1327 |
+
" 0,\n",
|
1328 |
+
" 0,\n",
|
1329 |
+
" 0,\n",
|
1330 |
+
" 0,\n",
|
1331 |
+
" 0,\n",
|
1332 |
+
" 0,\n",
|
1333 |
+
" 0,\n",
|
1334 |
+
" 0,\n",
|
1335 |
+
" 0,\n",
|
1336 |
+
" 0,\n",
|
1337 |
+
" 0,\n",
|
1338 |
+
" 0,\n",
|
1339 |
+
" 0,\n",
|
1340 |
+
" 0,\n",
|
1341 |
+
" 0,\n",
|
1342 |
+
" 0]],\n",
|
1343 |
+
" 'token_type_ids': [[0,\n",
|
1344 |
+
" 0,\n",
|
1345 |
+
" 0,\n",
|
1346 |
+
" 0,\n",
|
1347 |
+
" 0,\n",
|
1348 |
+
" 0,\n",
|
1349 |
+
" 0,\n",
|
1350 |
+
" 0,\n",
|
1351 |
+
" 0,\n",
|
1352 |
+
" 0,\n",
|
1353 |
+
" 0,\n",
|
1354 |
+
" 0,\n",
|
1355 |
+
" 0,\n",
|
1356 |
+
" 0,\n",
|
1357 |
+
" 0,\n",
|
1358 |
+
" 0,\n",
|
1359 |
+
" 0,\n",
|
1360 |
+
" 0,\n",
|
1361 |
+
" 0,\n",
|
1362 |
+
" 0,\n",
|
1363 |
+
" 0,\n",
|
1364 |
+
" 0,\n",
|
1365 |
+
" 0,\n",
|
1366 |
+
" 0,\n",
|
1367 |
+
" 0,\n",
|
1368 |
+
" 0,\n",
|
1369 |
+
" 0,\n",
|
1370 |
+
" 0,\n",
|
1371 |
+
" 0,\n",
|
1372 |
+
" 0,\n",
|
1373 |
+
" 0,\n",
|
1374 |
+
" 0,\n",
|
1375 |
+
" 0,\n",
|
1376 |
+
" 0,\n",
|
1377 |
+
" 0,\n",
|
1378 |
+
" 0,\n",
|
1379 |
+
" 0,\n",
|
1380 |
+
" 0,\n",
|
1381 |
+
" 0,\n",
|
1382 |
+
" 0,\n",
|
1383 |
+
" 0,\n",
|
1384 |
+
" 0,\n",
|
1385 |
+
" 0,\n",
|
1386 |
+
" 0,\n",
|
1387 |
+
" 0,\n",
|
1388 |
+
" 0,\n",
|
1389 |
+
" 0,\n",
|
1390 |
+
" 0,\n",
|
1391 |
+
" 0,\n",
|
1392 |
+
" 0,\n",
|
1393 |
+
" 0,\n",
|
1394 |
+
" 0,\n",
|
1395 |
+
" 0,\n",
|
1396 |
+
" 0,\n",
|
1397 |
+
" 0,\n",
|
1398 |
+
" 0],\n",
|
1399 |
+
" [0,\n",
|
1400 |
+
" 0,\n",
|
1401 |
+
" 0,\n",
|
1402 |
+
" 0,\n",
|
1403 |
+
" 0,\n",
|
1404 |
+
" 0,\n",
|
1405 |
+
" 0,\n",
|
1406 |
+
" 0,\n",
|
1407 |
+
" 0,\n",
|
1408 |
+
" 0,\n",
|
1409 |
+
" 0,\n",
|
1410 |
+
" 0,\n",
|
1411 |
+
" 0,\n",
|
1412 |
+
" 0,\n",
|
1413 |
+
" 0,\n",
|
1414 |
+
" 0,\n",
|
1415 |
+
" 0,\n",
|
1416 |
+
" 0,\n",
|
1417 |
+
" 0,\n",
|
1418 |
+
" 0,\n",
|
1419 |
+
" 0,\n",
|
1420 |
+
" 0,\n",
|
1421 |
+
" 0,\n",
|
1422 |
+
" 0,\n",
|
1423 |
+
" 0,\n",
|
1424 |
+
" 0,\n",
|
1425 |
+
" 0,\n",
|
1426 |
+
" 0,\n",
|
1427 |
+
" 0,\n",
|
1428 |
+
" 0,\n",
|
1429 |
+
" 0,\n",
|
1430 |
+
" 0,\n",
|
1431 |
+
" 0,\n",
|
1432 |
+
" 0,\n",
|
1433 |
+
" 0,\n",
|
1434 |
+
" 0,\n",
|
1435 |
+
" 0,\n",
|
1436 |
+
" 0,\n",
|
1437 |
+
" 0,\n",
|
1438 |
+
" 0,\n",
|
1439 |
+
" 0,\n",
|
1440 |
+
" 0,\n",
|
1441 |
+
" 0,\n",
|
1442 |
+
" 0,\n",
|
1443 |
+
" 0,\n",
|
1444 |
+
" 0,\n",
|
1445 |
+
" 0,\n",
|
1446 |
+
" 0,\n",
|
1447 |
+
" 0,\n",
|
1448 |
+
" 0,\n",
|
1449 |
+
" 0,\n",
|
1450 |
+
" 0,\n",
|
1451 |
+
" 0,\n",
|
1452 |
+
" 0,\n",
|
1453 |
+
" 0,\n",
|
1454 |
+
" 0],\n",
|
1455 |
+
" [0,\n",
|
1456 |
+
" 0,\n",
|
1457 |
+
" 0,\n",
|
1458 |
+
" 0,\n",
|
1459 |
+
" 0,\n",
|
1460 |
+
" 0,\n",
|
1461 |
+
" 0,\n",
|
1462 |
+
" 0,\n",
|
1463 |
+
" 0,\n",
|
1464 |
+
" 0,\n",
|
1465 |
+
" 0,\n",
|
1466 |
+
" 0,\n",
|
1467 |
+
" 0,\n",
|
1468 |
+
" 0,\n",
|
1469 |
+
" 0,\n",
|
1470 |
+
" 0,\n",
|
1471 |
+
" 0,\n",
|
1472 |
+
" 0,\n",
|
1473 |
+
" 0,\n",
|
1474 |
+
" 0,\n",
|
1475 |
+
" 0,\n",
|
1476 |
+
" 0,\n",
|
1477 |
+
" 0,\n",
|
1478 |
+
" 0,\n",
|
1479 |
+
" 0,\n",
|
1480 |
+
" 0,\n",
|
1481 |
+
" 0,\n",
|
1482 |
+
" 0,\n",
|
1483 |
+
" 0,\n",
|
1484 |
+
" 0,\n",
|
1485 |
+
" 0,\n",
|
1486 |
+
" 0,\n",
|
1487 |
+
" 0,\n",
|
1488 |
+
" 0,\n",
|
1489 |
+
" 0,\n",
|
1490 |
+
" 0,\n",
|
1491 |
+
" 0,\n",
|
1492 |
+
" 0,\n",
|
1493 |
+
" 0,\n",
|
1494 |
+
" 0,\n",
|
1495 |
+
" 0,\n",
|
1496 |
+
" 0,\n",
|
1497 |
+
" 0,\n",
|
1498 |
+
" 0,\n",
|
1499 |
+
" 0,\n",
|
1500 |
+
" 0,\n",
|
1501 |
+
" 0,\n",
|
1502 |
+
" 0,\n",
|
1503 |
+
" 0,\n",
|
1504 |
+
" 0,\n",
|
1505 |
+
" 0,\n",
|
1506 |
+
" 0,\n",
|
1507 |
+
" 0,\n",
|
1508 |
+
" 0,\n",
|
1509 |
+
" 0,\n",
|
1510 |
+
" 0]],\n",
|
1511 |
+
" 'attention_mask': [[1,\n",
|
1512 |
+
" 1,\n",
|
1513 |
+
" 1,\n",
|
1514 |
+
" 1,\n",
|
1515 |
+
" 1,\n",
|
1516 |
+
" 1,\n",
|
1517 |
+
" 1,\n",
|
1518 |
+
" 1,\n",
|
1519 |
+
" 1,\n",
|
1520 |
+
" 1,\n",
|
1521 |
+
" 0,\n",
|
1522 |
+
" 0,\n",
|
1523 |
+
" 0,\n",
|
1524 |
+
" 0,\n",
|
1525 |
+
" 0,\n",
|
1526 |
+
" 0,\n",
|
1527 |
+
" 0,\n",
|
1528 |
+
" 0,\n",
|
1529 |
+
" 0,\n",
|
1530 |
+
" 0,\n",
|
1531 |
+
" 0,\n",
|
1532 |
+
" 0,\n",
|
1533 |
+
" 0,\n",
|
1534 |
+
" 0,\n",
|
1535 |
+
" 0,\n",
|
1536 |
+
" 0,\n",
|
1537 |
+
" 0,\n",
|
1538 |
+
" 0,\n",
|
1539 |
+
" 0,\n",
|
1540 |
+
" 0,\n",
|
1541 |
+
" 0,\n",
|
1542 |
+
" 0,\n",
|
1543 |
+
" 0,\n",
|
1544 |
+
" 0,\n",
|
1545 |
+
" 0,\n",
|
1546 |
+
" 0,\n",
|
1547 |
+
" 0,\n",
|
1548 |
+
" 0,\n",
|
1549 |
+
" 0,\n",
|
1550 |
+
" 0,\n",
|
1551 |
+
" 0,\n",
|
1552 |
+
" 0,\n",
|
1553 |
+
" 0,\n",
|
1554 |
+
" 0,\n",
|
1555 |
+
" 0,\n",
|
1556 |
+
" 0,\n",
|
1557 |
+
" 0,\n",
|
1558 |
+
" 0,\n",
|
1559 |
+
" 0,\n",
|
1560 |
+
" 0,\n",
|
1561 |
+
" 0,\n",
|
1562 |
+
" 0,\n",
|
1563 |
+
" 0,\n",
|
1564 |
+
" 0,\n",
|
1565 |
+
" 0,\n",
|
1566 |
+
" 0],\n",
|
1567 |
+
" [1,\n",
|
1568 |
+
" 1,\n",
|
1569 |
+
" 1,\n",
|
1570 |
+
" 1,\n",
|
1571 |
+
" 1,\n",
|
1572 |
+
" 1,\n",
|
1573 |
+
" 1,\n",
|
1574 |
+
" 1,\n",
|
1575 |
+
" 1,\n",
|
1576 |
+
" 1,\n",
|
1577 |
+
" 1,\n",
|
1578 |
+
" 0,\n",
|
1579 |
+
" 0,\n",
|
1580 |
+
" 0,\n",
|
1581 |
+
" 0,\n",
|
1582 |
+
" 0,\n",
|
1583 |
+
" 0,\n",
|
1584 |
+
" 0,\n",
|
1585 |
+
" 0,\n",
|
1586 |
+
" 0,\n",
|
1587 |
+
" 0,\n",
|
1588 |
+
" 0,\n",
|
1589 |
+
" 0,\n",
|
1590 |
+
" 0,\n",
|
1591 |
+
" 0,\n",
|
1592 |
+
" 0,\n",
|
1593 |
+
" 0,\n",
|
1594 |
+
" 0,\n",
|
1595 |
+
" 0,\n",
|
1596 |
+
" 0,\n",
|
1597 |
+
" 0,\n",
|
1598 |
+
" 0,\n",
|
1599 |
+
" 0,\n",
|
1600 |
+
" 0,\n",
|
1601 |
+
" 0,\n",
|
1602 |
+
" 0,\n",
|
1603 |
+
" 0,\n",
|
1604 |
+
" 0,\n",
|
1605 |
+
" 0,\n",
|
1606 |
+
" 0,\n",
|
1607 |
+
" 0,\n",
|
1608 |
+
" 0,\n",
|
1609 |
+
" 0,\n",
|
1610 |
+
" 0,\n",
|
1611 |
+
" 0,\n",
|
1612 |
+
" 0,\n",
|
1613 |
+
" 0,\n",
|
1614 |
+
" 0,\n",
|
1615 |
+
" 0,\n",
|
1616 |
+
" 0,\n",
|
1617 |
+
" 0,\n",
|
1618 |
+
" 0,\n",
|
1619 |
+
" 0,\n",
|
1620 |
+
" 0,\n",
|
1621 |
+
" 0,\n",
|
1622 |
+
" 0],\n",
|
1623 |
+
" [1,\n",
|
1624 |
+
" 1,\n",
|
1625 |
+
" 1,\n",
|
1626 |
+
" 1,\n",
|
1627 |
+
" 1,\n",
|
1628 |
+
" 1,\n",
|
1629 |
+
" 1,\n",
|
1630 |
+
" 1,\n",
|
1631 |
+
" 1,\n",
|
1632 |
+
" 1,\n",
|
1633 |
+
" 1,\n",
|
1634 |
+
" 1,\n",
|
1635 |
+
" 1,\n",
|
1636 |
+
" 1,\n",
|
1637 |
+
" 1,\n",
|
1638 |
+
" 0,\n",
|
1639 |
+
" 0,\n",
|
1640 |
+
" 0,\n",
|
1641 |
+
" 0,\n",
|
1642 |
+
" 0,\n",
|
1643 |
+
" 0,\n",
|
1644 |
+
" 0,\n",
|
1645 |
+
" 0,\n",
|
1646 |
+
" 0,\n",
|
1647 |
+
" 0,\n",
|
1648 |
+
" 0,\n",
|
1649 |
+
" 0,\n",
|
1650 |
+
" 0,\n",
|
1651 |
+
" 0,\n",
|
1652 |
+
" 0,\n",
|
1653 |
+
" 0,\n",
|
1654 |
+
" 0,\n",
|
1655 |
+
" 0,\n",
|
1656 |
+
" 0,\n",
|
1657 |
+
" 0,\n",
|
1658 |
+
" 0,\n",
|
1659 |
+
" 0,\n",
|
1660 |
+
" 0,\n",
|
1661 |
+
" 0,\n",
|
1662 |
+
" 0,\n",
|
1663 |
+
" 0,\n",
|
1664 |
+
" 0,\n",
|
1665 |
+
" 0,\n",
|
1666 |
+
" 0,\n",
|
1667 |
+
" 0,\n",
|
1668 |
+
" 0,\n",
|
1669 |
+
" 0,\n",
|
1670 |
+
" 0,\n",
|
1671 |
+
" 0,\n",
|
1672 |
+
" 0,\n",
|
1673 |
+
" 0,\n",
|
1674 |
+
" 0,\n",
|
1675 |
+
" 0,\n",
|
1676 |
+
" 0,\n",
|
1677 |
+
" 0,\n",
|
1678 |
+
" 0]]}"
|
1679 |
+
]
|
1680 |
+
},
|
1681 |
+
"execution_count": 62,
|
1682 |
+
"metadata": {},
|
1683 |
+
"output_type": "execute_result"
|
1684 |
+
}
|
1685 |
+
],
|
1686 |
+
"source": [
|
1687 |
+
"samples"
|
1688 |
+
]
|
1689 |
+
},
|
1690 |
+
{
|
1691 |
+
"cell_type": "markdown",
|
1692 |
+
"metadata": {},
|
1693 |
+
"source": [
|
1694 |
+
"The datasets library is pretty intuitive in the way it is structured. We just need to make sure before collating, we have the necessary fields and drop the unnecessary fields from the dataset. And that we do dynamic padding based on a batch of data and not on the model dim or the max sequence length of the entire corpus. It will be economical in terms of computation and also help training."
|
1695 |
+
]
|
1696 |
+
}
|
1697 |
+
],
|
1698 |
+
"metadata": {
|
1699 |
+
"kernelspec": {
|
1700 |
+
"display_name": "Python 3",
|
1701 |
+
"language": "python",
|
1702 |
+
"name": "python3"
|
1703 |
+
},
|
1704 |
+
"language_info": {
|
1705 |
+
"codemirror_mode": {
|
1706 |
+
"name": "ipython",
|
1707 |
+
"version": 3
|
1708 |
+
},
|
1709 |
+
"file_extension": ".py",
|
1710 |
+
"mimetype": "text/x-python",
|
1711 |
+
"name": "python",
|
1712 |
+
"nbconvert_exporter": "python",
|
1713 |
+
"pygments_lexer": "ipython3",
|
1714 |
+
"version": "3.10.14"
|
1715 |
+
}
|
1716 |
+
},
|
1717 |
+
"nbformat": 4,
|
1718 |
+
"nbformat_minor": 2
|
1719 |
+
}
|