Granther commited on
Commit
da79090
·
verified ·
1 Parent(s): 8288c7e

Upload gemma_2_gptq_inference.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. gemma_2_gptq_inference.ipynb +304 -0
gemma_2_gptq_inference.ipynb ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "abba1d18-4a4d-4978-9766-82f3940e411e",
7
+ "metadata": {
8
+ "scrolled": true
9
+ },
10
+ "outputs": [
11
+ {
12
+ "name": "stdout",
13
+ "output_type": "stream",
14
+ "text": [
15
+ "Requirement already satisfied: auto-gptq in /opt/conda/lib/python3.10/site-packages (0.7.1)\n",
16
+ "Requirement already satisfied: accelerate>=0.26.0 in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (0.31.0)\n",
17
+ "Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (2.20.0)\n",
18
+ "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (0.2.0)\n",
19
+ "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (1.26.3)\n",
20
+ "Requirement already satisfied: rouge in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (1.0.1)\n",
21
+ "Requirement already satisfied: gekko in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (1.1.3)\n",
22
+ "Requirement already satisfied: torch>=1.13.0 in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (2.2.0)\n",
23
+ "Requirement already satisfied: safetensors in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (0.4.3)\n",
24
+ "Requirement already satisfied: transformers>=4.31.0 in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (4.43.0.dev0)\n",
25
+ "Requirement already satisfied: peft>=0.5.0 in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (0.11.1)\n",
26
+ "Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from auto-gptq) (4.66.4)\n",
27
+ "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from accelerate>=0.26.0->auto-gptq) (23.1)\n",
28
+ "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate>=0.26.0->auto-gptq) (5.9.0)\n",
29
+ "Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from accelerate>=0.26.0->auto-gptq) (6.0.1)\n",
30
+ "Requirement already satisfied: huggingface-hub in /opt/conda/lib/python3.10/site-packages (from accelerate>=0.26.0->auto-gptq) (0.23.4)\n",
31
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->auto-gptq) (3.13.1)\n",
32
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->auto-gptq) (4.9.0)\n",
33
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->auto-gptq) (1.12)\n",
34
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->auto-gptq) (3.1)\n",
35
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->auto-gptq) (3.1.2)\n",
36
+ "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch>=1.13.0->auto-gptq) (2023.12.2)\n",
37
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.31.0->auto-gptq) (2024.5.15)\n",
38
+ "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers>=4.31.0->auto-gptq) (2.32.3)\n",
39
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.31.0->auto-gptq) (0.19.1)\n",
40
+ "Requirement already satisfied: pyarrow>=15.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets->auto-gptq) (16.1.0)\n",
41
+ "Requirement already satisfied: pyarrow-hotfix in /opt/conda/lib/python3.10/site-packages (from datasets->auto-gptq) (0.6)\n",
42
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets->auto-gptq) (0.3.8)\n",
43
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets->auto-gptq) (2.2.2)\n",
44
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets->auto-gptq) (3.4.1)\n",
45
+ "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets->auto-gptq) (0.70.16)\n",
46
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets->auto-gptq) (3.9.5)\n",
47
+ "Requirement already satisfied: six in /opt/conda/lib/python3.10/site-packages (from rouge->auto-gptq) (1.16.0)\n",
48
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->auto-gptq) (1.3.1)\n",
49
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->auto-gptq) (23.1.0)\n",
50
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->auto-gptq) (1.4.1)\n",
51
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->auto-gptq) (6.0.5)\n",
52
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->auto-gptq) (1.9.4)\n",
53
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets->auto-gptq) (4.0.3)\n",
54
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.31.0->auto-gptq) (2.0.4)\n",
55
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.31.0->auto-gptq) (3.4)\n",
56
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.31.0->auto-gptq) (1.26.18)\n",
57
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.31.0->auto-gptq) (2023.11.17)\n",
58
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.13.0->auto-gptq) (2.1.3)\n",
59
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets->auto-gptq) (2.9.0.post0)\n",
60
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets->auto-gptq) (2023.3.post1)\n",
61
+ "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets->auto-gptq) (2024.1)\n",
62
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.13.0->auto-gptq) (1.3.0)\n",
63
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
64
+ "\u001b[0mERROR: unknown command \"insall\" - maybe you meant \"install\"\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "!pip install auto-gptq\n",
70
+ "!pip insall --upgrade transformers optimum accelerate"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "6cefbcaa-9c28-4d68-b15e-9889f73332dd",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "!pip install parallelformers"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 38,
86
+ "id": "bf65bf24-e67b-4af7-8890-fba4182584e0",
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "data": {
91
+ "application/vnd.jupyter.widget-view+json": {
92
+ "model_id": "c7abb0ecf7aa42b3a5ade8aa6644cbbc",
93
+ "version_major": 2,
94
+ "version_minor": 0
95
+ },
96
+ "text/plain": [
97
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
98
+ ]
99
+ },
100
+ "metadata": {},
101
+ "output_type": "display_data"
102
+ }
103
+ ],
104
+ "source": [
105
+ "from huggingface_hub import HfApi, notebook_login\n",
106
+ "\n",
107
+ "notebook_login()"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 4,
113
+ "id": "fe3c29b7-1946-4f44-9d07-e9798e656bc2",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
118
+ "import torch"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 11,
124
+ "id": "1bff45ce-9a64-4ceb-97e0-0e2ab5f15c45",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "# Hyperparams\n",
129
+ "\n",
130
+ "quant_model_repo = \"Granther/Gemma-2-9B-Instruct-4Bit-GPTQ\""
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 15,
136
+ "id": "5b060dc1-4a95-476f-8707-412a4452dd89",
137
+ "metadata": {},
138
+ "outputs": [
139
+ {
140
+ "name": "stderr",
141
+ "output_type": "stream",
142
+ "text": [
143
+ "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:4565: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n",
144
+ " warnings.warn(\n"
145
+ ]
146
+ },
147
+ {
148
+ "data": {
149
+ "application/vnd.jupyter.widget-view+json": {
150
+ "model_id": "fceb82427c58488da421d6415e0df516",
151
+ "version_major": 2,
152
+ "version_minor": 0
153
+ },
154
+ "text/plain": [
155
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
156
+ ]
157
+ },
158
+ "metadata": {},
159
+ "output_type": "display_data"
160
+ }
161
+ ],
162
+ "source": [
163
+ "from optimum.gptq import load_quantized_model\n",
164
+ "\n",
165
+ "model = AutoModelForCausalLM.from_pretrained(quant_model_repo, device_map='auto')\n",
166
+ "tokenizer = AutoTokenizer.from_pretrained(quant_model_repo)\n",
167
+ "\n",
168
+ "#model = load_quantized_model(model, device_map=\"auto\")"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 36,
174
+ "id": "440a9d11-8fb0-4ded-885f-52a226e326d1",
175
+ "metadata": {},
176
+ "outputs": [
177
+ {
178
+ "data": {
179
+ "text/plain": [
180
+ "tensor([[ 2, 106, 1645, 108, 27445, 682, 1105, 20731, 107, 108,\n",
181
+ " 106, 2516, 108]])"
182
+ ]
183
+ },
184
+ "execution_count": 36,
185
+ "metadata": {},
186
+ "output_type": "execute_result"
187
+ }
188
+ ],
189
+ "source": [
190
+ "input_text = [\n",
191
+ " {\"role\": \"user\", \"content\": \"Tell me about paris\"},\n",
192
+ "]\n",
193
+ "\n",
194
+ "prompt = tokenizer.apply_chat_template(input_text, tokenize=True, add_generation_prompt=True, return_tensors='pt')\n",
195
+ "# >>> '<bos><start_of_turn>user\\nTell me about paris<end_of_turn>\\n<start_of_turn>model\\n' \n",
196
+ "\n",
197
+ "prompt\n",
198
+ "\n",
199
+ "#prompt = {key: value.to(model.device) for key, value in prompt}\n",
200
+ "\n",
201
+ "# with torch.no_grad():\n",
202
+ "# outputs = model.generate(**inputs, max_length=100, num_return_sequences=1)"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 26,
208
+ "id": "9ed28a1b-2c65-4762-8b7c-4338195729af",
209
+ "metadata": {},
210
+ "outputs": [
211
+ {
212
+ "data": {
213
+ "text/plain": [
214
+ "\"Tell me about python \\n\\nLet's talk about Python!\\n\\n**What is Python?**\\n\\nPython is a high-level, interpreted, general-purpose programming language. Let's break down what that means:\\n\\n* **High-level:** Python is designed to be easy for humans to read and write. It uses words and symbols that are closer to natural language than the low-level instructions computers directly understand.\\n* **Interpreted:** Python code is executed line\""
215
+ ]
216
+ },
217
+ "execution_count": 26,
218
+ "metadata": {},
219
+ "output_type": "execute_result"
220
+ }
221
+ ],
222
+ "source": [
223
+ "outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
224
+ "outputs"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "8f040240-5bb0-41a1-a823-678f5ffb3dc2",
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "\n",
235
+ "chat = [\n",
236
+ " { \"role\": \"user\", \"content\": \"Write a hello world program\" },\n",
237
+ "]\n",
238
+ "prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)\n"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 12,
244
+ "id": "9cf56d96-bf6c-43cd-a062-aa8368b3406b",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "from transformers import pipeline\n",
249
+ "\n",
250
+ "pipe = pipeline('text-generation', model=model, tokenizer=quant_model_repo)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 14,
256
+ "id": "eb34f0bc-9a4b-4f8b-baf9-9869e2d1bef1",
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "data": {
261
+ "text/plain": [
262
+ "[{'generated_text': 'hello, how are you?\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n'}]"
263
+ ]
264
+ },
265
+ "execution_count": 14,
266
+ "metadata": {},
267
+ "output_type": "execute_result"
268
+ }
269
+ ],
270
+ "source": [
271
+ "pipe(\"hello, how are you\")"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "id": "52a00f3b-e601-4800-9708-5f656c54ffc8",
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": []
281
+ }
282
+ ],
283
+ "metadata": {
284
+ "kernelspec": {
285
+ "display_name": "Python 3 (ipykernel)",
286
+ "language": "python",
287
+ "name": "python3"
288
+ },
289
+ "language_info": {
290
+ "codemirror_mode": {
291
+ "name": "ipython",
292
+ "version": 3
293
+ },
294
+ "file_extension": ".py",
295
+ "mimetype": "text/x-python",
296
+ "name": "python",
297
+ "nbconvert_exporter": "python",
298
+ "pygments_lexer": "ipython3",
299
+ "version": "3.10.13"
300
+ }
301
+ },
302
+ "nbformat": 4,
303
+ "nbformat_minor": 5
304
+ }