"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "%load_ext tensorboard\n",
+ "%tensorboard --logdir experiments/runs\n",
+ "%reload_ext tensorboard"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "G0tXhYO5ea5l",
+ "outputId": "00c73bb3-0e4a-45d3-c235-9088b79477de",
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
+ "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [160/160 08:34, Epoch 0/1]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 13.098800 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 13.048300 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 13.306200 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 13.201500 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 12.833000 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 11.263900 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 11.254500 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 9.762500 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 8.756900 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 7.434900 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 6.682500 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 5.917600 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 5.254400 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 5.121800 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 4.021000 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 3.738400 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 3.626200 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 3.464700 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 3.121000 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 2.591900 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 2.441300 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 2.239800 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 1.953600 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 1.947700 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 2.155100 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 1.642400 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 1.526800 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 1.712300 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 1.177600 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 1.286600 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 1.244600 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 1.238200 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 1.431700 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 1.178200 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 1.132500 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.744700 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 1.574000 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 1.171600 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 1.036600 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 1.207000 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 1.098000 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.682900 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 1.237700 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 1.030800 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 1.009700 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.952200 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 1.284700 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.950900 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 1.146300 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 1.043400 | \n",
+ "
\n",
+ " \n",
+ " 51 | \n",
+ " 1.105700 | \n",
+ "
\n",
+ " \n",
+ " 52 | \n",
+ " 1.178600 | \n",
+ "
\n",
+ " \n",
+ " 53 | \n",
+ " 0.713800 | \n",
+ "
\n",
+ " \n",
+ " 54 | \n",
+ " 1.203200 | \n",
+ "
\n",
+ " \n",
+ " 55 | \n",
+ " 0.783500 | \n",
+ "
\n",
+ " \n",
+ " 56 | \n",
+ " 0.894600 | \n",
+ "
\n",
+ " \n",
+ " 57 | \n",
+ " 1.049400 | \n",
+ "
\n",
+ " \n",
+ " 58 | \n",
+ " 0.886800 | \n",
+ "
\n",
+ " \n",
+ " 59 | \n",
+ " 0.774900 | \n",
+ "
\n",
+ " \n",
+ " 60 | \n",
+ " 0.735200 | \n",
+ "
\n",
+ " \n",
+ " 61 | \n",
+ " 0.829800 | \n",
+ "
\n",
+ " \n",
+ " 62 | \n",
+ " 0.985800 | \n",
+ "
\n",
+ " \n",
+ " 63 | \n",
+ " 0.860800 | \n",
+ "
\n",
+ " \n",
+ " 64 | \n",
+ " 1.090300 | \n",
+ "
\n",
+ " \n",
+ " 65 | \n",
+ " 0.968900 | \n",
+ "
\n",
+ " \n",
+ " 66 | \n",
+ " 0.766600 | \n",
+ "
\n",
+ " \n",
+ " 67 | \n",
+ " 1.027200 | \n",
+ "
\n",
+ " \n",
+ " 68 | \n",
+ " 1.426600 | \n",
+ "
\n",
+ " \n",
+ " 69 | \n",
+ " 0.817600 | \n",
+ "
\n",
+ " \n",
+ " 70 | \n",
+ " 0.954100 | \n",
+ "
\n",
+ " \n",
+ " 71 | \n",
+ " 0.943500 | \n",
+ "
\n",
+ " \n",
+ " 72 | \n",
+ " 1.138100 | \n",
+ "
\n",
+ " \n",
+ " 73 | \n",
+ " 0.955000 | \n",
+ "
\n",
+ " \n",
+ " 74 | \n",
+ " 1.066400 | \n",
+ "
\n",
+ " \n",
+ " 75 | \n",
+ " 0.714300 | \n",
+ "
\n",
+ " \n",
+ " 76 | \n",
+ " 1.025000 | \n",
+ "
\n",
+ " \n",
+ " 77 | \n",
+ " 0.689200 | \n",
+ "
\n",
+ " \n",
+ " 78 | \n",
+ " 0.783400 | \n",
+ "
\n",
+ " \n",
+ " 79 | \n",
+ " 0.781600 | \n",
+ "
\n",
+ " \n",
+ " 80 | \n",
+ " 0.838400 | \n",
+ "
\n",
+ " \n",
+ " 81 | \n",
+ " 0.731500 | \n",
+ "
\n",
+ " \n",
+ " 82 | \n",
+ " 0.901400 | \n",
+ "
\n",
+ " \n",
+ " 83 | \n",
+ " 0.802100 | \n",
+ "
\n",
+ " \n",
+ " 84 | \n",
+ " 1.113400 | \n",
+ "
\n",
+ " \n",
+ " 85 | \n",
+ " 0.755600 | \n",
+ "
\n",
+ " \n",
+ " 86 | \n",
+ " 0.845900 | \n",
+ "
\n",
+ " \n",
+ " 87 | \n",
+ " 1.089000 | \n",
+ "
\n",
+ " \n",
+ " 88 | \n",
+ " 1.094800 | \n",
+ "
\n",
+ " \n",
+ " 89 | \n",
+ " 1.035400 | \n",
+ "
\n",
+ " \n",
+ " 90 | \n",
+ " 0.824700 | \n",
+ "
\n",
+ " \n",
+ " 91 | \n",
+ " 0.899000 | \n",
+ "
\n",
+ " \n",
+ " 92 | \n",
+ " 0.897400 | \n",
+ "
\n",
+ " \n",
+ " 93 | \n",
+ " 1.172800 | \n",
+ "
\n",
+ " \n",
+ " 94 | \n",
+ " 1.036700 | \n",
+ "
\n",
+ " \n",
+ " 95 | \n",
+ " 0.878500 | \n",
+ "
\n",
+ " \n",
+ " 96 | \n",
+ " 0.890800 | \n",
+ "
\n",
+ " \n",
+ " 97 | \n",
+ " 0.872300 | \n",
+ "
\n",
+ " \n",
+ " 98 | \n",
+ " 0.854800 | \n",
+ "
\n",
+ " \n",
+ " 99 | \n",
+ " 1.054900 | \n",
+ "
\n",
+ " \n",
+ " 100 | \n",
+ " 1.077300 | \n",
+ "
\n",
+ " \n",
+ " 101 | \n",
+ " 0.820400 | \n",
+ "
\n",
+ " \n",
+ " 102 | \n",
+ " 1.013900 | \n",
+ "
\n",
+ " \n",
+ " 103 | \n",
+ " 0.903800 | \n",
+ "
\n",
+ " \n",
+ " 104 | \n",
+ " 1.006800 | \n",
+ "
\n",
+ " \n",
+ " 105 | \n",
+ " 1.019600 | \n",
+ "
\n",
+ " \n",
+ " 106 | \n",
+ " 0.922600 | \n",
+ "
\n",
+ " \n",
+ " 107 | \n",
+ " 0.787000 | \n",
+ "
\n",
+ " \n",
+ " 108 | \n",
+ " 0.783300 | \n",
+ "
\n",
+ " \n",
+ " 109 | \n",
+ " 0.599800 | \n",
+ "
\n",
+ " \n",
+ " 110 | \n",
+ " 0.853000 | \n",
+ "
\n",
+ " \n",
+ " 111 | \n",
+ " 0.788900 | \n",
+ "
\n",
+ " \n",
+ " 112 | \n",
+ " 0.677800 | \n",
+ "
\n",
+ " \n",
+ " 113 | \n",
+ " 1.074000 | \n",
+ "
\n",
+ " \n",
+ " 114 | \n",
+ " 0.813000 | \n",
+ "
\n",
+ " \n",
+ " 115 | \n",
+ " 1.067300 | \n",
+ "
\n",
+ " \n",
+ " 116 | \n",
+ " 0.730000 | \n",
+ "
\n",
+ " \n",
+ " 117 | \n",
+ " 0.936800 | \n",
+ "
\n",
+ " \n",
+ " 118 | \n",
+ " 0.852000 | \n",
+ "
\n",
+ " \n",
+ " 119 | \n",
+ " 0.769600 | \n",
+ "
\n",
+ " \n",
+ " 120 | \n",
+ " 0.876800 | \n",
+ "
\n",
+ " \n",
+ " 121 | \n",
+ " 0.965800 | \n",
+ "
\n",
+ " \n",
+ " 122 | \n",
+ " 1.302800 | \n",
+ "
\n",
+ " \n",
+ " 123 | \n",
+ " 0.890500 | \n",
+ "
\n",
+ " \n",
+ " 124 | \n",
+ " 0.844400 | \n",
+ "
\n",
+ " \n",
+ " 125 | \n",
+ " 0.914500 | \n",
+ "
\n",
+ " \n",
+ " 126 | \n",
+ " 0.831100 | \n",
+ "
\n",
+ " \n",
+ " 127 | \n",
+ " 1.012000 | \n",
+ "
\n",
+ " \n",
+ " 128 | \n",
+ " 0.933600 | \n",
+ "
\n",
+ " \n",
+ " 129 | \n",
+ " 0.750600 | \n",
+ "
\n",
+ " \n",
+ " 130 | \n",
+ " 0.695000 | \n",
+ "
\n",
+ " \n",
+ " 131 | \n",
+ " 0.944900 | \n",
+ "
\n",
+ " \n",
+ " 132 | \n",
+ " 0.907700 | \n",
+ "
\n",
+ " \n",
+ " 133 | \n",
+ " 0.851100 | \n",
+ "
\n",
+ " \n",
+ " 134 | \n",
+ " 0.911900 | \n",
+ "
\n",
+ " \n",
+ " 135 | \n",
+ " 0.914100 | \n",
+ "
\n",
+ " \n",
+ " 136 | \n",
+ " 0.922600 | \n",
+ "
\n",
+ " \n",
+ " 137 | \n",
+ " 0.921800 | \n",
+ "
\n",
+ " \n",
+ " 138 | \n",
+ " 0.580200 | \n",
+ "
\n",
+ " \n",
+ " 139 | \n",
+ " 0.880200 | \n",
+ "
\n",
+ " \n",
+ " 140 | \n",
+ " 0.689700 | \n",
+ "
\n",
+ " \n",
+ " 141 | \n",
+ " 0.662000 | \n",
+ "
\n",
+ " \n",
+ " 142 | \n",
+ " 1.098300 | \n",
+ "
\n",
+ " \n",
+ " 143 | \n",
+ " 0.895300 | \n",
+ "
\n",
+ " \n",
+ " 144 | \n",
+ " 0.875000 | \n",
+ "
\n",
+ " \n",
+ " 145 | \n",
+ " 1.023400 | \n",
+ "
\n",
+ " \n",
+ " 146 | \n",
+ " 0.606500 | \n",
+ "
\n",
+ " \n",
+ " 147 | \n",
+ " 0.867700 | \n",
+ "
\n",
+ " \n",
+ " 148 | \n",
+ " 1.160800 | \n",
+ "
\n",
+ " \n",
+ " 149 | \n",
+ " 0.899800 | \n",
+ "
\n",
+ " \n",
+ " 150 | \n",
+ " 0.679200 | \n",
+ "
\n",
+ " \n",
+ " 151 | \n",
+ " 0.701300 | \n",
+ "
\n",
+ " \n",
+ " 152 | \n",
+ " 1.261700 | \n",
+ "
\n",
+ " \n",
+ " 153 | \n",
+ " 0.837900 | \n",
+ "
\n",
+ " \n",
+ " 154 | \n",
+ " 0.955800 | \n",
+ "
\n",
+ " \n",
+ " 155 | \n",
+ " 1.104200 | \n",
+ "
\n",
+ " \n",
+ " 156 | \n",
+ " 0.684500 | \n",
+ "
\n",
+ " \n",
+ " 157 | \n",
+ " 0.828300 | \n",
+ "
\n",
+ " \n",
+ " 158 | \n",
+ " 0.778700 | \n",
+ "
\n",
+ " \n",
+ " 159 | \n",
+ " 0.596200 | \n",
+ "
\n",
+ " \n",
+ " 160 | \n",
+ " 0.691600 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=160, training_loss=1.8560316052287817, metrics={'train_runtime': 518.0598, 'train_samples_per_second': 1.235, 'train_steps_per_second': 0.309, 'total_flos': 5380710900916224.0, 'train_loss': 1.8560316052287817, 'epoch': 0.71})"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "training_args = transformers.TrainingArguments(\n",
+ " per_device_train_batch_size=1,\n",
+ " gradient_accumulation_steps=4,\n",
+ " num_train_epochs=1,\n",
+ " learning_rate=2e-4,\n",
+ " fp16=True,\n",
+ " save_total_limit=3,\n",
+ " logging_steps=1,\n",
+ " output_dir=OUTPUT_DIR,\n",
+ " max_steps=160,\n",
+ " optim=\"paged_adamw_8bit\",\n",
+ " lr_scheduler_type=\"cosine\",\n",
+ " warmup_ratio=0.05,\n",
+ " report_to=\"tensorboard\",\n",
+ ")\n",
+ "\n",
+ "trainer = transformers.Trainer(\n",
+ " model=model,\n",
+ " train_dataset=data,\n",
+ " args=training_args,\n",
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
+ ")\n",
+ "model.config.use_cache = False\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XEMYxmFFZevu"
+ },
+ "source": [
+ "## Save Trained Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "id": "__D-YiF4i4E-"
+ },
+ "outputs": [],
+ "source": [
+ "model.save_pretrained(\"trained-model\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 240,
+ "referenced_widgets": [
+ "1765001f2e8a45789c303bf25802ec26",
+ "b9342727e5b04a31a44a2bb39eb9b972",
+ "fb9cd1d7b199467ca84fdee089c3e616",
+ "cc2c910395a14a1b876e76a9f1967159",
+ "e6870d171e2e4f49b29940a584f8234f",
+ "6a21b3db09d043a09364c677e380fd1a",
+ "b47756ecdc5544fb881978069650dff7",
+ "8623daa5b6104e01aa049cce796d57a5",
+ "6a369f47577d4a1ab89fc34e893295e3",
+ "ec6c022167c549949cc2d4261e6b388a",
+ "defa462b8d0c425b8a3c5c10d3824449",
+ "ed46a43aa3b74664b7be0c5393af55c7",
+ "ea756bf0eec44265a894ca2c007ecdd3",
+ "6e2ec6222dbc4dafbbe3379d89e74a33",
+ "05c0d40b5c01447790d0ceccebc03d0f",
+ "3e72185f6ba44316818d958ff112e6a5",
+ "535b91b499d64a19a74ad64b96ff520b",
+ "b20a32370ca94e4490ec2541fedf4629",
+ "c770c0437f13471092d8b7bd92169f27",
+ "19e71863a0384ec8851bde3abbb3fa68",
+ "d0d5e191b9ea48fc83e8d14eef28807e",
+ "55aa6981a6304872b253b5452dd0399f",
+ "7cf8792fe92346e4be71b1cc0707c40f",
+ "95ea93ef06164a6792df160fc5a856e1",
+ "2c8a23857bfb41a28a4dcbaf360c5152",
+ "d81c35bd6c964b8396a93ae5ce9b7035",
+ "5f5378a779524d05bd130689944face8",
+ "7a12995dd12f4aa8ac273ed26fe8c774",
+ "e5a51c252b4642828528866ef7a71f61",
+ "996b7bc4804747bc8e3d3bff8c09bc4e",
+ "ef942a8f706549fd9d0ff97019b23d07",
+ "1f5ec464af694baeb020c677c9804203",
+ "05b4e2dcdb774de3847a1af4328a2d1e"
+ ]
+ },
+ "id": "SBTcZs_EODfg",
+ "outputId": "4727fb36-a7a7-4f27-f6c7-ce93ced846d9"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/conda/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:282: UserWarning: About to update multiple times the same file in the same commit: 'adapter_model.bin'. This can cause undesired inconsistencies in your repo.\n",
+ " warnings.warn(\n",
+ "/opt/conda/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:282: UserWarning: About to update multiple times the same file in the same commit: 'adapter_config.json'. This can cause undesired inconsistencies in your repo.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f63eb70685e141d9844ff71b1ff5d54b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Upload 2 LFS files: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "72e97297f30c41949fed27d0f783e03d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "adapter_model.bin: 0%| | 0.00/33.6M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "857ada0243b94ff3961189c61d87ef6a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "adapter_model.bin: 0%| | 0.00/33.6M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "CommitInfo(commit_url='https://huggingface.co/shanjay/mgc-ds/commit/2b44025091c4853cbff8c5cc1117465f7dc7b7cf', commit_message='Upload model', commit_description='', oid='2b44025091c4853cbff8c5cc1117465f7dc7b7cf', pr_url=None, pr_revision=None, pr_num=None)"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.push_to_hub(\n",
+ " \"shanjay/mgc-ds\", use_auth_token=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KIlYhwJhZgjb"
+ },
+ "source": [
+ "## Load Trained Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 113,
+ "referenced_widgets": [
+ "66408bee12974f96b42c7ebef03f5fb1",
+ "a1b2b3007e4d45039abb88e2d867f7ea",
+ "8ba08878f4e34cf592c5022b4be1c45f",
+ "041ed7f21ecd4a7fbae240a5bd0c20a9",
+ "3a083454009c431181869fd7bfd4ac9a",
+ "fcb4d59661c44d27932e0194d6b0aa1c",
+ "eee58e8c4ff641559649c3b82961a4c3",
+ "b5887b1d286b4cdc9255f55c8334cfb3",
+ "8cec295bad7447d1ab45aaa6c56bbef7",
+ "fa64654a10814ad3a11aa11af4cf36ce",
+ "f82c5b1d7d8f49e3ac3708b6a91f71f2",
+ "9c1d89fdf66243548435317b75da3595",
+ "ba6f3f3d687c474c9bacb2aff2f6cbfa",
+ "da3010bd98e94beb9defdf825db3630d",
+ "57a3f9fdddd0418aa96c99fd894726c3",
+ "e9b09df570954dd78079b1c6e65429fa",
+ "26482d9ce0944ce8adb7aff5d2174e06",
+ "42a40156d8274b3195219f6811db3d92",
+ "dd3861e35f264c629476755db522ea58",
+ "ce1470c542724a97874b595d0002c2d7",
+ "d4337a0b3e244c938a26bf4768a03854",
+ "a2802ba1c11c4fe6b63f03a931ed7a67",
+ "77e56c383fbf44d9ae314136db39adaf",
+ "551736eb29504faaa884d6d00937e0d2",
+ "1af09897d7e446449269106afcb079a8",
+ "05cf083266ab4da6af75da41d784a973",
+ "198f3f9771734a72b2a660ffeb178709",
+ "2d68eaef329c417bad449df84e664463",
+ "21f3873dc2914858b80617074a79e71d",
+ "e79cca46e0a14d1c9d4e1ab17620862e",
+ "d5e7ccf8574540a2b60138653d856e5d",
+ "ab5655e1b7c948ddb07428d5d2e7d00f",
+ "956740f8080e4838b8c5a67096c6894d"
+ ]
+ },
+ "id": "owt5PmrcjG5C",
+ "outputId": "9b79908f-5510-476b-d086-626686305190",
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b14b229f1c8441c68f2a034e44f85fb1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "adapter_config.json: 0%| | 0.00/427 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "edd588118cb046daa2cdb754eb89647a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/6 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of LlamaForCausalLM were not initialized from the model checkpoint at ise-uiuc/Magicoder-S-DS-6.7B and are newly initialized: ['model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.26.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.13.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.21.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.28.self_attn.rotary_emb.inv_freq']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "eb18d64297354b7d9201d54a1eb7a8f2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "adapter_model.bin: 0%| | 0.00/33.6M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "PEFT_MODEL = \"shanjay/mgc-ds\"\n",
+ "\n",
+ "config = PeftConfig.from_pretrained(PEFT_MODEL)\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " config.base_model_name_or_path,\n",
+ " return_dict=True,\n",
+ " quantization_config=bnb_config,\n",
+ " device_map=\"auto\",\n",
+ " trust_remote_code=True,\n",
+ ")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
+ "tokenizer.pad_token = tokenizer.eos_token\n",
+ "\n",
+ "model = PeftModel.from_pretrained(model, PEFT_MODEL)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1XvuuaswZk4e"
+ },
+ "source": [
+ "## Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "id": "LAmXdHItPgQV"
+ },
+ "outputs": [],
+ "source": [
+ "generation_config = model.generation_config\n",
+ "generation_config.max_new_tokens = 200\n",
+ "generation_config.temperature = 0.7\n",
+ "generation_config.top_p = 0.7\n",
+ "generation_config.num_return_sequences = 1\n",
+ "generation_config.pad_token_id = tokenizer.eos_token_id\n",
+ "generation_config.eos_token_id = tokenizer.eos_token_id"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "id": "qsS-RbE_UwJW"
+ },
+ "outputs": [],
+ "source": [
+ "DEVICE = \"cuda:0\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ivZI5MCcTCRC",
+ "outputId": "1015d4d6-82ed-4650-d53e-b541bb8444e7"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ ": How can I create a dataframe?\n",
+ "