"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir experiments/runs\n",
"%reload_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "G0tXhYO5ea5l",
"outputId": "00c73bb3-0e4a-45d3-c235-9088b79477de",
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [448/448 24:50, Epoch 1/2]\n",
"
\n",
" \n",
" \n",
" \n",
" Step | \n",
" Training Loss | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 12.814700 | \n",
"
\n",
" \n",
" 2 | \n",
" 13.697400 | \n",
"
\n",
" \n",
" 3 | \n",
" 12.960800 | \n",
"
\n",
" \n",
" 4 | \n",
" 13.333700 | \n",
"
\n",
" \n",
" 5 | \n",
" 12.766400 | \n",
"
\n",
" \n",
" 6 | \n",
" 13.581100 | \n",
"
\n",
" \n",
" 7 | \n",
" 12.539900 | \n",
"
\n",
" \n",
" 8 | \n",
" 12.319500 | \n",
"
\n",
" \n",
" 9 | \n",
" 11.367000 | \n",
"
\n",
" \n",
" 10 | \n",
" 10.986600 | \n",
"
\n",
" \n",
" 11 | \n",
" 10.852100 | \n",
"
\n",
" \n",
" 12 | \n",
" 10.185600 | \n",
"
\n",
" \n",
" 13 | \n",
" 9.537600 | \n",
"
\n",
" \n",
" 14 | \n",
" 8.335100 | \n",
"
\n",
" \n",
" 15 | \n",
" 7.010200 | \n",
"
\n",
" \n",
" 16 | \n",
" 6.907400 | \n",
"
\n",
" \n",
" 17 | \n",
" 5.939900 | \n",
"
\n",
" \n",
" 18 | \n",
" 5.495400 | \n",
"
\n",
" \n",
" 19 | \n",
" 4.819800 | \n",
"
\n",
" \n",
" 20 | \n",
" 4.209600 | \n",
"
\n",
" \n",
" 21 | \n",
" 3.982900 | \n",
"
\n",
" \n",
" 22 | \n",
" 3.795100 | \n",
"
\n",
" \n",
" 23 | \n",
" 3.108400 | \n",
"
\n",
" \n",
" 24 | \n",
" 2.894000 | \n",
"
\n",
" \n",
" 25 | \n",
" 2.691100 | \n",
"
\n",
" \n",
" 26 | \n",
" 2.841600 | \n",
"
\n",
" \n",
" 27 | \n",
" 2.394400 | \n",
"
\n",
" \n",
" 28 | \n",
" 2.180700 | \n",
"
\n",
" \n",
" 29 | \n",
" 2.338900 | \n",
"
\n",
" \n",
" 30 | \n",
" 1.686200 | \n",
"
\n",
" \n",
" 31 | \n",
" 1.859100 | \n",
"
\n",
" \n",
" 32 | \n",
" 1.520700 | \n",
"
\n",
" \n",
" 33 | \n",
" 1.582500 | \n",
"
\n",
" \n",
" 34 | \n",
" 1.144400 | \n",
"
\n",
" \n",
" 35 | \n",
" 1.368300 | \n",
"
\n",
" \n",
" 36 | \n",
" 0.990500 | \n",
"
\n",
" \n",
" 37 | \n",
" 1.010200 | \n",
"
\n",
" \n",
" 38 | \n",
" 1.229700 | \n",
"
\n",
" \n",
" 39 | \n",
" 1.153800 | \n",
"
\n",
" \n",
" 40 | \n",
" 0.891700 | \n",
"
\n",
" \n",
" 41 | \n",
" 1.138300 | \n",
"
\n",
" \n",
" 42 | \n",
" 0.996700 | \n",
"
\n",
" \n",
" 43 | \n",
" 1.284100 | \n",
"
\n",
" \n",
" 44 | \n",
" 1.055400 | \n",
"
\n",
" \n",
" 45 | \n",
" 1.098900 | \n",
"
\n",
" \n",
" 46 | \n",
" 1.184600 | \n",
"
\n",
" \n",
" 47 | \n",
" 1.129900 | \n",
"
\n",
" \n",
" 48 | \n",
" 1.223300 | \n",
"
\n",
" \n",
" 49 | \n",
" 1.037800 | \n",
"
\n",
" \n",
" 50 | \n",
" 0.783300 | \n",
"
\n",
" \n",
" 51 | \n",
" 0.861600 | \n",
"
\n",
" \n",
" 52 | \n",
" 0.918200 | \n",
"
\n",
" \n",
" 53 | \n",
" 1.042900 | \n",
"
\n",
" \n",
" 54 | \n",
" 1.283900 | \n",
"
\n",
" \n",
" 55 | \n",
" 0.664500 | \n",
"
\n",
" \n",
" 56 | \n",
" 0.998500 | \n",
"
\n",
" \n",
" 57 | \n",
" 0.890900 | \n",
"
\n",
" \n",
" 58 | \n",
" 1.372800 | \n",
"
\n",
" \n",
" 59 | \n",
" 0.764400 | \n",
"
\n",
" \n",
" 60 | \n",
" 1.194000 | \n",
"
\n",
" \n",
" 61 | \n",
" 1.111300 | \n",
"
\n",
" \n",
" 62 | \n",
" 1.113600 | \n",
"
\n",
" \n",
" 63 | \n",
" 1.198000 | \n",
"
\n",
" \n",
" 64 | \n",
" 0.987000 | \n",
"
\n",
" \n",
" 65 | \n",
" 1.191600 | \n",
"
\n",
" \n",
" 66 | \n",
" 0.791000 | \n",
"
\n",
" \n",
" 67 | \n",
" 0.851000 | \n",
"
\n",
" \n",
" 68 | \n",
" 1.068800 | \n",
"
\n",
" \n",
" 69 | \n",
" 0.987200 | \n",
"
\n",
" \n",
" 70 | \n",
" 0.840300 | \n",
"
\n",
" \n",
" 71 | \n",
" 0.995800 | \n",
"
\n",
" \n",
" 72 | \n",
" 0.947700 | \n",
"
\n",
" \n",
" 73 | \n",
" 0.633000 | \n",
"
\n",
" \n",
" 74 | \n",
" 0.747000 | \n",
"
\n",
" \n",
" 75 | \n",
" 0.995600 | \n",
"
\n",
" \n",
" 76 | \n",
" 1.260800 | \n",
"
\n",
" \n",
" 77 | \n",
" 1.063600 | \n",
"
\n",
" \n",
" 78 | \n",
" 1.100100 | \n",
"
\n",
" \n",
" 79 | \n",
" 0.863500 | \n",
"
\n",
" \n",
" 80 | \n",
" 0.942200 | \n",
"
\n",
" \n",
" 81 | \n",
" 1.008200 | \n",
"
\n",
" \n",
" 82 | \n",
" 1.013100 | \n",
"
\n",
" \n",
" 83 | \n",
" 1.201100 | \n",
"
\n",
" \n",
" 84 | \n",
" 0.996700 | \n",
"
\n",
" \n",
" 85 | \n",
" 0.986900 | \n",
"
\n",
" \n",
" 86 | \n",
" 0.387100 | \n",
"
\n",
" \n",
" 87 | \n",
" 0.920500 | \n",
"
\n",
" \n",
" 88 | \n",
" 0.916600 | \n",
"
\n",
" \n",
" 89 | \n",
" 1.078900 | \n",
"
\n",
" \n",
" 90 | \n",
" 0.968400 | \n",
"
\n",
" \n",
" 91 | \n",
" 0.620100 | \n",
"
\n",
" \n",
" 92 | \n",
" 0.244100 | \n",
"
\n",
" \n",
" 93 | \n",
" 1.460300 | \n",
"
\n",
" \n",
" 94 | \n",
" 0.766300 | \n",
"
\n",
" \n",
" 95 | \n",
" 1.270700 | \n",
"
\n",
" \n",
" 96 | \n",
" 0.853500 | \n",
"
\n",
" \n",
" 97 | \n",
" 1.294200 | \n",
"
\n",
" \n",
" 98 | \n",
" 0.685400 | \n",
"
\n",
" \n",
" 99 | \n",
" 0.773200 | \n",
"
\n",
" \n",
" 100 | \n",
" 0.830100 | \n",
"
\n",
" \n",
" 101 | \n",
" 1.346000 | \n",
"
\n",
" \n",
" 102 | \n",
" 0.662700 | \n",
"
\n",
" \n",
" 103 | \n",
" 1.044100 | \n",
"
\n",
" \n",
" 104 | \n",
" 1.011200 | \n",
"
\n",
" \n",
" 105 | \n",
" 0.709300 | \n",
"
\n",
" \n",
" 106 | \n",
" 0.690500 | \n",
"
\n",
" \n",
" 107 | \n",
" 0.506500 | \n",
"
\n",
" \n",
" 108 | \n",
" 0.757100 | \n",
"
\n",
" \n",
" 109 | \n",
" 1.117300 | \n",
"
\n",
" \n",
" 110 | \n",
" 1.001700 | \n",
"
\n",
" \n",
" 111 | \n",
" 1.274600 | \n",
"
\n",
" \n",
" 112 | \n",
" 1.047600 | \n",
"
\n",
" \n",
" 113 | \n",
" 0.978200 | \n",
"
\n",
" \n",
" 114 | \n",
" 0.664500 | \n",
"
\n",
" \n",
" 115 | \n",
" 0.613600 | \n",
"
\n",
" \n",
" 116 | \n",
" 0.807700 | \n",
"
\n",
" \n",
" 117 | \n",
" 1.063500 | \n",
"
\n",
" \n",
" 118 | \n",
" 2.378700 | \n",
"
\n",
" \n",
" 119 | \n",
" 0.762800 | \n",
"
\n",
" \n",
" 120 | \n",
" 0.669000 | \n",
"
\n",
" \n",
" 121 | \n",
" 0.161600 | \n",
"
\n",
" \n",
" 122 | \n",
" 2.231600 | \n",
"
\n",
" \n",
" 123 | \n",
" 1.073700 | \n",
"
\n",
" \n",
" 124 | \n",
" 1.289300 | \n",
"
\n",
" \n",
" 125 | \n",
" 2.160400 | \n",
"
\n",
" \n",
" 126 | \n",
" 1.919500 | \n",
"
\n",
" \n",
" 127 | \n",
" 0.471700 | \n",
"
\n",
" \n",
" 128 | \n",
" 1.039600 | \n",
"
\n",
" \n",
" 129 | \n",
" 0.448000 | \n",
"
\n",
" \n",
" 130 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 131 | \n",
" 0.782300 | \n",
"
\n",
" \n",
" 132 | \n",
" 0.861400 | \n",
"
\n",
" \n",
" 133 | \n",
" 0.236500 | \n",
"
\n",
" \n",
" 134 | \n",
" 0.274500 | \n",
"
\n",
" \n",
" 135 | \n",
" 1.382100 | \n",
"
\n",
" \n",
" 136 | \n",
" 0.295800 | \n",
"
\n",
" \n",
" 137 | \n",
" 0.817400 | \n",
"
\n",
" \n",
" 138 | \n",
" 1.276000 | \n",
"
\n",
" \n",
" 139 | \n",
" 2.482400 | \n",
"
\n",
" \n",
" 140 | \n",
" 0.080000 | \n",
"
\n",
" \n",
" 141 | \n",
" 1.378000 | \n",
"
\n",
" \n",
" 142 | \n",
" 0.585800 | \n",
"
\n",
" \n",
" 143 | \n",
" 0.409700 | \n",
"
\n",
" \n",
" 144 | \n",
" 0.641100 | \n",
"
\n",
" \n",
" 145 | \n",
" 0.709000 | \n",
"
\n",
" \n",
" 146 | \n",
" 0.618000 | \n",
"
\n",
" \n",
" 147 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 148 | \n",
" 0.888400 | \n",
"
\n",
" \n",
" 149 | \n",
" 2.182600 | \n",
"
\n",
" \n",
" 150 | \n",
" 0.875900 | \n",
"
\n",
" \n",
" 151 | \n",
" 0.226900 | \n",
"
\n",
" \n",
" 152 | \n",
" 0.839500 | \n",
"
\n",
" \n",
" 153 | \n",
" 1.874100 | \n",
"
\n",
" \n",
" 154 | \n",
" 0.512100 | \n",
"
\n",
" \n",
" 155 | \n",
" 1.362000 | \n",
"
\n",
" \n",
" 156 | \n",
" 0.158800 | \n",
"
\n",
" \n",
" 157 | \n",
" 0.866400 | \n",
"
\n",
" \n",
" 158 | \n",
" 1.023500 | \n",
"
\n",
" \n",
" 159 | \n",
" 1.205100 | \n",
"
\n",
" \n",
" 160 | \n",
" 1.205600 | \n",
"
\n",
" \n",
" 161 | \n",
" 0.486500 | \n",
"
\n",
" \n",
" 162 | \n",
" 1.571700 | \n",
"
\n",
" \n",
" 163 | \n",
" 0.694100 | \n",
"
\n",
" \n",
" 164 | \n",
" 1.097500 | \n",
"
\n",
" \n",
" 165 | \n",
" 0.334000 | \n",
"
\n",
" \n",
" 166 | \n",
" 0.779400 | \n",
"
\n",
" \n",
" 167 | \n",
" 0.842800 | \n",
"
\n",
" \n",
" 168 | \n",
" 0.987100 | \n",
"
\n",
" \n",
" 169 | \n",
" 0.417900 | \n",
"
\n",
" \n",
" 170 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 171 | \n",
" 0.884600 | \n",
"
\n",
" \n",
" 172 | \n",
" 1.102700 | \n",
"
\n",
" \n",
" 173 | \n",
" 0.686900 | \n",
"
\n",
" \n",
" 174 | \n",
" 0.847400 | \n",
"
\n",
" \n",
" 175 | \n",
" 0.597300 | \n",
"
\n",
" \n",
" 176 | \n",
" 0.238700 | \n",
"
\n",
" \n",
" 177 | \n",
" 0.609500 | \n",
"
\n",
" \n",
" 178 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 179 | \n",
" 0.489100 | \n",
"
\n",
" \n",
" 180 | \n",
" 0.519500 | \n",
"
\n",
" \n",
" 181 | \n",
" 1.168700 | \n",
"
\n",
" \n",
" 182 | \n",
" 1.765000 | \n",
"
\n",
" \n",
" 183 | \n",
" 0.428500 | \n",
"
\n",
" \n",
" 184 | \n",
" 0.675400 | \n",
"
\n",
" \n",
" 185 | \n",
" 1.213000 | \n",
"
\n",
" \n",
" 186 | \n",
" 0.514700 | \n",
"
\n",
" \n",
" 187 | \n",
" 0.789700 | \n",
"
\n",
" \n",
" 188 | \n",
" 1.128900 | \n",
"
\n",
" \n",
" 189 | \n",
" 0.805000 | \n",
"
\n",
" \n",
" 190 | \n",
" 1.411600 | \n",
"
\n",
" \n",
" 191 | \n",
" 0.417500 | \n",
"
\n",
" \n",
" 192 | \n",
" 0.406900 | \n",
"
\n",
" \n",
" 193 | \n",
" 1.131900 | \n",
"
\n",
" \n",
" 194 | \n",
" 0.701900 | \n",
"
\n",
" \n",
" 195 | \n",
" 0.441700 | \n",
"
\n",
" \n",
" 196 | \n",
" 0.855900 | \n",
"
\n",
" \n",
" 197 | \n",
" 1.314900 | \n",
"
\n",
" \n",
" 198 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 199 | \n",
" 0.471100 | \n",
"
\n",
" \n",
" 200 | \n",
" 1.746400 | \n",
"
\n",
" \n",
" 201 | \n",
" 0.863600 | \n",
"
\n",
" \n",
" 202 | \n",
" 0.662000 | \n",
"
\n",
" \n",
" 203 | \n",
" 0.558200 | \n",
"
\n",
" \n",
" 204 | \n",
" 1.235200 | \n",
"
\n",
" \n",
" 205 | \n",
" 0.697600 | \n",
"
\n",
" \n",
" 206 | \n",
" 1.531700 | \n",
"
\n",
" \n",
" 207 | \n",
" 2.750600 | \n",
"
\n",
" \n",
" 208 | \n",
" 2.015200 | \n",
"
\n",
" \n",
" 209 | \n",
" 0.366600 | \n",
"
\n",
" \n",
" 210 | \n",
" 0.991000 | \n",
"
\n",
" \n",
" 211 | \n",
" 0.455500 | \n",
"
\n",
" \n",
" 212 | \n",
" 1.305800 | \n",
"
\n",
" \n",
" 213 | \n",
" 0.089700 | \n",
"
\n",
" \n",
" 214 | \n",
" 0.514500 | \n",
"
\n",
" \n",
" 215 | \n",
" 0.298500 | \n",
"
\n",
" \n",
" 216 | \n",
" 1.117000 | \n",
"
\n",
" \n",
" 217 | \n",
" 0.613800 | \n",
"
\n",
" \n",
" 218 | \n",
" 1.231100 | \n",
"
\n",
" \n",
" 219 | \n",
" 2.003700 | \n",
"
\n",
" \n",
" 220 | \n",
" 0.390000 | \n",
"
\n",
" \n",
" 221 | \n",
" 0.508700 | \n",
"
\n",
" \n",
" 222 | \n",
" 0.949100 | \n",
"
\n",
" \n",
" 223 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 224 | \n",
" 2.724900 | \n",
"
\n",
" \n",
" 225 | \n",
" 0.997100 | \n",
"
\n",
" \n",
" 226 | \n",
" 1.501400 | \n",
"
\n",
" \n",
" 227 | \n",
" 0.938600 | \n",
"
\n",
" \n",
" 228 | \n",
" 0.620900 | \n",
"
\n",
" \n",
" 229 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 230 | \n",
" 1.187500 | \n",
"
\n",
" \n",
" 231 | \n",
" 0.299500 | \n",
"
\n",
" \n",
" 232 | \n",
" 2.285400 | \n",
"
\n",
" \n",
" 233 | \n",
" 1.262700 | \n",
"
\n",
" \n",
" 234 | \n",
" 0.273800 | \n",
"
\n",
" \n",
" 235 | \n",
" 1.746800 | \n",
"
\n",
" \n",
" 236 | \n",
" 0.538200 | \n",
"
\n",
" \n",
" 237 | \n",
" 0.767000 | \n",
"
\n",
" \n",
" 238 | \n",
" 1.133400 | \n",
"
\n",
" \n",
" 239 | \n",
" 1.064600 | \n",
"
\n",
" \n",
" 240 | \n",
" 0.591700 | \n",
"
\n",
" \n",
" 241 | \n",
" 0.981300 | \n",
"
\n",
" \n",
" 242 | \n",
" 0.894800 | \n",
"
\n",
" \n",
" 243 | \n",
" 0.532700 | \n",
"
\n",
" \n",
" 244 | \n",
" 1.896300 | \n",
"
\n",
" \n",
" 245 | \n",
" 0.157000 | \n",
"
\n",
" \n",
" 246 | \n",
" 0.331900 | \n",
"
\n",
" \n",
" 247 | \n",
" 0.458000 | \n",
"
\n",
" \n",
" 248 | \n",
" 0.309100 | \n",
"
\n",
" \n",
" 249 | \n",
" 1.084300 | \n",
"
\n",
" \n",
" 250 | \n",
" 0.640100 | \n",
"
\n",
" \n",
" 251 | \n",
" 0.937000 | \n",
"
\n",
" \n",
" 252 | \n",
" 0.799200 | \n",
"
\n",
" \n",
" 253 | \n",
" 0.879600 | \n",
"
\n",
" \n",
" 254 | \n",
" 1.224200 | \n",
"
\n",
" \n",
" 255 | \n",
" 1.198200 | \n",
"
\n",
" \n",
" 256 | \n",
" 1.216500 | \n",
"
\n",
" \n",
" 257 | \n",
" 1.556000 | \n",
"
\n",
" \n",
" 258 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 259 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 260 | \n",
" 0.417000 | \n",
"
\n",
" \n",
" 261 | \n",
" 0.547500 | \n",
"
\n",
" \n",
" 262 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 263 | \n",
" 0.483800 | \n",
"
\n",
" \n",
" 264 | \n",
" 0.590700 | \n",
"
\n",
" \n",
" 265 | \n",
" 1.072300 | \n",
"
\n",
" \n",
" 266 | \n",
" 1.545200 | \n",
"
\n",
" \n",
" 267 | \n",
" 0.627900 | \n",
"
\n",
" \n",
" 268 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 269 | \n",
" 1.594000 | \n",
"
\n",
" \n",
" 270 | \n",
" 0.303900 | \n",
"
\n",
" \n",
" 271 | \n",
" 1.394500 | \n",
"
\n",
" \n",
" 272 | \n",
" 1.145200 | \n",
"
\n",
" \n",
" 273 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 274 | \n",
" 1.038000 | \n",
"
\n",
" \n",
" 275 | \n",
" 1.045600 | \n",
"
\n",
" \n",
" 276 | \n",
" 0.793800 | \n",
"
\n",
" \n",
" 277 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 278 | \n",
" 0.959600 | \n",
"
\n",
" \n",
" 279 | \n",
" 0.893700 | \n",
"
\n",
" \n",
" 280 | \n",
" 0.618400 | \n",
"
\n",
" \n",
" 281 | \n",
" 1.273100 | \n",
"
\n",
" \n",
" 282 | \n",
" 0.741300 | \n",
"
\n",
" \n",
" 283 | \n",
" 0.342800 | \n",
"
\n",
" \n",
" 284 | \n",
" 0.579900 | \n",
"
\n",
" \n",
" 285 | \n",
" 0.095800 | \n",
"
\n",
" \n",
" 286 | \n",
" 2.692000 | \n",
"
\n",
" \n",
" 287 | \n",
" 0.909500 | \n",
"
\n",
" \n",
" 288 | \n",
" 1.106500 | \n",
"
\n",
" \n",
" 289 | \n",
" 0.725700 | \n",
"
\n",
" \n",
" 290 | \n",
" 0.092800 | \n",
"
\n",
" \n",
" 291 | \n",
" 0.445400 | \n",
"
\n",
" \n",
" 292 | \n",
" 0.638700 | \n",
"
\n",
" \n",
" 293 | \n",
" 1.117200 | \n",
"
\n",
" \n",
" 294 | \n",
" 1.338200 | \n",
"
\n",
" \n",
" 295 | \n",
" 0.570200 | \n",
"
\n",
" \n",
" 296 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 297 | \n",
" 1.295400 | \n",
"
\n",
" \n",
" 298 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 299 | \n",
" 0.464700 | \n",
"
\n",
" \n",
" 300 | \n",
" 2.450600 | \n",
"
\n",
" \n",
" 301 | \n",
" 0.426900 | \n",
"
\n",
" \n",
" 302 | \n",
" 0.830400 | \n",
"
\n",
" \n",
" 303 | \n",
" 0.663800 | \n",
"
\n",
" \n",
" 304 | \n",
" 0.590300 | \n",
"
\n",
" \n",
" 305 | \n",
" 0.301900 | \n",
"
\n",
" \n",
" 306 | \n",
" 2.835100 | \n",
"
\n",
" \n",
" 307 | \n",
" 1.676100 | \n",
"
\n",
" \n",
" 308 | \n",
" 0.993700 | \n",
"
\n",
" \n",
" 309 | \n",
" 1.127800 | \n",
"
\n",
" \n",
" 310 | \n",
" 0.846100 | \n",
"
\n",
" \n",
" 311 | \n",
" 0.636300 | \n",
"
\n",
" \n",
" 312 | \n",
" 0.589500 | \n",
"
\n",
" \n",
" 313 | \n",
" 1.250700 | \n",
"
\n",
" \n",
" 314 | \n",
" 0.781500 | \n",
"
\n",
" \n",
" 315 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 316 | \n",
" 0.659700 | \n",
"
\n",
" \n",
" 317 | \n",
" 0.514100 | \n",
"
\n",
" \n",
" 318 | \n",
" 1.730400 | \n",
"
\n",
" \n",
" 319 | \n",
" 0.320600 | \n",
"
\n",
" \n",
" 320 | \n",
" 0.274600 | \n",
"
\n",
" \n",
" 321 | \n",
" 0.923600 | \n",
"
\n",
" \n",
" 322 | \n",
" 1.327900 | \n",
"
\n",
" \n",
" 323 | \n",
" 0.779100 | \n",
"
\n",
" \n",
" 324 | \n",
" 0.700900 | \n",
"
\n",
" \n",
" 325 | \n",
" 0.865400 | \n",
"
\n",
" \n",
" 326 | \n",
" 1.052100 | \n",
"
\n",
" \n",
" 327 | \n",
" 0.574700 | \n",
"
\n",
" \n",
" 328 | \n",
" 0.338800 | \n",
"
\n",
" \n",
" 329 | \n",
" 0.188400 | \n",
"
\n",
" \n",
" 330 | \n",
" 0.181300 | \n",
"
\n",
" \n",
" 331 | \n",
" 0.901900 | \n",
"
\n",
" \n",
" 332 | \n",
" 0.618300 | \n",
"
\n",
" \n",
" 333 | \n",
" 1.031900 | \n",
"
\n",
" \n",
" 334 | \n",
" 1.050200 | \n",
"
\n",
" \n",
" 335 | \n",
" 0.754700 | \n",
"
\n",
" \n",
" 336 | \n",
" 0.852700 | \n",
"
\n",
" \n",
" 337 | \n",
" 2.501500 | \n",
"
\n",
" \n",
" 338 | \n",
" 0.495700 | \n",
"
\n",
" \n",
" 339 | \n",
" 2.482400 | \n",
"
\n",
" \n",
" 340 | \n",
" 0.407100 | \n",
"
\n",
" \n",
" 341 | \n",
" 1.133500 | \n",
"
\n",
" \n",
" 342 | \n",
" 1.867900 | \n",
"
\n",
" \n",
" 343 | \n",
" 1.067900 | \n",
"
\n",
" \n",
" 344 | \n",
" 0.716000 | \n",
"
\n",
" \n",
" 345 | \n",
" 0.185300 | \n",
"
\n",
" \n",
" 346 | \n",
" 0.342500 | \n",
"
\n",
" \n",
" 347 | \n",
" 1.124800 | \n",
"
\n",
" \n",
" 348 | \n",
" 1.322300 | \n",
"
\n",
" \n",
" 349 | \n",
" 0.510300 | \n",
"
\n",
" \n",
" 350 | \n",
" 0.533400 | \n",
"
\n",
" \n",
" 351 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 352 | \n",
" 0.821600 | \n",
"
\n",
" \n",
" 353 | \n",
" 0.289800 | \n",
"
\n",
" \n",
" 354 | \n",
" 1.929100 | \n",
"
\n",
" \n",
" 355 | \n",
" 0.078900 | \n",
"
\n",
" \n",
" 356 | \n",
" 0.333200 | \n",
"
\n",
" \n",
" 357 | \n",
" 0.721300 | \n",
"
\n",
" \n",
" 358 | \n",
" 1.022400 | \n",
"
\n",
" \n",
" 359 | \n",
" 0.097000 | \n",
"
\n",
" \n",
" 360 | \n",
" 0.522800 | \n",
"
\n",
" \n",
" 361 | \n",
" 0.190700 | \n",
"
\n",
" \n",
" 362 | \n",
" 0.382600 | \n",
"
\n",
" \n",
" 363 | \n",
" 0.905100 | \n",
"
\n",
" \n",
" 364 | \n",
" 0.482900 | \n",
"
\n",
" \n",
" 365 | \n",
" 1.020900 | \n",
"
\n",
" \n",
" 366 | \n",
" 1.468100 | \n",
"
\n",
" \n",
" 367 | \n",
" 1.571100 | \n",
"
\n",
" \n",
" 368 | \n",
" 2.384400 | \n",
"
\n",
" \n",
" 369 | \n",
" 0.539300 | \n",
"
\n",
" \n",
" 370 | \n",
" 0.187700 | \n",
"
\n",
" \n",
" 371 | \n",
" 0.098300 | \n",
"
\n",
" \n",
" 372 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 373 | \n",
" 0.232400 | \n",
"
\n",
" \n",
" 374 | \n",
" 0.696600 | \n",
"
\n",
" \n",
" 375 | \n",
" 0.360600 | \n",
"
\n",
" \n",
" 376 | \n",
" 0.121300 | \n",
"
\n",
" \n",
" 377 | \n",
" 1.139000 | \n",
"
\n",
" \n",
" 378 | \n",
" 0.851800 | \n",
"
\n",
" \n",
" 379 | \n",
" 1.320600 | \n",
"
\n",
" \n",
" 380 | \n",
" 1.886100 | \n",
"
\n",
" \n",
" 381 | \n",
" 0.105100 | \n",
"
\n",
" \n",
" 382 | \n",
" 1.109300 | \n",
"
\n",
" \n",
" 383 | \n",
" 2.284200 | \n",
"
\n",
" \n",
" 384 | \n",
" 0.817400 | \n",
"
\n",
" \n",
" 385 | \n",
" 0.154600 | \n",
"
\n",
" \n",
" 386 | \n",
" 0.115300 | \n",
"
\n",
" \n",
" 387 | \n",
" 2.456400 | \n",
"
\n",
" \n",
" 388 | \n",
" 0.565700 | \n",
"
\n",
" \n",
" 389 | \n",
" 0.345400 | \n",
"
\n",
" \n",
" 390 | \n",
" 0.935200 | \n",
"
\n",
" \n",
" 391 | \n",
" 1.356900 | \n",
"
\n",
" \n",
" 392 | \n",
" 0.776300 | \n",
"
\n",
" \n",
" 393 | \n",
" 1.266200 | \n",
"
\n",
" \n",
" 394 | \n",
" 0.145500 | \n",
"
\n",
" \n",
" 395 | \n",
" 0.566400 | \n",
"
\n",
" \n",
" 396 | \n",
" 0.083300 | \n",
"
\n",
" \n",
" 397 | \n",
" 1.588300 | \n",
"
\n",
" \n",
" 398 | \n",
" 0.107100 | \n",
"
\n",
" \n",
" 399 | \n",
" 0.658600 | \n",
"
\n",
" \n",
" 400 | \n",
" 0.313300 | \n",
"
\n",
" \n",
" 401 | \n",
" 0.505500 | \n",
"
\n",
" \n",
" 402 | \n",
" 0.634500 | \n",
"
\n",
" \n",
" 403 | \n",
" 0.785100 | \n",
"
\n",
" \n",
" 404 | \n",
" 0.244700 | \n",
"
\n",
" \n",
" 405 | \n",
" 0.811900 | \n",
"
\n",
" \n",
" 406 | \n",
" 1.668400 | \n",
"
\n",
" \n",
" 407 | \n",
" 0.600500 | \n",
"
\n",
" \n",
" 408 | \n",
" 2.731800 | \n",
"
\n",
" \n",
" 409 | \n",
" 0.893100 | \n",
"
\n",
" \n",
" 410 | \n",
" 0.973200 | \n",
"
\n",
" \n",
" 411 | \n",
" 2.966500 | \n",
"
\n",
" \n",
" 412 | \n",
" 0.673500 | \n",
"
\n",
" \n",
" 413 | \n",
" 0.639700 | \n",
"
\n",
" \n",
" 414 | \n",
" 0.234400 | \n",
"
\n",
" \n",
" 415 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 416 | \n",
" 0.285100 | \n",
"
\n",
" \n",
" 417 | \n",
" 1.262300 | \n",
"
\n",
" \n",
" 418 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 419 | \n",
" 1.122000 | \n",
"
\n",
" \n",
" 420 | \n",
" 0.414500 | \n",
"
\n",
" \n",
" 421 | \n",
" 1.577000 | \n",
"
\n",
" \n",
" 422 | \n",
" 1.034600 | \n",
"
\n",
" \n",
" 423 | \n",
" 0.704000 | \n",
"
\n",
" \n",
" 424 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 425 | \n",
" 2.343700 | \n",
"
\n",
" \n",
" 426 | \n",
" 1.488700 | \n",
"
\n",
" \n",
" 427 | \n",
" 0.875100 | \n",
"
\n",
" \n",
" 428 | \n",
" 0.331800 | \n",
"
\n",
" \n",
" 429 | \n",
" 0.412300 | \n",
"
\n",
" \n",
" 430 | \n",
" 1.658200 | \n",
"
\n",
" \n",
" 431 | \n",
" 1.794600 | \n",
"
\n",
" \n",
" 432 | \n",
" 1.353900 | \n",
"
\n",
" \n",
" 433 | \n",
" 0.273500 | \n",
"
\n",
" \n",
" 434 | \n",
" 1.162000 | \n",
"
\n",
" \n",
" 435 | \n",
" 0.561000 | \n",
"
\n",
" \n",
" 436 | \n",
" 0.929300 | \n",
"
\n",
" \n",
" 437 | \n",
" 2.154800 | \n",
"
\n",
" \n",
" 438 | \n",
" 0.343600 | \n",
"
\n",
" \n",
" 439 | \n",
" 0.094500 | \n",
"
\n",
" \n",
" 440 | \n",
" 0.377800 | \n",
"
\n",
" \n",
" 441 | \n",
" 1.108300 | \n",
"
\n",
" \n",
" 442 | \n",
" 1.419600 | \n",
"
\n",
" \n",
" 443 | \n",
" 1.209500 | \n",
"
\n",
" \n",
" 444 | \n",
" 0.643400 | \n",
"
\n",
" \n",
" 445 | \n",
" 2.111500 | \n",
"
\n",
" \n",
" 446 | \n",
" 0.670900 | \n",
"
\n",
" \n",
" 447 | \n",
" 0.807400 | \n",
"
\n",
" \n",
" 448 | \n",
" 0.720900 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=448, training_loss=1.3365338097126889, metrics={'train_runtime': 1494.2716, 'train_samples_per_second': 1.201, 'train_steps_per_second': 0.3, 'total_flos': 1.4862886764208128e+16, 'train_loss': 1.3365338097126889, 'epoch': 2.0})"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_args = transformers.TrainingArguments(\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" num_train_epochs=2,\n",
" learning_rate=2e-4,\n",
" fp16=True,\n",
" save_total_limit=3,\n",
" logging_steps=1,\n",
" output_dir=OUTPUT_DIR,\n",
" optim=\"paged_adamw_8bit\",\n",
" lr_scheduler_type=\"cosine\",\n",
" warmup_ratio=0.05,\n",
" report_to=\"tensorboard\",\n",
")\n",
"\n",
"trainer = transformers.Trainer(\n",
" model=model,\n",
" train_dataset=data,\n",
" args=training_args,\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
")\n",
"model.config.use_cache = False\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XEMYxmFFZevu"
},
"source": [
"## Save Trained Model"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "__D-YiF4i4E-"
},
"outputs": [],
"source": [
"model.save_pretrained(\"trained-model\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 240,
"referenced_widgets": [
"1765001f2e8a45789c303bf25802ec26",
"b9342727e5b04a31a44a2bb39eb9b972",
"fb9cd1d7b199467ca84fdee089c3e616",
"cc2c910395a14a1b876e76a9f1967159",
"e6870d171e2e4f49b29940a584f8234f",
"6a21b3db09d043a09364c677e380fd1a",
"b47756ecdc5544fb881978069650dff7",
"8623daa5b6104e01aa049cce796d57a5",
"6a369f47577d4a1ab89fc34e893295e3",
"ec6c022167c549949cc2d4261e6b388a",
"defa462b8d0c425b8a3c5c10d3824449",
"ed46a43aa3b74664b7be0c5393af55c7",
"ea756bf0eec44265a894ca2c007ecdd3",
"6e2ec6222dbc4dafbbe3379d89e74a33",
"05c0d40b5c01447790d0ceccebc03d0f",
"3e72185f6ba44316818d958ff112e6a5",
"535b91b499d64a19a74ad64b96ff520b",
"b20a32370ca94e4490ec2541fedf4629",
"c770c0437f13471092d8b7bd92169f27",
"19e71863a0384ec8851bde3abbb3fa68",
"d0d5e191b9ea48fc83e8d14eef28807e",
"55aa6981a6304872b253b5452dd0399f",
"7cf8792fe92346e4be71b1cc0707c40f",
"95ea93ef06164a6792df160fc5a856e1",
"2c8a23857bfb41a28a4dcbaf360c5152",
"d81c35bd6c964b8396a93ae5ce9b7035",
"5f5378a779524d05bd130689944face8",
"7a12995dd12f4aa8ac273ed26fe8c774",
"e5a51c252b4642828528866ef7a71f61",
"996b7bc4804747bc8e3d3bff8c09bc4e",
"ef942a8f706549fd9d0ff97019b23d07",
"1f5ec464af694baeb020c677c9804203",
"05b4e2dcdb774de3847a1af4328a2d1e"
]
},
"id": "SBTcZs_EODfg",
"outputId": "4727fb36-a7a7-4f27-f6c7-ce93ced846d9"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:282: UserWarning: About to update multiple times the same file in the same commit: 'adapter_model.bin'. This can cause undesired inconsistencies in your repo.\n",
" warnings.warn(\n",
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/_commit_api.py:282: UserWarning: About to update multiple times the same file in the same commit: 'adapter_config.json'. This can cause undesired inconsistencies in your repo.\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e4969cfac16f4abcb364da41e7c5843f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Upload 2 LFS files: 0%| | 0/2 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bd936b2e528e47cc965fea9166493713",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"adapter_model.bin: 0%| | 0.00/33.6M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3f8e124bf625448583b4eeb455cd4834",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"adapter_model.bin: 0%| | 0.00/33.6M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"CommitInfo(commit_url='https://huggingface.co/shanjay/mgc-ds-v2/commit/d8976fe26b00e962b17bd0315671e5c500054518', commit_message='Upload model', commit_description='', oid='d8976fe26b00e962b17bd0315671e5c500054518', pr_url=None, pr_revision=None, pr_num=None)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.push_to_hub(\n",
" \"shanjay/mgc-ds-v2\", use_auth_token=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIlYhwJhZgjb"
},
"source": [
"## Load Trained Model"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 113,
"referenced_widgets": [
"66408bee12974f96b42c7ebef03f5fb1",
"a1b2b3007e4d45039abb88e2d867f7ea",
"8ba08878f4e34cf592c5022b4be1c45f",
"041ed7f21ecd4a7fbae240a5bd0c20a9",
"3a083454009c431181869fd7bfd4ac9a",
"fcb4d59661c44d27932e0194d6b0aa1c",
"eee58e8c4ff641559649c3b82961a4c3",
"b5887b1d286b4cdc9255f55c8334cfb3",
"8cec295bad7447d1ab45aaa6c56bbef7",
"fa64654a10814ad3a11aa11af4cf36ce",
"f82c5b1d7d8f49e3ac3708b6a91f71f2",
"9c1d89fdf66243548435317b75da3595",
"ba6f3f3d687c474c9bacb2aff2f6cbfa",
"da3010bd98e94beb9defdf825db3630d",
"57a3f9fdddd0418aa96c99fd894726c3",
"e9b09df570954dd78079b1c6e65429fa",
"26482d9ce0944ce8adb7aff5d2174e06",
"42a40156d8274b3195219f6811db3d92",
"dd3861e35f264c629476755db522ea58",
"ce1470c542724a97874b595d0002c2d7",
"d4337a0b3e244c938a26bf4768a03854",
"a2802ba1c11c4fe6b63f03a931ed7a67",
"77e56c383fbf44d9ae314136db39adaf",
"551736eb29504faaa884d6d00937e0d2",
"1af09897d7e446449269106afcb079a8",
"05cf083266ab4da6af75da41d784a973",
"198f3f9771734a72b2a660ffeb178709",
"2d68eaef329c417bad449df84e664463",
"21f3873dc2914858b80617074a79e71d",
"e79cca46e0a14d1c9d4e1ab17620862e",
"d5e7ccf8574540a2b60138653d856e5d",
"ab5655e1b7c948ddb07428d5d2e7d00f",
"956740f8080e4838b8c5a67096c6894d"
]
},
"id": "owt5PmrcjG5C",
"outputId": "9b79908f-5510-476b-d086-626686305190",
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9d2a9ab1e69049948b5b521f9bfae877",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"adapter_config.json: 0%| | 0.00/427 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c5358ac3a70a4219b1acbf9d0f95c7ba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/6 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of LlamaForCausalLM were not initialized from the model checkpoint at ise-uiuc/Magicoder-S-DS-6.7B and are newly initialized: ['model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.13.self_attn.rotary_emb.inv_freq', 'model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.26.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.21.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.28.self_attn.rotary_emb.inv_freq']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "987356df5dd5403f95247c8e3622ab3f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"adapter_model.bin: 0%| | 0.00/33.6M [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"PEFT_MODEL = \"shanjay/mgc-ds-v2\"\n",
"\n",
"config = PeftConfig.from_pretrained(PEFT_MODEL)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" config.base_model_name_or_path,\n",
" return_dict=True,\n",
" quantization_config=bnb_config,\n",
" device_map=\"auto\",\n",
" trust_remote_code=True,\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"model = PeftModel.from_pretrained(model, PEFT_MODEL)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1XvuuaswZk4e"
},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"id": "LAmXdHItPgQV"
},
"outputs": [],
"source": [
"generation_config = model.generation_config\n",
"generation_config.max_new_tokens = 400\n",
"generation_config.temperature = 0.7\n",
"generation_config.top_p = 0.7\n",
"generation_config.num_return_sequences = 1\n",
"generation_config.pad_token_id = tokenizer.eos_token_id\n",
"generation_config.eos_token_id = tokenizer.eos_token_id"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"id": "qsS-RbE_UwJW"
},
"outputs": [],
"source": [
"DEVICE = \"cuda:0\""
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ivZI5MCcTCRC",
"outputId": "1015d4d6-82ed-4650-d53e-b541bb8444e7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
": How can I create a dataframe?\n",
"