Samuel Mueller commited on
Commit
e487255
0 Parent(s):
Files changed (47) hide show
  1. .gitattributes +33 -0
  2. .gitmodules +0 -0
  3. README.md +12 -0
  4. TabPFN/PrepareDatasets.ipynb +373 -0
  5. TabPFN/README.md +23 -0
  6. TabPFN/SyntheticGPAblation.ipynb +392 -0
  7. TabPFN/TabPFNPredictionOnly.ipynb +253 -0
  8. TabPFN/TabularEvaluationVisualization.ipynb +0 -0
  9. TabPFN/TrainingTuningAndPrediction.ipynb +0 -0
  10. TabPFN/datasets/__init__.py +149 -0
  11. TabPFN/datasets/utils.py +8 -0
  12. TabPFN/decoders.py +30 -0
  13. TabPFN/differentiable_pfn_evaluation.py +345 -0
  14. TabPFN/encoders.py +225 -0
  15. TabPFN/initializers.py +9 -0
  16. TabPFN/layer.py +125 -0
  17. TabPFN/losses.py +41 -0
  18. TabPFN/model_builder.py +273 -0
  19. TabPFN/models_diff/gp_ablation_model.cpkt +3 -0
  20. TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt +3 -0
  21. TabPFN/notebook_utils.py +32 -0
  22. TabPFN/positional_encodings.py +70 -0
  23. TabPFN/prior_tuning_result.pkl +3 -0
  24. TabPFN/priors/__init__.py +4 -0
  25. TabPFN/priors/differentiable_prior.py +293 -0
  26. TabPFN/priors/fast_gp.py +144 -0
  27. TabPFN/priors/flexible_categorical.py +240 -0
  28. TabPFN/priors/mlp.py +173 -0
  29. TabPFN/priors/prior.py +12 -0
  30. TabPFN/priors/prior_bag.py +32 -0
  31. TabPFN/priors/utils.py +163 -0
  32. TabPFN/requirements.txt +15 -0
  33. TabPFN/scripts/baseline_prediction_interface.py +38 -0
  34. TabPFN/scripts/differentiable_pfn_evaluation.py +391 -0
  35. TabPFN/scripts/model_configs.py +210 -0
  36. TabPFN/scripts/tabular_baselines.py +421 -0
  37. TabPFN/scripts/tabular_evaluation.py +284 -0
  38. TabPFN/scripts/tabular_metrics.py +181 -0
  39. TabPFN/scripts/transformer_prediction_interface.py +357 -0
  40. TabPFN/tabular_evaluation.py +283 -0
  41. TabPFN/train.py +386 -0
  42. TabPFN/transformer.py +226 -0
  43. TabPFN/utils.py +236 -0
  44. app.py +96 -0
  45. balance-scale.arff +694 -0
  46. iris.csv +151 -0
  47. requirements.txt +16 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
33
+ *.cpkt filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
File without changes
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TabPFN
3
+ emoji: 🐨
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.1.1
8
+ app_file: app.py
9
+ pinned: true
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
TabPFN/PrepareDatasets.ipynb ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import numpy as np\n",
10
+ "\n",
11
+ "import openml\n",
12
+ "import pandas as pd"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "from tqdm import tqdm\n",
22
+ "\n",
23
+ "from datasets import load_openml_list, test_dids_classification, valid_large_classification, open_cc_dids, open_cc_valid_dids\n"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 6,
29
+ "metadata": {},
30
+ "outputs": [
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "The autoreload extension is already loaded. To reload it, use:\n",
36
+ " %reload_ext autoreload\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "%load_ext autoreload\n",
42
+ "\n",
43
+ "%autoreload 2"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {
49
+ "tags": []
50
+ },
51
+ "source": [
52
+ "### Prepare test datasets"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 7,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "renamer = {'name': 'Name', 'NumberOfFeatures': '# Features', 'NumberOfSymbolicFeatures': '# Categorical Features', 'NumberOfInstances': '# Instances', 'NumberOfMissingValues': '# NaNs', 'NumberOfClasses': '# Classes', 'MinorityClassSize': 'Minority Class Size'}\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 8,
67
+ "metadata": {},
68
+ "outputs": [
69
+ {
70
+ "data": {
71
+ "text/plain": [
72
+ "OrderedDict([(99,\n",
73
+ " {'id': 99,\n",
74
+ " 'alias': 'OpenML-CC18',\n",
75
+ " 'main_entity_type': 'task',\n",
76
+ " 'name': 'OpenML-CC18 Curated Classification benchmark',\n",
77
+ " 'status': 'active',\n",
78
+ " 'creation_date': '2019-02-21 18:47:13',\n",
79
+ " 'creator': 1}),\n",
80
+ " (225,\n",
81
+ " {'id': 225,\n",
82
+ " 'alias': 'OpenML-friendly',\n",
83
+ " 'main_entity_type': 'task',\n",
84
+ " 'name': 'OpenML100-friendly',\n",
85
+ " 'status': 'active',\n",
86
+ " 'creation_date': '2019-09-16 19:41:46',\n",
87
+ " 'creator': 1})])"
88
+ ]
89
+ },
90
+ "execution_count": 8,
91
+ "metadata": {},
92
+ "output_type": "execute_result"
93
+ }
94
+ ],
95
+ "source": [
96
+ "openml.study.list_suites()"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 9,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "suite = openml.study.get_suite(suite_id=99)\n",
106
+ "tasks = openml.tasks.list_tasks(output_format=\"dataframe\")"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 10,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "# Using ``@`` in `pd.DataFrame.query <\n",
116
+ "# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html>`_\n",
117
+ "# accesses variables outside of the current dataframe.\n",
118
+ "tasks = tasks.query(\"tid in @suite.tasks\")"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 11,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "tids = list(tasks[np.logical_and(np.logical_and((tasks.NumberOfInstances <= 2000), (tasks.NumberOfFeatures <= 100))\n",
128
+ " , (tasks.NumberOfClasses <= 10))].tid)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 12,
134
+ "metadata": {},
135
+ "outputs": [
136
+ {
137
+ "data": {
138
+ "text/plain": [
139
+ "30"
140
+ ]
141
+ },
142
+ "execution_count": 12,
143
+ "metadata": {},
144
+ "output_type": "execute_result"
145
+ }
146
+ ],
147
+ "source": [
148
+ "len(tids)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 13,
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "tids = list(tasks[tasks.NumberOfInstances <= 2000].tid)"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 14,
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "open_cc_dids = [openml.tasks.get_task(task_id).get_dataset().id for task_id in tids]"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "outputs": [],
173
+ "source": [
174
+ "open_ml_datasets, open_ml_datasets_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 100000, num_feats=100, return_capped=True)\n"
175
+ ],
176
+ "metadata": {
177
+ "collapsed": false,
178
+ "pycharm": {
179
+ "name": "#%%\n"
180
+ }
181
+ }
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 16,
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "open_ml_datasets_df = open_ml_datasets_df[open_ml_datasets_df.NumberOfInstances > 10000]"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 17,
195
+ "metadata": {},
196
+ "outputs": [
197
+ {
198
+ "name": "stdout",
199
+ "output_type": "stream",
200
+ "text": [
201
+ "\\begin{tabular}{lrrrrrrr}\n",
202
+ "\\toprule\n",
203
+ " Name & \\# Features & \\# Categorical Features & \\# Instances & \\# Classes & \\# NaNs & Minority Class Size & id \\\\\n",
204
+ "\\midrule\n",
205
+ " KDDCup09\\_appetency & 231 & 39 & 50000 & 2 & 8024152 & 890 & 1111 \\\\\n",
206
+ " airlines & 8 & 5 & 539383 & 2 & 0 & 240264 & 1169 \\\\\n",
207
+ " bank-marketing & 17 & 10 & 45211 & 2 & 0 & 5289 & 1461 \\\\\n",
208
+ " nomao & 119 & 30 & 34465 & 2 & 0 & 9844 & 1486 \\\\\n",
209
+ " adult & 15 & 9 & 48842 & 2 & 6465 & 11687 & 1590 \\\\\n",
210
+ " covertype & 55 & 45 & 581012 & 7 & 0 & 2747 & 1596 \\\\\n",
211
+ " numerai28.6 & 22 & 1 & 96320 & 2 & 0 & 47662 & 23517 \\\\\n",
212
+ " connect-4 & 43 & 43 & 67557 & 3 & 0 & 6449 & 40668 \\\\\n",
213
+ "jungle\\_chess\\_2pcs\\_raw\\_endgame\\_complete & 7 & 1 & 44819 & 3 & 0 & 4335 & 41027 \\\\\n",
214
+ " APSFailure & 171 & 1 & 76000 & 2 & 1078695 & 1375 & 41138 \\\\\n",
215
+ " albert & 79 & 53 & 425240 & 2 & 2734000 & 212620 & 41147 \\\\\n",
216
+ " MiniBooNE & 51 & 1 & 130064 & 2 & 0 & 36499 & 41150 \\\\\n",
217
+ " guillermo & 4297 & 1 & 20000 & 2 & 0 & 8003 & 41159 \\\\\n",
218
+ " riccardo & 4297 & 1 & 20000 & 2 & 0 & 5000 & 41161 \\\\\n",
219
+ " volkert & 181 & 1 & 58310 & 10 & 0 & 1361 & 41166 \\\\\n",
220
+ " dionis & 61 & 1 & 416188 & 355 & 0 & 878 & 41167 \\\\\n",
221
+ " jannis & 55 & 1 & 83733 & 4 & 0 & 1687 & 41168 \\\\\n",
222
+ " helena & 28 & 1 & 65196 & 100 & 0 & 111 & 41169 \\\\\n",
223
+ "\\bottomrule\n",
224
+ "\\end{tabular}\n",
225
+ "\n"
226
+ ]
227
+ }
228
+ ],
229
+ "source": [
230
+ "print_table = open_ml_datasets_df\n",
231
+ "print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
232
+ "print_table['id'] = print_table.index\n",
233
+ "print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']] = print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].astype(int)\n",
234
+ "print_table = print_table.rename(columns=renamer)\n",
235
+ "print(print_table.to_latex(index=False))"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "markdown",
240
+ "metadata": {
241
+ "tags": []
242
+ },
243
+ "source": [
244
+ "### Prepare Validation datasets"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "outputs": [],
251
+ "source": [
252
+ "open_cc_datasets, open_cc_datasets_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 2000, num_feats=100, return_capped=True)\n",
253
+ "\n",
254
+ "def extend_datasets(datasets, filtering = False):\n",
255
+ " extended_datasets = {}\n",
256
+ " i = 0\n",
257
+ " for d in tqdm(datasets):\n",
258
+ " if ((not 'NumberOfFeatures' in datasets[d])\n",
259
+ " or (not 'NumberOfClasses' in datasets[d])\n",
260
+ " or (not 'NumberOfInstances' in datasets[d])\n",
261
+ " # or datasets[d]['NumberOfFeatures'] >= num_feats\n",
262
+ " or datasets[d]['NumberOfClasses'] <= 0):\n",
263
+ " print(datasets[d])\n",
264
+ " continue\n",
265
+ " ds = openml.datasets.get_dataset(d, download_data=False)\n",
266
+ " if filtering and (datasets[d]['NumberOfInstances'] < 150\n",
267
+ " or datasets[d]['NumberOfInstances'] > 2000\n",
268
+ " or datasets[d]['NumberOfFeatures'] > 100\n",
269
+ " or datasets[d]['NumberOfClasses'] > 10):\n",
270
+ " continue\n",
271
+ " extended_datasets[d] = datasets[d]\n",
272
+ " extended_datasets[d].update(ds.qualities)\n",
273
+ " \n",
274
+ " return extended_datasets\n",
275
+ "\n",
276
+ "# All datasets\n",
277
+ "openml_list = openml.datasets.list_datasets()\n",
278
+ "openml_list = pd.DataFrame.from_dict(openml_list, orient=\"index\")\n",
279
+ "\n",
280
+ "# Select only classification\n",
281
+ "openml_list = openml_list[~openml_list['MajorityClassSize'].isna()]\n",
282
+ "\n",
283
+ "# Remove duplicated datasets\n",
284
+ "duplicated = openml_list.duplicated(subset=['MajorityClassSize', 'MaxNominalAttDistinctValues', 'MinorityClassSize',\n",
285
+ " 'NumberOfClasses', 'NumberOfFeatures', 'NumberOfInstances',\n",
286
+ " 'NumberOfInstancesWithMissingValues', 'NumberOfMissingValues',\n",
287
+ " 'NumberOfNumericFeatures', 'NumberOfSymbolicFeatures'], keep='first')\n",
288
+ "openml_list = openml_list[~duplicated]\n",
289
+ "\n",
290
+ "duplicated = openml_list.duplicated(subset=['name'], keep='first')\n",
291
+ "openml_list = openml_list[~duplicated]\n",
292
+ "\n",
293
+ "# Filter out datasets that don't have meta information or Don't fulfill other criteria\n",
294
+ "openml_list = openml_list.to_dict(orient='index')\n",
295
+ "openml_list = pd.DataFrame.from_dict(extend_datasets(openml_list, filtering=True), orient=\"index\")\n",
296
+ "\n",
297
+ "# Filter out datasets in Open CC\n",
298
+ "openml_list = openml_list[~openml_list.name.apply(lambda x: x in test_datasets_multiclass_df.name.values)]\n",
299
+ "openml_list['CFI'] = openml_list.apply(lambda x: str(x.NumberOfClasses) + '_' + str(x.NumberOfFeatures) + '_' + str(x.NumberOfInstances), axis = 1)\n",
300
+ "test_datasets_multiclass_df['CFI'] = test_datasets_multiclass_df.apply(lambda x: str(x.NumberOfClasses) + '_' + str(x.NumberOfFeatures) + '_' + str(x.NumberOfInstances), axis = 1)\n",
301
+ "openml_list = openml_list[~openml_list.CFI.apply(lambda x: x in test_datasets_multiclass_df.CFI.values)]\n",
302
+ "\n",
303
+ "# Remove time series and artificial data\n",
304
+ "openml_list = openml_list[~openml_list.name.apply(lambda x: 'autoUniv' in x)]\n",
305
+ "openml_list = openml_list[~openml_list.name.apply(lambda x: 'fri_' in x)]\n",
306
+ "openml_list = openml_list[~openml_list.name.apply(lambda x: 'FOREX' in x)]\n",
307
+ "\n",
308
+ "# Remove datasets that overlapped with Open CC closely by name\n",
309
+ "openml_list = openml_list[~openml_list.name.apply(lambda x: 'ilpd' in x)]\n",
310
+ "openml_list = openml_list[~openml_list.name.apply(lambda x: 'car' in x)]\n",
311
+ "openml_list = openml_list[~openml_list.name.apply(lambda x: 'pc1' in x)]\n",
312
+ "\n",
313
+ "# Remove datasets that didn't load\n",
314
+ "openml_list = openml_list[~openml_list.did.apply(lambda x: x in {1065, 40589, 41496, 770, 43097, 43148, 43255, 43595, 43786, 41701})]\n",
315
+ "\n",
316
+ "# Remove class skew\n",
317
+ "openml_list = openml_list[(openml_list.MinorityClassSize / openml_list.MajorityClassSize) > 0.05]\n",
318
+ "openml_list = openml_list[openml_list.AutoCorrelation != 1]\n",
319
+ "\n",
320
+ "# Remove too easy\n",
321
+ "openml_list = openml_list[openml_list.CfsSubsetEval_DecisionStumpAUC != 1]"
322
+ ],
323
+ "metadata": {
324
+ "collapsed": false,
325
+ "pycharm": {
326
+ "name": "#%%\n"
327
+ }
328
+ }
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": null,
333
+ "metadata": {},
334
+ "outputs": [],
335
+ "source": [
336
+ "print_table = openml_list\n",
337
+ "print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
338
+ "print_table['id'] = print_table.index\n",
339
+ "print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']] = print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].astype(int)\n",
340
+ "print_table = print_table.rename(columns=renamer)\n",
341
+ "print(print_table.to_latex(index=False))"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": []
350
+ }
351
+ ],
352
+ "metadata": {
353
+ "kernelspec": {
354
+ "display_name": "Python 3 (ipykernel)",
355
+ "language": "python",
356
+ "name": "python3"
357
+ },
358
+ "language_info": {
359
+ "codemirror_mode": {
360
+ "name": "ipython",
361
+ "version": 3
362
+ },
363
+ "file_extension": ".py",
364
+ "mimetype": "text/x-python",
365
+ "name": "python",
366
+ "nbconvert_exporter": "python",
367
+ "pygments_lexer": "ipython3",
368
+ "version": "3.7.13"
369
+ }
370
+ },
371
+ "nbformat": 4,
372
+ "nbformat_minor": 4
373
+ }
TabPFN/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TabPFN
2
+
3
+ ## Installation
4
+ ```
5
+ git clone [email protected]:automl/TabPFN.git
6
+ cd TabPFN
7
+ conda create -n TabPFN python=3.7
8
+ conda activate TabPFN
9
+ pip install -r requirements.txt
10
+ ```
11
+
12
+ To run the autogluon baseline please create a separate environment and install autogluon==0.4.0, installation in the same environment as our other baselines is not possible.
13
+
14
+ ## Usage
15
+ TrainingTuningAndPrediction: Train a TabPFN, Prior Tune and predict using a pretrained model.
16
+
17
+ TabularEvaluationVisualization: Run Baselines and load Baseline and TabPFN Results for comparison and plotting.
18
+
19
+ PrepareDatasets: Notebook used to inspect Datasets (Not needed to run baselines / TabPFN).
20
+
21
+ SytheticGPAblation: Ablation experiments for Gaussian Process fitting with differentiable Hyper Parameters.
22
+
23
+
TabPFN/SyntheticGPAblation.ipynb ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "\n",
11
+ "%autoreload 2"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import os\n",
21
+ "import time\n",
22
+ "\n",
23
+ "import torch\n",
24
+ "\n",
25
+ "import numpy as np\n",
26
+ "\n",
27
+ "import matplotlib.pyplot as plt\n",
28
+ "\n",
29
+ "from model_builder import get_model, get_default_spec, save_model, load_model\n",
30
+ "\n",
31
+ "from scripts.model_configs import *"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "tags": []
38
+ },
39
+ "source": [
40
+ "# Setting params"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 6,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "device = 'cuda'\n",
50
+ "base_path = os.path.join('.')"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 7,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "def train_function(config_sample, i, add_name=''):\n",
60
+ " start_time = time.time()\n",
61
+ " N_epochs_to_save = 50\n",
62
+ " \n",
63
+ " def save_callback(model, epoch):\n",
64
+ " if not hasattr(model, 'last_saved_epoch'):\n",
65
+ " model.last_saved_epoch = 0\n",
66
+ " if ((time.time() - start_time) / (maximum_runtime * 60 / N_epochs_to_save)) > model.last_saved_epoch:\n",
67
+ " print('Saving model..')\n",
68
+ " config_sample['epoch_in_training'] = epoch\n",
69
+ " save_model(model, base_path, f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{model.last_saved_epoch}.cpkt',\n",
70
+ " config_sample)\n",
71
+ " model.last_saved_epoch = model.last_saved_epoch + 1 # TODO: Rename to checkpoint\n",
72
+ " \n",
73
+ " model = get_model(config_sample\n",
74
+ " , device\n",
75
+ " , should_train=True\n",
76
+ " , verbose=1\n",
77
+ " , epoch_callback = save_callback)\n",
78
+ " \n",
79
+ " return"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {
85
+ "heading_collapsed": true,
86
+ "tags": []
87
+ },
88
+ "source": [
89
+ "# Check synthetic data fitting"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {
95
+ "tags": []
96
+ },
97
+ "source": [
98
+ "#### Workflow functions"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 8,
104
+ "metadata": {
105
+ "hidden": true,
106
+ "tags": []
107
+ },
108
+ "outputs": [],
109
+ "source": [
110
+ "def generate_test_data(test_gp_params):\n",
111
+ " # Generate test data\n",
112
+ " config = {**test_gp_params}\n",
113
+ "\n",
114
+ " config['verbose'] = False\n",
115
+ " config['differentiable'] = False\n",
116
+ " #config['bptt'] = config['bptt_in_training']\n",
117
+ "\n",
118
+ " model_test_data = get_model(config, device, should_train=False, verbose=True)\n",
119
+ " (hp_embedding, data, targets_), targets = next(iter(model_test_data[3]))\n",
120
+ " (hp_embedding, data, targets_), targets = (hp_embedding, data.to(device), targets_.to(device)), targets.to(device)\n",
121
+ " \n",
122
+ " return (hp_embedding, data, targets_), targets\n",
123
+ "\n",
124
+ "def evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size):\n",
125
+ " losses, hparams = [], []\n",
126
+ " for l in np.arange(-1.74, 1.74, plot_step_size):\n",
127
+ " hparam = [*hparam_true]\n",
128
+ " hparam[vary_hparam_ind] = l\n",
129
+ " hp_embedding_used = torch.tensor(hparam).to(device).float()\n",
130
+ " with torch.inference_mode():\n",
131
+ " outputs = torch.sigmoid(model[2]((hp_embedding_used.repeat(data.shape[1], 1), data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
132
+ " \n",
133
+ " loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten()).detach().cpu()\n",
134
+ " losses += [loss]\n",
135
+ " hparam_real = [diff_hparams_f[i][1](hp) for i, hp in enumerate(hparam)]\n",
136
+ " hparams += [hparam_real]\n",
137
+ " \n",
138
+ " print(loss, hparam_real, hparam, outputs.shape)\n",
139
+ " return np.array(losses), np.array(hparams)"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 9,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "def differentiable_hparam_tuning_workflow(config_sample, hparam_label, batch_size=4, N_grad_steps=50, plot_step_size=0.1):\n",
149
+ " test_gp_params = {\n",
150
+ " \"lengthscale\": 1.0,\n",
151
+ " #\"lengthscale_mean\": true_lengthscale,\n",
152
+ " #\"lengthscale_std\": 0.5,\n",
153
+ " \"noise\": 0.2,\n",
154
+ " \"outputscale\": 1.0,\n",
155
+ " 'batch_size': batch_size\n",
156
+ " }\n",
157
+ " config_sample.update(test_gp_params)\n",
158
+ " (hp_embedding, data, targets_), targets = generate_test_data(config_sample)\n",
159
+ " hparam_true = [diff_hparams_f[i][0](test_gp_params[hp]) for i, hp in enumerate(diff_hparams_keys)]\n",
160
+ " #hparam_true = [test_gp_params[hp] for i, hp in enumerate(diff_hparams_keys)]\n",
161
+ "\n",
162
+ " for vary_hparam_ind, vary_hparam_name in hparam_label:\n",
163
+ " print(vary_hparam_name)\n",
164
+ "\n",
165
+ " losses, hparams = evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size=plot_step_size)\n",
166
+ "\n",
167
+ " # TODO: Make only one parameter diffable\n",
168
+ " hparam = torch.tensor([*hparam_true]).to(device).float()\n",
169
+ " hparam[vary_hparam_ind] = hparam[vary_hparam_ind] + 0.1 #random.random() * 2 - 1\n",
170
+ " hparam = torch.nn.Parameter(hparam, requires_grad=True)\n",
171
+ " hparam_grad_mask = torch.zeros_like(hparam)\n",
172
+ " hparam_grad_mask[vary_hparam_ind] = 1\n",
173
+ "\n",
174
+ " optimizer = torch.optim.Adam([hparam], lr=0.1)\n",
175
+ " \n",
176
+ " for t in range(N_grad_steps):\n",
177
+ " style = hparam.repeat(data.shape[1], 1)\n",
178
+ " outputs = torch.sigmoid(model[2]((style, data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
179
+ " loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten())\n",
180
+ " optimizer.zero_grad()\n",
181
+ " loss.backward()\n",
182
+ " with torch.no_grad():\n",
183
+ " hparam.grad *= hparam_grad_mask\n",
184
+ " optimizer.step()\n",
185
+ " print('loss:', loss, 'hparams', diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind]), 'true', diff_hparams_f[vary_hparam_ind][1](hparam_true[vary_hparam_ind]))\n",
186
+ " inferred_param = diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind].cpu().detach().numpy())\n",
187
+ " return hparams, losses, inferred_param, vary_hparam_ind, hparam_true\n",
188
+ " "
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "metadata": {
194
+ "tags": []
195
+ },
196
+ "source": [
197
+ "#### Fitting a PFN with HP-Diffable GP Prior"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 10,
203
+ "metadata": {
204
+ "hidden": true,
205
+ "tags": []
206
+ },
207
+ "outputs": [],
208
+ "source": [
209
+ "num_features = 5\n",
210
+ "bptt = 200\n",
211
+ "eval_positions = [100]\n",
212
+ "\n",
213
+ "config_general = get_general_config(num_features, bptt, eval_positions)\n",
214
+ "config_flexible_categorical = get_flexible_categorical_config(num_features)\n",
215
+ "\n",
216
+ "config_gp = {'noise': 0.2, \"lengthscale\": 1.0, \"outputscale\": 1.0}\n",
217
+ "config_diff_gp = {'differentiable_hyperparameters': {\n",
218
+ " 'outputscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
219
+ " 'lengthscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
220
+ " 'noise': {'distribution': 'uniform', 'min': 0.0000001, 'max': 0.5},\n",
221
+ " }\n",
222
+ "}\n",
223
+ "\n",
224
+ "config = {**config_general, **config_flexible_categorical, **config_diff_gp, **config_gp}\n",
225
+ "\n",
226
+ "config['prior_type'], config['differentiable'], config['flexible'] = 'gp', True, True\n",
227
+ "config['num_features'], config['num_features_used'] = num_features, num_features\n",
228
+ "config['epochs'], config['num_steps'], config['verbose'] = 500, 100, False\n",
229
+ "config[\"lr\"] = 0.00001\n",
230
+ "config[\"dropout\"] = 0\n",
231
+ "config[\"emsize\"] = 512\n",
232
+ "config[\"batch_size\"] = 128\n",
233
+ "config[\"aggregate_k_gradients\"] = 1\n",
234
+ "config['set_value_to_nan'] = 0.0\n",
235
+ "config['output_multiclass_ordered_p'] = 1.0\n",
236
+ "config['categorical_feature_p'] = 0.0\n",
237
+ "config['nan_prob_a_reason'] = 0.0\n",
238
+ "config['nan_prob_no_reason'] = 0.0\n",
239
+ "config['nan_prob_unknown_reason'] = 0.0\n",
240
+ "config[\"nlayers\"] = 8\n",
241
+ "\n",
242
+ "# TODO: This should not be sampled, but be one config\n",
243
+ "# TODO: This uses old hyperparam sampler throws error\n",
244
+ "config_sample = evaluate_hypers(config)"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": 11,
250
+ "metadata": {
251
+ "hidden": true,
252
+ "tags": []
253
+ },
254
+ "outputs": [
255
+ {
256
+ "name": "stdout",
257
+ "output_type": "stream",
258
+ "text": [
259
+ "Using style prior: True\n",
260
+ "Using cpu:0 device\n",
261
+ "Not using distributed\n",
262
+ "DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 128, 'seq_len': 200, 'seq_len_maximum': 200, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 128, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 200, 'eval_positions': None, 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': 5, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.2, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'output_multiclass_ordered_p': 1.0, 'recompute_attn': False}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad8dcf80>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
263
+ "Using a Transformer with 17.35 M parameters\n"
264
+ ]
265
+ }
266
+ ],
267
+ "source": [
268
+ "device = 'cuda'\n",
269
+ "train_function(config_sample, 0, add_name='gp_experiments_diff_with_noise_no_meta_new')"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "metadata": {
275
+ "tags": []
276
+ },
277
+ "source": [
278
+ "#### Evaluating a PFN (with pretrained model)"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 13,
284
+ "metadata": {
285
+ "hidden": true,
286
+ "tags": []
287
+ },
288
+ "outputs": [
289
+ {
290
+ "name": "stdout",
291
+ "output_type": "stream",
292
+ "text": [
293
+ "Using style prior: True\n",
294
+ "Using cpu:0 device\n",
295
+ "Not using distributed\n",
296
+ "DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 1, 'seq_len': 10, 'seq_len_maximum': 10, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 1, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 10, 'eval_positions': [190], 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'output_multiclass_ordered_p': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'multiclass_type': 'rank', 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': <function load_model.<locals>.<lambda> at 0x7f39ad8534d0>, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.03, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'recompute_attn': False, 'bptt_extra_samples': None, 'epoch_in_training': 0.998, 'categorical_features_sampler': <function load_model.<locals>.<lambda> at 0x7f39ad853680>, 'num_features_used_in_training': 5, 'num_classes_in_training': 2, 'batch_size_in_training': 128, 'bptt_in_training': 200, 'bptt_extra_samples_in_training': None}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad81ab90>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
297
+ "Using a Transformer with 17.35 M parameters\n"
298
+ ]
299
+ }
300
+ ],
301
+ "source": [
302
+ "device = 'cpu'\n",
303
+ "model, c = load_model(base_path, f'models_diff/gp_ablation_model.cpkt', device, eval_positions, verbose=False)"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 14,
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "from priors.differentiable_prior import DifferentiableHyperparameterList\n",
313
+ "diff_list = DifferentiableHyperparameterList(c['differentiable_hyperparameters'], 512, device)\n",
314
+ "diff_hparams_keys, diff_hparams_f = diff_list.get_hyperparameter_info()"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {
321
+ "tags": []
322
+ },
323
+ "outputs": [],
324
+ "source": [
325
+ "model[2].eval()\n",
326
+ "eval_pos = 100\n",
327
+ "\n",
328
+ "hparam_label = [(1, 'outputscale')]\n",
329
+ "hparam_label = [(0, 'lengthscale')]\n",
330
+ "hparam_label = [(2, 'noise')]\n",
331
+ "hparam_labels = [[(1, 'outputscale')], [(2, 'noise')], [(0, 'lengthscale')]]\n",
332
+ "#hparam_labels = [[(2, 'noise')]]\n",
333
+ "\n",
334
+ "hparams, losses, inferred_param, vary_hparam_ind, hparam_true = {}, {}, {}, {}, {}\n",
335
+ "\n",
336
+ "for hparam_label in hparam_labels:\n",
337
+ " (hparams[hparam_label[0][1]], losses[hparam_label[0][1]], inferred_param[hparam_label[0][1]], vary_hparam_ind[hparam_label[0][1]], \n",
338
+ " hparam_true[hparam_label[0][1]]) = differentiable_hparam_tuning_workflow(config_sample, \n",
339
+ " hparam_label=hparam_label, \n",
340
+ " batch_size=256, \n",
341
+ " N_grad_steps=50,\n",
342
+ " plot_step_size=0.05)\n"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": null,
348
+ "metadata": {},
349
+ "outputs": [],
350
+ "source": [
351
+ "label = 'lengthscale'\n",
352
+ "\n",
353
+ "#import tikzplotlib\n",
354
+ "\n",
355
+ "inferred = losses[label]\n",
356
+ "\n",
357
+ "plt.plot(hparams[label][:, vary_hparam_ind[label]], losses[label])\n",
358
+ "true = diff_hparams_f[vary_hparam_ind[label]][1](hparam_true[label][vary_hparam_ind[label]])\n",
359
+ "plt.axvline(x=inferred_param[label], linestyle='solid', color='red')\n",
360
+ "plt.axvline(x=true, linestyle='dashed')\n",
361
+ "\n",
362
+ "plt.ylabel('Cross entropy Loss')\n",
363
+ "plt.xlabel(label)\n",
364
+ "\n",
365
+ "#tikzplotlib.save(f'diff_inferred_params_{label}.tex', axis_height='5.2cm', axis_width='5.2cm', strict=True)\n",
366
+ "\n",
367
+ "plt.show()"
368
+ ]
369
+ }
370
+ ],
371
+ "metadata": {
372
+ "kernelspec": {
373
+ "display_name": "Python 3 (ipykernel)",
374
+ "language": "python",
375
+ "name": "python3"
376
+ },
377
+ "language_info": {
378
+ "codemirror_mode": {
379
+ "name": "ipython",
380
+ "version": 3
381
+ },
382
+ "file_extension": ".py",
383
+ "mimetype": "text/x-python",
384
+ "name": "python",
385
+ "nbconvert_exporter": "python",
386
+ "pygments_lexer": "ipython3",
387
+ "version": "3.7.13"
388
+ }
389
+ },
390
+ "nbformat": 4,
391
+ "nbformat_minor": 4
392
+ }
TabPFN/TabPFNPredictionOnly.ipynb ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "This notebook shows how to use TabPFN for tabular prediction with a scikit learn wrapper.\n",
8
+ "\n",
9
+ "classifier = TabPFNClassifier(device='cpu')\n",
10
+ "classifier.fit(train_xs, train_ys)\n",
11
+ "prediction_ = classifier.predict(test_xs)\n",
12
+ "\n",
13
+ "The fit function does not perform any computations, but only saves the training data. Computations are only done at inference time, when calling predict.\n",
14
+ "Note that the presaved models were trained for up to 100 features, 10 classes and 1000 samples. While the model does not have a hard bound on the number of samples, the features and classes are restricted and larger sizes lead to an error."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {
20
+ "tags": []
21
+ },
22
+ "source": [
23
+ "### Setup"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "%load_ext autoreload\n",
33
+ "\n",
34
+ "%autoreload 2"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "import time\n",
44
+ "import torch\n",
45
+ "import numpy as np\n",
46
+ "import os\n",
47
+ "import random\n",
48
+ "\n",
49
+ "from model_builder import get_model, get_default_spec, save_model, load_model\n",
50
+ "from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier\n",
51
+ "\n",
52
+ "from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids\n",
53
+ "\n",
54
+ "from scripts import tabular_metrics"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "base_path = '.'"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "metadata": {
69
+ "tags": []
70
+ },
71
+ "source": [
72
+ "### Load datasets"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {
79
+ "jupyter": {
80
+ "outputs_hidden": true
81
+ },
82
+ "tags": []
83
+ },
84
+ "outputs": [],
85
+ "source": [
86
+ "max_samples = 10000\n",
87
+ "bptt = 10000\n",
88
+ "\n",
89
+ "cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n",
90
+ "cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n",
91
+ "\n",
92
+ "# Loading longer OpenML Datasets for generalization experiments (optional)\n",
93
+ "# test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)\n",
94
+ "\n",
95
+ "random.seed(0)\n",
96
+ "random.shuffle(cc_valid_datasets_multiclass)"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "from datasets import get_openml_classification"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "dataset = openml.datasets.get_dataset(31)\n",
115
+ "X, y, categorical_indicator, attribute_names = dataset.get_data(\n",
116
+ " dataset_format=\"array\", target=dataset.default_target_attribute\n",
117
+ " )"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "def get_datasets(selector, task_type, suite='cc'):\n",
127
+ " if task_type == 'binary':\n",
128
+ " ds = valid_datasets_binary if selector == 'valid' else test_datasets_binary\n",
129
+ " else:\n",
130
+ " if suite == 'openml':\n",
131
+ " ds = valid_datasets_multiclass if selector == 'valid' else test_datasets_multiclass\n",
132
+ " elif suite == 'cc':\n",
133
+ " ds = cc_valid_datasets_multiclass if selector == 'valid' else cc_test_datasets_multiclass\n",
134
+ " else:\n",
135
+ " raise Exception(\"Unknown suite\")\n",
136
+ " return ds"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "model_string, longer, task_type = '', 1, 'multiclass'\n",
146
+ "eval_positions = [1000]\n",
147
+ "bptt = 2000\n",
148
+ " \n",
149
+ "test_datasets, valid_datasets = get_datasets('test', task_type, suite='cc'), get_datasets('valid', task_type, suite='cc')"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {
155
+ "jp-MarkdownHeadingCollapsed": true,
156
+ "tags": []
157
+ },
158
+ "source": [
159
+ "### Select a dataset for prediction"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "[(i, test_datasets[i][0]) for i in range(len(test_datasets))]"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "evaluation_dataset_index = 4 # Index of the dataset to predict\n",
178
+ "ds = test_datasets[evaluation_dataset_index]\n",
179
+ "print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "xs, ys = ds[1].clone(), ds[2].clone()\n",
189
+ "eval_position = xs.shape[0] // 2\n",
190
+ "train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]\n",
191
+ "test_xs, test_ys = xs[eval_position:], ys[eval_position:]"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "markdown",
196
+ "metadata": {
197
+ "tags": []
198
+ },
199
+ "source": [
200
+ "### Predict using a Fitted and Tuned Model"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": [
209
+ "classifier = TabPFNClassifier(device='cpu')\n",
210
+ "classifier.fit(train_xs, train_ys)\n",
211
+ "prediction_ = classifier.predict_proba(test_xs)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)\n",
221
+ "'AUC', float(roc), 'Cross Entropy', float(ce)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": []
230
+ }
231
+ ],
232
+ "metadata": {
233
+ "kernelspec": {
234
+ "display_name": "Python 3 (ipykernel)",
235
+ "language": "python",
236
+ "name": "python3"
237
+ },
238
+ "language_info": {
239
+ "codemirror_mode": {
240
+ "name": "ipython",
241
+ "version": 3
242
+ },
243
+ "file_extension": ".py",
244
+ "mimetype": "text/x-python",
245
+ "name": "python",
246
+ "nbconvert_exporter": "python",
247
+ "pygments_lexer": "ipython3",
248
+ "version": "3.7.13"
249
+ }
250
+ },
251
+ "nbformat": 4,
252
+ "nbformat_minor": 4
253
+ }
TabPFN/TabularEvaluationVisualization.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
TabPFN/TrainingTuningAndPrediction.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
TabPFN/datasets/__init__.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import numpy as np
4
+ import openml
5
+
6
+
7
+ def get_openml_classification(did, max_samples, multiclass=True, shuffled=True):
8
+ dataset = openml.datasets.get_dataset(did)
9
+ X, y, categorical_indicator, attribute_names = dataset.get_data(
10
+ dataset_format="array", target=dataset.default_target_attribute
11
+ )
12
+
13
+ if not multiclass:
14
+ X = X[y < 2]
15
+ y = y[y < 2]
16
+
17
+ if multiclass and not shuffled:
18
+ raise NotImplementedError("This combination of multiclass and shuffling isn't implemented")
19
+
20
+ if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
21
+ print('Not a NP Array, skipping')
22
+ return None, None, None, None
23
+
24
+ if not shuffled:
25
+ sort = np.argsort(y) if y.mean() < 0.5 else np.argsort(-y)
26
+ pos = int(y.sum()) if y.mean() < 0.5 else int((1 - y).sum())
27
+ X, y = X[sort][-pos * 2:], y[sort][-pos * 2:]
28
+ y = torch.tensor(y).reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).float()
29
+ X = torch.tensor(X).reshape(2, -1, X.shape[1]).transpose(0, 1).reshape(-1, X.shape[1]).flip([0]).float()
30
+ else:
31
+ order = np.arange(y.shape[0])
32
+ np.random.seed(13)
33
+ np.random.shuffle(order)
34
+ X, y = torch.tensor(X[order]), torch.tensor(y[order])
35
+ if max_samples:
36
+ X, y = X[:max_samples], y[:max_samples]
37
+
38
+ return X, y, list(np.where(categorical_indicator)[0]), attribute_names
39
+
40
+ def load_openml_list(dids, filter_for_nan=False
41
+ , num_feats=100
42
+ , min_samples = 100
43
+ , max_samples=400
44
+ , multiclass=True
45
+ , max_num_classes=10
46
+ , shuffled=True
47
+ , return_capped = False):
48
+ datasets = []
49
+ openml_list = openml.datasets.list_datasets(dids)
50
+ print(f'Number of datasets: {len(openml_list)}')
51
+
52
+ datalist = pd.DataFrame.from_dict(openml_list, orient="index")
53
+ if filter_for_nan:
54
+ datalist = datalist[datalist['NumberOfInstancesWithMissingValues'] == 0]
55
+ print(f'Number of datasets after Nan and feature number filtering: {len(datalist)}')
56
+
57
+ for ds in datalist.index:
58
+ modifications = {'samples_capped': False, 'classes_capped': False, 'feats_capped': False}
59
+ entry = datalist.loc[ds]
60
+
61
+ print('Loading', entry['name'], entry.did, '..')
62
+
63
+ if entry['NumberOfClasses'] == 0.0:
64
+ raise Exception("Regression not supported")
65
+ #X, y, categorical_feats, attribute_names = get_openml_regression(int(entry.did), max_samples)
66
+ else:
67
+ X, y, categorical_feats, attribute_names = get_openml_classification(int(entry.did), max_samples
68
+ , multiclass=multiclass, shuffled=shuffled)
69
+ if X is None:
70
+ continue
71
+
72
+ if X.shape[1] > num_feats:
73
+ if return_capped:
74
+ X = X[:, 0:num_feats]
75
+ categorical_feats = [c for c in categorical_feats if c < num_feats]
76
+ modifications['feats_capped'] = True
77
+ else:
78
+ print('Too many features')
79
+ continue
80
+ if X.shape[0] == max_samples:
81
+ modifications['samples_capped'] = True
82
+
83
+ if X.shape[0] < min_samples:
84
+ print(f'Too few samples left')
85
+ continue
86
+
87
+ if len(np.unique(y)) > max_num_classes:
88
+ if return_capped:
89
+ X = X[y < np.unique(y)[10]]
90
+ y = y[y < np.unique(y)[10]]
91
+ modifications['classes_capped'] = True
92
+ else:
93
+ print(f'Too many classes')
94
+ continue
95
+
96
+ datasets += [[entry['name'], X, y, categorical_feats, attribute_names, modifications]]
97
+
98
+ return datasets, datalist
99
+
100
+
101
+ # Classification
102
+ valid_dids_classification = [13, 59, 4, 15, 40710, 43, 1498]
103
+ test_dids_classification = [973, 1596, 40981, 1468, 40984, 40975, 41163, 41147, 1111, 41164, 1169, 1486, 41143, 1461, 41167, 40668, 41146, 41169, 41027, 23517, 41165, 41161, 41159, 41138, 1590, 41166, 1464, 41168, 41150, 1489, 41142, 3, 12, 31, 54, 1067]
104
+ valid_large_classification = [ 943, 23512, 49, 838, 1131, 767, 1142, 748, 1112,
105
+ 1541, 384, 912, 1503, 796, 20, 30, 903, 4541,
106
+ 961, 805, 1000, 4135, 1442, 816, 1130, 906, 1511,
107
+ 184, 181, 137, 1452, 1481, 949, 449, 50, 913,
108
+ 1071, 831, 843, 9, 896, 1532, 311, 39, 451,
109
+ 463, 382, 778, 474, 737, 1162, 1538, 820, 188,
110
+ 452, 1156, 37, 957, 911, 1508, 1054, 745, 1220,
111
+ 763, 900, 25, 387, 38, 757, 1507, 396, 4153,
112
+ 806, 779, 746, 1037, 871, 717, 1480, 1010, 1016,
113
+ 981, 1547, 1002, 1126, 1459, 846, 837, 1042, 273,
114
+ 1524, 375, 1018, 1531, 1458, 6332, 1546, 1129, 679,
115
+ 389]
116
+
117
+ open_cc_dids = [11,
118
+ 14,
119
+ 15,
120
+ 16,
121
+ 18,
122
+ 22,
123
+ 23,
124
+ 29,
125
+ 31,
126
+ 37,
127
+ 50,
128
+ 54,
129
+ 188,
130
+ 458,
131
+ 469,
132
+ 1049,
133
+ 1050,
134
+ 1063,
135
+ 1068,
136
+ 1510,
137
+ 1494,
138
+ 1480,
139
+ 1462,
140
+ 1464,
141
+ 6332,
142
+ 23381,
143
+ 40966,
144
+ 40982,
145
+ 40994,
146
+ 40975]
147
+ # Filtered by N_samples < 2000, N feats < 100, N classes < 10
148
+
149
+ open_cc_valid_dids = [13,25,35,40,41,43,48,49,51,53,55,56,59,61,187,285,329,333,334,335,336,337,338,377,446,450,451,452,460,463,464,466,470,475,481,679,694,717,721,724,733,738,745,747,748,750,753,756,757,764,765,767,774,778,786,788,795,796,798,801,802,810,811,814,820,825,826,827,831,839,840,841,844,852,853,854,860,880,886,895,900,906,907,908,909,915,925,930,931,934,939,940,941,949,966,968,984,987,996,1048,1054,1071,1073,1100,1115,1412,1442,1443,1444,1446,1447,1448,1451,1453,1488,1490,1495,1498,1499,1506,1508,1511,1512,1520,1523,4153,23499,40496,40646,40663,40669,40680,40682,40686,40690,40693,40705,40706,40710,40711,40981,41430,41538,41919,41976,42172,42261,42544,42585,42638]
TabPFN/datasets/utils.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ def normalize_data(eval_xs):
2
+ mean = eval_xs.mean(0)
3
+ std = eval_xs.std(0) + .000001
4
+ eval_xs = (eval_xs - mean) / std
5
+
6
+ return eval_xs
7
+
8
+
TabPFN/decoders.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import random
4
+
5
+
6
+ class ScaledDecoder(nn.Module):
7
+ def __init__(self, ninp, nhid, nout):
8
+ super().__init__()
9
+ self.linear = nn.Linear(ninp, nhid)
10
+ self.linear1 = nn.Linear(nhid, nout)
11
+ self.linear2 = nn.Linear(nhid, 10)
12
+
13
+ def forward(self, x):
14
+ #return torch.cat([self.linear1(x), self.linear2(x)], -1)
15
+ x = self.linear(x)
16
+ x = nn.GELU()(x)
17
+ temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device)
18
+ if random.random() > .99:
19
+ print(temps.shape,temps[:,:2])
20
+ return self.linear1(x) / temps.unsqueeze(-1)
21
+
22
+ class FixedScaledDecoder(nn.Module):
23
+ def __init__(self, ninp, nhid, nout):
24
+ super().__init__()
25
+ self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout))
26
+ self.T = nn.Parameter(torch.ones(10000)/10000)
27
+
28
+ def forward(self, x):
29
+ return self.mapper(x)/self.T.sum()
30
+
TabPFN/differentiable_pfn_evaluation.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ import pickle
6
+ from scripts import tabular_metrics
7
+ from scripts.tabular_metrics import calculate_score_per_method
8
+ from scripts.tabular_evaluation import evaluate
9
+ from priors.differentiable_prior import draw_random_style
10
+ from tqdm import tqdm
11
+ import random
12
+ from scripts.transformer_prediction_interface import get_params_from_config, load_model_workflow
13
+
14
+ """
15
+ ===============================
16
+ PUBLIC FUNCTIONS FOR EVALUATION
17
+ ===============================
18
+ """
19
+
20
+
21
+ def eval_model_range(i_range, *args, **kwargs):
22
+ for i in i_range:
23
+ eval_model(i, *args, **kwargs)
24
+
25
+
26
+
27
+ def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
28
+ bptt_valid,
29
+ bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
30
+ """
31
+ Differentiable model evaliation workflow. Evaluates and saves results to disk.
32
+
33
+ :param i:
34
+ :param e:
35
+ :param valid_datasets:
36
+ :param test_datasets:
37
+ :param train_datasets:
38
+ :param eval_positions_valid:
39
+ :param eval_positions_test:
40
+ :param bptt_valid:
41
+ :param bptt_test:
42
+ :param add_name:
43
+ :param base_path:
44
+ :param device:
45
+ :param eval_addition:
46
+ :param extra_tuning_args:
47
+ :return:
48
+ """
49
+ model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition)
50
+ params = {'bptt': bptt_valid
51
+ , 'bptt_final': bptt_test
52
+ , 'eval_positions': eval_positions_valid
53
+ , 'eval_positions_test': eval_positions_test
54
+ , 'valid_datasets': valid_datasets
55
+ , 'test_datasets': test_datasets
56
+ , 'train_datasets': train_datasets
57
+ , 'verbose': True
58
+ , 'device': device
59
+ }
60
+
61
+ params.update(get_params_from_config(c))
62
+
63
+ start = time.time()
64
+ metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params,
65
+ **extra_tuning_args)
66
+ print('Evaluation time: ', time.time() - start)
67
+
68
+ print(results_file)
69
+ r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
70
+ with open(results_file, 'wb') as output:
71
+ del r[0]['num_features_used']
72
+ del r[0]['categorical_features_sampler']
73
+ pickle.dump(r, output)
74
+
75
+ _, _, _, style, temperature, _ = r
76
+
77
+ return r, model
78
+
79
+ """
80
+ ===============================
81
+ INTERNAL HELPER FUNCTIONS
82
+ ===============================
83
+ """
84
+
85
+ def evaluate_differentiable_model(model
86
+ , valid_datasets
87
+ , test_datasets
88
+ , train_datasets
89
+ , N_draws=100
90
+ , N_grad_steps=10
91
+ , eval_positions=None
92
+ , eval_positions_test=None
93
+ , bptt=100
94
+ , bptt_final=200
95
+ , style=None
96
+ , n_parallel_configurations=1
97
+ , device='cpu'
98
+ , selection_metric='auc'
99
+ , final_splits=[1, 2, 3, 4, 5]
100
+ , N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
101
+ , **kwargs):
102
+ """
103
+ Evaluation function for diffable model evaluation. Returns a list of results.
104
+
105
+ :param model:
106
+ :param valid_datasets:
107
+ :param test_datasets:
108
+ :param train_datasets:
109
+ :param N_draws:
110
+ :param N_grad_steps:
111
+ :param eval_positions:
112
+ :param eval_positions_test:
113
+ :param bptt:
114
+ :param bptt_final:
115
+ :param style:
116
+ :param n_parallel_configurations:
117
+ :param device:
118
+ :param selection_metric:
119
+ :param final_splits:
120
+ :param N_ensemble_configurations_list:
121
+ :param kwargs:
122
+ :return:
123
+ """
124
+ torch.manual_seed(0)
125
+ np.random.seed(0)
126
+ random.seed(0)
127
+
128
+ diffable_metric = tabular_metrics.cross_entropy
129
+ evaluation_metric = tabular_metrics.auc_metric
130
+ if selection_metric in ('auc', 'roc'):
131
+ selection_metric_min_max = 'max'
132
+ selection_metric = tabular_metrics.auc_metric
133
+ evaluation_metric = selection_metric
134
+ elif selection_metric in ('ce', 'selection_metric'):
135
+ selection_metric_min_max = 'min'
136
+ selection_metric = tabular_metrics.cross_entropy
137
+ evaluation_metric = selection_metric
138
+
139
+ print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
140
+ evaluation_metric)
141
+ print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
142
+ print('eval_positions', eval_positions)
143
+
144
+ def evaluate_valid(style, softmax_temperature, results, results_tracked):
145
+ result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
146
+ return_tensor=False, inference_mode=True, selection_metric=selection_metric,
147
+ evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
148
+ result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
149
+ results += [result_valid]
150
+ results_tracked += [np.nanmean(result_valid)]
151
+
152
+ model[2].to(device)
153
+ model[2].eval()
154
+
155
+ results_on_valid, results_on_valid_tracked = [], []
156
+ best_style, best_softmax_temperature = style, torch.cat(
157
+ [torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
158
+ optimization_routes = []
159
+
160
+ best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
161
+ 0)
162
+ best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
163
+ 0)
164
+
165
+
166
+ for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
167
+ style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
168
+ 0)
169
+ softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
170
+ 0)
171
+
172
+ evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)
173
+
174
+ print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')
175
+
176
+ if N_grad_steps > 0:
177
+ gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
178
+ , softmax_temperature=softmax_temperature
179
+ , model=model[2]
180
+ , train_datasets=train_datasets
181
+ , valid_datasets=valid_datasets
182
+ , selection_metric_min_max=selection_metric_min_max
183
+ , **kwargs)
184
+ optimization_routes += [gradient_optimize_result['optimization_route']]
185
+
186
+ evaluate_valid(gradient_optimize_result['best_style']
187
+ , gradient_optimize_result['best_temperature']
188
+ , results_on_valid, results_on_valid_tracked)
189
+
190
+ print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')
191
+
192
+ if selection_metric_min_max == 'min':
193
+ is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
194
+ else:
195
+ is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))
196
+
197
+ if is_best or best_style is None:
198
+ best_style = gradient_optimize_result['best_style'].clone()
199
+ best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
200
+ torch.cuda.empty_cache()
201
+
202
+ def final_evaluation():
203
+ print('Running eval dataset with final params (no gradients)..')
204
+ print(best_style, best_softmax_temperature)
205
+ result_test = []
206
+ for N_ensemble_configurations in N_ensemble_configurations_list:
207
+ print(f'Running with {N_ensemble_configurations} ensemble_configurations')
208
+ kwargs['N_ensemble_configurations'] = N_ensemble_configurations
209
+ splits = []
210
+ for split in final_splits:
211
+ splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature
212
+ , return_tensor=False, eval_positions=eval_positions_test,
213
+ bptt=bptt_final, inference_mode=True, split_number=split, model=model[2]
214
+ , selection_metric=selection_metric, evaluation_metric=evaluation_metric)]
215
+ result_test += [splits]
216
+
217
+ print('Running valid dataset with final params (no gradients)..')
218
+ result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature
219
+ , return_tensor=False, eval_positions=eval_positions_test,
220
+ bptt=bptt_final, inference_mode=True, model=model[2]
221
+ , selection_metric=selection_metric, evaluation_metric=evaluation_metric)
222
+
223
+ return result_test, result_valid
224
+
225
+ result_test, result_valid = final_evaluation()
226
+
227
+ return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes
228
+
229
+
230
+ def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
231
+ def step():
232
+ return evaluate(datasets=ds,
233
+ method='transformer'
234
+ , overwrite=True
235
+ , style=used_style
236
+ , eval_positions=eval_positions
237
+ , metric_used=selection_metric
238
+ , save=False
239
+ , path_interfix=None
240
+ , base_path=None
241
+ , verbose=True
242
+ , **kwargs)
243
+
244
+ if return_tensor:
245
+ r = step()
246
+ else:
247
+ with torch.no_grad():
248
+ r = step()
249
+
250
+ calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean')
251
+ calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean')
252
+
253
+ return r
254
+
255
+
256
+ def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
257
+ limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
258
+ """
259
+ Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
260
+
261
+ :param model:
262
+ :param init_style:
263
+ :param steps:
264
+ :param learning_rate:
265
+ :param softmax_temperature:
266
+ :param train_datasets:
267
+ :param valid_datasets:
268
+ :param optimize_all:
269
+ :param limit_style:
270
+ :param N_datasets_sampled:
271
+ :param optimize_softmax_temperature:
272
+ :param selection_metric_min_max:
273
+ :param kwargs:
274
+ :return:
275
+ """
276
+ grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True)
277
+
278
+ best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None
279
+ softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature)
280
+ variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature]
281
+ optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate)
282
+
283
+ optimization_route_selection, optimization_route_diffable = [], []
284
+ optimization_route_selection_valid, optimization_route_diffable_valid = [], []
285
+
286
+ def eval_opt(ds, return_tensor=True, inference_mode=False):
287
+ result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
288
+ , inference_mode=inference_mode, model=model[2], **kwargs)
289
+
290
+ diffable_metric = result['mean_metric']
291
+ selection_metric = result['mean_select']
292
+
293
+ return diffable_metric, selection_metric
294
+
295
+ def eval_all_datasets(datasets, propagate=True):
296
+ selection_metrics_this_step, diffable_metrics_this_step = [], []
297
+ for ds in datasets:
298
+ diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate))
299
+ if not torch.isnan(diffable_metric_train).any():
300
+ if propagate and diffable_metric_train.requires_grad == True:
301
+ diffable_metric_train.backward()
302
+ selection_metrics_this_step += [selection_metric_train]
303
+ diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())]
304
+ diffable_metric_train = np.nanmean(diffable_metrics_this_step)
305
+ selection_metric_train = np.nanmean(selection_metrics_this_step)
306
+
307
+ return diffable_metric_train, selection_metric_train
308
+
309
+ for t in tqdm(range(steps), desc='Iterate over Optimization steps'):
310
+ optimizer.zero_grad()
311
+
312
+ # Select subset of datasets
313
+ random.seed(t)
314
+ train_datasets_ = random.sample(train_datasets, N_datasets_sampled)
315
+
316
+ # Get score on train
317
+ diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True)
318
+ optimization_route_selection += [float(selection_metric_train)]
319
+ optimization_route_diffable += [float(diffable_metric_train)]
320
+
321
+ # Get score on valid
322
+ diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False)
323
+ optimization_route_selection_valid += [float(selection_metric_valid)]
324
+ optimization_route_diffable_valid += [float(diffable_metric_valid)]
325
+
326
+ is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
327
+ is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
328
+ if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best):
329
+ print('New best', best_selection_metric, selection_metric_valid)
330
+ best_style = grad_style.detach().clone()
331
+ best_temperature = softmax_temperature.detach().clone()
332
+ best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid
333
+
334
+ optimizer.step()
335
+
336
+ if limit_style:
337
+ grad_style = grad_style.detach().clamp(-1.74, 1.74)
338
+
339
+ print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' +
340
+ f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}')
341
+
342
+ print(f'Return best:{best_style} {best_selection_metric}')
343
+ return {'best_style': best_style, 'best_temperature': best_temperature
344
+ , 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable,
345
+ 'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}
TabPFN/encoders.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from utils import normalize_data
6
+ import torch.nn.functional as F
7
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
+
9
+
10
+ class StyleEncoder(nn.Module):
11
+ def __init__(self, em_size, hyperparameter_definitions):
12
+ super().__init__()
13
+ # self.embeddings = {}
14
+ self.em_size = em_size
15
+ # self.hyperparameter_definitions = {}
16
+ # for hp in hyperparameter_definitions:
17
+ # self.embeddings[hp] = nn.Linear(1, self.em_size)
18
+ # self.embeddings = nn.ModuleDict(self.embeddings)
19
+ self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
20
+
21
+ def forward(self, hyperparameters): # T x B x num_features
22
+ # Make faster by using matrices
23
+ # sampled_embeddings = [torch.stack([
24
+ # self.embeddings[hp](torch.tensor([batch[hp]], device=self.embeddings[hp].weight.device, dtype=torch.float))
25
+ # for hp in batch
26
+ # ], -1).sum(-1) for batch in hyperparameters]
27
+ # return torch.stack(sampled_embeddings, 0)
28
+ return self.embedding(hyperparameters)
29
+
30
+
31
+ class _PositionalEncoding(nn.Module):
32
+ def __init__(self, d_model, dropout=0.):
33
+ super().__init__()
34
+ self.dropout = nn.Dropout(p=dropout)
35
+ self.d_model = d_model
36
+ self.device_test_tensor = nn.Parameter(torch.tensor(1.))
37
+
38
+ def forward(self, x):# T x B x num_features
39
+ assert self.d_model % x.shape[-1]*2 == 0
40
+ d_per_feature = self.d_model // x.shape[-1]
41
+ pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
42
+ #position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
43
+ interval_size = 10
44
+ div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
45
+ #print(div_term/2/math.pi)
46
+ pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
47
+ pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
48
+ return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
49
+
50
+
51
+ Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
52
+
53
+ class EmbeddingEncoder(nn.Module):
54
+ def __init__(self, num_features, em_size, num_embs=100):
55
+ super().__init__()
56
+ self.num_embs = num_embs
57
+ self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
58
+ self.init_weights(.1)
59
+ self.min_max = (-2,+2)
60
+
61
+ @property
62
+ def width(self):
63
+ return self.min_max[1] - self.min_max[0]
64
+
65
+ def init_weights(self, initrange):
66
+ self.embeddings.weight.data.uniform_(-initrange, initrange)
67
+
68
+ def discretize(self, x):
69
+ split_size = self.width / self.num_embs
70
+ return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
71
+
72
+ def forward(self, x): # T x B x num_features
73
+ x_idxs = self.discretize(x)
74
+ x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
75
+ # print(x_idxs,self.embeddings.weight.shape)
76
+ return self.embeddings(x_idxs).mean(-2)
77
+
78
+
79
+ class Normalize(nn.Module):
80
+ def __init__(self, mean, std):
81
+ super().__init__()
82
+ self.mean = mean
83
+ self.std = std
84
+
85
+ def forward(self, x):
86
+ return (x-self.mean)/self.std
87
+
88
+
89
+ def get_normalized_uniform_encoder(encoder_creator):
90
+ """
91
+ This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
92
+ For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
93
+ be initialized with `encoder_creator(feature_dim, in_dim)`.
94
+ :param encoder:
95
+ :return:
96
+ """
97
+ return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
98
+
99
+
100
+ Linear = nn.Linear
101
+ MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
102
+ nn.ReLU(),
103
+ nn.Linear(emsize*2,emsize))
104
+
105
+ class NanHandlingEncoder(nn.Module):
106
+ def __init__(self, num_features, emsize, keep_nans=True):
107
+ super().__init__()
108
+ self.num_features = 2 * num_features if keep_nans else num_features
109
+ self.emsize = emsize
110
+ self.keep_nans = keep_nans
111
+ self.layer = nn.Linear(self.num_features, self.emsize)
112
+
113
+ def forward(self, x):
114
+ if self.keep_nans:
115
+ x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
116
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
117
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
118
+ )], -1)
119
+ else:
120
+ x = torch.nan_to_num(x, nan=0.0)
121
+ return self.layer(x)
122
+
123
+ class Linear(nn.Linear):
124
+ def __init__(self, num_features, emsize):
125
+ super().__init__(num_features, emsize)
126
+ self.num_features = num_features
127
+ self.emsize = emsize
128
+
129
+ def forward(self, x):
130
+ x = torch.nan_to_num(x, nan=0.0)
131
+ return super().forward(x)
132
+
133
+ class SequenceSpanningEncoder(nn.Module):
134
+ # Regular Encoder transforms Seq_len, B, S -> Seq_len, B, E attending only to last dimension
135
+ # This Encoder accesses the Seq_Len dimension additionally
136
+
137
+ # Why would we want this? We can learn normalization and embedding of features
138
+ # , this might be more important for e.g. categorical, ordinal feats, nan detection
139
+ # However maybe this can be easily learned through transformer as well?
140
+ # A problem is to make this work across any sequence length and be independent of ordering
141
+
142
+ # We could use average and maximum pooling and use those with a linear layer
143
+
144
+
145
+ # Another idea !! Similar to this we would like to encode features so that their number is variable
146
+ # We would like to embed features, also using knowledge of the features in the entire sequence
147
+
148
+ # We could use convolution or another transformer
149
+ # Convolution:
150
+
151
+ # Transformer/Conv across sequence dimension that encodes and normalizes features
152
+ # -> Transformer across feature dimension that encodes features to a constant size
153
+
154
+ # Conv with flexible features but no sequence info: S,B,F -(reshape)-> S*B,1,F
155
+ # -(Conv1d)-> S*B,N,F -(AvgPool,MaxPool)-> S*B,N,1 -> S,B,N
156
+ # This probably won't work since it's missing a way to recognize which feature is encoded
157
+
158
+ # Transformer with flexible features: S,B,F -> F,B*S,1 -> F2,B*S,1 -> S,B,F2
159
+
160
+ def __init__(self, num_features, em_size):
161
+ super().__init__()
162
+
163
+ raise NotImplementedError()
164
+ # Seq_len, B, S -> Seq_len, B, E
165
+ #
166
+ self.convs = torch.nn.ModuleList([nn.Conv1d(64 if i else 1, 64, 3) for i in range(5)])
167
+ # self.linear = nn.Linear(64, emsize)
168
+
169
+ class TransformerBasedFeatureEncoder(nn.Module):
170
+ def __init__(self, num_features, emsize):
171
+ super().__init__()
172
+
173
+ hidden_emsize = emsize
174
+ encoder = Linear(1, hidden_emsize)
175
+ n_out = emsize
176
+ nhid = 2*emsize
177
+ dropout =0.0
178
+ nhead=4
179
+ nlayers=4
180
+ model = nn.Transformer(nhead=nhead, num_encoder_layers=4, num_decoder_layers=4, d_model=1)
181
+
182
+ def forward(self, *input):
183
+ # S,B,F -> F,S*B,1 -> F2,S*B,1 -> S,B,F2
184
+ input = input.transpose()
185
+ self.model(input)
186
+
187
+ class Conv(nn.Module):
188
+ def __init__(self, input_size, emsize):
189
+ super().__init__()
190
+ self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
191
+ self.linear = nn.Linear(64,emsize)
192
+
193
+
194
+ def forward(self, x):
195
+ size = math.isqrt(x.shape[-1])
196
+ assert size*size == x.shape[-1]
197
+ x = x.reshape(*x.shape[:-1], 1, size, size)
198
+ for conv in self.convs:
199
+ if x.shape[-1] < 4:
200
+ break
201
+ x = conv(x)
202
+ x.relu_()
203
+ x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
204
+ return self.linear(x)
205
+
206
+
207
+
208
+
209
+ class CanEmb(nn.Embedding):
210
+ def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
211
+ assert embedding_dim % num_features == 0
212
+ embedding_dim = embedding_dim // num_features
213
+ super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
214
+
215
+ def forward(self, x):
216
+ lx = x.long()
217
+ assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
218
+ x = super().forward(lx)
219
+ return x.view(*x.shape[:-2], -1)
220
+
221
+ def get_Canonical(num_classes):
222
+ return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
223
+
224
+ def get_Embedding(num_embs_per_feature=100):
225
+ return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
TabPFN/initializers.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def get_NormalInitializer(std):
5
+ def initializer(m):
6
+ if isinstance(m, nn.Linear):
7
+ nn.init.normal_(m.weight, 0, std)
8
+ nn.init.normal_(m.bias, 0, std)
9
+ return initializer
TabPFN/layer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ from torch import nn
4
+ from torch.nn.modules.transformer import *
5
+ from torch.nn.modules.transformer import _get_activation_fn
6
+
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+
10
+ class TransformerEncoderLayer(Module):
11
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
12
+ This standard encoder layer is based on the paper "Attention Is All You Need".
13
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
14
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
15
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
16
+ in a different way during application.
17
+
18
+ Args:
19
+ d_model: the number of expected features in the input (required).
20
+ nhead: the number of heads in the multiheadattention models (required).
21
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
22
+ dropout: the dropout value (default=0.1).
23
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
24
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
25
+ batch_first: If ``True``, then the input and output tensors are provided
26
+ as (batch, seq, feature). Default: ``False``.
27
+
28
+ Examples::
29
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
30
+ >>> src = torch.rand(10, 32, 512)
31
+ >>> out = encoder_layer(src)
32
+
33
+ Alternatively, when ``batch_first`` is ``True``:
34
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
35
+ >>> src = torch.rand(32, 10, 512)
36
+ >>> out = encoder_layer(src)
37
+ """
38
+ __constants__ = ['batch_first']
39
+
40
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
41
+ layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
42
+ device=None, dtype=None, recompute_attn=False) -> None:
43
+ factory_kwargs = {'device': device, 'dtype': dtype}
44
+ super().__init__()
45
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
46
+ **factory_kwargs)
47
+ # Implementation of Feedforward model
48
+ self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
49
+ self.dropout = Dropout(dropout)
50
+ self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
51
+
52
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
53
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
54
+ self.dropout1 = Dropout(dropout)
55
+ self.dropout2 = Dropout(dropout)
56
+ self.pre_norm = pre_norm
57
+ self.recompute_attn = recompute_attn
58
+
59
+ self.activation = _get_activation_fn(activation)
60
+
61
+ def __setstate__(self, state):
62
+ if 'activation' not in state:
63
+ state['activation'] = F.relu
64
+ super().__setstate__(state)
65
+
66
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
67
+ r"""Pass the input through the encoder layer.
68
+
69
+ Args:
70
+ src: the sequence to the encoder layer (required).
71
+ src_mask: the mask for the src sequence (optional).
72
+ src_key_padding_mask: the mask for the src keys per batch (optional).
73
+
74
+ Shape:
75
+ see the docs in Transformer class.
76
+ """
77
+ if self.pre_norm:
78
+ src_ = self.norm1(src)
79
+ else:
80
+ src_ = src
81
+ if isinstance(src_mask, tuple):
82
+ # global attention setup
83
+ assert not self.self_attn.batch_first
84
+ assert src_key_padding_mask is None
85
+
86
+ global_src_mask, trainset_src_mask, valset_src_mask = src_mask
87
+
88
+ num_global_tokens = global_src_mask.shape[0]
89
+ num_train_tokens = trainset_src_mask.shape[0]
90
+
91
+ global_tokens_src = src_[:num_global_tokens]
92
+ train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
93
+ global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
94
+ eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
95
+
96
+
97
+ attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
98
+
99
+ global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
100
+ train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
101
+ eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
102
+ None, True, valset_src_mask)[0]
103
+
104
+ src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
105
+
106
+ else:
107
+ if self.recompute_attn:
108
+ src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
109
+ else:
110
+ src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
111
+ key_padding_mask=src_key_padding_mask)[0]
112
+ src = src + self.dropout1(src2)
113
+ if not self.pre_norm:
114
+ src = self.norm1(src)
115
+
116
+ if self.pre_norm:
117
+ src_ = self.norm2(src)
118
+ else:
119
+ src_ = src
120
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
121
+ src = src + self.dropout2(src2)
122
+
123
+ if not self.pre_norm:
124
+ src = self.norm2(src)
125
+ return src
TabPFN/losses.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class CrossEntropyForMulticlassLoss(torch.nn.CrossEntropyLoss):
5
+ # This loss applies cross entropy after reducing the number of prediction
6
+ # dimensions to the number of classes in the target
7
+
8
+ # TODO: loss.item() doesn't work so the displayed losses are Nans
9
+ def __init__(self, num_classes, weight=None, size_average=None, ignore_index: int = -100,
10
+ reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None:
11
+ super().__init__(size_average=size_average, reduce=reduce, reduction=reduction, ignore_index=ignore_index)
12
+ self.num_classes = num_classes
13
+
14
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
+ loss = torch.zeros_like(input[:, :, 0])
16
+ for b in range(target.shape[1]):
17
+ l = super().forward(input[:, b, 0:len(torch.unique(target[:, b]))], target[:, b])
18
+ loss[:, b] += l
19
+ return loss.flatten()
20
+
21
+ def JointBCELossWithLogits(output, target):
22
+ # output shape: (S, B, NS) with NS = Number of sequences
23
+ # target shape: (S, B, SL)
24
+ # Loss = -log(mean_NS(prod_SL(p(target_SL, output_NS))))
25
+ # Here at the moment NS = SL
26
+ output = output.unsqueeze(-1).repeat(1, 1, 1, target.shape[-1]) # (S, B, NS, SL)
27
+ output = output.permute(2, 0, 1, 3) # (NS, S, B, SL)
28
+ print(target.shape, output.shape)
29
+ loss = (target * torch.sigmoid(output)) + ((1-target) * (1-torch.sigmoid(output)))
30
+ loss = loss.prod(-1)
31
+ loss = loss.mean(0)
32
+ loss = -torch.log(loss)
33
+ loss = loss.mean()
34
+ return loss
35
+
36
+ class ScaledSoftmaxCE(nn.Module):
37
+ def forward(self, x, label):
38
+ logits = x[..., :-10]
39
+ temp_scales = x[..., -10:]
40
+
41
+ logprobs = logits.softmax(-1)
TabPFN/model_builder.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from train import train, Losses
2
+ import priors
3
+ import encoders
4
+
5
+ from collections import defaultdict
6
+
7
+ from priors.utils import trunc_norm_sampler_f, gamma_sampler_f
8
+ from utils import get_uniform_single_eval_pos_sampler
9
+ import torch
10
+ import math
11
+
12
+ def save_model(model, path, filename, config_sample):
13
+ config_sample = {**config_sample}
14
+
15
+ def make_serializable(config_sample):
16
+ if isinstance(config_sample, dict):
17
+ config_sample = {k: make_serializable(config_sample[k]) for k in config_sample}
18
+ if isinstance(config_sample, list):
19
+ config_sample = [make_serializable(v) for v in config_sample]
20
+ if callable(config_sample):
21
+ config_sample = str(config_sample)
22
+ return config_sample
23
+
24
+ #if 'num_features_used' in config_sample:
25
+ # del config_sample['num_features_used']
26
+
27
+ #config_sample['num_classes_as_str'] = str(config_sample['num_classes'])
28
+ #del config_sample['num_classes']
29
+
30
+ config_sample = make_serializable(config_sample)
31
+
32
+ torch.save((model.state_dict(), None, config_sample), os.path.join(path, filename))
33
+
34
+
35
+ import subprocess as sp
36
+ import os
37
+
38
+ def get_gpu_memory():
39
+ command = "nvidia-smi"
40
+ memory_free_info = sp.check_output(command.split()).decode('ascii')
41
+ return memory_free_info
42
+
43
+
44
+ def load_model(path, filename, device, eval_positions, verbose):
45
+ # TODO: This function only restores evaluation functionality but training canät be continued. It is also not flexible.
46
+
47
+ model_state, optimizer_state, config_sample = torch.load(
48
+ os.path.join(path, filename), map_location='cpu')
49
+ if ('differentiable_hyperparameters' in config_sample
50
+ and 'prior_mlp_activations' in config_sample['differentiable_hyperparameters']):
51
+ config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values_used'] = config_sample[
52
+ 'differentiable_hyperparameters'][
53
+ 'prior_mlp_activations'][
54
+ 'choice_values']
55
+ config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values'] = [
56
+ torch.nn.Tanh for k in config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values']]
57
+
58
+ config_sample['categorical_features_sampler'] = lambda: lambda x: ([], [], [])
59
+ config_sample['num_features_used_in_training'] = config_sample['num_features_used']
60
+ config_sample['num_features_used'] = lambda: config_sample['num_features']
61
+ config_sample['num_classes_in_training'] = config_sample['num_classes']
62
+ config_sample['num_classes'] = 2
63
+ config_sample['batch_size_in_training'] = config_sample['batch_size']
64
+ config_sample['batch_size'] = 1
65
+ config_sample['bptt_in_training'] = config_sample['bptt']
66
+ config_sample['bptt'] = 10
67
+ config_sample['bptt_extra_samples_in_training'] = config_sample['bptt_extra_samples']
68
+ config_sample['bptt_extra_samples'] = None
69
+
70
+ #print('Memory', str(get_gpu_memory()))
71
+
72
+ model = get_model(config_sample, device=device, should_train=False, verbose=verbose)
73
+ module_prefix = 'module.'
74
+ model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
75
+ model[2].load_state_dict(model_state)
76
+ model[2].to(device)
77
+
78
+ return model, config_sample
79
+
80
+ def fix_loaded_config_sample(loaded_config_sample, config):
81
+ def copy_to_sample(*k):
82
+ t,s = loaded_config_sample, config
83
+ for k_ in k[:-1]:
84
+ t = t[k_]
85
+ s = s[k_]
86
+ t[k[-1]] = s[k[-1]]
87
+ copy_to_sample('num_features_used')
88
+ copy_to_sample('num_classes')
89
+ copy_to_sample('differentiable_hyperparameters','prior_mlp_activations','choice_values')
90
+
91
+ def load_config_sample(path, template_config):
92
+ model_state, optimizer_state, loaded_config_sample = torch.load(path, map_location='cpu')
93
+ fix_loaded_config_sample(loaded_config_sample, template_config)
94
+ return loaded_config_sample
95
+
96
+ def get_default_spec(test_datasets, valid_datasets):
97
+ bptt = 10000
98
+ eval_positions = [1000, 2000, 3000, 4000, 5000] # list(2 ** np.array([4, 5, 6, 7, 8, 9, 10, 11, 12]))
99
+ max_features = max([X.shape[1] for (_, X, _, _, _, _) in test_datasets] + [X.shape[1] for (_, X, _, _, _, _) in valid_datasets])
100
+ max_splits = 5
101
+
102
+ return bptt, eval_positions, max_features, max_splits
103
+
104
+ def get_mlp_prior_hyperparameters(config):
105
+ config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
106
+
107
+ if "prior_sigma_gamma_k" in config:
108
+ sigma_sampler = gamma_sampler_f(config["prior_sigma_gamma_k"], config["prior_sigma_gamma_theta"])
109
+ config['init_std'] = sigma_sampler
110
+ if "prior_noise_std_gamma_k" in config:
111
+ noise_std_sampler = gamma_sampler_f(config["prior_noise_std_gamma_k"], config["prior_noise_std_gamma_theta"])
112
+ config['noise_std'] = noise_std_sampler
113
+
114
+ return config
115
+
116
+
117
+ def get_gp_mix_prior_hyperparameters(config):
118
+ return {'lengthscale_concentration': config["prior_lengthscale_concentration"],
119
+ 'nu': config["prior_nu"],
120
+ 'outputscale_concentration': config["prior_outputscale_concentration"],
121
+ 'categorical_data': config["prior_y_minmax_norm"],
122
+ 'y_minmax_norm': config["prior_lengthscale_concentration"],
123
+ 'noise_concentration': config["prior_noise_concentration"],
124
+ 'noise_rate': config["prior_noise_rate"]}
125
+
126
+ def get_gp_prior_hyperparameters(config):
127
+ return {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
128
+
129
+
130
+ def get_meta_gp_prior_hyperparameters(config):
131
+ config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
132
+
133
+ if "outputscale_mean" in config:
134
+ outputscale_sampler = trunc_norm_sampler_f(config["outputscale_mean"]
135
+ , config["outputscale_mean"] * config["outputscale_std_f"])
136
+ config['outputscale'] = outputscale_sampler
137
+ if "lengthscale_mean" in config:
138
+ lengthscale_sampler = trunc_norm_sampler_f(config["lengthscale_mean"],
139
+ config["lengthscale_mean"] * config["lengthscale_std_f"])
140
+ config['lengthscale'] = lengthscale_sampler
141
+
142
+ return config
143
+
144
+
145
+ def get_model(config, device, should_train=True, verbose=False, state_dict=None, epoch_callback=None):
146
+ extra_kwargs = {}
147
+ verbose_train, verbose_prior = verbose >= 1, verbose >= 2
148
+ config['verbose'] = verbose_prior
149
+
150
+ if 'aggregate_k_gradients' not in config or config['aggregate_k_gradients'] is None:
151
+ config['aggregate_k_gradients'] = math.ceil(config['batch_size'] * ((config['nlayers'] * config['emsize'] * config['bptt'] * config['bptt']) / 10824640000))
152
+
153
+ config['num_steps'] = math.ceil(config['num_steps'] * config['aggregate_k_gradients'])
154
+ config['batch_size'] = math.ceil(config['batch_size'] / config['aggregate_k_gradients'])
155
+ config['recompute_attn'] = config['recompute_attn'] if 'recompute_attn' in config else False
156
+
157
+ def make_get_batch(model_proto, **extra_kwargs):
158
+ extra_kwargs = defaultdict(lambda: None, **extra_kwargs)
159
+ return (lambda batch_size, seq_len, num_features, hyperparameters
160
+ , device, model_proto=model_proto, get_batch=extra_kwargs['get_batch']
161
+ , prior_bag_priors=extra_kwargs['prior_bag_priors']: model_proto.get_batch(
162
+ batch_size=batch_size
163
+ , seq_len=seq_len
164
+ , device=device
165
+ , get_batch=get_batch
166
+ , hyperparameters=hyperparameters
167
+ , num_features=num_features))
168
+
169
+ if config['prior_type'] == 'prior_bag':
170
+ # Prior bag combines priors
171
+ get_batch_gp = make_get_batch(priors.fast_gp)
172
+ get_batch_mlp = make_get_batch(priors.mlp)
173
+ if 'flexible' in config and config['flexible']:
174
+ get_batch_gp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_gp})
175
+ get_batch_mlp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_mlp})
176
+ prior_bag_hyperparameters = {'prior_bag_get_batch': (get_batch_gp, get_batch_mlp)
177
+ , 'prior_bag_exp_weights_1': 2.0}
178
+ prior_hyperparameters = {**get_mlp_prior_hyperparameters(config), **get_gp_prior_hyperparameters(config)
179
+ , **prior_bag_hyperparameters}
180
+ model_proto = priors.prior_bag
181
+ else:
182
+ if config['prior_type'] == 'mlp':
183
+ prior_hyperparameters = get_mlp_prior_hyperparameters(config)
184
+ model_proto = priors.mlp
185
+ elif config['prior_type'] == 'gp':
186
+ prior_hyperparameters = get_gp_prior_hyperparameters(config)
187
+ model_proto = priors.fast_gp
188
+ elif config['prior_type'] == 'gp_mix':
189
+ prior_hyperparameters = get_gp_mix_prior_hyperparameters(config)
190
+ model_proto = priors.fast_gp_mix
191
+ else:
192
+ raise Exception()
193
+
194
+ if 'flexible' in config and config['flexible']:
195
+ get_batch_base = make_get_batch(model_proto)
196
+ extra_kwargs['get_batch'] = get_batch_base
197
+ model_proto = priors.flexible_categorical
198
+
199
+ use_style = False
200
+
201
+ if 'differentiable' in config and config['differentiable']:
202
+ get_batch_base = make_get_batch(model_proto, **extra_kwargs)
203
+ extra_kwargs = {'get_batch': get_batch_base, 'differentiable_hyperparameters': config['differentiable_hyperparameters']}
204
+ model_proto = priors.differentiable_prior
205
+ use_style = True
206
+ print(f"Using style prior: {use_style}")
207
+
208
+ if (('nan_prob_no_reason' in config and config['nan_prob_no_reason'] > 0.0) or
209
+ ('nan_prob_a_reason' in config and config['nan_prob_a_reason'] > 0.0) or
210
+ ('nan_prob_unknown_reason' in config and config['nan_prob_unknown_reason'] > 0.0)):
211
+ encoder = encoders.NanHandlingEncoder
212
+ else:
213
+ encoder = encoders.Linear
214
+
215
+ num_outputs = config['num_outputs'] if 'num_outputs' in config else 1
216
+ if config['max_num_classes'] == 2:
217
+ if 'joint_loss' in config and config['joint_loss']:
218
+ loss = JointBCELossWithLogits
219
+ else:
220
+ loss = Losses.bce
221
+ elif config['max_num_classes'] > 2:
222
+ loss = Losses.ce(torch.ones((config['max_num_classes'])))
223
+ else:
224
+ loss = BarDistribution(borders=get_bucket_limits(500, full_range=(-10, 10)))
225
+
226
+ aggregate_k_gradients = 1 if 'aggregate_k_gradients' not in config else config['aggregate_k_gradients']
227
+ check_is_compatible = False if 'multiclass_loss_type' not in config else (config['multiclass_loss_type'] == 'compatible')
228
+ config['multiclass_type'] = config['multiclass_type'] if 'multiclass_type' in config else 'rank'
229
+ config['mix_activations'] = config['mix_activations'] if 'mix_activations' in config else False
230
+
231
+ config['bptt_extra_samples'] = config['bptt_extra_samples'] if 'bptt_extra_samples' in config else None
232
+ config['eval_positions'] = [int(config['bptt'] * 0.95)] if config['bptt_extra_samples'] is None else [int(config['bptt'])]
233
+
234
+ epochs = 0 if not should_train else config['epochs']
235
+ model = train(model_proto.DataLoader
236
+ , loss
237
+ , encoder
238
+ , style_encoder_generator = encoders.StyleEncoder if use_style else None
239
+ , emsize=config['emsize']
240
+ , nhead=config['nhead']
241
+ , y_encoder_generator= encoders.get_Canonical(config['max_num_classes']) if config.get('canonical_y_encoder', False) else encoders.Linear
242
+ , pos_encoder_generator=None
243
+ , batch_size=config['batch_size']
244
+ , nlayers=config['nlayers']
245
+ , nhid=config['emsize'] * config['nhid_factor']
246
+ , epochs=epochs
247
+ , total_available_time_in_s=config.get('total_available_time_in_s', None)
248
+ , warmup_epochs=20
249
+ , bptt=config['bptt']
250
+ , gpu_device=device
251
+ , dropout=config['dropout']
252
+ , steps_per_epoch=config['num_steps']
253
+ , single_eval_pos_gen=get_uniform_single_eval_pos_sampler(config['bptt'])
254
+ , load_weights_from_this_state_dict=state_dict
255
+ , aggregate_k_gradients=aggregate_k_gradients
256
+ , check_is_compatible=check_is_compatible
257
+ , recompute_attn=config['recompute_attn']
258
+ , epoch_callback=epoch_callback
259
+ , bptt_extra_samples = config['bptt_extra_samples']
260
+ , extra_prior_kwargs_dict={
261
+ 'num_features': config['num_features']
262
+ , 'fuse_x_y': False
263
+ , 'hyperparameters': prior_hyperparameters
264
+ , 'num_outputs':num_outputs
265
+ , 'dynamic_batch_size': 1 if ('num_global_att_tokens' in config and config['num_global_att_tokens']) else 2
266
+ , **extra_kwargs
267
+ }
268
+ , lr=config['lr']
269
+ , verbose=verbose_train,
270
+ weight_decay=config.get('weight_decay', 0.0),
271
+ normalize_labels=True)
272
+
273
+ return model
TabPFN/models_diff/gp_ablation_model.cpkt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7b0c8febc553cca3fdee265b5a1cd7567dbf83da855969940be4707a9218ffb
3
+ size 69460013
TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dae97f45bd53d719fc2b23fac4ec55eab16d63892196d939b1bb1c3b408be242
3
+ size 103616779
TabPFN/notebook_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import io
5
+ import torch
6
+ import pickle
7
+
8
+ def print_models(base_path, model_string):
9
+ print(model_string)
10
+
11
+ for i in range(80):
12
+ for e in range(50):
13
+ exists = Path(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt')).is_file()
14
+ if exists:
15
+ print(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt'))
16
+ print()
17
+
18
+ class CustomUnpickler(pickle.Unpickler):
19
+ def find_class(self, module, name):
20
+ if name == 'Manager':
21
+ from settings import Manager
22
+ return Manager
23
+ try:
24
+ return self.find_class_cpu(module, name)
25
+ except:
26
+ return None
27
+
28
+ def find_class_cpu(self, module, name):
29
+ if module == 'torch.storage' and name == '_load_from_bytes':
30
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
31
+ else:
32
+ return super().find_class(module, name)
TabPFN/positional_encodings.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ # Protocol for positonal encodings.
8
+ # __init__(d_model, max_len=..[, more optionals])
9
+ # forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
10
+
11
+
12
+ class NoPositionalEncoding(nn.Module):
13
+ def __init__(self, d_model, max_len=None):
14
+ super(NoPositionalEncoding, self).__init__()
15
+ pass
16
+
17
+ def forward(self, x):
18
+ return x #* math.sqrt(x.shape[-1])
19
+
20
+
21
+ class PositionalEncoding(nn.Module):
22
+ def __init__(self, d_model, max_len=5000):
23
+ super(PositionalEncoding, self).__init__()
24
+ pe = torch.zeros(max_len, d_model)
25
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
26
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
27
+ pe[:, 0::2] = torch.sin(position * div_term)
28
+ pe[:, 1::2] = torch.cos(position * div_term)
29
+ pe = pe.unsqueeze(0).transpose(0, 1)
30
+ self.register_buffer('pe', pe)
31
+
32
+ def forward(self, x):
33
+ x = self.pe[:x.size(0), :] + x # * math.sqrt(x.shape[-1])
34
+ return x
35
+
36
+
37
+ class LearnedPositionalEncoding(nn.Module):
38
+ def __init__(self, d_model, max_len=5000):
39
+ super(LearnedPositionalEncoding, self).__init__()
40
+ self.max_seq_len = max_len
41
+ #self.positional_embeddings = nn.Embedding(max_len, d_model)
42
+ self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
43
+ nn.init.normal_(self.positional_embeddings, mean=0, std=d_model ** -0.5)
44
+
45
+ def forward(self, x):
46
+ seq_len, bs, d_model = x.shape
47
+ assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
48
+ pos_emb = self.positional_embeddings[:seq_len]
49
+ return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
50
+
51
+
52
+ class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
53
+ # TODO check whether it is a problem to use the same perm. for full batch
54
+ def forward(self, x):
55
+ seq_len, bs, d_model = x.shape
56
+ assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
57
+ assert len(self.positional_embeddings) % 2 == 0, 'Please specify an even max_len.'
58
+
59
+ paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2)
60
+ pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(*self.positional_embeddings.shape)[:seq_len]
61
+
62
+ return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
TabPFN/prior_tuning_result.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24d2189bbc836aeea888cf6c540f2c1b45b5351822931189e8bf10a0bc80a0b6
3
+ size 18668851
TabPFN/priors/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import fast_gp, mlp, flexible_categorical, differentiable_prior, prior_bag
2
+
3
+
4
+
TabPFN/priors/differentiable_prior.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+ from .utils import get_batch_to_dataloader
6
+ from utils import default_device
7
+ from .utils import order_by_y, normalize_by_used_features_f
8
+
9
+ from .utils import trunc_norm_sampler_f, beta_sampler_f, gamma_sampler_f, uniform_sampler_f, zipf_sampler_f, scaled_beta_sampler_f, uniform_int_sampler_f
10
+
11
+
12
+ def unpack_dict_of_tuples(d):
13
+ # Returns list of dicts where each dict i contains values of tuple position i
14
+ # {'a': (1,2), 'b': (3,4)} -> [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}]
15
+ return [dict(zip(d.keys(), v)) for v in list(zip(*list(d.values())))]
16
+
17
+ class DifferentiableHyperparameter(nn.Module):
18
+ ## We can sample this and get a hyperparameter value and a normalized hyperparameter indicator
19
+ def __init__(self, distribution, embedding_dim, device, **args):
20
+ super(DifferentiableHyperparameter, self).__init__()
21
+
22
+ self.distribution = distribution
23
+ self.embedding_dim = embedding_dim
24
+ self.device=device
25
+ for key in args:
26
+ setattr(self, key, args[key])
27
+
28
+ def get_sampler():
29
+ #if self.distribution == "beta":
30
+ # return beta_sampler_f(self.a, self.b), 0, 1
31
+ #elif self.distribution == "gamma":
32
+ # return gamma_sampler_f(self.a, self.b), 0, 1
33
+ #elif self.distribution == "beta_int":
34
+ # return scaled_beta_sampler_f(self.a, self.b, self.scale, self.min), self.scale + self.min, self.min, self.a / (self.a + self.b)
35
+ if self.distribution == "uniform":
36
+ if not hasattr(self, 'sample'):
37
+ return uniform_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min))
38
+ else:
39
+ return lambda: self.sample, self.min, self.max, None, None
40
+ elif self.distribution == "uniform_int":
41
+ return uniform_int_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min))
42
+
43
+ if self.distribution.startswith("meta"):
44
+ self.hparams = {}
45
+ def sample_meta(f):
46
+ indicators, passed = unpack_dict_of_tuples({hp: self.hparams[hp]() for hp in self.hparams})
47
+ # sampled_embeddings = list(itertools.chain.from_iterable([sampled_embeddings[k] for k in sampled_embeddings]))
48
+ meta_passed = f(**passed)
49
+ return indicators, meta_passed
50
+
51
+ args_passed = {'device': device, 'embedding_dim': embedding_dim}
52
+ if self.distribution == "meta_beta":
53
+ ## Truncated normal where std and mean are drawn randomly logarithmically scaled
54
+ if hasattr(self, 'b') and hasattr(self, 'k'):
55
+ self.hparams = {'b': lambda: (None, self.b), 'k': lambda: (None, self.k)}
56
+ else:
57
+ self.hparams = {"b": DifferentiableHyperparameter(distribution="uniform", min=self.min
58
+ , max=self.max, **args_passed)
59
+ , "k": DifferentiableHyperparameter(distribution="uniform", min=self.min
60
+ , max=self.max, **args_passed)}
61
+ def make_beta(b, k):
62
+ return lambda b=b, k=k: self.scale * beta_sampler_f(b, k)()
63
+ self.sampler = lambda make_beta=make_beta : sample_meta(make_beta)
64
+ elif self.distribution == "meta_trunc_norm_log_scaled":
65
+ # these choices are copied down below, don't change these without changing `replace_differentiable_distributions`
66
+ self.min_std = self.min_std if hasattr(self, 'min_std') else 0.001
67
+ self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean
68
+ ## Truncated normal where std and mean are drawn randomly logarithmically scaled
69
+ if not hasattr(self, 'log_mean'):
70
+ self.hparams = {"log_mean": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_mean)
71
+ , max=math.log(self.max_mean), **args_passed)
72
+ , "log_std": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_std)
73
+ , max=math.log(self.max_std), **args_passed)}
74
+ else:
75
+ self.hparams = {'log_mean': lambda: (None, self.log_mean), 'log_std': lambda: (None, self.log_std)}
76
+ def make_trunc_norm(log_mean, log_std):
77
+ return ((lambda : self.lower_bound + round(trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))())) if self.round
78
+ else (lambda: self.lower_bound + trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))()))
79
+
80
+ self.sampler = lambda make_trunc_norm=make_trunc_norm: sample_meta(make_trunc_norm)
81
+ elif self.distribution == "meta_trunc_norm":
82
+ self.min_std = self.min_std if hasattr(self, 'min_std') else 0
83
+ self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean
84
+ self.hparams = {"mean": DifferentiableHyperparameter(distribution="uniform", min=self.min_mean
85
+ , max=self.max_mean, **args_passed)
86
+ , "std": DifferentiableHyperparameter(distribution="uniform", min=self.min_std
87
+ , max=self.max_std, **args_passed)}
88
+ def make_trunc_norm(mean, std):
89
+ return ((lambda: self.lower_bound + round(
90
+ trunc_norm_sampler_f(math.exp(mean), math.exp(std))())) if self.round
91
+ else (
92
+ lambda make_trunc_norm=make_trunc_norm: self.lower_bound + trunc_norm_sampler_f(math.exp(mean), math.exp(std))()))
93
+ self.sampler = lambda : sample_meta(make_trunc_norm)
94
+ elif self.distribution == "meta_choice":
95
+ if hasattr(self, 'choice_1_weight'):
96
+ self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))}
97
+ else:
98
+ self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0
99
+ , max=6.0, **args_passed) for i in range(1, len(self.choice_values))}
100
+ def make_choice(**choices):
101
+ weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights
102
+ sample = torch.multinomial(weights, 1, replacement=True).numpy()[0]
103
+ return self.choice_values[sample]
104
+
105
+ self.sampler = lambda make_choice=make_choice: sample_meta(make_choice)
106
+ elif self.distribution == "meta_choice_mixed":
107
+ if hasattr(self, 'choice_1_weight'):
108
+ self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))}
109
+ else:
110
+ self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0
111
+ , max=6.0, **args_passed) for i in range(1, len(self.choice_values))}
112
+ def make_choice(**choices):
113
+ weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights
114
+ def sample():
115
+ s = torch.multinomial(weights, 1, replacement=True).numpy()[0]
116
+ return self.choice_values[s]()
117
+ return lambda: sample
118
+
119
+ self.sampler = lambda make_choice=make_choice: sample_meta(make_choice)
120
+ else:
121
+ def return_two(x, min, max, mean, std):
122
+ # Returns (a hyperparameter value, and an indicator value passed to the model)
123
+ if mean is not None:
124
+ ind = (x-mean)/std#(2 * (x-min) / (max-min) - 1)
125
+ else:
126
+ ind = None
127
+ return ind, x # normalize indicator to [-1, 1]
128
+ # def sample_standard(sampler_f, embedding):
129
+ # s = torch.tensor([sampler_f()], device = self.device)
130
+ # return s, embedding(s)
131
+ self.sampler_f, self.sampler_min, self.sampler_max, self.sampler_mean, self.sampler_std = get_sampler()
132
+ self.sampler = lambda : return_two(self.sampler_f(), min=self.sampler_min, max=self.sampler_max
133
+ , mean=self.sampler_mean, std=self.sampler_std)
134
+ # self.embedding_layer = nn.Linear(1, self.embedding_dim, device=self.device)
135
+ # self.embed = lambda x : self.embedding_layer(
136
+ # (x - self.sampler_min) / (self.sampler_max - self.sampler_min))
137
+ #self.sampler = lambda : sample_standard(self.sampler_f, self.embedding)
138
+
139
+
140
+ def forward(self):
141
+ s, s_passed = self.sampler()
142
+ return s, s_passed
143
+
144
+
145
+
146
+ class DifferentiableHyperparameterList(nn.Module):
147
+ def __init__(self, hyperparameters, embedding_dim, device):
148
+ super().__init__()
149
+
150
+ self.device = device
151
+ hyperparameters = {k: v for (k, v) in hyperparameters.items() if v}
152
+ self.hyperparameters = nn.ModuleDict({hp: DifferentiableHyperparameter(embedding_dim = embedding_dim
153
+ , name = hp
154
+ , device = device, **hyperparameters[hp]) for hp in hyperparameters})
155
+ def get_hyperparameter_info(self):
156
+ sampled_hyperparameters_f, sampled_hyperparameters_keys = [], []
157
+ def append_hp(hp_key, hp_val):
158
+ sampled_hyperparameters_keys.append(hp_key)
159
+ # Function remaps hyperparameters from [-1, 1] range to true value
160
+ s_min, s_max, s_mean, s_std = hp_val.sampler_min, hp_val.sampler_max, hp_val.sampler_mean, hp_val.sampler_std
161
+ sampled_hyperparameters_f.append((lambda x: (x-s_mean)/s_std, lambda y : (y * s_std)+s_mean))
162
+ #sampled_hyperparameters_f.append(((lambda x: ((x - s_min) / (s_max - s_min) * (2) - 1)
163
+ # , (lambda y: ((y + 1) * (1 / 2) * (s_max - s_min) + s_min))))
164
+ for hp in self.hyperparameters:
165
+ hp_val = self.hyperparameters[hp]
166
+ if hasattr(hp_val, 'hparams'):
167
+ for hp_ in hp_val.hparams:
168
+ append_hp(f'{hp}_{hp_}', hp_val.hparams[hp_])
169
+ else:
170
+ append_hp(hp, hp_val)
171
+
172
+
173
+ return sampled_hyperparameters_keys, sampled_hyperparameters_f
174
+
175
+ def sample_parameter_object(self):
176
+ sampled_hyperparameters, s_passed = {}, {}
177
+ for hp in self.hyperparameters:
178
+ sampled_hyperparameters_, s_passed_ = self.hyperparameters[hp]()
179
+ s_passed[hp] = s_passed_
180
+ if isinstance(sampled_hyperparameters_, dict):
181
+ sampled_hyperparameters_ = {hp + '_' + str(key): val for key, val in sampled_hyperparameters_.items()}
182
+ sampled_hyperparameters.update(sampled_hyperparameters_)
183
+ else:
184
+ sampled_hyperparameters[hp] = sampled_hyperparameters_
185
+
186
+ # s_passed contains the values passed to the get_batch function
187
+ # sampled_hyperparameters contains the indicator of the sampled value, i.e. only number that describe the sampled object
188
+ return s_passed, sampled_hyperparameters#self.pack_parameter_object(sampled_embeddings)
189
+
190
+ class DifferentiablePrior(torch.nn.Module):
191
+ def __init__(self, get_batch, hyperparameters, differentiable_hyperparameters, args):
192
+ super(DifferentiablePrior, self).__init__()
193
+
194
+ self.h = hyperparameters
195
+ self.args = args
196
+ self.get_batch = get_batch
197
+ self.differentiable_hyperparameters = DifferentiableHyperparameterList(differentiable_hyperparameters
198
+ , embedding_dim=self.h['emsize']
199
+ , device=self.args['device'])
200
+
201
+ def forward(self):
202
+ # Sample hyperparameters
203
+ sampled_hyperparameters_passed, sampled_hyperparameters_indicators = self.differentiable_hyperparameters.sample_parameter_object()
204
+
205
+ hyperparameters = {**self.h, **sampled_hyperparameters_passed}
206
+ x, y, y_ = self.get_batch(hyperparameters=hyperparameters, **self.args)
207
+
208
+ return x, y, y_, sampled_hyperparameters_indicators
209
+
210
+
211
+ # TODO: Make this a class that keeps objects
212
+ @torch.no_grad()
213
+ def get_batch(batch_size, seq_len, num_features, get_batch
214
+ , device=default_device, differentiable_hyperparameters={}
215
+ , hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
216
+ batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size))
217
+ num_models = batch_size // batch_size_per_gp_sample
218
+ assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
219
+
220
+ args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
221
+
222
+ models = [DifferentiablePrior(get_batch, hyperparameters, differentiable_hyperparameters, args) for _ in range(num_models)]
223
+ sample = sum([[model()] for model in models], [])
224
+
225
+ x, y, y_, hyperparameter_dict = zip(*sample)
226
+
227
+ if 'verbose' in hyperparameters and hyperparameters['verbose']:
228
+ print('Hparams', hyperparameter_dict[0].keys())
229
+
230
+ hyperparameter_matrix = []
231
+ for batch in hyperparameter_dict:
232
+ hyperparameter_matrix.append([batch[hp] for hp in batch])
233
+
234
+ transposed_hyperparameter_matrix = list(zip(*hyperparameter_matrix))
235
+ assert all([all([hp is None for hp in hp_]) or all([hp is not None for hp in hp_]) for hp_ in transposed_hyperparameter_matrix]), 'it should always be the case that when a hyper-parameter is None, once it is always None'
236
+ # we remove columns that are only None (i.e. not sampled)
237
+ hyperparameter_matrix = [[hp for hp in hp_ if hp is not None] for hp_ in hyperparameter_matrix]
238
+ if len(hyperparameter_matrix[0]) > 0:
239
+ packed_hyperparameters = torch.tensor(hyperparameter_matrix)
240
+ packed_hyperparameters = torch.repeat_interleave(packed_hyperparameters, repeats=batch_size_per_gp_sample, dim=0).detach()
241
+ else:
242
+ packed_hyperparameters = None
243
+
244
+ x, y, y_, packed_hyperparameters = (torch.cat(x, 1).detach()
245
+ , torch.cat(y, 1).detach()
246
+ , torch.cat(y_, 1).detach()
247
+ , packed_hyperparameters)#list(itertools.chain.from_iterable(itertools.repeat(x, batch_size_per_gp_sample) for x in packed_hyperparameters)))#torch.repeat_interleave(torch.stack(packed_hyperparameters, 0).detach(), repeats=batch_size_per_gp_sample, dim=0))
248
+
249
+ return x, y, y_, packed_hyperparameters
250
+
251
+ DataLoader = get_batch_to_dataloader(get_batch)
252
+ DataLoader.num_outputs = 1
253
+ #DataLoader.validate = lambda : 0
254
+
255
+ def draw_random_style(dl, device):
256
+ (hp_embedding, data, targets_), targets = next(iter(dl))
257
+ return hp_embedding.to(device)[0:1, :]
258
+
259
+ def merge_style_with_info(diff_hparams_keys, diff_hparams_f, style, transform=True):
260
+ params = dict(zip(diff_hparams_keys, zip(diff_hparams_f, style.detach().cpu().numpy().tolist()[0])))
261
+ def t(v):
262
+ if transform:
263
+ return v[0][1](v[1])
264
+ else:
265
+ return v[1]
266
+ return {k : t(v) for k, v in params.items()}
267
+
268
+
269
+ import ConfigSpace.hyperparameters as CSH
270
+
271
+ def replace_differentiable_distributions(config):
272
+ diff_config = config['differentiable_hyperparameters']
273
+ for name, diff_hp_dict in diff_config.items():
274
+ distribution = diff_hp_dict['distribution']
275
+ if distribution == 'uniform':
276
+ diff_hp_dict['sample'] = CSH.UniformFloatHyperparameter(name, diff_hp_dict['min'], diff_hp_dict['max'])
277
+ elif distribution == 'meta_beta':
278
+ diff_hp_dict['k'] = CSH.UniformFloatHyperparameter(name+'_k', diff_hp_dict['min'], diff_hp_dict['max'])
279
+ diff_hp_dict['b'] = CSH.UniformFloatHyperparameter(name+'_b', diff_hp_dict['min'], diff_hp_dict['max'])
280
+ elif distribution == 'meta_choice':
281
+ for i in range(1, len(diff_hp_dict['choice_values'])):
282
+ diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0)
283
+ elif distribution == 'meta_choice_mixed':
284
+ for i in range(1, len(diff_hp_dict['choice_values'])):
285
+ diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0)
286
+ elif distribution == 'meta_trunc_norm_log_scaled':
287
+ diff_hp_dict['log_mean'] = CSH.UniformFloatHyperparameter(name+'_log_mean', math.log(diff_hp_dict['min_mean']), math.log(diff_hp_dict['max_mean']))
288
+ min_std = diff_hp_dict['min_std'] if 'min_std' in diff_hp_dict else 0.001
289
+ max_std = diff_hp_dict['max_std'] if 'max_std' in diff_hp_dict else diff_hp_dict['max_mean']
290
+ diff_hp_dict['log_std'] = CSH.UniformFloatHyperparameter(name+'_log_std', math.log(min_std), math.log(max_std))
291
+ else:
292
+ raise ValueError(f'Unknown distribution {distribution}')
293
+
TabPFN/priors/fast_gp.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ from torch import nn
5
+ import gpytorch
6
+
7
+ from .utils import get_batch_to_dataloader
8
+ from utils import default_device
9
+
10
+
11
+ # We will use the simplest form of GP model, exact inference
12
+ class ExactGPModel(gpytorch.models.ExactGP):
13
+ def __init__(self, train_x, train_y, likelihood):
14
+ super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
15
+ self.mean_module = gpytorch.means.ConstantMean()
16
+ self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
17
+
18
+ def forward(self, x):
19
+ mean_x = self.mean_module(x)
20
+ covar_x = self.covar_module(x)
21
+ return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
22
+
23
+
24
+ def get_model(x, y, hyperparameters):
25
+ likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
26
+ model = ExactGPModel(x, y, likelihood)
27
+ model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
28
+ model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
29
+ model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
30
+ hyperparameters["lengthscale"]
31
+ return model, likelihood
32
+
33
+
34
+ @torch.no_grad()
35
+ def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
36
+ equidistant_x=False, fix_x=None, **kwargs):
37
+ if isinstance(hyperparameters, (tuple, list)):
38
+ hyperparameters = {"noise": hyperparameters[0]
39
+ , "outputscale": hyperparameters[1]
40
+ , "lengthscale": hyperparameters[2]
41
+ , "is_binary_classification": hyperparameters[3]
42
+ # , "num_features_used": hyperparameters[4]
43
+ , "normalize_by_used_features": hyperparameters[5]
44
+ , "order_y": hyperparameters[6]
45
+ , "sampling": hyperparameters[7]
46
+ }
47
+ elif hyperparameters is None:
48
+ hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
49
+
50
+ if 'verbose' in hyperparameters and hyperparameters['verbose']:
51
+ print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
52
+ , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})
53
+
54
+ # hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
55
+ # hyperparameters.keys()}
56
+ assert not (equidistant_x and (fix_x is not None))
57
+
58
+ with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
59
+ if equidistant_x:
60
+ assert num_features == 1
61
+ x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
62
+ elif fix_x is not None:
63
+ assert fix_x.shape == (seq_len, num_features)
64
+ x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
65
+ else:
66
+ if hyperparameters.get('sampling','uniform') == 'uniform':
67
+ x = torch.rand(batch_size, seq_len, num_features, device=device)
68
+ else:
69
+ x = torch.randn(batch_size, seq_len, num_features, device=device)
70
+ model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
71
+ model.to(device)
72
+ # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
73
+ # trained_model.eval()
74
+ is_fitted = False
75
+ while not is_fitted:
76
+ try:
77
+ with gpytorch.settings.prior_mode(True):
78
+ model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
79
+ model.to(device)
80
+
81
+ d = model(x)
82
+ d = likelihood(d)
83
+ sample = d.sample().transpose(0, 1)
84
+ is_fitted = True
85
+ except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
86
+ print('GP Fitting unsuccessful, retrying.. ')
87
+ print(x)
88
+ print(hyperparameters)
89
+
90
+ if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
91
+ print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
92
+ , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})
93
+
94
+ # TODO: Multi output
95
+ return x.transpose(0, 1), sample, sample # x.shape = (T,B,H)
96
+
97
+ DataLoader = get_batch_to_dataloader(get_batch)
98
+ DataLoader.num_outputs = 1
99
+
100
+ def get_model_on_device(x,y,hyperparameters,device):
101
+ model, likelihood = get_model(x, y, hyperparameters)
102
+ model.to(device)
103
+ return model, likelihood
104
+
105
+
106
+ @torch.no_grad()
107
+ def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
108
+ start_time = time.time()
109
+ losses_after_t = [.0] if start_pos == 0 else []
110
+ all_losses_after_t = []
111
+
112
+ with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
113
+ for t in range(max(start_pos, 1), len(x), step_size):
114
+ loss_sum = 0.
115
+ model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
116
+
117
+
118
+ model.eval()
119
+ # print([t.shape for t in model.train_inputs])
120
+ # print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
121
+ f = model(x[t].unsqueeze(1))
122
+ l = likelihood(f)
123
+ means = l.mean.squeeze()
124
+ varis = l.covariance_matrix.squeeze()
125
+ # print(l.variance.squeeze(), l.mean.squeeze(), y[t])
126
+
127
+ assert len(means.shape) == len(varis.shape) == 1
128
+ assert len(means) == len(varis) == x.shape[1]
129
+
130
+ if use_mse:
131
+ c = nn.MSELoss(reduction='none')
132
+ ls = c(means, y[t])
133
+ else:
134
+ ls = -l.log_prob(y[t].unsqueeze(1))
135
+
136
+ losses_after_t.append(ls.mean())
137
+ all_losses_after_t.append(ls.flatten())
138
+ return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
139
+
140
+ if __name__ == '__main__':
141
+ hps = (.1,.1,.1)
142
+ for redo_idx in range(1):
143
+ print(
144
+ evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
TabPFN/priors/flexible_categorical.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import random
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from .utils import get_batch_to_dataloader
8
+ from utils import normalize_data, nan_handling_missing_for_unknown_reason_value, nan_handling_missing_for_no_reason_value, nan_handling_missing_for_a_reason_value, to_ranking_low_mem, remove_outliers
9
+ from .utils import normalize_by_used_features_f, randomize_classes, CategoricalActivation
10
+ from .utils import uniform_int_sampler_f
11
+
12
+ time_it = False
13
+
14
+ class BalancedBinarize(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, x):
19
+ return (x > torch.median(x)).float()
20
+
21
+ def class_sampler_f(min_, max_):
22
+ def s():
23
+ if random.random() > 0.5:
24
+ return uniform_int_sampler_f(min_, max_)()
25
+ return 2
26
+ return s
27
+
28
+ class MulticlassRank(nn.Module):
29
+ def __init__(self, num_classes, ordered_p=0.5):
30
+ super().__init__()
31
+ self.num_classes = class_sampler_f(2, num_classes)()
32
+ self.ordered_p = ordered_p
33
+
34
+ def forward(self, x):
35
+ # x has shape (T,B,H)
36
+
37
+ # CAUTION: This samples the same idx in sequence for each class boundary in a batch
38
+ class_boundaries = torch.randint(0, x.shape[0], (self.num_classes - 1,))
39
+ class_boundaries = x[class_boundaries].unsqueeze(1)
40
+
41
+ d = (x > class_boundaries).sum(axis=0)
42
+
43
+ randomized_classes = torch.rand((d.shape[1], )) > self.ordered_p
44
+ d[:, randomized_classes] = randomize_classes(d[:, randomized_classes], self.num_classes)
45
+ reverse_classes = torch.rand((d.shape[1],)) > 0.5
46
+ d[:, reverse_classes] = self.num_classes - 1 - d[:, reverse_classes]
47
+ return d
48
+
49
+ class MulticlassValue(nn.Module):
50
+ def __init__(self, num_classes, ordered_p=0.5):
51
+ super().__init__()
52
+ self.num_classes = class_sampler_f(2, num_classes)()
53
+ self.classes = nn.Parameter(torch.randn(num_classes-1), requires_grad=False)
54
+ self.ordered_p = ordered_p
55
+
56
+ def forward(self, x):
57
+ # x has shape (T,B,H)
58
+ d = (x > (self.classes.unsqueeze(-1).unsqueeze(-1))).sum(axis=0)
59
+
60
+ randomized_classes = torch.rand((d.shape[1],)) > self.ordered_p
61
+ d[:, randomized_classes] = randomize_classes(d[:, randomized_classes], self.num_classes)
62
+ reverse_classes = torch.rand((d.shape[1],)) > 0.5
63
+ d[:, reverse_classes] = self.num_classes - 1 - d[:, reverse_classes]
64
+ return d
65
+
66
+ class MulticlassMultiNode(nn.Module):
67
+ def __init__(self, num_classes, ordered_p=0.5):
68
+ super().__init__()
69
+ self.num_classes = class_sampler_f(2, num_classes)()
70
+ self.classes = nn.Parameter(torch.randn(num_classes-1), requires_grad=False)
71
+ self.alt_multi_class = MulticlassValue(num_classes, ordered_p)
72
+
73
+ def forward(self, x):
74
+ # x has shape T, B, H
75
+ if len(x.shape) == 2:
76
+ return self.alt_multi_class(x)
77
+ T = 3
78
+ x[torch.isnan(x)] = 0.00001
79
+ d = torch.multinomial(torch.pow(0.00001+torch.sigmoid(x[:, :, 0:self.num_classes]).reshape(-1, self.num_classes), T), 1, replacement=True).reshape(x.shape[0], x.shape[1]).float()
80
+ return d
81
+
82
+
83
+ class FlexibleCategorical(torch.nn.Module):
84
+ def __init__(self, get_batch, hyperparameters, args):
85
+ super(FlexibleCategorical, self).__init__()
86
+
87
+ self.h = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
88
+ hyperparameters.keys()}
89
+ self.args = args
90
+ self.args_passed = {**self.args}
91
+ self.args_passed.update({'num_features': self.h['num_features_used']})
92
+ self.get_batch = get_batch
93
+
94
+ if self.h['num_classes'] > 1 and not self.h['balanced']:
95
+ if self.h['multiclass_type'] == 'rank':
96
+ self.class_assigner = MulticlassRank(self.h['num_classes']
97
+ , ordered_p=self.h['output_multiclass_ordered_p']
98
+ )
99
+ elif self.h['multiclass_type'] == 'value':
100
+ self.class_assigner = MulticlassValue(self.h['num_classes']
101
+ , ordered_p=self.h['output_multiclass_ordered_p']
102
+ )
103
+ elif self.h['multiclass_type'] == 'multi_node':
104
+ self.class_assigner = MulticlassMultiNode(self.h['num_classes'])
105
+ else:
106
+ raise ValueError("Unknow Multiclass type")
107
+ elif self.h['num_classes'] == 2 and self.h['balanced']:
108
+ self.class_assigner = BalancedBinarize()
109
+ elif self.h['num_classes'] > 2 and self.h['balanced']:
110
+ raise NotImplementedError("Balanced multiclass training is not possible")
111
+ else:
112
+ self.class_assigner = lambda x:x # Regression
113
+
114
+ def drop_for_reason(self, x, v):
115
+ nan_prob_sampler = CategoricalActivation(ordered_p=0.0
116
+ , categorical_p=1.0
117
+ , keep_activation_size=False,
118
+ num_classes_sampler=lambda: 20)
119
+ d = nan_prob_sampler(x)
120
+ # TODO: Make a different ordering for each activation
121
+ x[d < torch.rand((1,), device=x.device) * 20 * self.h['nan_prob_no_reason'] * random.random()] = v
122
+ return x
123
+
124
+ def drop_for_no_reason(self, x, v):
125
+ x[torch.rand(x.shape, device=self.args['device']) < self.h['nan_prob_no_reason']] = v
126
+ return x
127
+
128
+ def forward(self, batch_size):
129
+ start = time.time()
130
+ x, y, y_ = self.get_batch(hyperparameters=self.h, **self.args_passed)
131
+ if time_it:
132
+ print('Flex Forward Block 1', round(time.time() - start, 3))
133
+
134
+ start = time.time()
135
+
136
+ if self.h['nan_prob_no_reason']+self.h['nan_prob_a_reason']+self.h['nan_prob_unknown_reason'] > 0 and random.random() > 0.5: # Only one out of two datasets should have nans
137
+ if self.h['nan_prob_no_reason'] > 0 and random.random() > 0.5: # Missing for no reason
138
+ x = self.drop_for_no_reason(x, nan_handling_missing_for_no_reason_value(self.h['set_value_to_nan']))
139
+
140
+ if self.h['nan_prob_a_reason'] > 0 and random.random() > 0.5: # Missing for a reason
141
+ x = self.drop_for_reason(x, nan_handling_missing_for_a_reason_value(self.h['set_value_to_nan']))
142
+
143
+ if self.h['nan_prob_unknown_reason'] > 0: # Missing for unknown reason and random.random() > 0.5
144
+ if random.random() < self.h['nan_prob_unknown_reason_reason_prior']:
145
+ x = self.drop_for_no_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan']))
146
+ else:
147
+ x = self.drop_for_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan']))
148
+
149
+ # Categorical features
150
+ if 'categorical_feature_p' in self.h and random.random() > 1 - self.h['categorical_feature_p']:
151
+ p = random.random()
152
+ for col in range(x.shape[2]):
153
+ m = MulticlassRank(10, ordered_p=0.3)
154
+ if random.random() > p:
155
+ x[:, :, col] = m(x[:, :, col])
156
+
157
+ if time_it:
158
+ print('Flex Forward Block 2', round(time.time() - start, 3))
159
+ start = time.time()
160
+
161
+ if self.h['normalize_to_ranking']:
162
+ x = to_ranking_low_mem(x)
163
+ else:
164
+ x = remove_outliers(x)
165
+ x, y = normalize_data(x), normalize_data(y)
166
+
167
+ if time_it:
168
+ print('Flex Forward Block 3', round(time.time() - start, 3))
169
+ start = time.time()
170
+
171
+ # Cast to classification if enabled
172
+ y = self.class_assigner(y).float()
173
+
174
+ if time_it:
175
+ print('Flex Forward Block 4', round(time.time() - start, 3))
176
+ start = time.time()
177
+ if self.h['normalize_by_used_features']:
178
+ x = normalize_by_used_features_f(x, self.h['num_features_used'], self.args['num_features'], normalize_with_sqrt=self.h.get('normalize_with_sqrt',False))
179
+ if time_it:
180
+ print('Flex Forward Block 5', round(time.time() - start, 3))
181
+
182
+ start = time.time()
183
+ # Append empty features if enabled
184
+ x = torch.cat(
185
+ [x, torch.zeros((x.shape[0], x.shape[1], self.args['num_features'] - self.h['num_features_used']),
186
+ device=self.args['device'])], -1)
187
+ if time_it:
188
+ print('Flex Forward Block 6', round(time.time() - start, 3))
189
+
190
+ return x, y, y # x.shape = (T,B,H)
191
+
192
+ import torch.cuda as cutorch
193
+
194
+ @torch.no_grad()
195
+ def get_batch(batch_size, seq_len, num_features, get_batch, device, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
196
+ batch_size_per_gp_sample = batch_size_per_gp_sample or (min(32, batch_size))
197
+ num_models = batch_size // batch_size_per_gp_sample
198
+ assert num_models > 0, f'Batch size ({batch_size}) is too small for batch_size_per_gp_sample ({batch_size_per_gp_sample})'
199
+ assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
200
+
201
+ # Sample one seq_len for entire batch
202
+ seq_len = hyperparameters['seq_len_used']() if callable(hyperparameters['seq_len_used']) else seq_len
203
+
204
+ args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
205
+
206
+ models = [FlexibleCategorical(get_batch, hyperparameters, args).to(device) for _ in range(num_models)]
207
+
208
+ start = time.time()
209
+ sample = sum([[model(batch_size=batch_size_per_gp_sample)] for model in models], [])
210
+ #print('sample', time.time() - start)
211
+
212
+ x, y, y_ = zip(*sample)
213
+ x, y, y_ = torch.cat(x, 1).detach(), torch.cat(y, 1).detach(), torch.cat(y_, 1).detach()
214
+
215
+ # # TODO: Reintegrate this code (Doesn't work on batch dim), could be applied to each batch sample individually
216
+ # if hyperparameters['is_binary_classification'] and hyperparameters['order_y']:
217
+ # x, y = order_by_y(x, y)
218
+
219
+ return x, y, y_
220
+
221
+ # num_features_used = num_features_used_sampler()
222
+ # prior_outputscale = prior_outputscale_sampler()
223
+ # prior_lengthscale = prior_lengthscale_sampler()
224
+ #
225
+ # x, sample = normalize_data(x), normalize_data(sample)
226
+ #
227
+ # if is_binary_classification:
228
+ # sample = (sample > torch.median(sample, dim=0)[0]).float()
229
+ #
230
+ # if normalize_by_used_features:
231
+ # x = normalize_by_used_features_f(x, num_features_used, num_features)
232
+ #
233
+ # # # if is_binary_classification and order_y:
234
+ # # # x, sample = order_by_y(x, sample)
235
+ # #
236
+ # # Append empty features if enabled
237
+ # x = torch.cat([x, torch.zeros((x.shape[0], x.shape[1], num_features - num_features_used), device=device)], -1)
238
+
239
+ DataLoader = get_batch_to_dataloader(get_batch)
240
+ DataLoader.num_outputs = 1
TabPFN/priors/mlp.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+ from utils import default_device
9
+ from .utils import get_batch_to_dataloader
10
+
11
+ class GaussianNoise(nn.Module):
12
+ def __init__(self, std, device):
13
+ super().__init__()
14
+ self.std = std
15
+ self.device=device
16
+
17
+ def forward(self, x):
18
+ return x + torch.normal(torch.zeros_like(x), self.std)
19
+
20
+
21
+ def causes_sampler_f(num_causes):
22
+ means = np.random.normal(0, 1, (num_causes))
23
+ std = np.abs(np.random.normal(0, 1, (num_causes)) * means)
24
+ return means, std
25
+
26
+ def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, sampling='normal', **kwargs):
27
+ if ('mix_activations' in hyperparameters) and hyperparameters['mix_activations']:
28
+ s = hyperparameters['prior_mlp_activations']()
29
+ hyperparameters['prior_mlp_activations'] = lambda : s
30
+
31
+ class MLP(torch.nn.Module):
32
+ def __init__(self, hyperparameters):
33
+ super(MLP, self).__init__()
34
+
35
+ with torch.no_grad():
36
+
37
+ for key in hyperparameters:
38
+ setattr(self, key, hyperparameters[key])
39
+
40
+ assert (self.num_layers >= 2)
41
+
42
+ if 'verbose' in hyperparameters and self.verbose:
43
+ print({k : hyperparameters[k] for k in ['is_causal', 'num_causes', 'prior_mlp_hidden_dim'
44
+ , 'num_layers', 'noise_std', 'y_is_effect', 'pre_sample_weights', 'prior_mlp_dropout_prob'
45
+ , 'pre_sample_causes']})
46
+
47
+ if self.is_causal:
48
+ self.prior_mlp_hidden_dim = max(self.prior_mlp_hidden_dim, num_outputs + 2 * num_features)
49
+ else:
50
+ self.num_causes = num_features
51
+
52
+ # This means that the mean and standard deviation of each cause is determined in advance
53
+ if self.pre_sample_causes:
54
+ self.causes_mean, self.causes_std = causes_sampler_f(self.num_causes)
55
+ self.causes_mean = torch.tensor(self.causes_mean, device=device).unsqueeze(0).unsqueeze(0).tile(
56
+ (seq_len, 1, 1))
57
+ self.causes_std = torch.tensor(self.causes_std, device=device).unsqueeze(0).unsqueeze(0).tile(
58
+ (seq_len, 1, 1))
59
+
60
+ def generate_module(layer_idx, out_dim):
61
+ # Determine std of each noise term in initialization, so that is shared in runs
62
+ # torch.abs(torch.normal(torch.zeros((out_dim)), self.noise_std)) - Change std for each dimension?
63
+ noise = (GaussianNoise(torch.abs(torch.normal(torch.zeros(size=(1, out_dim), device=device), float(self.noise_std))), device=device)
64
+ if self.pre_sample_weights else GaussianNoise(float(self.noise_std), device=device))
65
+ return [
66
+ nn.Sequential(*[self.prior_mlp_activations()
67
+ , nn.Linear(self.prior_mlp_hidden_dim, out_dim)
68
+ , noise])
69
+ ]
70
+
71
+ self.layers = [nn.Linear(self.num_causes, self.prior_mlp_hidden_dim, device=device)]
72
+ self.layers += [module for layer_idx in range(self.num_layers-1) for module in generate_module(layer_idx, self.prior_mlp_hidden_dim)]
73
+ if not self.is_causal:
74
+ self.layers += generate_module(-1, num_outputs)
75
+ self.layers = nn.Sequential(*self.layers)
76
+
77
+ # Initialize Model parameters
78
+ for i, (n, p) in enumerate(self.layers.named_parameters()):
79
+ if self.block_wise_dropout:
80
+ if len(p.shape) == 2: # Only apply to weight matrices and not bias
81
+ nn.init.zeros_(p)
82
+ # TODO: N blocks should be a setting
83
+ n_blocks = random.randint(1, math.ceil(math.sqrt(min(p.shape[0], p.shape[1]))))
84
+ w, h = p.shape[0] // n_blocks, p.shape[1] // n_blocks
85
+ keep_prob = (n_blocks*w*h) / p.numel()
86
+ for block in range(0, n_blocks):
87
+ nn.init.normal_(p[w * block: w * (block+1), h * block: h * (block+1)], std=self.init_std / keep_prob**(1/2))
88
+ else:
89
+ if len(p.shape) == 2: # Only apply to weight matrices and not bias
90
+ dropout_prob = self.prior_mlp_dropout_prob if i > 0 else 0.0 # Don't apply dropout in first layer
91
+ dropout_prob = min(dropout_prob, 0.99)
92
+ nn.init.normal_(p, std=self.init_std / (1. - dropout_prob)**(1/2))
93
+ p *= torch.bernoulli(torch.zeros_like(p) + 1. - dropout_prob)
94
+
95
+ def forward(self):
96
+ def sample_normal():
97
+ if self.pre_sample_causes:
98
+ causes = torch.normal(self.causes_mean, self.causes_std.abs()).float()
99
+ else:
100
+ causes = torch.normal(0., 1., (seq_len, 1, self.num_causes), device=device).float()
101
+ return causes
102
+
103
+ if self.sampling == 'normal':
104
+ causes = sample_normal()
105
+ elif self.sampling == 'mixed':
106
+ zipf_p, multi_p, normal_p = random.random() * 0.66, random.random() * 0.66, random.random() * 0.66
107
+ def sample_cause(n):
108
+ if random.random() > normal_p:
109
+ if self.pre_sample_causes:
110
+ return torch.normal(self.causes_mean[:, :, n], self.causes_std[:, :, n].abs()).float()
111
+ else:
112
+ return torch.normal(0., 1., (seq_len, 1), device=device).float()
113
+ elif random.random() > multi_p:
114
+ x = torch.multinomial(torch.rand((random.randint(2, 10))), seq_len, replacement=True).to(device).unsqueeze(-1).float()
115
+ x = (x - torch.mean(x)) / torch.std(x)
116
+ return x
117
+ else:
118
+ x = torch.minimum(torch.tensor(np.random.zipf(2.0 + random.random() * 2, size=(seq_len)),
119
+ device=device).unsqueeze(-1).float(), torch.tensor(10.0, device=device))
120
+ return x - torch.mean(x)
121
+ causes = torch.cat([sample_cause(n).unsqueeze(-1) for n in range(self.num_causes)], -1)
122
+ elif self.sampling == 'uniform':
123
+ causes = torch.rand((seq_len, 1, self.num_causes), device=device)
124
+ else:
125
+ raise ValueError(f'Sampling is set to invalid setting: {sampling}.')
126
+
127
+ outputs = [causes]
128
+ for layer in self.layers:
129
+ outputs.append(layer(outputs[-1]))
130
+ outputs = outputs[2:]
131
+
132
+ if self.is_causal:
133
+ ## Sample nodes from graph if model is causal
134
+ outputs_flat = torch.cat(outputs, -1)
135
+
136
+ if self.in_clique:
137
+ random_perm = random.randint(0, outputs_flat.shape[-1] - num_outputs - num_features) + torch.randperm(num_outputs + num_features, device=device)
138
+ else:
139
+ random_perm = torch.randperm(outputs_flat.shape[-1]-1, device=device)
140
+
141
+ random_idx_y = list(range(-num_outputs, -0)) if self.y_is_effect else random_perm[0:num_outputs]
142
+ random_idx = random_perm[num_outputs:num_outputs + num_features]
143
+
144
+ if self.sort_features:
145
+ random_idx, _ = torch.sort(random_idx)
146
+ y = outputs_flat[:, :, random_idx_y]
147
+
148
+ x = outputs_flat[:, :, random_idx]
149
+ else:
150
+ y = outputs[-1][:, :, :]
151
+ x = causes
152
+
153
+ if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()) or bool(torch.any(torch.isnan(y)).detach().cpu().numpy()):
154
+ x[:] = 0.0
155
+ y[:] = 1.0
156
+
157
+ return x, y
158
+
159
+ model = MLP(hyperparameters).to(device)
160
+
161
+ sample = sum([[model()] for _ in range(0, batch_size)], [])
162
+
163
+ x, y = zip(*sample)
164
+ y = torch.cat(y, 1).detach().squeeze(2)
165
+ x = torch.cat(x, 1).detach()
166
+ x = x[..., torch.randperm(x.shape[-1])]
167
+
168
+ return x, y, y
169
+
170
+
171
+ DataLoader = get_batch_to_dataloader(get_batch)
172
+ DataLoader.num_outputs = 1
173
+
TabPFN/priors/prior.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+
3
+
4
+ class PriorDataLoader(DataLoader):
5
+ pass
6
+ # init accepts num_steps as first argument
7
+
8
+ # has two attributes set on class or object level:
9
+ # num_features: int and
10
+ # num_outputs: int
11
+ # fuse_x_y: bool
12
+ # Optional: validate function that accepts a transformer model
TabPFN/priors/prior_bag.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .utils import get_batch_to_dataloader
4
+ from utils import default_device
5
+
6
+ def get_batch(batch_size, seq_len, num_features, device=default_device
7
+ , hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
8
+ batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size))
9
+ num_models = batch_size // batch_size_per_gp_sample
10
+ assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
11
+
12
+ args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
13
+
14
+ prior_bag_priors_get_batch = hyperparameters['prior_bag_get_batch']
15
+ prior_bag_priors_p = [1.0] + [hyperparameters[f'prior_bag_exp_weights_{i}'] for i in range(1, len(prior_bag_priors_get_batch))]
16
+
17
+ weights = torch.tensor(prior_bag_priors_p, dtype=torch.float) # create a tensor of weights
18
+ batch_assignments = torch.multinomial(torch.softmax(weights, 0), num_models, replacement=True).numpy()
19
+
20
+ if 'verbose' in hyperparameters and hyperparameters['verbose']:
21
+ print('PRIOR_BAG:', weights, batch_assignments)
22
+
23
+ sample = sum([[prior_bag_priors_get_batch[int(prior_idx)](hyperparameters=hyperparameters, **args)] for prior_idx in batch_assignments], [])
24
+
25
+ x, y, y_ = zip(*sample)
26
+ x, y, y_ = (torch.cat(x, 1).detach()
27
+ , torch.cat(y, 1).detach()
28
+ , torch.cat(y_, 1).detach())
29
+ return x, y, y_
30
+
31
+ DataLoader = get_batch_to_dataloader(get_batch)
32
+ DataLoader.num_outputs = 1
TabPFN/priors/utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+
5
+ from utils import set_locals_in_self
6
+ from .prior import PriorDataLoader
7
+ from torch import nn
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.gridspec as gridspec
11
+ import scipy.stats as stats
12
+ import math
13
+
14
+ def get_batch_to_dataloader(get_batch_method_):
15
+ class DL(PriorDataLoader):
16
+ get_batch_method = get_batch_method_
17
+
18
+ # Caution, you might need to set self.num_features manually if it is not part of the args.
19
+ def __init__(self, num_steps, fuse_x_y=False, **get_batch_kwargs):
20
+ set_locals_in_self(locals())
21
+ # The stuff outside the or is set as class attribute before instantiation.
22
+ self.num_features = get_batch_kwargs.get('num_features') or self.num_features
23
+ self.num_outputs = get_batch_kwargs.get('num_outputs') or self.num_outputs
24
+ print('DataLoader.__dict__', self.__dict__)
25
+
26
+ @staticmethod
27
+ def gbm(*args, fuse_x_y=True, **kwargs):
28
+ dynamic_seq_len = callable(kwargs['seq_len'])
29
+ kwargs['seq_len'] = kwargs['seq_len']() if dynamic_seq_len else kwargs['seq_len']
30
+ # Scales the batch size dynamically with the power of 'dynamic_batch_size'.
31
+ # A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
32
+ if dynamic_seq_len and 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0:
33
+ kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
34
+ batch = get_batch_method_(*args, **kwargs)
35
+ x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
36
+ if fuse_x_y:
37
+ return torch.cat([x, torch.cat([torch.zeros_like(y[:1]), y[:-1]], 0).unsqueeze(-1).float()],
38
+ -1), target_y
39
+ else:
40
+ return (style, x, y), target_y
41
+
42
+ def __len__(self):
43
+ return self.num_steps
44
+
45
+ def __iter__(self):
46
+ return iter(self.gbm(**self.get_batch_kwargs, fuse_x_y=self.fuse_x_y) for _ in range(self.num_steps))
47
+
48
+
49
+ return DL
50
+
51
+ import seaborn as sns
52
+ def plot_features(data, targets, fig=None):
53
+ if torch.is_tensor(data):
54
+ data = data.detach().cpu().numpy()
55
+ targets = targets.detach().cpu().numpy()
56
+ #data = np.concatenate([data, data[:, -1:]], -1)
57
+ #df = pd.DataFrame(data, columns=list(range(0, data.shape[1])))
58
+ #g = sns.pairplot(df, hue=data.shape[1]-1, palette="Set2", diag_kind="kde", height=2.5)
59
+ #plt.legend([], [], frameon=False)
60
+ #g._legend.remove()
61
+ #g = sns.PairGrid(df, hue=data.shape[1]-1)
62
+ #g.map_diag(sns.histplot)
63
+ #g.map_offdiag(sns.scatterplot)
64
+ #g._legend.remove()
65
+
66
+ fig2 = fig if fig else plt.figure(figsize=(8, 8))
67
+ spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
68
+ for d in range(0, data.shape[1]):
69
+ for d2 in range(0, data.shape[1]):
70
+ sub_ax = fig2.add_subplot(spec2[d, d2])
71
+ if d == d2:
72
+ sns.kdeplot(data[:, d],hue=targets[:],ax=sub_ax,legend=False, palette="deep")
73
+ sub_ax.set(ylabel=None)
74
+ else:
75
+ sns.scatterplot(x=data[:, d], y=data[:, d2],
76
+ hue=targets[:],legend=False, palette="deep")
77
+ #plt.scatter(data[:, d], data[:, d2],
78
+ # c=targets[:])
79
+ sub_ax.get_xaxis().set_ticks([])
80
+ sub_ax.get_yaxis().set_ticks([])
81
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
82
+ fig2.show()
83
+
84
+
85
+ def plot_prior(prior):
86
+ s = np.array([prior() for _ in range(0, 1000)])
87
+ count, bins, ignored = plt.hist(s, 50, density=True)
88
+ print(s.min())
89
+ plt.show()
90
+
91
+ trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
92
+ beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
93
+ gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
94
+ uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
95
+ uniform_int_sampler_f = lambda a, b : lambda : round(np.random.uniform(a, b))
96
+ def zipf_sampler_f(a, b, c):
97
+ x = np.arange(b, c)
98
+ weights = x ** (-a)
99
+ weights /= weights.sum()
100
+ return lambda : stats.rv_discrete(name='bounded_zipf', values=(x, weights)).rvs(1)
101
+ scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum))
102
+
103
+
104
+ def normalize_by_used_features_f(x, num_features_used, num_features, normalize_with_sqrt=False):
105
+ if normalize_with_sqrt:
106
+ return x / (num_features_used / num_features)**(1 / 2)
107
+ return x / (num_features_used / num_features)
108
+
109
+
110
+ def order_by_y(x, y):
111
+ order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
112
+ order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
113
+ x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
114
+ y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
115
+
116
+ return x, y
117
+
118
+ def randomize_classes(x, num_classes):
119
+ classes = torch.arange(0, num_classes, device=x.device)
120
+ random_classes = torch.randperm(num_classes, device=x.device).type(x.type())
121
+ x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1)
122
+ return x
123
+
124
+
125
+ class CategoricalActivation(nn.Module):
126
+ def __init__(self, categorical_p=0.1, ordered_p=0.7
127
+ , keep_activation_size=False
128
+ , num_classes_sampler=zipf_sampler_f(0.8, 1, 10)):
129
+ self.categorical_p = categorical_p
130
+ self.ordered_p = ordered_p
131
+ self.keep_activation_size = keep_activation_size
132
+ self.num_classes_sampler = num_classes_sampler
133
+
134
+ super().__init__()
135
+
136
+ def forward(self, x):
137
+ # x shape: T, B, H
138
+
139
+ x = nn.Softsign()(x)
140
+
141
+ num_classes = self.num_classes_sampler()
142
+ hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None
143
+
144
+ categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p
145
+ class_boundaries = torch.zeros((num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype)
146
+ # Sample a different index for each hidden dimension, but shared for all batches
147
+ for b in range(x.shape[1]):
148
+ for h in range(x.shape[2]):
149
+ ind = torch.randint(0, x.shape[0], (num_classes - 1,))
150
+ class_boundaries[:, b, h] = x[ind, b, h]
151
+
152
+ for b in range(x.shape[1]):
153
+ x_rel = x[:, b, categorical_classes[b]]
154
+ boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1)
155
+ x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum(dim=0).float() - num_classes / 2
156
+
157
+ ordered_classes = torch.rand((x.shape[1],x.shape[2])) < self.ordered_p
158
+ ordered_classes = torch.logical_and(ordered_classes, categorical_classes)
159
+ x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes)
160
+
161
+ x = x * hid_strength if self.keep_activation_size else x
162
+
163
+ return x
TabPFN/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please use python V 3.7 to be compatible with all packages
2
+ gpytorch==1.5.0
3
+ torch==1.9.0
4
+ scikit-learn==0.24.2
5
+ pyyaml==5.4.1
6
+ seaborn==0.11.2
7
+ xgboost==1.4.0
8
+ tqdm==4.62.1
9
+ numpy==1.21.2
10
+ openml==0.12.2
11
+ catboost==0.26.1
12
+ auto-sklearn==0.14.5
13
+ hyperopt==0.2.5
14
+ configspace==0.4.21
15
+ # autogluon==0.4.0
TabPFN/scripts/baseline_prediction_interface.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import numpy as np
3
+
4
+ def baseline_predict(metric_function, eval_xs, eval_ys, categorical_feats, metric_used=None, eval_pos=2, max_time=300, **kwargs):
5
+ """
6
+ Baseline prediction interface.
7
+ :param metric_function:
8
+ :param eval_xs:
9
+ :param eval_ys:
10
+ :param categorical_feats:
11
+ :param metric_used:
12
+ :param eval_pos:
13
+ :param max_time: Scheduled maximum time
14
+ :param kwargs:
15
+ :return: list [np.array(metrics), np.array(outputs), best_configs] or [None, None, None] if failed
16
+ """
17
+
18
+ metrics = []
19
+ outputs = []
20
+ best_configs = []
21
+ eval_splits = list(zip(eval_xs.transpose(0, 1), eval_ys.transpose(0, 1)))
22
+ for eval_x, eval_y in tqdm.tqdm(eval_splits, desc='Calculating splits'+str(metric_function)+' '+str(eval_pos)):
23
+ try:
24
+ metric, output, best_config = metric_function(eval_x[:eval_pos],
25
+ eval_y[:eval_pos],
26
+ eval_x[eval_pos:],
27
+ eval_y[eval_pos:],
28
+ categorical_feats,
29
+ metric_used=metric_used
30
+ , max_time=max_time)
31
+ metrics += [metric]
32
+ outputs += [output]
33
+ best_configs += [best_config]
34
+ return np.array(metrics), np.array(outputs), best_configs
35
+ except Exception as e:
36
+ print(f'There was an exception in {metric_function}')
37
+ print(e)
38
+ return None, None, None
TabPFN/scripts/differentiable_pfn_evaluation.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ import pickle
6
+ from scripts import tabular_metrics
7
+ from scripts.tabular_metrics import calculate_score_per_method
8
+ from scripts.tabular_evaluation import evaluate
9
+ from priors.differentiable_prior import draw_random_style
10
+ from tqdm import tqdm
11
+ from pathlib import Path
12
+ import random
13
+ from model_builder import load_model
14
+ from scripts.transformer_prediction_interface import get_params_from_config
15
+
16
+ """
17
+ ===============================
18
+ PUBLIC FUNCTIONS FOR EVALUATION
19
+ ===============================
20
+ """
21
+
22
+
23
+ def eval_model_range(i_range, *args, **kwargs):
24
+ for i in i_range:
25
+ eval_model(i, *args, **kwargs)
26
+
27
+
28
+ def load_model_workflow(i, e, add_name, base_path, device='cpu', eval_addition=''):
29
+ """
30
+ Workflow for loading a model and setting appropriate parameters for diffable hparam tuning.
31
+
32
+ :param i:
33
+ :param e:
34
+ :param eval_positions_valid:
35
+ :param add_name:
36
+ :param base_path:
37
+ :param device:
38
+ :param eval_addition:
39
+ :return:
40
+ """
41
+ def check_file(e):
42
+ model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
43
+ model_path = os.path.join(base_path, model_file)
44
+ # print('Evaluate ', model_path)
45
+ results_file = os.path.join(base_path,
46
+ f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
47
+ if not Path(model_path).is_file(): # or Path(results_file).is_file():
48
+ return None, None, None
49
+ return model_file, model_path, results_file
50
+
51
+ model_file = None
52
+ if e == -1:
53
+ for e_ in range(100, -1, -1):
54
+ model_file_, model_path_, results_file_ = check_file(e_)
55
+ if model_file_ is not None:
56
+ e = e_
57
+ model_file, model_path, results_file = model_file_, model_path_, results_file_
58
+ break
59
+ else:
60
+ model_file, model_path, results_file = check_file(e)
61
+
62
+ if model_file is None:
63
+ print('No checkpoint found')
64
+ return None
65
+
66
+ print(f'Loading {model_file}')
67
+
68
+ model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False)
69
+
70
+ return model, c, results_file
71
+
72
+
73
+ def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
74
+ bptt_valid,
75
+ bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
76
+ """
77
+ Differentiable model evaliation workflow. Evaluates and saves results to disk.
78
+
79
+ :param i:
80
+ :param e:
81
+ :param valid_datasets:
82
+ :param test_datasets:
83
+ :param train_datasets:
84
+ :param eval_positions_valid:
85
+ :param eval_positions_test:
86
+ :param bptt_valid:
87
+ :param bptt_test:
88
+ :param add_name:
89
+ :param base_path:
90
+ :param device:
91
+ :param eval_addition:
92
+ :param extra_tuning_args:
93
+ :return:
94
+ """
95
+ model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition)
96
+ params = {'bptt': bptt_valid
97
+ , 'bptt_final': bptt_test
98
+ , 'eval_positions': eval_positions_valid
99
+ , 'eval_positions_test': eval_positions_test
100
+ , 'valid_datasets': valid_datasets
101
+ , 'test_datasets': test_datasets
102
+ , 'train_datasets': train_datasets
103
+ , 'verbose': True
104
+ , 'device': device
105
+ }
106
+
107
+ params.update(get_params_from_config(c))
108
+
109
+ start = time.time()
110
+ metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params,
111
+ **extra_tuning_args)
112
+ print('Evaluation time: ', time.time() - start)
113
+
114
+ print(results_file)
115
+ r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
116
+ with open(results_file, 'wb') as output:
117
+ del r[0]['num_features_used']
118
+ del r[0]['categorical_features_sampler']
119
+ pickle.dump(r, output)
120
+
121
+ _, _, _, style, temperature, _ = r
122
+
123
+ return r, model
124
+
125
+ """
126
+ ===============================
127
+ INTERNAL HELPER FUNCTIONS
128
+ ===============================
129
+ """
130
+
131
+ def evaluate_differentiable_model(model
132
+ , valid_datasets
133
+ , test_datasets
134
+ , train_datasets
135
+ , N_draws=100
136
+ , N_grad_steps=10
137
+ , eval_positions=None
138
+ , eval_positions_test=None
139
+ , bptt=100
140
+ , bptt_final=200
141
+ , style=None
142
+ , n_parallel_configurations=1
143
+ , device='cpu'
144
+ , selection_metric='auc'
145
+ , final_splits=[1, 2, 3, 4, 5]
146
+ , N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
147
+ , **kwargs):
148
+ """
149
+ Evaluation function for diffable model evaluation. Returns a list of results.
150
+
151
+ :param model:
152
+ :param valid_datasets:
153
+ :param test_datasets:
154
+ :param train_datasets:
155
+ :param N_draws:
156
+ :param N_grad_steps:
157
+ :param eval_positions:
158
+ :param eval_positions_test:
159
+ :param bptt:
160
+ :param bptt_final:
161
+ :param style:
162
+ :param n_parallel_configurations:
163
+ :param device:
164
+ :param selection_metric:
165
+ :param final_splits:
166
+ :param N_ensemble_configurations_list:
167
+ :param kwargs:
168
+ :return:
169
+ """
170
+ torch.manual_seed(0)
171
+ np.random.seed(0)
172
+ random.seed(0)
173
+
174
+ diffable_metric = tabular_metrics.cross_entropy
175
+ evaluation_metric = tabular_metrics.auc_metric
176
+ if selection_metric in ('auc', 'roc'):
177
+ selection_metric_min_max = 'max'
178
+ selection_metric = tabular_metrics.auc_metric
179
+ evaluation_metric = selection_metric
180
+ elif selection_metric in ('ce', 'selection_metric'):
181
+ selection_metric_min_max = 'min'
182
+ selection_metric = tabular_metrics.cross_entropy
183
+ evaluation_metric = selection_metric
184
+
185
+ print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
186
+ evaluation_metric)
187
+ print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
188
+ print('eval_positions', eval_positions)
189
+
190
+ def evaluate_valid(style, softmax_temperature, results, results_tracked):
191
+ result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
192
+ return_tensor=False, inference_mode=True, selection_metric=selection_metric,
193
+ evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
194
+ result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
195
+ results += [result_valid]
196
+ results_tracked += [np.nanmean(result_valid)]
197
+
198
+ model[2].to(device)
199
+ model[2].eval()
200
+
201
+ results_on_valid, results_on_valid_tracked = [], []
202
+ best_style, best_softmax_temperature = style, torch.cat(
203
+ [torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
204
+ optimization_routes = []
205
+
206
+ best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
207
+ 0)
208
+ best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
209
+ 0)
210
+
211
+
212
+ for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
213
+ style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
214
+ 0)
215
+ softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
216
+ 0)
217
+
218
+ evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)
219
+
220
+ print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')
221
+
222
+ if N_grad_steps > 0:
223
+ gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
224
+ , softmax_temperature=softmax_temperature
225
+ , model=model[2]
226
+ , train_datasets=train_datasets
227
+ , valid_datasets=valid_datasets
228
+ , selection_metric_min_max=selection_metric_min_max
229
+ , **kwargs)
230
+ optimization_routes += [gradient_optimize_result['optimization_route']]
231
+
232
+ evaluate_valid(gradient_optimize_result['best_style']
233
+ , gradient_optimize_result['best_temperature']
234
+ , results_on_valid, results_on_valid_tracked)
235
+
236
+ print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')
237
+
238
+ if selection_metric_min_max == 'min':
239
+ is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
240
+ else:
241
+ is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))
242
+
243
+ if is_best or best_style is None:
244
+ best_style = gradient_optimize_result['best_style'].clone()
245
+ best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
246
+ torch.cuda.empty_cache()
247
+
248
+ def final_evaluation():
249
+ print('Running eval dataset with final params (no gradients)..')
250
+ print(best_style, best_softmax_temperature)
251
+ result_test = []
252
+ for N_ensemble_configurations in N_ensemble_configurations_list:
253
+ print(f'Running with {N_ensemble_configurations} ensemble_configurations')
254
+ kwargs['N_ensemble_configurations'] = N_ensemble_configurations
255
+ splits = []
256
+ for split in final_splits:
257
+ splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature
258
+ , return_tensor=False, eval_positions=eval_positions_test,
259
+ bptt=bptt_final, inference_mode=True, split_number=split, model=model[2]
260
+ , selection_metric=selection_metric, evaluation_metric=evaluation_metric)]
261
+ result_test += [splits]
262
+
263
+ print('Running valid dataset with final params (no gradients)..')
264
+ result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature
265
+ , return_tensor=False, eval_positions=eval_positions_test,
266
+ bptt=bptt_final, inference_mode=True, model=model[2]
267
+ , selection_metric=selection_metric, evaluation_metric=evaluation_metric)
268
+
269
+ return result_test, result_valid
270
+
271
+ result_test, result_valid = final_evaluation()
272
+
273
+ return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes
274
+
275
+
276
+ def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
277
+ def step():
278
+ return evaluate(datasets=ds,
279
+ method='transformer'
280
+ , overwrite=True
281
+ , style=used_style
282
+ , eval_positions=eval_positions
283
+ , metric_used=selection_metric
284
+ , save=False
285
+ , path_interfix=None
286
+ , base_path=None
287
+ , verbose=True
288
+ , **kwargs)
289
+
290
+ if return_tensor:
291
+ r = step()
292
+ else:
293
+ with torch.no_grad():
294
+ r = step()
295
+
296
+ calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean')
297
+ calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean')
298
+
299
+ return r
300
+
301
+
302
+ def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
303
+ limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
304
+ """
305
+ Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
306
+
307
+ :param model:
308
+ :param init_style:
309
+ :param steps:
310
+ :param learning_rate:
311
+ :param softmax_temperature:
312
+ :param train_datasets:
313
+ :param valid_datasets:
314
+ :param optimize_all:
315
+ :param limit_style:
316
+ :param N_datasets_sampled:
317
+ :param optimize_softmax_temperature:
318
+ :param selection_metric_min_max:
319
+ :param kwargs:
320
+ :return:
321
+ """
322
+ grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True)
323
+
324
+ best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None
325
+ softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature)
326
+ variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature]
327
+ optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate)
328
+
329
+ optimization_route_selection, optimization_route_diffable = [], []
330
+ optimization_route_selection_valid, optimization_route_diffable_valid = [], []
331
+
332
+ def eval_opt(ds, return_tensor=True, inference_mode=False):
333
+ result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
334
+ , inference_mode=inference_mode, model=model[2], **kwargs)
335
+
336
+ diffable_metric = result['mean_metric']
337
+ selection_metric = result['mean_select']
338
+
339
+ return diffable_metric, selection_metric
340
+
341
+ def eval_all_datasets(datasets, propagate=True):
342
+ selection_metrics_this_step, diffable_metrics_this_step = [], []
343
+ for ds in datasets:
344
+ diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate))
345
+ if not torch.isnan(diffable_metric_train).any():
346
+ if propagate and diffable_metric_train.requires_grad == True:
347
+ diffable_metric_train.backward()
348
+ selection_metrics_this_step += [selection_metric_train]
349
+ diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())]
350
+ diffable_metric_train = np.nanmean(diffable_metrics_this_step)
351
+ selection_metric_train = np.nanmean(selection_metrics_this_step)
352
+
353
+ return diffable_metric_train, selection_metric_train
354
+
355
+ for t in tqdm(range(steps), desc='Iterate over Optimization steps'):
356
+ optimizer.zero_grad()
357
+
358
+ # Select subset of datasets
359
+ random.seed(t)
360
+ train_datasets_ = random.sample(train_datasets, N_datasets_sampled)
361
+
362
+ # Get score on train
363
+ diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True)
364
+ optimization_route_selection += [float(selection_metric_train)]
365
+ optimization_route_diffable += [float(diffable_metric_train)]
366
+
367
+ # Get score on valid
368
+ diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False)
369
+ optimization_route_selection_valid += [float(selection_metric_valid)]
370
+ optimization_route_diffable_valid += [float(diffable_metric_valid)]
371
+
372
+ is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
373
+ is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
374
+ if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best):
375
+ print('New best', best_selection_metric, selection_metric_valid)
376
+ best_style = grad_style.detach().clone()
377
+ best_temperature = softmax_temperature.detach().clone()
378
+ best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid
379
+
380
+ optimizer.step()
381
+
382
+ if limit_style:
383
+ grad_style = grad_style.detach().clamp(-1.74, 1.74)
384
+
385
+ print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' +
386
+ f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}')
387
+
388
+ print(f'Return best:{best_style} {best_selection_metric}')
389
+ return {'best_style': best_style, 'best_temperature': best_temperature
390
+ , 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable,
391
+ 'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}
TabPFN/scripts/model_configs.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from priors.utils import uniform_int_sampler_f
3
+ from priors.differentiable_prior import DifferentiableHyperparameter
4
+ from ConfigSpace import hyperparameters as CSH
5
+ import torch
6
+ from priors.differentiable_prior import replace_differentiable_distributions
7
+
8
+ import ConfigSpace as CS
9
+
10
+ def get_general_config(max_features, bptt, eval_positions=None):
11
+ """"
12
+ Returns the general PFN training hyperparameters.
13
+ """
14
+ config_general = {
15
+ "lr": CSH.UniformFloatHyperparameter('lr', lower=0.00002, upper=0.0002, log=True),
16
+ "dropout": CSH.CategoricalHyperparameter('dropout', [0.0]),
17
+ "emsize": CSH.CategoricalHyperparameter('emsize', [2 ** i for i in range(8, 9)]), ## upper bound is -1
18
+ "batch_size": CSH.CategoricalHyperparameter('batch_size', [2 ** i for i in range(8, 9)]),
19
+ "nlayers": CSH.CategoricalHyperparameter('nlayers', [12]),
20
+ "num_features": max_features,
21
+ "nhead": CSH.CategoricalHyperparameter('nhead', [4]),
22
+ "nhid_factor": 2,
23
+ "bptt": bptt,
24
+ "eval_positions": None,
25
+ "seq_len_used": bptt,
26
+ "sampling": 'normal',#hp.choice('sampling', ['mixed', 'normal']), # uniform
27
+ "epochs": 80,
28
+ "num_steps": 100,
29
+ "verbose": False,
30
+ "pre_sample_causes": True, # This is MLP
31
+ "mix_activations": False,#hp.choice('mix_activations', [True, False]),
32
+ }
33
+
34
+ return config_general
35
+
36
+ def get_flexible_categorical_config(max_features):
37
+ """"
38
+ Returns the configuration parameters for the tabular multiclass wrapper.
39
+ """
40
+ config_flexible_categorical = {
41
+ "nan_prob_unknown_reason_reason_prior": CSH.CategoricalHyperparameter('nan_prob_unknown_reason_reason_prior', [1.0]),
42
+ "categorical_feature_p": CSH.CategoricalHyperparameter('categorical_feature_p', [0.0]),
43
+ "nan_prob_no_reason": CSH.CategoricalHyperparameter('nan_prob_no_reason', [0.0, 0.1, 0.2]),
44
+ "nan_prob_unknown_reason": CSH.CategoricalHyperparameter('nan_prob_unknown_reason', [0.0]),
45
+ "nan_prob_a_reason": CSH.CategoricalHyperparameter('nan_prob_a_reason', [0.0]),
46
+ # "num_classes": lambda : random.randint(2, 10), "balanced": False,
47
+ "max_num_classes": 2,
48
+ "num_classes": 2,
49
+ "noise_type": CSH.CategoricalHyperparameter('noise_type', ["Gaussian"]), # NN
50
+ "balanced": True,
51
+ "normalize_to_ranking": CSH.CategoricalHyperparameter('normalize_to_ranking', [False]),
52
+ "set_value_to_nan": CSH.CategoricalHyperparameter('set_value_to_nan', [0.5, 0.2, 0.0]),
53
+ "normalize_by_used_features": True,
54
+ "num_features_used":
55
+ {'uniform_int_sampler_f(3,max_features)': uniform_int_sampler_f(1, max_features)}
56
+ # hp.choice('conv_activation', [{'distribution': 'uniform', 'min': 2.0, 'max': 8.0}, None]),
57
+ }
58
+ return config_flexible_categorical
59
+
60
+ def get_diff_flex():
61
+ """"
62
+ Returns the configuration parameters for a differentiable wrapper around the tabular multiclass wrapper.
63
+ """
64
+ diff_flex = {
65
+ # "ordinal_pct": {'distribution': 'uniform', 'min': 0.0, 'max': 0.5},
66
+ # "num_categorical_features_sampler_a": hp.choice('num_categorical_features_sampler_a',
67
+ # [{'distribution': 'uniform', 'min': 0.3, 'max': 0.9}, None]),
68
+ # "num_categorical_features_sampler_b": {'distribution': 'uniform', 'min': 0.3, 'max': 0.9},
69
+ "output_multiclass_ordered_p": {'distribution': 'uniform', 'min': 0.0, 'max': 0.5}, #CSH.CategoricalHyperparameter('output_multiclass_ordered_p', [0.0, 0.1, 0.2]),
70
+ "multiclass_type": {'distribution': 'meta_choice', 'choice_values': ['value', 'rank']},
71
+ }
72
+
73
+ return diff_flex
74
+
75
+ def get_diff_gp():
76
+ """"
77
+ Returns the configuration parameters for a differentiable wrapper around GP.
78
+ """
79
+ diff_gp = {
80
+ 'outputscale': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10., 'min_mean': 0.00001, 'round': False,
81
+ 'lower_bound': 0},
82
+ 'lengthscale': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10., 'min_mean': 0.00001, 'round': False,
83
+ 'lower_bound': 0},
84
+ 'noise': {'distribution': 'meta_choice', 'choice_values': [0.00001, 0.0001, 0.01]}
85
+ }
86
+
87
+ return diff_gp
88
+
89
+ def get_diff_causal():
90
+ """"
91
+ Returns the configuration parameters for a differentiable wrapper around MLP / Causal mixture.
92
+ """
93
+ diff_causal = {
94
+ "num_layers": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 6, 'min_mean': 1, 'round': True,
95
+ 'lower_bound': 2},
96
+ # Better beta?
97
+ "prior_mlp_hidden_dim": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 130, 'min_mean': 5,
98
+ 'round': True, 'lower_bound': 4},
99
+
100
+ "prior_mlp_dropout_prob": {'distribution': 'meta_beta', 'scale': 0.9, 'min': 0.1, 'max': 5.0},
101
+ # This mustn't be too high since activations get too large otherwise
102
+
103
+ "noise_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': .3, 'min_mean': 0.0001, 'round': False,
104
+ 'lower_bound': 0.0},
105
+ "init_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10.0, 'min_mean': 0.01, 'round': False,
106
+ 'lower_bound': 0.0},
107
+ "num_causes": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 12, 'min_mean': 1, 'round': True,
108
+ 'lower_bound': 1},
109
+ "is_causal": {'distribution': 'meta_choice', 'choice_values': [True, False]},
110
+ "pre_sample_weights": {'distribution': 'meta_choice', 'choice_values': [True, False]},
111
+ "y_is_effect": {'distribution': 'meta_choice', 'choice_values': [True, False]},
112
+ "prior_mlp_activations": {'distribution': 'meta_choice_mixed', 'choice_values': [
113
+ torch.nn.Tanh
114
+ , torch.nn.ReLU
115
+ , torch.nn.Identity
116
+ , lambda : torch.nn.LeakyReLU(negative_slope=0.1)
117
+ , torch.nn.ELU
118
+ ]},
119
+ "block_wise_dropout": {'distribution': 'meta_choice', 'choice_values': [True, False]},
120
+ "sort_features": {'distribution': 'meta_choice', 'choice_values': [True, False]},
121
+ "in_clique": {'distribution': 'meta_choice', 'choice_values': [True, False]},
122
+ }
123
+
124
+ return diff_causal
125
+
126
+ def get_diff_prior_bag():
127
+ """"
128
+ Returns the configuration parameters for a GP and MLP / Causal mixture.
129
+ """
130
+ diff_prior_bag = {
131
+ 'prior_bag_exp_weights_1': {'distribution': 'uniform', 'min': 100000., 'max': 100001.},
132
+ # MLP Weight (Biased, since MLP works better, 1.0 is weight for prior number 0)
133
+ }
134
+
135
+ return diff_prior_bag
136
+
137
+ def get_diff_config():
138
+ """"
139
+ Returns the configuration parameters for a differentiable wrapper around GP and MLP / Causal mixture priors.
140
+ """
141
+ diff_prior_bag = get_diff_prior_bag()
142
+ diff_causal = get_diff_causal()
143
+ diff_gp = get_diff_gp()
144
+ diff_flex = get_diff_flex()
145
+
146
+ config_diff = {'differentiable_hyperparameters': {**diff_prior_bag, **diff_causal, **diff_gp, **diff_flex}}
147
+
148
+ return config_diff
149
+
150
+
151
+ def sample_differentiable(config):
152
+ """"
153
+ Returns sampled hyperparameters from a differentiable wrapper, that is it makes a non-differentiable out of
154
+ differentiable.
155
+ """
156
+ # config is a dict of dicts, dicts that have a 'distribution' key are treated as distributions to be sampled
157
+ result = deepcopy(config)
158
+ del result['differentiable_hyperparameters']
159
+
160
+ for k, v in config['differentiable_hyperparameters'].items():
161
+ s_indicator, s_hp = DifferentiableHyperparameter(**v, embedding_dim=None,
162
+ device=None)() # both of these are actually not used to the best of my knowledge
163
+ result[k] = s_hp
164
+
165
+ return result
166
+
167
+ def list_all_hps_in_nested(config):
168
+ """"
169
+ Returns a list of hyperparameters from a neszed dict of hyperparameters.
170
+ """
171
+
172
+ if isinstance(config, CSH.Hyperparameter):
173
+ return [config]
174
+ elif isinstance(config, dict):
175
+ result = []
176
+ for k, v in config.items():
177
+ result += list_all_hps_in_nested(v)
178
+ return result
179
+ else:
180
+ return []
181
+
182
+ def create_configspace_from_hierarchical(config):
183
+ cs = CS.ConfigurationSpace()
184
+ for hp in list_all_hps_in_nested(config):
185
+ cs.add_hyperparameter(hp)
186
+ return cs
187
+
188
+ def fill_in_configsample(config, configsample):
189
+ # config is our dict that defines config distribution
190
+ # configsample is a CS.Configuration
191
+ hierarchical_configsample = deepcopy(config)
192
+ for k, v in config.items():
193
+ if isinstance(v, CSH.Hyperparameter):
194
+ hierarchical_configsample[k] = configsample[v.name]
195
+ elif isinstance(v, dict):
196
+ hierarchical_configsample[k] = fill_in_configsample(v, configsample)
197
+ return hierarchical_configsample
198
+
199
+
200
+ def evaluate_hypers(config, sample_diff_hps=False):
201
+ """"
202
+ Samples a hyperparameter configuration from a sampleable configuration (can be used in HP search).
203
+ """
204
+ if sample_diff_hps:
205
+ # I do a deepcopy here, such that the config stays the same and can still be used with diff. hps
206
+ config = deepcopy(config)
207
+ replace_differentiable_distributions(config)
208
+ cs = create_configspace_from_hierarchical(config)
209
+ cs_sample = cs.sample_configuration()
210
+ return fill_in_configsample(config, cs_sample)
TabPFN/scripts/tabular_baselines.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from catboost import CatBoostClassifier, Pool
2
+
3
+ import math
4
+
5
+ from sklearn.impute import SimpleImputer
6
+
7
+ import xgboost as xgb
8
+ from sklearn import neighbors
9
+ from sklearn.gaussian_process import GaussianProcessClassifier
10
+ from sklearn.gaussian_process.kernels import RBF
11
+ import numpy as np
12
+
13
+ from scripts import tabular_metrics
14
+ import pandas as pd
15
+
16
+ from sklearn.linear_model import LogisticRegression
17
+ from sklearn.model_selection import cross_val_score
18
+ import time
19
+
20
+ from hyperopt import fmin, tpe, hp, STATUS_OK, Trials , space_eval, rand
21
+ from sklearn.compose import ColumnTransformer
22
+ from sklearn.preprocessing import OneHotEncoder
23
+ from sklearn.preprocessing import MinMaxScaler
24
+
25
+ import autosklearn.classification
26
+
27
+ CV = 5
28
+ MULTITHREAD = 1 # Number of threads baselines are able to use at most
29
+ param_grid, param_grid_hyperopt = {}, {}
30
+
31
+ def get_scoring_direction(metric_used):
32
+ # Not needed
33
+ if metric_used == tabular_metrics.auc_metric:
34
+ return -1
35
+ elif metric_used == tabular_metrics.cross_entropy:
36
+ return 1
37
+ else:
38
+ raise Exception('No scoring string found for metric')
39
+
40
+ def get_scoring_string(metric_used, multiclass=True, usage="sklearn_cv"):
41
+ if metric_used == tabular_metrics.auc_metric:
42
+ if usage == 'sklearn_cv':
43
+ return 'roc_auc_ovo'
44
+ elif usage == 'autogluon':
45
+ return 'log_loss' # Autogluon crashes when using 'roc_auc' with some datasets usning logloss gives better scores;
46
+ # We might be able to fix this, but doesn't work out of box.
47
+ # File bug report? Error happens with dataset robert and fabert
48
+ if multiclass:
49
+ return 'roc_auc_ovo_macro'
50
+ else:
51
+ return 'roc_auc'
52
+ elif usage == 'autosklearn':
53
+ if multiclass:
54
+ return autosklearn.metrics.log_loss # roc_auc only works for binary, use logloss instead
55
+ else:
56
+ return autosklearn.metrics.roc_auc
57
+ elif usage == 'catboost':
58
+ return 'MultiClass' # Effectively LogLoss, ROC not available
59
+ elif usage == 'xgb':
60
+ return 'logloss'
61
+ return 'roc_auc'
62
+ elif metric_used == tabular_metrics.cross_entropy:
63
+ if usage == 'sklearn_cv':
64
+ return 'neg_log_loss'
65
+ elif usage == 'autogluon':
66
+ return 'log_loss'
67
+ elif usage == 'autosklearn':
68
+ return autosklearn.metrics.log_loss
69
+ elif usage == 'catboost':
70
+ return 'MultiClass' # Effectively LogLoss
71
+ return 'logloss'
72
+ else:
73
+ raise Exception('No scoring string found for metric')
74
+
75
+ def eval_f(params, clf_, x, y, metric_used, start_time, max_time):
76
+ if time.time() - start_time > max_time:
77
+ return np.nan
78
+ scores = cross_val_score(clf_(**params), x, y, cv=CV, scoring=get_scoring_string(metric_used))
79
+
80
+ return -np.nanmean(scores)
81
+
82
+ def preprocess_impute(x, y, test_x, test_y, impute, one_hot, standardize, cat_features=[]):
83
+ import warnings
84
+ def warn(*args, **kwargs):
85
+ pass
86
+
87
+ warnings.warn = warn
88
+
89
+ x, y, test_x, test_y = x.cpu().numpy(), y.cpu().long().numpy(), test_x.cpu().numpy(), test_y.cpu().long().numpy()
90
+
91
+ if impute:
92
+ imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
93
+ imp_mean.fit(x)
94
+ x, test_x = imp_mean.transform(x), imp_mean.transform(test_x)
95
+
96
+ if one_hot:
97
+ def make_pd_from_np(x):
98
+ data = pd.DataFrame(x)
99
+ for c in cat_features:
100
+ data.iloc[:, c] = data.iloc[:, c].astype('int')
101
+ return data
102
+ x, test_x = make_pd_from_np(x), make_pd_from_np(test_x)
103
+ transformer = ColumnTransformer(transformers=[('cat', OneHotEncoder(handle_unknown='ignore', sparse=False), cat_features)], remainder="passthrough")
104
+ transformer.fit(x)
105
+ x, test_x = transformer.transform(x), transformer.transform(test_x)
106
+
107
+ if standardize:
108
+ scaler = MinMaxScaler()
109
+ scaler.fit(x)
110
+ x, test_x = scaler.transform(x), scaler.transform(test_x)
111
+
112
+ return x, y, test_x, test_y
113
+
114
+ ## Auto Gluon
115
+ def autogluon_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
116
+ from autogluon.tabular import TabularPredictor # Inside function so package can be sued without installation
117
+ x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
118
+ , one_hot=False
119
+ , cat_features=cat_features
120
+ , impute=False
121
+ , standardize=False)
122
+ train_data = pd.DataFrame(np.concatenate([x, y[:, np.newaxis]], 1))
123
+ test_data = pd.DataFrame(np.concatenate([test_x, test_y[:, np.newaxis]], 1))
124
+
125
+ # AutoGluon automatically infers datatypes, we don't specify the categorical labels
126
+ predictor = TabularPredictor(
127
+ label=train_data.columns[-1],
128
+ eval_metric=get_scoring_string(metric_used, usage='autogluon', multiclass=(len(np.unique(y)) > 2)),
129
+ problem_type='multiclass' if len(np.unique(y)) > 2 else 'binary'
130
+ ## seed=int(y[:].sum()) doesn't accept seed
131
+ ).fit(
132
+ train_data=train_data,
133
+ time_limit=max_time,
134
+ presets=['best_quality']
135
+ # The seed is deterministic but varies for each dataset and each split of it
136
+ )
137
+
138
+ pred = predictor.predict_proba(test_data, as_multiclass=True).values
139
+
140
+ metric = metric_used(test_y, pred)
141
+
142
+ return metric, pred, predictor.fit_summary()
143
+
144
+ ## AUTO Sklearn
145
+ def autosklearn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
146
+ return autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=max_time, version=1)
147
+
148
+ from autosklearn.experimental.askl2 import AutoSklearn2Classifier
149
+ from autosklearn.classification import AutoSklearnClassifier
150
+ def autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300, version=2):
151
+ x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
152
+ , one_hot=False
153
+ , cat_features=cat_features
154
+ , impute=False
155
+ , standardize=False)
156
+
157
+ def make_pd_from_np(x):
158
+ data = pd.DataFrame(x)
159
+ for c in cat_features:
160
+ data.iloc[:, c] = data.iloc[:, c].astype('category')
161
+ return data
162
+
163
+ x = make_pd_from_np(x)
164
+ test_x = make_pd_from_np(test_x)
165
+
166
+ clf_ = AutoSklearn2Classifier if version == 2 else AutoSklearnClassifier
167
+ clf = clf_(time_left_for_this_task=max_time,
168
+ memory_limit=4000,
169
+ n_jobs=MULTITHREAD,
170
+ seed=int(y[:].sum()),
171
+ # The seed is deterministic but varies for each dataset and each split of it
172
+ metric=get_scoring_string(metric_used, usage='autosklearn', multiclass=len(np.unique(y)) > 2))
173
+
174
+ # fit model to data
175
+ clf.fit(x, y)
176
+
177
+ pred = clf.predict_proba(test_x)
178
+ metric = metric_used(test_y, pred)
179
+
180
+ return metric, pred, None
181
+
182
+ param_grid_hyperopt['logistic'] = {
183
+ 'penalty': hp.choice('penalty', ['l1', 'l2', 'none'])
184
+ , 'max_iter': hp.randint('max_iter', [50, 500])
185
+ , 'fit_intercept': hp.choice('fit_intercept', [True, False])
186
+ , 'C': hp.loguniform('C', -5, math.log(5.0))} # 'normalize': [False],
187
+
188
+ def logistic_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
189
+ x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
190
+ , one_hot=True, impute=True, standardize=True
191
+ , cat_features=cat_features)
192
+
193
+ def clf_(**params):
194
+ return LogisticRegression(solver='saga', tol=1e-4, n_jobs=1, **params)
195
+
196
+ start_time = time.time()
197
+
198
+ def stop(trial):
199
+ return time.time() - start_time > max_time, []
200
+
201
+ best = fmin(
202
+ fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
203
+ space=param_grid_hyperopt['logistic'],
204
+ algo=rand.suggest,
205
+ rstate=np.random.RandomState(int(y[:].sum())),
206
+ early_stop_fn=stop,
207
+ # The seed is deterministic but varies for each dataset and each split of it
208
+ max_evals=10000)
209
+ best = space_eval(param_grid_hyperopt['logistic'], best)
210
+
211
+ clf = clf_(**best)
212
+ clf.fit(x, y)
213
+
214
+ pred = clf.predict_proba(test_x)
215
+ metric = metric_used(test_y, pred)
216
+
217
+ return metric, pred, best
218
+
219
+ ## KNN
220
+ param_grid_hyperopt['knn'] = {'n_neighbors': hp.randint('n_neighbors', 1,16)
221
+ }
222
+ def knn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
223
+ x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y,
224
+ one_hot=True, impute=True, standardize=True,
225
+ cat_features=cat_features)
226
+
227
+ def clf_(**params):
228
+ return neighbors.KNeighborsClassifier(n_jobs=1, **params)
229
+
230
+ start_time = time.time()
231
+
232
+ def stop(trial):
233
+ return time.time() - start_time > max_time, []
234
+
235
+ best = fmin(
236
+ fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
237
+ space=param_grid_hyperopt['knn'],
238
+ algo=rand.suggest,
239
+ rstate=np.random.RandomState(int(y[:].sum())),
240
+ early_stop_fn=stop,
241
+ # The seed is deterministic but varies for each dataset and each split of it
242
+ max_evals=10000)
243
+ best = space_eval(param_grid_hyperopt['knn'], best)
244
+
245
+ clf = clf_(**best)
246
+ clf.fit(x, y)
247
+
248
+ pred = clf.predict_proba(test_x)
249
+ metric = metric_used(test_y, pred)
250
+
251
+ return metric, pred, best
252
+
253
+ ## GP
254
+ param_grid_hyperopt['gp'] = {
255
+ 'params_y_scale': hp.loguniform('params_y_scale', math.log(0.05), math.log(5.0)),
256
+ 'params_length_scale': hp.loguniform('params_length_scale', math.log(0.1), math.log(1.0)),
257
+ 'n_jobs': hp.choice('njobs', [1])
258
+ }
259
+ def gp_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
260
+ x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y,
261
+ one_hot=True, impute=True, standardize=True,
262
+ cat_features=cat_features)
263
+
264
+ def clf_(params_y_scale,params_length_scale, **params):
265
+ return GaussianProcessClassifier(kernel= params_y_scale * RBF(params_length_scale), **params)
266
+
267
+ start_time = time.time()
268
+ def stop(trial):
269
+ return time.time() - start_time > max_time, []
270
+
271
+
272
+ best = fmin(
273
+ fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
274
+ space=param_grid_hyperopt['gp'],
275
+ algo=rand.suggest,
276
+ rstate=np.random.RandomState(int(y[:].sum())),
277
+ early_stop_fn=stop,
278
+ # The seed is deterministic but varies for each dataset and each split of it
279
+ max_evals=1000)
280
+ best = space_eval(param_grid_hyperopt['gp'], best)
281
+
282
+ clf = clf_(**best)
283
+ clf.fit(x, y)
284
+
285
+ pred = clf.predict_proba(test_x)
286
+ metric = metric_used(test_y, pred)
287
+
288
+ return metric, pred, best
289
+
290
+
291
+ # Catboost
292
+ # Hyperparameter space: https://arxiv.org/pdf/2106.03253.pdf
293
+
294
+ param_grid_hyperopt['catboost'] = {
295
+ 'learning_rate': hp.loguniform('learning_rate', math.log(math.pow(math.e, -5)), math.log(1)),
296
+ 'random_strength': hp.randint('random_strength', 1, 20),
297
+ 'l2_leaf_reg': hp.loguniform('l2_leaf_reg', math.log(1), math.log(10)),
298
+ 'bagging_temperature': hp.uniform('bagging_temperature', 0., 1),
299
+ 'leaf_estimation_iterations': hp.randint('leaf_estimation_iterations', 1, 20),
300
+ 'iterations': hp.randint('iterations', 100, 4000), # This is smaller than in paper, 4000 leads to ram overusage
301
+ }
302
+
303
+ def catboost_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
304
+ print(x)
305
+
306
+ x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
307
+ , one_hot=False
308
+ , cat_features=cat_features
309
+ , impute=False
310
+ , standardize=False)
311
+
312
+ # Nans in categorical features must be encoded as separate class
313
+ x[:, cat_features], test_x[:, cat_features] = np.nan_to_num(x[:, cat_features], -1), np.nan_to_num(
314
+ test_x[:, cat_features], -1)
315
+
316
+ def make_pd_from_np(x):
317
+ data = pd.DataFrame(x)
318
+ for c in cat_features:
319
+ data.iloc[:, c] = data.iloc[:, c].astype('int')
320
+ return data
321
+
322
+ x = make_pd_from_np(x)
323
+ test_x = make_pd_from_np(test_x)
324
+
325
+ def clf_(**params):
326
+ return CatBoostClassifier(
327
+ loss_function=get_scoring_string(metric_used, usage='catboost'),
328
+ thread_count = MULTITHREAD,
329
+ used_ram_limit='4gb',
330
+ random_seed=int(y[:].sum()),
331
+ logging_level='Silent',
332
+ cat_features=cat_features,
333
+ **params)
334
+
335
+ start_time = time.time()
336
+ def stop(trial):
337
+ return time.time() - start_time > max_time, []
338
+
339
+ best = fmin(
340
+ fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
341
+ space=param_grid_hyperopt['catboost'],
342
+ algo=rand.suggest,
343
+ rstate=np.random.RandomState(int(y[:].sum())),
344
+ early_stop_fn=stop,
345
+ # The seed is deterministic but varies for each dataset and each split of it
346
+ max_evals=1000)
347
+ best = space_eval(param_grid_hyperopt['catboost'], best)
348
+
349
+ clf = clf_(**best)
350
+ clf.fit(x, y)
351
+
352
+ pred = clf.predict_proba(test_x)
353
+ metric = metric_used(test_y, pred)
354
+
355
+ return metric, pred, best
356
+
357
+
358
+ # XGBoost
359
+ # Hyperparameter space: https://arxiv.org/pdf/2106.03253.pdf
360
+ param_grid_hyperopt['xgb'] = {
361
+ 'learning_rate': hp.loguniform('learning_rate', -7, math.log(1)),
362
+ 'max_depth': hp.randint('max_depth', 1, 10),
363
+ 'subsample': hp.uniform('subsample', 0.2, 1),
364
+ 'colsample_bytree': hp.uniform('colsample_bytree', 0.2, 1),
365
+ 'colsample_bylevel': hp.uniform('colsample_bylevel', 0.2, 1),
366
+ 'min_child_weight': hp.loguniform('min_child_weight', -16, 5),
367
+ 'alpha': hp.loguniform('alpha', -16, 2),
368
+ 'lambda': hp.loguniform('lambda', -16, 2),
369
+ 'gamma': hp.loguniform('gamma', -16, 2),
370
+ 'n_estimators': hp.randint('n_estimators', 100, 4000), # This is smaller than in paper
371
+ }
372
+
373
+ def xgb_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
374
+ # XGB Documentation:
375
+ # XGB handles categorical data appropriately without using One Hot Encoding, categorical features are experimetal
376
+ # XGB handles missing values appropriately without imputation
377
+
378
+ x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
379
+ , one_hot=False
380
+ , cat_features=cat_features
381
+ , impute=False
382
+ , standardize=False)
383
+
384
+ def clf_(**params):
385
+ return xgb.XGBClassifier(use_label_encoder=False
386
+ , nthread=1
387
+ , **params
388
+ , eval_metric=get_scoring_string(metric_used, usage='xgb') # AUC not implemented
389
+ )
390
+
391
+ start_time = time.time()
392
+ def stop(trial):
393
+ return time.time() - start_time > max_time, []
394
+
395
+ best = fmin(
396
+ fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
397
+ space=param_grid_hyperopt['xgb'],
398
+ algo=rand.suggest,
399
+ rstate=np.random.RandomState(int(y[:].sum())),
400
+ early_stop_fn=stop,
401
+ # The seed is deterministic but varies for each dataset and each split of it
402
+ max_evals=1000)
403
+ best = space_eval(param_grid_hyperopt['xgb'], best)
404
+
405
+ clf = clf_(**best)
406
+ clf.fit(x, y)
407
+
408
+ pred = clf.predict_proba(test_x)
409
+ metric = metric_used(test_y, pred)
410
+
411
+ return metric, pred, best
412
+
413
+
414
+ clf_dict = {'gp': gp_metric
415
+ , 'knn': knn_metric
416
+ , 'catboost': catboost_metric
417
+ , 'xgb': xgb_metric
418
+ , 'logistic': logistic_metric
419
+ , 'autosklearn': autosklearn_metric
420
+ , 'autosklearn2': autosklearn2_metric
421
+ , 'autogluon': autogluon_metric}
TabPFN/scripts/tabular_evaluation.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from tqdm import tqdm
6
+ import random
7
+ import numpy as np
8
+
9
+ from torch import nn
10
+
11
+ from utils import torch_nanmean
12
+ from datasets import *
13
+ from model_builder import load_model
14
+ from scripts.tabular_baselines import get_scoring_string
15
+ from scripts import tabular_metrics
16
+ from scripts.transformer_prediction_interface import *
17
+ from scripts.baseline_prediction_interface import *
18
+ """
19
+ ===============================
20
+ PUBLIC FUNCTIONS FOR EVALUATION
21
+ ===============================
22
+ """
23
+
24
+
25
+ def eval_model(i, e, valid_datasets, test_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
26
+ metrics_test, config_sample, model_path = eval_model_on_ds(i, e, test_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
27
+ metrics_valid, _, _ = eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
28
+ return {'mean_auc_test': metrics_test['mean_roc_at_1000'], 'mean_auc_valid': metrics_valid['mean_roc_at_1000'], 'mean_ce_test': metrics_test['mean_ce_at_1000'], 'mean_ce_valid': metrics_valid['mean_ce_at_1000'], 'config_sample': config_sample, 'model_path': model_path}
29
+
30
+ def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
31
+
32
+ # How to use: evaluate_without_fitting(i,0,valid_datasets, [1024], 100000, add_name=model_string, base_path=base_path,)
33
+ def check_file(e):
34
+ model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
35
+ model_path = os.path.join(base_path, model_file)
36
+ # print('Evaluate ', model_path)
37
+ results_file = os.path.join(base_path,
38
+ f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
39
+ if not Path(model_path).is_file(): # or Path(results_file).is_file():
40
+ # print('checkpoint exists: ', Path(model_file).is_file(), ', results are written:', Path(results_file).is_file())
41
+ return None, None, None
42
+ return model_file, model_path, results_file
43
+
44
+ if e == -1: # use last checkpoint, if e == -1
45
+ for e_ in range(100, -1, -1):
46
+ model_file_, model_path_, results_file_ = check_file(e_)
47
+ if model_file_ is not None:
48
+ e = e_
49
+ model_file, model_path, results_file = model_file_, model_path_, results_file_
50
+ break
51
+ else:
52
+ model_file, model_path, results_file = check_file(e)
53
+
54
+ model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
55
+ print(model[2].style_encoder)
56
+
57
+ params = {'max_features': config_sample['num_features']
58
+ , 'rescale_features': config_sample["normalize_by_used_features"]
59
+ , 'normalize_to_ranking': config_sample["normalize_to_ranking"]
60
+ , 'normalize_with_sqrt': config_sample.get("normalize_with_sqrt", False)
61
+ }
62
+ metrics_valid = evaluate(datasets=valid_datasets, model=model[2], method='transformer', device=device, overwrite=True,
63
+ extend_features=True
64
+ # just removed the style keyword but transformer is trained with style, just empty
65
+ , save=False
66
+ , metric_used=tabular_metrics.cross_entropy
67
+ , return_tensor=True
68
+ , verbose=False
69
+ , eval_positions=eval_positions
70
+ , bptt=bptt
71
+ , base_path=None
72
+ , inference_mode=True
73
+ , **params
74
+ , **kwargs)
75
+
76
+ tabular_metrics.calculate_score_per_method(tabular_metrics.auc_metric, 'roc', metrics_valid, valid_datasets, eval_positions)
77
+ tabular_metrics.calculate_score_per_method(tabular_metrics.cross_entropy, 'ce', metrics_valid, valid_datasets, eval_positions)
78
+
79
+ return metrics_valid, config_sample, model_path
80
+
81
+
82
+ def evaluate(datasets, bptt, eval_positions, metric_used, model
83
+ , verbose=False
84
+ , return_tensor=False
85
+ , **kwargs):
86
+ """
87
+ Evaluates a list of datasets for a model function.
88
+
89
+ :param datasets: List of datasets
90
+ :param bptt: maximum sequence length
91
+ :param eval_positions: List of positions where to evaluate models
92
+ :param verbose: If True, is verbose.
93
+ :param metric_used: Which metric is optimized for.
94
+ :param return_tensor: Wheater to return results as a pytorch.tensor or numpy, this is only relevant for transformer.
95
+ :param kwargs:
96
+ :return:
97
+ """
98
+ overall_result = {'metric_used': get_scoring_string(metric_used)
99
+ , 'bptt': bptt
100
+ , 'eval_positions': eval_positions}
101
+
102
+ aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
103
+
104
+ # For each dataset
105
+ for [ds_name, X, y, categorical_feats, _, _] in tqdm.tqdm(datasets, desc='Iterate over datasets') if verbose else datasets:
106
+ dataset_bptt = min(len(X), bptt)
107
+ # if verbose and dataset_bptt < bptt:
108
+ # print(f'Dataset too small for given sequence length, reducing to {len(X)} ({bptt})')
109
+
110
+ aggregated_metric, num = torch.tensor(0.0), 0
111
+ ds_result = {}
112
+
113
+ for eval_position in (eval_positions if verbose else eval_positions):
114
+ eval_position_real = int(dataset_bptt * 0.5) if 2 * eval_position > dataset_bptt else eval_position
115
+ eval_position_bptt = int(eval_position_real * 2.0)
116
+
117
+ r = evaluate_position(X, y, model=model
118
+ , num_classes=len(torch.unique(y))
119
+ , categorical_feats = categorical_feats
120
+ , bptt = eval_position_bptt
121
+ , ds_name=ds_name
122
+ , eval_position = eval_position_real
123
+ , metric_used = metric_used
124
+ ,**kwargs)
125
+
126
+ if r is None:
127
+ continue
128
+
129
+ _, outputs, ys, best_configs, time_used = r
130
+
131
+ if torch.is_tensor(outputs):
132
+ outputs = outputs.to(outputs.device)
133
+ ys = ys.to(outputs.device)
134
+
135
+ ys = ys.T
136
+ ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
137
+ ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
138
+ ds_result[f'{ds_name}_ys_at_{eval_position}'] = ys
139
+ ds_result[f'{ds_name}_time_at_{eval_position}'] = time_used
140
+
141
+ new_metric = torch_nanmean(torch.stack([metric_used(ys[i], outputs[i]) for i in range(ys.shape[0])]))
142
+
143
+ if not return_tensor:
144
+ make_scalar = lambda x: float(x.detach().cpu().numpy()) if (torch.is_tensor(x) and (len(x.shape) == 0)) else x
145
+ new_metric = make_scalar(new_metric)
146
+ ds_result = {k: make_scalar(ds_result[k]) for k in ds_result.keys()}
147
+
148
+ lib = torch if return_tensor else np
149
+ if not lib.isnan(new_metric).any():
150
+ aggregated_metric, num = aggregated_metric + new_metric, num + 1
151
+
152
+ overall_result.update(ds_result)
153
+ if num > 0:
154
+ aggregated_metric_datasets, num_datasets = (aggregated_metric_datasets + (aggregated_metric / num)), num_datasets + 1
155
+
156
+ overall_result['mean_metric'] = aggregated_metric_datasets / num_datasets
157
+
158
+ return overall_result
159
+
160
+ """
161
+ ===============================
162
+ INTERNAL HELPER FUNCTIONS
163
+ ===============================
164
+ """
165
+
166
+ def check_file_exists(path):
167
+ """Checks if a pickle file exists. Returns None if not, else returns the unpickled file."""
168
+ if (os.path.isfile(path)):
169
+ print(f'loading results from {path}')
170
+ with open(path, 'rb') as f:
171
+ return np.load(f, allow_pickle=True).tolist()
172
+ return None
173
+
174
+ def generate_valid_split(X, y, bptt, eval_position, split_number=1):
175
+ """Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
176
+ the entire datasets. If no such split can be sampled in 7 passes, returns None.
177
+
178
+ :param X: torch tensor, feature values
179
+ :param y: torch tensor, class values
180
+ :param bptt: Number of samples in train + test
181
+ :param eval_position: Number of samples in train, i.e. from which index values are in test
182
+ :param split_number: The split id
183
+ :return:
184
+ """
185
+ done, seed = False, 13
186
+
187
+ torch.manual_seed(split_number)
188
+ perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
189
+ X, y = X[perm], y[perm]
190
+
191
+ while not done:
192
+ if seed > 20:
193
+ return None, None # No split could be generated in 7 passes, return None
194
+ random.seed(seed)
195
+ i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
196
+ y_ = y[i:i + bptt]
197
+
198
+ # Checks if all classes from dataset are contained and classes in train and test are equal (contain same
199
+ # classes) and
200
+ done = len(torch.unique(y_)) == len(torch.unique(y))
201
+ done = done and torch.all(torch.unique(y_) == torch.unique(y))
202
+ done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
203
+ done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
204
+ seed = seed + 1
205
+
206
+ eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
207
+ eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
208
+
209
+ return eval_xs, eval_ys
210
+
211
+
212
+ def evaluate_position(X, y, categorical_feats, model, bptt
213
+ , eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
214
+ , max_time=300, split_number=1
215
+ , per_step_normalization=False, **kwargs):
216
+ """
217
+ Evaluates a dataset with a 'bptt' number of training samples.
218
+
219
+ :param X: Dataset X
220
+ :param y: Dataset labels
221
+ :param categorical_feats: Indices of categorical features.
222
+ :param model: Model function
223
+ :param bptt: Sequence length.
224
+ :param eval_position: Number of training samples.
225
+ :param overwrite: Wheater to ove
226
+ :param overwrite: If True, results on disk are overwritten.
227
+ :param save:
228
+ :param path_interfix: Used for constructing path to write on disk.
229
+ :param method: Model name.
230
+ :param ds_name: Datset name.
231
+ :param fetch_only: Wheater to calculate or only fetch results.
232
+ :param per_step_normalization:
233
+ :param kwargs:
234
+ :return:
235
+ """
236
+
237
+ if save:
238
+ path = os.path.join(base_path, f'results/tabular/{path_interfix}/results_{method}_{ds_name}_{eval_position}_{bptt}_{split_number}.npy')
239
+ #log_path =
240
+
241
+ ## Load results if on disk
242
+ if not overwrite:
243
+ result = check_file_exists(path)
244
+ if result is not None:
245
+ if not fetch_only:
246
+ print(f'Loaded saved result for {path}')
247
+ return result
248
+ elif fetch_only:
249
+ print(f'Could not load saved result for {path}')
250
+ return None
251
+
252
+ ## Generate data splits
253
+ eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position, split_number=split_number)
254
+ if eval_xs is None:
255
+ return None
256
+ print(f"No dataset could be generated {ds_name} {bptt}")
257
+
258
+ eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
259
+
260
+ start_time = time.time()
261
+
262
+ if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
263
+ outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, categorical_feats=categorical_feats, **kwargs), None
264
+ else:
265
+ _, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
266
+ , eval_pos=eval_position
267
+ , max_time=max_time, **kwargs)
268
+
269
+ eval_ys = eval_ys[eval_position:]
270
+ if outputs is None:
271
+ return None
272
+
273
+ if torch.is_tensor(outputs): # Transfers data to cpu for saving
274
+ outputs = outputs.cpu()
275
+ eval_ys = eval_ys.cpu()
276
+
277
+ ds_result = None, outputs, eval_ys, best_configs, time.time() - start_time
278
+
279
+ if save:
280
+ with open(path, 'wb') as f:
281
+ np.save(f, ds_result)
282
+ print(f'saved results to {path}')
283
+
284
+ return ds_result
TabPFN/scripts/tabular_metrics.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ===============================
3
+ Metrics calculation
4
+ ===============================
5
+ Includes a few metric as well as functions composing metrics on results files.
6
+
7
+ """
8
+
9
+
10
+
11
+ import numpy as np
12
+ import torch
13
+ from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, average_precision_score
14
+ from scipy.stats import rankdata
15
+ import pandas as pd
16
+
17
+ """
18
+ ===============================
19
+ Metrics calculation
20
+ ===============================
21
+ """
22
+ def auc_metric(target, pred, multi_class='ovo', numpy=False):
23
+ lib = np if numpy else torch
24
+ try:
25
+ if not numpy:
26
+ target = torch.tensor(target) if not torch.is_tensor(target) else target
27
+ pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
28
+ if len(lib.unique(target)) > 2:
29
+ if not numpy:
30
+ return torch.tensor(roc_auc_score(target, pred, multi_class=multi_class))
31
+ return roc_auc_score(target, pred, multi_class=multi_class)
32
+ else:
33
+ if len(pred.shape) == 2:
34
+ pred = pred[:, 1]
35
+ if not numpy:
36
+ return torch.tensor(roc_auc_score(target, pred))
37
+ return roc_auc_score(target, pred)
38
+ except ValueError as e:
39
+ print(e)
40
+ return np.nan
41
+
42
+ def accuracy_metric(target, pred):
43
+ target = torch.tensor(target) if not torch.is_tensor(target) else target
44
+ pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
45
+ if len(torch.unique(target)) > 2:
46
+ return torch.tensor(accuracy_score(target, torch.argmax(pred, -1)))
47
+ else:
48
+ return torch.tensor(accuracy_score(target, pred[:, 1] > 0.5))
49
+
50
+ def average_precision_metric(target, pred):
51
+ target = torch.tensor(target) if not torch.is_tensor(target) else target
52
+ pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
53
+ if len(torch.unique(target)) > 2:
54
+ return torch.tensor(average_precision_score(target, torch.argmax(pred, -1)))
55
+ else:
56
+ return torch.tensor(average_precision_score(target, pred[:, 1] > 0.5))
57
+
58
+ def balanced_accuracy_metric(target, pred):
59
+ target = torch.tensor(target) if not torch.is_tensor(target) else target
60
+ pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
61
+ if len(torch.unique(target)) > 2:
62
+ return torch.tensor(balanced_accuracy_score(target, torch.argmax(pred, -1)))
63
+ else:
64
+ return torch.tensor(balanced_accuracy_score(target, pred[:, 1] > 0.5))
65
+
66
+ def cross_entropy(target, pred):
67
+ target = torch.tensor(target) if not torch.is_tensor(target) else target
68
+ pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
69
+ if len(torch.unique(target)) > 2:
70
+ ce = torch.nn.CrossEntropyLoss()
71
+ return ce(pred.float(), target.long())
72
+ else:
73
+ bce = torch.nn.BCELoss()
74
+ return bce(pred[:, 1].float(), target.float())
75
+
76
+ def time_metric():
77
+ """
78
+ Dummy function, will just be used as a handler.
79
+ """
80
+ pass
81
+
82
+ def count_metric(x, y):
83
+ """
84
+ Dummy function, returns one count per dataset.
85
+ """
86
+ return 1
87
+
88
+ """
89
+ ===============================
90
+ Metrics composition
91
+ ===============================
92
+ """
93
+ def calculate_score_per_method(metric, name:str, global_results:dict, ds:list, eval_positions:list, aggregator:str='mean'):
94
+ """
95
+ Calculates the metric given by 'metric' and saves it under 'name' in the 'global_results'
96
+
97
+ :param metric: Metric function
98
+ :param name: Name of metric in 'global_results'
99
+ :param global_results: Dicrtonary containing the results for current method for a collection of datasets
100
+ :param ds: Dataset to calculate metrics on, a list of dataset properties
101
+ :param eval_positions: List of positions to calculate metrics on
102
+ :param aggregator: Specifies way to aggregate results across evaluation positions
103
+ :return:
104
+ """
105
+ aggregator_f = np.nanmean if aggregator == 'mean' else np.nansum
106
+ for pos in eval_positions:
107
+ valid_positions = 0
108
+ for d in ds:
109
+ if f'{d[0]}_outputs_at_{pos}' in global_results:
110
+ preds = global_results[f'{d[0]}_outputs_at_{pos}']
111
+ y = global_results[f'{d[0]}_ys_at_{pos}']
112
+
113
+ preds, y = preds.detach().cpu().numpy() if torch.is_tensor(
114
+ preds) else preds, y.detach().cpu().numpy() if torch.is_tensor(y) else y
115
+
116
+ try:
117
+ if metric == time_metric:
118
+ global_results[f'{d[0]}_{name}_at_{pos}'] = global_results[f'{d[0]}_time_at_{pos}']
119
+ valid_positions = valid_positions + 1
120
+ else:
121
+ global_results[f'{d[0]}_{name}_at_{pos}'] = aggregator_f(
122
+ [metric(y[split], preds[split]) for split in range(y.shape[0])])
123
+ valid_positions = valid_positions + 1
124
+ except Exception as err:
125
+ print(f'Error calculating metric with {err}, {type(err)} at {d[0]} {pos} {name}')
126
+ global_results[f'{d[0]}_{name}_at_{pos}'] = np.nan
127
+ else:
128
+ global_results[f'{d[0]}_{name}_at_{pos}'] = np.nan
129
+
130
+ if valid_positions > 0:
131
+ global_results[f'{aggregator}_{name}_at_{pos}'] = aggregator_f([global_results[f'{d[0]}_{name}_at_{pos}'] for d in ds])
132
+ else:
133
+ global_results[f'{aggregator}_{name}_at_{pos}'] = np.nan
134
+
135
+ for d in ds:
136
+ metrics = [global_results[f'{d[0]}_{name}_at_{pos}'] for pos in eval_positions]
137
+ metrics = [m for m in metrics if not np.isnan(m)]
138
+ global_results[f'{d[0]}_{aggregator}_{name}'] = aggregator_f(metrics) if len(metrics) > 0 else np.nan
139
+
140
+ metrics = [global_results[f'{aggregator}_{name}_at_{pos}'] for pos in eval_positions]
141
+ metrics = [m for m in metrics if not np.isnan(m)]
142
+ global_results[f'{aggregator}_{name}'] = aggregator_f(metrics) if len(metrics) > 0 else np.nan
143
+
144
+
145
+ def calculate_score(metric, name, global_results, ds, eval_positions, aggregator='mean', limit_to=''):
146
+ """
147
+ Calls calculate_metrics_by_method with a range of methods. See arguments of that method.
148
+ :param limit_to: This method will not get metric calculations.
149
+ """
150
+ for m in global_results:
151
+ if limit_to not in m:
152
+ continue
153
+ calculate_score_per_method(metric, name, global_results[m], ds, eval_positions, aggregator=aggregator)
154
+
155
+
156
+ def make_metric_matrix(global_results, methods, pos, name, ds):
157
+ result = []
158
+ for m in global_results:
159
+ result += [[global_results[m][d[0] + '_' + name + '_at_' + str(pos)] for d in ds]]
160
+ result = np.array(result)
161
+ result = pd.DataFrame(result.T, index=[d[0] for d in ds], columns=[k[:-8] for k in list(global_results.keys())])
162
+
163
+ matrix_means, matrix_stds = [], []
164
+
165
+ for method in methods:
166
+ matrix_means += [result.iloc[:, [(method) in c for c in result.columns]].mean(axis=1)]
167
+ matrix_stds += [result.iloc[:, [(method) in c for c in result.columns]].std(axis=1)]
168
+
169
+ matrix_means = pd.DataFrame(matrix_means, index=methods).T
170
+ matrix_stds = pd.DataFrame(matrix_stds, index=methods).T
171
+
172
+ return matrix_means, matrix_stds
173
+
174
+
175
+ def make_ranks_and_wins_table(matrix):
176
+ for dss in matrix.T:
177
+ matrix.loc[dss] = rankdata(-matrix.round(3).loc[dss])
178
+ ranks_acc = matrix.mean()
179
+ wins_acc = (matrix == 1).sum()
180
+
181
+ return ranks_acc, wins_acc
TabPFN/scripts/transformer_prediction_interface.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+
4
+ from torch.utils.checkpoint import checkpoint
5
+
6
+ from utils import normalize_data, to_ranking_low_mem, remove_outliers
7
+ from priors.utils import normalize_by_used_features_f
8
+ from utils import NOP
9
+
10
+ from sklearn.preprocessing import PowerTransformer, QuantileTransformer, RobustScaler
11
+
12
+ from notebook_utils import CustomUnpickler
13
+
14
+ import numpy as np
15
+ from sklearn.base import BaseEstimator, ClassifierMixin
16
+ from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
17
+ from sklearn.utils.multiclass import check_classification_targets
18
+ from sklearn.utils import column_or_1d
19
+ from pathlib import Path
20
+ from model_builder import load_model
21
+ import os
22
+
23
+ def load_model_workflow(i, e, add_name, base_path, device='cpu', eval_addition=''):
24
+ """
25
+ Workflow for loading a model and setting appropriate parameters for diffable hparam tuning.
26
+
27
+ :param i:
28
+ :param e:
29
+ :param eval_positions_valid:
30
+ :param add_name:
31
+ :param base_path:
32
+ :param device:
33
+ :param eval_addition:
34
+ :return:
35
+ """
36
+ def check_file(e):
37
+ model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
38
+ model_path = os.path.join(base_path, model_file)
39
+ # print('Evaluate ', model_path)
40
+ results_file = os.path.join(base_path,
41
+ f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
42
+ if not Path(model_path).is_file(): # or Path(results_file).is_file():
43
+ return None, None, None
44
+ return model_file, model_path, results_file
45
+
46
+ model_file = None
47
+ if e == -1:
48
+ for e_ in range(100, -1, -1):
49
+ model_file_, model_path_, results_file_ = check_file(e_)
50
+ if model_file_ is not None:
51
+ e = e_
52
+ model_file, model_path, results_file = model_file_, model_path_, results_file_
53
+ break
54
+ else:
55
+ model_file, model_path, results_file = check_file(e)
56
+
57
+ if model_file is None:
58
+ print('No checkpoint found')
59
+ return None
60
+
61
+ print(f'Loading {model_file}')
62
+
63
+ model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False)
64
+
65
+ return model, c, results_file
66
+
67
+
68
+ class TabPFNClassifier(BaseEstimator, ClassifierMixin):
69
+
70
+ def __init__(self, device='cpu', base_path='.'):
71
+ # Model file specification (Model name, Epoch)
72
+ model_string = ''
73
+ i, e = '8x_lr0.0003', -1
74
+
75
+ # File which contains result of hyperparameter tuning run: style (i.e. hyperparameters) and a dataframe with results.
76
+ style_file = 'prior_tuning_result.pkl'
77
+
78
+ model, c, results_file = load_model_workflow(i, e, add_name=model_string, base_path=base_path, device=device,
79
+ eval_addition='')
80
+ style, temperature = self.load_result_minimal(style_file, i, e, base_path=base_path)
81
+
82
+ self.device = device
83
+ self.base_path = base_path
84
+ self.model = model
85
+ self.c = c
86
+ self.style = style
87
+ self.temperature = temperature
88
+
89
+ self.max_num_features = self.c['num_features']
90
+ self.max_num_classes = self.c['max_num_classes']
91
+
92
+ def load_result_minimal(self, path, i, e, base_path='.'):
93
+ with open(os.path.join(base_path,path), 'rb') as output:
94
+ _, _, _, style, temperature, optimization_route = CustomUnpickler(output).load()
95
+
96
+ return style, temperature
97
+
98
+ def fit(self, X, y):
99
+ # Check that X and y have correct shape
100
+ X, y = check_X_y(X, y)
101
+ y = self._validate_targets(y)
102
+
103
+ self.X_ = X
104
+ self.y_ = y
105
+
106
+ if X.shape[1] > self.max_num_features:
107
+ raise ValueError("The number of features for this classifier is restricted to ", self.max_num_features)
108
+ if len(np.unique(y)) > self.max_num_classes:
109
+ raise ValueError("The number of classes for this classifier is restricted to ", self.max_num_classes)
110
+
111
+ # Return the classifier
112
+ return self
113
+
114
+ def _validate_targets(self, y):
115
+ y_ = column_or_1d(y, warn=True)
116
+ check_classification_targets(y)
117
+ cls, y = np.unique(y_, return_inverse=True)
118
+ if len(cls) < 2:
119
+ raise ValueError(
120
+ "The number of classes has to be greater than one; got %d class"
121
+ % len(cls)
122
+ )
123
+
124
+ self.classes_ = cls
125
+
126
+ return np.asarray(y, dtype=np.float64, order="C")
127
+
128
+ def predict_proba(self, X):
129
+ # Check is fit had been called
130
+ check_is_fitted(self)
131
+
132
+ # Input validation
133
+ X = check_array(X)
134
+
135
+ X_full = np.concatenate([self.X_, X], axis=0)
136
+ X_full = torch.tensor(X_full, device=self.device).float().unsqueeze(1)
137
+ y_full = np.concatenate([self.y_, self.y_[0] + np.zeros_like(X[:, 0])], axis=0)
138
+ y_full = torch.tensor(y_full, device=self.device).float().unsqueeze(1)
139
+
140
+ eval_pos = self.X_.shape[0]
141
+
142
+ prediction = transformer_predict(self.model[2], X_full, y_full, eval_pos,
143
+ device=self.device,
144
+ style=self.style,
145
+ inference_mode=True,
146
+ N_ensemble_configurations=10,
147
+ softmax_temperature=self.temperature
148
+ , **get_params_from_config(self.c))
149
+ prediction_ = prediction.squeeze(0)
150
+
151
+ return prediction_.detach().cpu().numpy()
152
+
153
+ def predict(self, X, return_winning_probability=False):
154
+ p = self.predict_proba(X)
155
+ y = np.argmax(self.predict_proba(X), axis=-1)
156
+ y = self.classes_.take(np.asarray(y, dtype=np.intp))
157
+ if return_winning_probability:
158
+ return y, p.max(axis=-1)
159
+ return y
160
+
161
+ def transformer_predict(model, eval_xs, eval_ys, eval_position,
162
+ device='cpu',
163
+ max_features=100,
164
+ style=None,
165
+ inference_mode=False,
166
+ num_classes=2,
167
+ extend_features=True,
168
+ normalize_to_ranking=False,
169
+ softmax_temperature=0.0,
170
+ multiclass_decoder='permutation',
171
+ preprocess_transform='mix',
172
+ categorical_feats=[],
173
+ feature_shift_decoder=True,
174
+ N_ensemble_configurations=10,
175
+ average_logits=True,
176
+ normalize_with_sqrt=False, **kwargs):
177
+ """
178
+
179
+ :param model:
180
+ :param eval_xs:
181
+ :param eval_ys: should be classes that are 0-indexed and every class until num_classes-1 is present
182
+ :param eval_position:
183
+ :param rescale_features:
184
+ :param device:
185
+ :param max_features:
186
+ :param style:
187
+ :param inference_mode:
188
+ :param num_classes:
189
+ :param extend_features:
190
+ :param normalize_to_ranking:
191
+ :param softmax_temperature:
192
+ :param multiclass_decoder:
193
+ :param preprocess_transform:
194
+ :param categorical_feats:
195
+ :param feature_shift_decoder:
196
+ :param N_ensemble_configurations:
197
+ :param average_logits:
198
+ :param normalize_with_sqrt:
199
+ :param metric_used:
200
+ :return:
201
+ """
202
+ num_classes = len(torch.unique(eval_ys))
203
+
204
+ def predict(eval_xs, eval_ys, used_style, softmax_temperature, return_logits):
205
+ # Initialize results array size S, B, Classes
206
+
207
+ inference_mode_call = torch.inference_mode() if inference_mode else NOP()
208
+ with inference_mode_call:
209
+ output = model(
210
+ (used_style.repeat(eval_xs.shape[1], 1) if used_style is not None else None, eval_xs, eval_ys.float()),
211
+ single_eval_pos=eval_position)[:, :, 0:num_classes]
212
+
213
+ output = output[:, :, 0:num_classes] / torch.exp(softmax_temperature)
214
+ if not return_logits:
215
+ output = torch.nn.functional.softmax(output, dim=-1)
216
+ #else:
217
+ # output[:, :, 1] = model((style.repeat(eval_xs.shape[1], 1) if style is not None else None, eval_xs, eval_ys.float()),
218
+ # single_eval_pos=eval_position)
219
+
220
+ # output[:, :, 1] = torch.sigmoid(output[:, :, 1]).squeeze(-1)
221
+ # output[:, :, 0] = 1 - output[:, :, 1]
222
+
223
+ #print('RESULTS', eval_ys.shape, torch.unique(eval_ys, return_counts=True), output.mean(axis=0))
224
+
225
+ return output
226
+
227
+ def preprocess_input(eval_xs, preprocess_transform):
228
+ import warnings
229
+
230
+ if eval_xs.shape[1] > 1:
231
+ raise Exception("Transforms only allow one batch dim - TODO")
232
+ if preprocess_transform != 'none':
233
+ if preprocess_transform == 'power' or preprocess_transform == 'power_all':
234
+ pt = PowerTransformer(standardize=True)
235
+ elif preprocess_transform == 'quantile' or preprocess_transform == 'quantile_all':
236
+ pt = QuantileTransformer(output_distribution='normal')
237
+ elif preprocess_transform == 'robust' or preprocess_transform == 'robust_all':
238
+ pt = RobustScaler(unit_variance=True)
239
+
240
+ # eval_xs, eval_ys = normalize_data(eval_xs), normalize_data(eval_ys)
241
+ eval_xs = normalize_data(eval_xs)
242
+
243
+ # Removing empty features
244
+ eval_xs = eval_xs[:, 0, :].cpu().numpy()
245
+ sel = [len(np.unique(eval_xs[0:eval_ys.shape[0], col])) > 1 for col in range(eval_xs.shape[1])]
246
+ eval_xs = np.array(eval_xs[:, sel])
247
+
248
+ warnings.simplefilter('error')
249
+ if preprocess_transform != 'none':
250
+ feats = set(range(eval_xs.shape[1])) if 'all' in preprocess_transform else set(
251
+ range(eval_xs.shape[1])) - set(categorical_feats)
252
+ for col in feats:
253
+ try:
254
+ pt.fit(eval_xs[0:eval_ys.shape[0], col:col + 1])
255
+ trans = pt.transform(eval_xs[:, col:col + 1])
256
+ # print(scipy.stats.spearmanr(trans[~np.isnan(eval_xs[:, col:col+1])], eval_xs[:, col:col+1][~np.isnan(eval_xs[:, col:col+1])]))
257
+ eval_xs[:, col:col + 1] = trans
258
+ except:
259
+ pass
260
+ warnings.simplefilter('default')
261
+
262
+ eval_xs = torch.tensor(eval_xs).float().unsqueeze(1).to(device)
263
+
264
+ # eval_xs = normalize_data(eval_xs)
265
+
266
+ # TODO: Cautian there is information leakage when to_ranking is used, we should not use it
267
+ eval_xs = remove_outliers(eval_xs) if not normalize_to_ranking else normalize_data(to_ranking_low_mem(eval_xs))
268
+
269
+ # Rescale X
270
+ eval_xs = normalize_by_used_features_f(eval_xs, eval_xs.shape[-1], max_features,
271
+ normalize_with_sqrt=normalize_with_sqrt)
272
+ return eval_xs.detach()
273
+
274
+ eval_xs, eval_ys = eval_xs.to(device), eval_ys.to(device)
275
+ eval_ys = eval_ys[:eval_position]
276
+
277
+ model.to(device)
278
+ style = style.to(device)
279
+
280
+ model.eval()
281
+
282
+ import itertools
283
+ style = style.unsqueeze(0) if len(style.shape) == 1 else style
284
+ num_styles = style.shape[0]
285
+ styles_configurations = range(0, num_styles)
286
+ preprocess_transform_configurations = [preprocess_transform if i % 2 == 0 else 'none' for i in range(0, num_styles)]
287
+ if preprocess_transform == 'mix':
288
+ def get_preprocess(i):
289
+ if i == 0:
290
+ return 'power_all'
291
+ if i == 1:
292
+ return 'robust_all'
293
+ if i == 2:
294
+ return 'none'
295
+ preprocess_transform_configurations = [get_preprocess(i) for i in range(0, num_styles)]
296
+ styles_configurations = zip(styles_configurations, preprocess_transform_configurations)
297
+
298
+ feature_shift_configurations = range(0, eval_xs.shape[2]) if feature_shift_decoder else [0]
299
+ class_shift_configurations = range(0, len(torch.unique(eval_ys))) if multiclass_decoder == 'permutation' else [0]
300
+
301
+ ensemble_configurations = list(itertools.product(styles_configurations, feature_shift_configurations, class_shift_configurations))
302
+ random.shuffle(ensemble_configurations)
303
+ ensemble_configurations = ensemble_configurations[0:N_ensemble_configurations]
304
+
305
+ output = None
306
+
307
+ eval_xs_transformed = {}
308
+ for ensemble_configuration in ensemble_configurations:
309
+ (styles_configuration, preprocess_transform_configuration), feature_shift_configuration, class_shift_configuration = ensemble_configuration
310
+
311
+ style_ = style[styles_configuration:styles_configuration+1, :]
312
+ softmax_temperature_ = softmax_temperature[styles_configuration]
313
+
314
+ eval_xs_, eval_ys_ = eval_xs.clone(), eval_ys.clone()
315
+
316
+ if preprocess_transform_configuration in eval_xs_transformed:
317
+ eval_xs_ = eval_xs_transformed['preprocess_transform_configuration'].clone()
318
+ else:
319
+ eval_xs_ = preprocess_input(eval_xs_, preprocess_transform=preprocess_transform_configuration)
320
+ eval_xs_transformed['preprocess_transform_configuration'] = eval_xs_
321
+
322
+ eval_ys_ = ((eval_ys_ + class_shift_configuration) % num_classes).float()
323
+
324
+ eval_xs_ = torch.cat([eval_xs_[..., feature_shift_configuration:],eval_xs_[..., :feature_shift_configuration]],dim=-1)
325
+
326
+ # Extend X
327
+ if extend_features:
328
+ eval_xs_ = torch.cat(
329
+ [eval_xs_,
330
+ torch.zeros((eval_xs_.shape[0], eval_xs_.shape[1], max_features - eval_xs_.shape[2])).to(device)], -1)
331
+
332
+ #preprocess_transform_ = preprocess_transform if styles_configuration % 2 == 0 else 'none'
333
+ import warnings
334
+ with warnings.catch_warnings():
335
+ warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
336
+ output_ = checkpoint(predict, eval_xs_, eval_ys_, style_, softmax_temperature_, True)
337
+ output_ = torch.cat([output_[..., class_shift_configuration:],output_[..., :class_shift_configuration]],dim=-1)
338
+
339
+ #output_ = predict(eval_xs, eval_ys, style_, preprocess_transform_)
340
+ if not average_logits:
341
+ output_ = torch.nn.functional.softmax(output_, dim=-1)
342
+ output = output_ if output is None else output + output_
343
+
344
+ output = output / len(ensemble_configurations)
345
+ if average_logits:
346
+ output = torch.nn.functional.softmax(output, dim=-1)
347
+
348
+ output = torch.transpose(output, 0, 1)
349
+
350
+ return output
351
+
352
+ def get_params_from_config(c):
353
+ return {'max_features': c['num_features']
354
+ , 'rescale_features': c["normalize_by_used_features"]
355
+ , 'normalize_to_ranking': c["normalize_to_ranking"]
356
+ , 'normalize_with_sqrt': c.get("normalize_with_sqrt", False)
357
+ }
TabPFN/tabular_evaluation.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from tqdm import tqdm
6
+ import random
7
+ import numpy as np
8
+
9
+ from torch import nn
10
+
11
+ from utils import torch_nanmean
12
+ from datasets import *
13
+ from model_builder import load_model
14
+ from scripts.tabular_baselines import get_scoring_string
15
+ from scripts import tabular_metrics
16
+ from scripts.transformer_prediction_interface import *
17
+ from scripts.baseline_prediction_interface import *
18
+ """
19
+ ===============================
20
+ PUBLIC FUNCTIONS FOR EVALUATION
21
+ ===============================
22
+ """
23
+
24
+
25
+ def eval_model(i, e, valid_datasets, test_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
26
+ metrics_test, config_sample, model_path = eval_model_on_ds(i, e, test_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
27
+ metrics_valid, _, _ = eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
28
+ return {'mean_auc_test': metrics_test['mean_roc_at_1000'], 'mean_auc_valid': metrics_valid['mean_roc_at_1000'], 'mean_ce_test': metrics_test['mean_ce_at_1000'], 'mean_ce_valid': metrics_valid['mean_ce_at_1000'], 'config_sample': config_sample, 'model_path': model_path}
29
+
30
+ def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
31
+
32
+ # How to use: evaluate_without_fitting(i,0,valid_datasets, [1024], 100000, add_name=model_string, base_path=base_path,)
33
+ def check_file(e):
34
+ model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
35
+ model_path = os.path.join(base_path, model_file)
36
+ # print('Evaluate ', model_path)
37
+ results_file = os.path.join(base_path,
38
+ f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
39
+ if not Path(model_path).is_file(): # or Path(results_file).is_file():
40
+ # print('checkpoint exists: ', Path(model_file).is_file(), ', results are written:', Path(results_file).is_file())
41
+ return None, None, None
42
+ return model_file, model_path, results_file
43
+
44
+ if e == -1: # use last checkpoint, if e == -1
45
+ for e_ in range(100, -1, -1):
46
+ model_file_, model_path_, results_file_ = check_file(e_)
47
+ if model_file_ is not None:
48
+ e = e_
49
+ model_file, model_path, results_file = model_file_, model_path_, results_file_
50
+ break
51
+ else:
52
+ model_file, model_path, results_file = check_file(e)
53
+
54
+ model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
55
+
56
+ params = {'max_features': config_sample['num_features']
57
+ , 'rescale_features': config_sample["normalize_by_used_features"]
58
+ , 'normalize_to_ranking': config_sample["normalize_to_ranking"]
59
+ , 'normalize_with_sqrt': config_sample.get("normalize_with_sqrt", False)
60
+ }
61
+ metrics_valid = evaluate(datasets=valid_datasets, model=model[2], method='transformer', device=device, overwrite=True,
62
+ extend_features=True
63
+ # just removed the style keyword but transformer is trained with style, just empty
64
+ , save=False
65
+ , metric_used=tabular_metrics.cross_entropy
66
+ , return_tensor=True
67
+ , verbose=False
68
+ , eval_positions=eval_positions
69
+ , bptt=bptt
70
+ , base_path=None
71
+ , inference_mode=True
72
+ , **params
73
+ , **kwargs)
74
+
75
+ tabular_metrics.calculate_score_per_method(tabular_metrics.auc_metric, 'roc', metrics_valid, valid_datasets, eval_positions)
76
+ tabular_metrics.calculate_score_per_method(tabular_metrics.cross_entropy, 'ce', metrics_valid, valid_datasets, eval_positions)
77
+
78
+ return metrics_valid, config_sample, model_path
79
+
80
+
81
+ def evaluate(datasets, bptt, eval_positions, metric_used, model
82
+ , verbose=False
83
+ , return_tensor=False
84
+ , **kwargs):
85
+ """
86
+ Evaluates a list of datasets for a model function.
87
+
88
+ :param datasets: List of datasets
89
+ :param bptt: maximum sequence length
90
+ :param eval_positions: List of positions where to evaluate models
91
+ :param verbose: If True, is verbose.
92
+ :param metric_used: Which metric is optimized for.
93
+ :param return_tensor: Wheater to return results as a pytorch.tensor or numpy, this is only relevant for transformer.
94
+ :param kwargs:
95
+ :return:
96
+ """
97
+ overall_result = {'metric_used': get_scoring_string(metric_used)
98
+ , 'bptt': bptt
99
+ , 'eval_positions': eval_positions}
100
+
101
+ aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
102
+
103
+ # For each dataset
104
+ for [ds_name, X, y, categorical_feats, _, _] in tqdm.tqdm(datasets, desc='Iterate over datasets') if verbose else datasets:
105
+ dataset_bptt = min(len(X), bptt)
106
+ # if verbose and dataset_bptt < bptt:
107
+ # print(f'Dataset too small for given sequence length, reducing to {len(X)} ({bptt})')
108
+
109
+ aggregated_metric, num = torch.tensor(0.0), 0
110
+ ds_result = {}
111
+
112
+ for eval_position in (eval_positions if verbose else eval_positions):
113
+ eval_position_real = int(dataset_bptt * 0.5) if 2 * eval_position > dataset_bptt else eval_position
114
+ eval_position_bptt = int(eval_position_real * 2.0)
115
+
116
+ r = evaluate_position(X, y, model=model
117
+ , num_classes=len(torch.unique(y))
118
+ , categorical_feats = categorical_feats
119
+ , bptt = eval_position_bptt
120
+ , ds_name=ds_name
121
+ , eval_position = eval_position_real
122
+ , metric_used = metric_used
123
+ ,**kwargs)
124
+
125
+ if r is None:
126
+ continue
127
+
128
+ _, outputs, ys, best_configs, time_used = r
129
+
130
+ if torch.is_tensor(outputs):
131
+ outputs = outputs.to(outputs.device)
132
+ ys = ys.to(outputs.device)
133
+
134
+ ys = ys.T
135
+ ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
136
+ ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
137
+ ds_result[f'{ds_name}_ys_at_{eval_position}'] = ys
138
+ ds_result[f'{ds_name}_time_at_{eval_position}'] = time_used
139
+
140
+ new_metric = torch_nanmean(torch.stack([metric_used(ys[i], outputs[i]) for i in range(ys.shape[0])]))
141
+
142
+ if not return_tensor:
143
+ make_scalar = lambda x: float(x.detach().cpu().numpy()) if (torch.is_tensor(x) and (len(x.shape) == 0)) else x
144
+ new_metric = make_scalar(new_metric)
145
+ ds_result = {k: make_scalar(ds_result[k]) for k in ds_result.keys()}
146
+
147
+ lib = torch if return_tensor else np
148
+ if not lib.isnan(new_metric).any():
149
+ aggregated_metric, num = aggregated_metric + new_metric, num + 1
150
+
151
+ overall_result.update(ds_result)
152
+ if num > 0:
153
+ aggregated_metric_datasets, num_datasets = (aggregated_metric_datasets + (aggregated_metric / num)), num_datasets + 1
154
+
155
+ overall_result['mean_metric'] = aggregated_metric_datasets / num_datasets
156
+
157
+ return overall_result
158
+
159
+ """
160
+ ===============================
161
+ INTERNAL HELPER FUNCTIONS
162
+ ===============================
163
+ """
164
+
165
+ def check_file_exists(path):
166
+ """Checks if a pickle file exists. Returns None if not, else returns the unpickled file."""
167
+ if (os.path.isfile(path)):
168
+ print(f'loading results from {path}')
169
+ with open(path, 'rb') as f:
170
+ return np.load(f, allow_pickle=True).tolist()
171
+ return None
172
+
173
+ def generate_valid_split(X, y, bptt, eval_position, split_number=1):
174
+ """Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
175
+ the entire datasets. If no such split can be sampled in 7 passes, returns None.
176
+
177
+ :param X: torch tensor, feature values
178
+ :param y: torch tensor, class values
179
+ :param bptt: Number of samples in train + test
180
+ :param eval_position: Number of samples in train, i.e. from which index values are in test
181
+ :param split_number: The split id
182
+ :return:
183
+ """
184
+ done, seed = False, 13
185
+
186
+ torch.manual_seed(split_number)
187
+ perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
188
+ X, y = X[perm], y[perm]
189
+
190
+ while not done:
191
+ if seed > 20:
192
+ return None, None # No split could be generated in 7 passes, return None
193
+ random.seed(seed)
194
+ i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
195
+ y_ = y[i:i + bptt]
196
+
197
+ # Checks if all classes from dataset are contained and classes in train and test are equal (contain same
198
+ # classes) and
199
+ done = len(torch.unique(y_)) == len(torch.unique(y))
200
+ done = done and torch.all(torch.unique(y_) == torch.unique(y))
201
+ done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
202
+ done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
203
+ seed = seed + 1
204
+
205
+ eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
206
+ eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
207
+
208
+ return eval_xs, eval_ys
209
+
210
+
211
+ def evaluate_position(X, y, categorical_feats, model, bptt
212
+ , eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
213
+ , max_time=300, split_number=1
214
+ , per_step_normalization=False, **kwargs):
215
+ """
216
+ Evaluates a dataset with a 'bptt' number of training samples.
217
+
218
+ :param X: Dataset X
219
+ :param y: Dataset labels
220
+ :param categorical_feats: Indices of categorical features.
221
+ :param model: Model function
222
+ :param bptt: Sequence length.
223
+ :param eval_position: Number of training samples.
224
+ :param overwrite: Wheater to ove
225
+ :param overwrite: If True, results on disk are overwritten.
226
+ :param save:
227
+ :param path_interfix: Used for constructing path to write on disk.
228
+ :param method: Model name.
229
+ :param ds_name: Datset name.
230
+ :param fetch_only: Wheater to calculate or only fetch results.
231
+ :param per_step_normalization:
232
+ :param kwargs:
233
+ :return:
234
+ """
235
+
236
+ if save:
237
+ path = os.path.join(base_path, f'results/tabular/{path_interfix}/results_{method}_{ds_name}_{eval_position}_{bptt}_{split_number}.npy')
238
+ #log_path =
239
+
240
+ ## Load results if on disk
241
+ if not overwrite:
242
+ result = check_file_exists(path)
243
+ if result is not None:
244
+ if not fetch_only:
245
+ print(f'Loaded saved result for {path}')
246
+ return result
247
+ elif fetch_only:
248
+ print(f'Could not load saved result for {path}')
249
+ return None
250
+
251
+ ## Generate data splits
252
+ eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position, split_number=split_number)
253
+ if eval_xs is None:
254
+ return None
255
+ print(f"No dataset could be generated {ds_name} {bptt}")
256
+
257
+ eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
258
+
259
+ start_time = time.time()
260
+
261
+ if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
262
+ outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, categorical_feats=categorical_feats, **kwargs), None
263
+ else:
264
+ _, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
265
+ , eval_pos=eval_position
266
+ , max_time=max_time, **kwargs)
267
+
268
+ eval_ys = eval_ys[eval_position:]
269
+ if outputs is None:
270
+ return None
271
+
272
+ if torch.is_tensor(outputs): # Transfers data to cpu for saving
273
+ outputs = outputs.cpu()
274
+ eval_ys = eval_ys.cpu()
275
+
276
+ ds_result = None, outputs, eval_ys, best_configs, time.time() - start_time
277
+
278
+ if save:
279
+ with open(path, 'wb') as f:
280
+ np.save(f, ds_result)
281
+ print(f'saved results to {path}')
282
+
283
+ return ds_result
TabPFN/train.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import itertools
3
+ import argparse
4
+ import time
5
+ import datetime
6
+ import yaml
7
+ from contextlib import nullcontext
8
+
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ import utils
14
+ from transformer import TransformerModel
15
+ from utils import get_cosine_schedule_with_warmup, get_openai_lr, StoreDictKeyPair, get_weighted_single_eval_pos_sampler, get_uniform_single_eval_pos_sampler
16
+ import priors
17
+ import encoders
18
+ import positional_encodings
19
+ from utils import init_dist
20
+ from torch.cuda.amp import autocast
21
+
22
+ class Losses():
23
+ gaussian = nn.GaussianNLLLoss(full=True, reduction='none')
24
+ mse = nn.MSELoss(reduction='none')
25
+ ce = lambda weight : nn.CrossEntropyLoss(reduction='none', weight=weight)
26
+ bce = nn.BCEWithLogitsLoss(reduction='none')
27
+
28
+
29
+ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.2,
30
+ epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False,
31
+ y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup,
32
+ load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, bptt_extra_samples=None, gpu_device='cuda:0',
33
+ aggregate_k_gradients=1, verbose=True, style_encoder_generator=None, check_is_compatible=True, epoch_callback=None,
34
+ initializer=None, initialize_with_model=None, train_mixed_precision=False, total_available_time_in_s=None, normalize_labels=True, **model_extra_args
35
+ ):
36
+ assert (epochs is None) != (total_available_time_in_s is None)
37
+ start_of_training = time.time()
38
+ device = gpu_device if torch.cuda.is_available() else 'cpu:0'
39
+ print(f'Using {device} device')
40
+ using_dist, rank, device = init_dist(device)
41
+ bptt_sampler = (lambda : single_eval_pos_gen() + bptt_extra_samples if callable(single_eval_pos_gen) else single_eval_pos_gen + bptt_extra_samples) if bptt_extra_samples is not None else bptt
42
+ dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, seq_len=bptt_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict)
43
+ if dl.fuse_x_y:
44
+ raise Exception("Illegal parameter")
45
+
46
+ encoder = encoder_generator(dl.num_features+1 if dl.fuse_x_y else dl.num_features,emsize)
47
+ style_def = next(iter(dl))[0][0] # This is (style, x, y), target with x and y with batch size
48
+
49
+ style_encoder = style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize) if (style_def is not None) else None
50
+ n_out = dl.num_outputs
51
+ if isinstance(criterion, nn.GaussianNLLLoss):
52
+ n_out *= 2
53
+ elif isinstance(criterion, nn.CrossEntropyLoss):
54
+ n_out *= criterion.weight.shape[0]
55
+ model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout, style_encoder=style_encoder,
56
+ y_encoder=y_encoder_generator(dl.num_outputs, emsize), input_normalization=input_normalization,
57
+ pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2),
58
+ decoder=decoder, init_method=initializer, **model_extra_args
59
+ )
60
+ model.criterion = criterion
61
+ if load_weights_from_this_state_dict is not None:
62
+ model.load_state_dict(load_weights_from_this_state_dict)
63
+ if initialize_with_model is not None:
64
+ model.init_from_small_model(initialize_with_model)
65
+
66
+ print(f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters")
67
+
68
+ try:
69
+ for (k, v), (k2, v2) in zip(model.state_dict().items(), initialize_with_model.state_dict().items()):
70
+ print(k, ((v - v2) / v).abs().mean(), v.shape)
71
+ except Exception:
72
+ pass
73
+
74
+ model.to(device)
75
+ if using_dist:
76
+ print("Distributed training")
77
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
78
+
79
+
80
+ # learning rate
81
+ if lr is None:
82
+ lr = get_openai_lr(model)
83
+ print(f"Using OpenAI max lr of {lr}.")
84
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
85
+ scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps
86
+
87
+ def train_step():
88
+ model.train() # Turn on the train mode
89
+ total_loss = 0.
90
+ total_positional_losses = 0.
91
+ total_positional_losses_recorded = 0
92
+ before_get_batch = time.time()
93
+ assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.'
94
+ valid_batch_steps = 0.0
95
+ for batch, (data, targets) in enumerate(dl):
96
+ if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1):
97
+ cm = model.no_sync()
98
+ #print(f'p={rank}, no_sync', force=True)
99
+ else:
100
+ cm = nullcontext()
101
+ #print(f'p={rank}, sync', force=True)
102
+ with cm:
103
+ time_to_get_batch = time.time() - before_get_batch
104
+ before_forward = time.time()
105
+ if bptt_extra_samples is None:
106
+ single_eval_pos = single_eval_pos_gen() if callable(single_eval_pos_gen) else single_eval_pos_gen
107
+ else:
108
+ single_eval_pos = targets.shape[0] - bptt_extra_samples
109
+
110
+ is_compatible = torch.ones((targets.shape[1])).bool()
111
+ if check_is_compatible or normalize_labels:
112
+ for b in range(targets.shape[1]):
113
+ targets_in_train = torch.unique(targets[:single_eval_pos, b], sorted=True)
114
+ targets_in_eval = torch.unique(targets[single_eval_pos:, b], sorted=True)
115
+
116
+ if check_is_compatible:
117
+ is_compatible[b] = len(targets_in_train) == len(targets_in_eval) and (targets_in_train == targets_in_eval).all()
118
+ is_compatible[b] = is_compatible[b] and len(targets_in_train) > 1
119
+
120
+ # Set targets to range starting from 0 (e.g. targets 0, 2, 5, 2 will be converted to 0, 1, 2, 1)
121
+ if normalize_labels:
122
+ targets[:, b] = (targets[:, b] > torch.unique(targets[:, b]).unsqueeze(1)).sum(axis=0).unsqueeze(0)
123
+ valid_batch_steps += is_compatible.float().mean()
124
+ is_compatible = is_compatible.to(device)
125
+ #if using_dist and check_is_compatible:
126
+ # print('step share before reduce',curr_step_share, force=True)
127
+ # curr_step_share = curr_step_share.to(device)
128
+ # torch.distributed.all_reduce_multigpu([curr_step_share], op=torch.distributed.ReduceOp.SUM)
129
+ # curr_step_share = curr_step_share.cpu() / torch.distributed.get_world_size()
130
+ # print('step share after reduce',curr_step_share, torch.distributed.get_world_size(), force=True)
131
+
132
+ # If style is set to None, it should not be transferred to device
133
+ output = model(tuple(e.to(device) if torch.is_tensor(e) else e for e in data) if isinstance(data, tuple) else data.to(device)
134
+ , single_eval_pos=single_eval_pos)
135
+
136
+ forward_time = time.time() - before_forward
137
+
138
+ #output, targets = output[:, is_compatible], targets[:, is_compatible]
139
+
140
+ if single_eval_pos is not None:
141
+ targets = targets[single_eval_pos:]
142
+ if isinstance(criterion, nn.GaussianNLLLoss):
143
+ assert output.shape[-1] == 2, \
144
+ 'need to write a little bit of code to handle multiple regression targets at once'
145
+
146
+ mean_pred = output[..., 0]
147
+ var_pred = output[..., 1].abs()
148
+ losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten())
149
+ elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
150
+ losses = criterion(output.flatten(), targets.to(device).flatten())
151
+ elif isinstance(criterion, (nn.CrossEntropyLoss)):
152
+ #print(n_out, targets.min(), targets.max(), force=True)
153
+ losses = criterion(output.reshape(-1, n_out), targets.to(device).long().flatten())
154
+ else:
155
+ losses = criterion(output.reshape(-1, n_out), targets.to(device).flatten())
156
+ losses = losses.view(*output.shape[0:2])
157
+ loss = losses.mean(0) @ is_compatible.float() / losses.shape[1]
158
+ #loss = torch_nanmean(losses, axis=[0, 1]) * is_compatible.float().mean()
159
+ # not sure whether we can go without the nan checks.
160
+
161
+ loss.backward()
162
+
163
+ if ((batch % aggregate_k_gradients == aggregate_k_gradients - 1) and (not check_is_compatible or using_dist))\
164
+ or (valid_batch_steps >= aggregate_k_gradients and (check_is_compatible and not using_dist)):
165
+ with torch.no_grad():
166
+ for p in model.parameters():
167
+ if p.grad is not None:
168
+ p.grad.div_(valid_batch_steps)
169
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
170
+ try:
171
+ optimizer.step()
172
+ except:
173
+ print("Invalid optimization step encountered")
174
+ optimizer.zero_grad()
175
+ valid_batch_steps = 0.0
176
+
177
+ step_time = time.time() - before_forward
178
+
179
+ if not torch.isnan(loss):
180
+ total_loss += loss.item()
181
+ total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \
182
+ nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)*loss.cpu().detach()
183
+
184
+ total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \
185
+ nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
186
+
187
+ before_get_batch = time.time()
188
+ return total_loss / steps_per_epoch, (
189
+ total_positional_losses / total_positional_losses_recorded).tolist(), time_to_get_batch, forward_time, step_time
190
+
191
+ best_val_loss = float("inf")
192
+ best_model = None
193
+ total_loss = float('inf')
194
+ total_positional_losses = float('inf')
195
+ try:
196
+ for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)):
197
+
198
+ epoch_start_time = time.time()
199
+ if train_mixed_precision:
200
+ with autocast():
201
+ total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
202
+ else:
203
+ total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
204
+ if hasattr(dl, 'validate') and epoch % validation_period == 0:
205
+ with torch.no_grad():
206
+ val_score = dl.validate(model)
207
+ else:
208
+ val_score = None
209
+
210
+ if verbose:
211
+ print('-' * 89)
212
+ print(
213
+ f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | '
214
+ f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
215
+ f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}'
216
+ f' forward time {forward_time:5.2f}' + (f'val score {val_score}' if val_score is not None else ''))
217
+ print('-' * 89)
218
+
219
+ # stepping with wallclock time based scheduler
220
+ current_time = time.time()
221
+ if epoch_callback is not None and rank == 0:
222
+ epoch_callback(model, epoch / epochs if total_available_time_in_s is None else # noqa
223
+ (current_time - start_of_training) / total_available_time_in_s # noqa
224
+ )
225
+ if epochs is None and (current_time - start_of_training) > total_available_time_in_s: # noqa
226
+ break
227
+ if epochs is None:
228
+ scheduler.step((current_time - epoch_start_time) / total_available_time_in_s * 100)
229
+ else:
230
+ scheduler.step()
231
+ except KeyboardInterrupt:
232
+ pass
233
+
234
+ return total_loss, total_positional_losses, model.to('cpu'), dl
235
+
236
+ def _parse_args(config_parser, parser):
237
+ # Do we have a config file to parse?
238
+ args_config, remaining = config_parser.parse_known_args()
239
+ if args_config.config:
240
+ with open(args_config.config, 'r') as f:
241
+ cfg = yaml.safe_load(f)
242
+ parser.set_defaults(**cfg)
243
+
244
+ # The main arg parser parses the rest of the args, the usual
245
+ # defaults will have been overridden if config file specified.
246
+ args = parser.parse_args(remaining)
247
+
248
+ # Cache the args as a text string to save them in the output dir later
249
+ args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
250
+ return args, args_text
251
+
252
+
253
+ if __name__ == '__main__':
254
+ config_parser = argparse.ArgumentParser(description='Only used as a first parser for the config file path.')
255
+ config_parser.add_argument('--config')
256
+ parser = argparse.ArgumentParser()
257
+ parser.add_argument('prior')
258
+ parser.add_argument('--loss_function', default='barnll')
259
+ # Optional Arg's for `--loss_function barnll`
260
+ parser.add_argument('--min_y', type=float, help='barnll can only model y in strict ranges, this is the minimum y can take.')
261
+ parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.')
262
+ parser.add_argument('--num_buckets', default=100, type=int)
263
+ #parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
264
+ parser.add_argument("--extra_prior_kwargs_dict", default={'fuse_x_y': False}, dest="extra_prior_kwargs_dict", action=StoreDictKeyPair, nargs="+", metavar="KEY=VAL", help='Specify depending on the prior.')
265
+ parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.')
266
+ parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.')
267
+ parser.add_argument('--pos_encoder', default='sinus', type=str, help='Specify depending on the prior.')
268
+ parser.add_argument('--bptt', default=10, type=int)
269
+ parser.add_argument('--epochs', default=200, type=int)
270
+ parser.add_argument('--warmup_epochs', default=50, type=int)
271
+ parser.add_argument('--validation_period', default=10, type=int)
272
+ parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ')
273
+ parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.")
274
+
275
+ # these can likely be mostly left at defaults
276
+ parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024
277
+ parser.add_argument('--nlayers', default=6, type=int)
278
+ parser.add_argument('--nhid', default=None, type=int) # 2*emsize is the default
279
+ parser.add_argument('--nhead', default=4, type=int) # nhead = emsize / 64 in the original paper
280
+ parser.add_argument('--dropout', default=.0, type=float)
281
+ parser.add_argument('--steps_per_epoch', default=10, type=int)
282
+ parser.add_argument('--batch_size', default=1000, type=int)
283
+ parser.add_argument('--lr', '--learning_rate', default=.001, type=float) # try also .0003, .0001, go lower with lower batch size
284
+
285
+ args, _ = _parse_args(config_parser, parser)
286
+
287
+ if args.nhid is None:
288
+ args.nhid = 2*args.emsize
289
+
290
+ prior = args.__dict__.pop('prior')
291
+
292
+ if prior == 'gp':
293
+ prior = priors.fast_gp.DataLoader
294
+ elif prior == 'ridge':
295
+ prior = priors.ridge.DataLoader
296
+ elif prior == 'stroke':
297
+ prior = priors.stroke.DataLoader
298
+ elif prior == 'mix_gp':
299
+ prior = priors.fast_gp_mix.DataLoader
300
+ else:
301
+ raise NotImplementedError(f'Prior == {prior}.')
302
+
303
+ loss_function = args.__dict__.pop('loss_function')
304
+
305
+ criterion = nn.GaussianNLLLoss(reduction='none', full=True)
306
+ classificiation_criterion = nn.CrossEntropyLoss(reduction='none')
307
+ num_buckets = args.__dict__.pop('num_buckets')
308
+ max_y = args.__dict__.pop('max_y')
309
+ min_y = args.__dict__.pop('min_y')
310
+ # criterion = nn.MSELoss(reduction='none')
311
+
312
+ def get_y_sample():
313
+ dl = prior(num_steps=1, batch_size=args.batch_size * args.steps_per_epoch, seq_len=args.bptt, device=device,
314
+ **args.extra_prior_kwargs_dict)
315
+ y_sample = next(iter(dl))[-1]
316
+ print(f'Creating Bar distribution with borders from y sample of size {y_sample.numel()}')
317
+ return y_sample
318
+
319
+ if loss_function == 'ce':
320
+ criterion = nn.CrossEntropyLoss(reduction='none')
321
+ elif loss_function == 'gaussnll':
322
+ criterion = nn.GaussianNLLLoss(reduction='none', full=True)
323
+ elif loss_function == 'mse':
324
+ criterion = nn.MSELoss(reduction='none')
325
+ elif loss_function == 'barnll':
326
+ criterion = BarDistribution(borders=get_bucket_limits(num_buckets, full_range=(min_y,max_y)))
327
+ elif loss_function == 'adaptivebarnll':
328
+ borders = get_bucket_limits(num_buckets, ys=get_y_sample(), full_range=(min_y,max_y))
329
+ criterion = BarDistribution(borders=borders)
330
+ elif loss_function == 'adaptivefullsupportbarnll':
331
+ assert min_y is None and max_y is None, "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
332
+ borders = get_bucket_limits(num_buckets, ys=get_y_sample())
333
+ criterion = FullSupportBarDistribution(borders=borders)
334
+ else:
335
+ raise NotImplementedError(f'loss_function == {loss_function}.')
336
+
337
+
338
+
339
+ encoder = args.__dict__.pop('encoder')
340
+ y_encoder = args.__dict__.pop('y_encoder')
341
+
342
+ def get_encoder_generator(encoder):
343
+ if encoder == 'linear':
344
+ encoder_generator = encoders.Linear
345
+ elif encoder == 'mlp':
346
+ encoder_generator = encoders.MLP
347
+ elif encoder == 'positional':
348
+ encoder_generator = encoders.Positional
349
+ else:
350
+ raise NotImplementedError(f'A {encoder} encoder is not valid.')
351
+ return encoder_generator
352
+
353
+ encoder_generator = get_encoder_generator(encoder)
354
+ y_encoder_generator = get_encoder_generator(y_encoder)
355
+
356
+ pos_encoder = args.__dict__.pop('pos_encoder')
357
+
358
+ if pos_encoder == 'none':
359
+ pos_encoder_generator = None
360
+ elif pos_encoder == 'sinus':
361
+ pos_encoder_generator = positional_encodings.PositionalEncoding
362
+ elif pos_encoder == 'learned':
363
+ pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
364
+ elif pos_encoder == 'paired_scrambled_learned':
365
+ pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings
366
+ else:
367
+ raise NotImplementedError(f'pos_encoer == {pos_encoder} is not valid.')
368
+
369
+ permutation_invariant_max_eval_pos = args.__dict__.pop('permutation_invariant_max_eval_pos')
370
+ permutation_invariant_sampling = args.__dict__.pop('permutation_invariant_sampling')
371
+ if permutation_invariant_max_eval_pos is not None:
372
+ if permutation_invariant_sampling == 'weighted':
373
+ get_sampler = get_weighted_single_eval_pos_sampler
374
+ elif permutation_invariant_sampling == 'uniform':
375
+ get_sampler = get_uniform_single_eval_pos_sampler
376
+ else:
377
+ raise ValueError()
378
+ args.__dict__['single_eval_pos_gen'] = get_sampler(permutation_invariant_max_eval_pos)
379
+
380
+
381
+ print("ARGS for `train`:", args.__dict__)
382
+
383
+ train(prior, criterion, encoder_generator,
384
+ y_encoder_generator=y_encoder_generator, pos_encoder_generator=pos_encoder_generator,
385
+ **args.__dict__)
386
+
TabPFN/transformer.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from torch.nn import Module, TransformerEncoder
8
+
9
+ from layer import TransformerEncoderLayer, _get_activation_fn
10
+ from utils import SeqBN, bool_mask_to_att_mask
11
+
12
+
13
+
14
+ class TransformerModel(nn.Module):
15
+ def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
16
+ pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
17
+ activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
18
+ all_layers_same_init=True):
19
+ super().__init__()
20
+ self.model_type = 'Transformer'
21
+ encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
22
+ pre_norm=pre_norm, recompute_attn=recompute_attn)
23
+ self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
24
+ if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
25
+ self.ninp = ninp
26
+ self.encoder = encoder
27
+ self.y_encoder = y_encoder
28
+ self.pos_encoder = pos_encoder
29
+ self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
30
+ self.input_ln = SeqBN(ninp) if input_normalization else None
31
+ self.style_encoder = style_encoder
32
+ self.init_method = init_method
33
+ if num_global_att_tokens is not None:
34
+ assert not full_attention
35
+ self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
36
+ self.full_attention = full_attention
37
+
38
+ self.n_out = n_out
39
+ self.nhid = nhid
40
+
41
+ self.init_weights()
42
+
43
+ @staticmethod
44
+ def generate_square_subsequent_mask(sz):
45
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
46
+ return bool_mask_to_att_mask(mask)
47
+
48
+ @staticmethod
49
+ def generate_D_q_matrix(sz, query_size):
50
+ train_size = sz-query_size
51
+ mask = torch.zeros(sz,sz) == 0
52
+ mask[:,train_size:].zero_()
53
+ mask |= torch.eye(sz) == 1
54
+ return bool_mask_to_att_mask(mask)
55
+
56
+ @staticmethod
57
+ def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
58
+ train_size = seq_len + num_global_att_tokens - num_query_tokens
59
+ sz = seq_len + num_global_att_tokens
60
+ mask = torch.zeros(num_query_tokens, sz) == 0
61
+ mask[:,train_size:].zero_()
62
+ mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
63
+ return bool_mask_to_att_mask(mask)
64
+
65
+ @staticmethod
66
+ def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
67
+ train_size = seq_len + num_global_att_tokens - num_query_tokens
68
+ trainset_size = seq_len - num_query_tokens
69
+ mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
70
+ #mask[:,num_global_att_tokens:].zero_()
71
+ #mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
72
+ return bool_mask_to_att_mask(mask)
73
+
74
+ @staticmethod
75
+ def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
76
+ mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
77
+ return bool_mask_to_att_mask(mask)
78
+
79
+ def init_weights(self):
80
+ initrange = 1.
81
+ # if isinstance(self.encoder,EmbeddingEncoder):
82
+ # self.encoder.weight.data.uniform_(-initrange, initrange)
83
+ # self.decoder.bias.data.zero_()
84
+ # self.decoder.weight.data.uniform_(-initrange, initrange)
85
+ if self.init_method is not None:
86
+ self.apply(self.init_method)
87
+ for layer in self.transformer_encoder.layers:
88
+ nn.init.zeros_(layer.linear2.weight)
89
+ nn.init.zeros_(layer.linear2.bias)
90
+ attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
91
+ for attn in attns:
92
+ nn.init.zeros_(attn.out_proj.weight)
93
+ nn.init.zeros_(attn.out_proj.bias)
94
+
95
+ def forward(self, src, src_mask=None, single_eval_pos=None):
96
+ assert isinstance(src, tuple), 'fuse_x_y is forbidden, that is inputs have to be given as (x,y) or (style,x,y)'
97
+
98
+ if len(src) == 2:
99
+ src = (None,) + src
100
+
101
+ style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
102
+ if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
103
+ if src_mask is None:
104
+ x_src = src[1]
105
+ if self.global_att_embeddings is None:
106
+ full_len = len(x_src) + style_src_size
107
+ if self.full_attention:
108
+ src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
109
+ else:
110
+ src_mask = self.generate_D_q_matrix(len(x_src) + style_src_size, len(x_src) + style_src_size -single_eval_pos).to(x_src.device)
111
+ else:
112
+ src_mask_args = (self.global_att_embeddings.num_embeddings,
113
+ len(x_src) + style_src_size,
114
+ len(x_src) + style_src_size - single_eval_pos)
115
+ src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
116
+ self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
117
+ self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
118
+
119
+ style_src, x_src, y_src = src
120
+ x_src = self.encoder(x_src)
121
+ y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
122
+ style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
123
+ global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
124
+ self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
125
+ train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
126
+ src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
127
+
128
+ if self.input_ln is not None:
129
+ src = self.input_ln(src)
130
+
131
+ if self.pos_encoder is not None:
132
+ src = self.pos_encoder(src)
133
+
134
+ # If we have style input, drop its output
135
+ output = self.transformer_encoder(src, src_mask)[style_src_size:]
136
+ output = self.decoder(output)
137
+ return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
138
+
139
+ @torch.no_grad()
140
+ def init_from_small_model(self, small_model):
141
+ assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
142
+ and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
143
+
144
+ def set_encoder_weights(my_encoder, small_model_encoder):
145
+ my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
146
+ if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
147
+ small_in_dim = small_encoder_linear.out_features
148
+ my_encoder_linear.weight.zero_()
149
+ my_encoder_linear.bias.zero_()
150
+ my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
151
+ my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
152
+
153
+ set_encoder_weights(self.encoder, small_model.encoder)
154
+ set_encoder_weights(self.y_encoder, small_model.y_encoder)
155
+
156
+ small_in_dim = small_model.decoder.in_features
157
+
158
+ self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
159
+ self.decoder.bias = small_model.decoder.bias
160
+
161
+ for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
162
+ small_hid_dim = small_layer.linear1.out_features
163
+ my_in_dim = my_layer.linear1.in_features
164
+
165
+ # packed along q,k,v order in first dim
166
+ my_in_proj_w = my_layer.self_attn.in_proj_weight
167
+ small_in_proj_w = small_layer.self_attn.in_proj_weight
168
+
169
+ my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
170
+ small_in_dim,
171
+ small_in_dim)
172
+ my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
173
+ :small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
174
+
175
+ my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
176
+ my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
177
+
178
+ my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
179
+ my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
180
+
181
+ my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
182
+ my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
183
+
184
+ my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
185
+ my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
186
+
187
+ my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
188
+ my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
189
+
190
+
191
+ class TransformerEncoderDiffInit(Module):
192
+ r"""TransformerEncoder is a stack of N encoder layers
193
+
194
+ Args:
195
+ encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
196
+ num_layers: the number of sub-encoder-layers in the encoder (required).
197
+ norm: the layer normalization component (optional).
198
+ """
199
+ __constants__ = ['norm']
200
+
201
+ def __init__(self, encoder_layer_creator, num_layers, norm=None):
202
+ super().__init__()
203
+ self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
204
+ self.num_layers = num_layers
205
+ self.norm = norm
206
+
207
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
208
+ r"""Pass the input through the encoder layers in turn.
209
+
210
+ Args:
211
+ src: the sequence to the encoder (required).
212
+ mask: the mask for the src sequence (optional).
213
+ src_key_padding_mask: the mask for the src keys per batch (optional).
214
+
215
+ Shape:
216
+ see the docs in Transformer class.
217
+ """
218
+ output = src
219
+
220
+ for mod in self.layers:
221
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
222
+
223
+ if self.norm is not None:
224
+ output = self.norm(output)
225
+
226
+ return output
TabPFN/utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import argparse
4
+ import random
5
+ import datetime
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ import numpy as np
11
+
12
+ # copied from huggingface
13
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
14
+ """ Create a schedule with a learning rate that decreases following the
15
+ values of the cosine function between 0 and `pi * cycles` after a warmup
16
+ period during which it increases linearly between 0 and 1.
17
+ """
18
+
19
+ def lr_lambda(current_step):
20
+ if current_step < num_warmup_steps:
21
+ return float(current_step) / float(max(1, num_warmup_steps))
22
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
23
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
24
+
25
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
26
+
27
+ # copied from huggingface
28
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
29
+ """
30
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
31
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
32
+
33
+ Args:
34
+ optimizer (:class:`~torch.optim.Optimizer`):
35
+ The optimizer for which to schedule the learning rate.
36
+ num_warmup_steps (:obj:`int`):
37
+ The number of steps for the warmup phase.
38
+ num_training_steps (:obj:`int`):
39
+ The total number of training steps.
40
+ last_epoch (:obj:`int`, `optional`, defaults to -1):
41
+ The index of the last epoch when resuming training.
42
+
43
+ Return:
44
+ :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
45
+ """
46
+
47
+ def lr_lambda(current_step: int):
48
+ if current_step < num_warmup_steps:
49
+ return float(current_step) / float(max(1, num_warmup_steps))
50
+ return max(
51
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
52
+ )
53
+
54
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
55
+
56
+
57
+ def get_openai_lr(transformer_model):
58
+ num_params = sum(p.numel() for p in transformer_model.parameters())
59
+ return 0.003239 - 0.0001395 * math.log(num_params)
60
+
61
+
62
+ def get_weighted_single_eval_pos_sampler(max_len):
63
+ """
64
+ This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p,
65
+ where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
66
+ :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
67
+ """
68
+ return lambda: random.choices(range(max_len), [1 / (max_len - i) for i in range(max_len)])[0]
69
+
70
+
71
+ def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
72
+ """
73
+ Just sample any evaluation position with the same weight
74
+ :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
75
+ """
76
+ return lambda: random.choices(range(min_len, max_len))[0]
77
+
78
+
79
+ class SeqBN(nn.Module):
80
+ def __init__(self, d_model):
81
+ super().__init__()
82
+ self.bn = nn.BatchNorm1d(d_model)
83
+ self.d_model = d_model
84
+
85
+ def forward(self, x):
86
+ assert self.d_model == x.shape[-1]
87
+ flat_x = x.view(-1, self.d_model)
88
+ flat_x = self.bn(flat_x)
89
+ return flat_x.view(*x.shape)
90
+
91
+
92
+ def set_locals_in_self(locals):
93
+ self = locals['self']
94
+ for var_name, val in locals.items():
95
+ if var_name != 'self': setattr(self, var_name, val)
96
+
97
+
98
+ default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'
99
+
100
+
101
+ # Copied from StackOverflow, but we do an eval on the values additionally
102
+ class StoreDictKeyPair(argparse.Action):
103
+ def __init__(self, option_strings, dest, nargs=None, **kwargs):
104
+ self._nargs = nargs
105
+ super(StoreDictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
106
+
107
+ def __call__(self, parser, namespace, values, option_string=None):
108
+ my_dict = {}
109
+ for kv in values:
110
+ k, v = kv.split("=")
111
+ try:
112
+ my_dict[k] = eval(v)
113
+ except NameError:
114
+ my_dict[k] = v
115
+ setattr(namespace, self.dest, my_dict)
116
+ print("dict values: {}".format(my_dict))
117
+
118
+ def get_nan_value(v, set_value_to_nan=0.0):
119
+ if random.random() < set_value_to_nan:
120
+ return v
121
+ else:
122
+ return random.choice([-999, 0, 1, 999])
123
+
124
+ def to_ranking(data):
125
+ x = (data >= data.unsqueeze(-3))
126
+ x = x.sum(0)
127
+ return x
128
+ # TODO: Is there a better way to do this?
129
+ # 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
130
+ # 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
131
+ # 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast?
132
+ def to_ranking_low_mem(data):
133
+ x = torch.zeros_like(data)
134
+ for col in range(data.shape[-1]):
135
+ x_ = (data[:, :, col] >= data[:, :, col].unsqueeze(-2))
136
+ x_ = x_.sum(0)
137
+ x[:, :, col] = x_
138
+ return x
139
+
140
+ def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
141
+ return get_nan_value(float('nan'), set_value_to_nan)
142
+
143
+ def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
144
+ return get_nan_value(float('-inf'), set_value_to_nan)
145
+
146
+ def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
147
+ return get_nan_value(float('inf'), set_value_to_nan)
148
+
149
+ def torch_nanmean(x, axis=0):
150
+ num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
151
+ value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
152
+ return value / num
153
+
154
+ def torch_nanstd(x, axis=0):
155
+ num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
156
+ value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
157
+ mean = value / num
158
+ mean_broadcast = torch.repeat_interleave(mean.unsqueeze(axis), x.shape[axis], dim=axis)
159
+ return torch.sqrt(torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1))
160
+
161
+ def normalize_data(data, normalize_positions=-1):
162
+ if normalize_positions > 0:
163
+ mean = torch_nanmean(data[:normalize_positions], axis=0)
164
+ std = torch_nanstd(data[:normalize_positions], axis=0) + .000001
165
+ else:
166
+ mean = torch_nanmean(data, axis=0)
167
+ std = torch_nanstd(data, axis=0) + .000001
168
+ data = (data - mean) / std
169
+ data = torch.clip(data, min=-100, max=100)
170
+
171
+ return data
172
+
173
+ def remove_outliers(X, n_sigma=4):
174
+ # Expects T, B, H
175
+ assert len(X.shape) == 3, "X must be T,B,H"
176
+ #for b in range(X.shape[1]):
177
+ #for col in range(X.shape[2]):
178
+ data = X
179
+ data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
180
+ cut_off = data_std * n_sigma
181
+ lower, upper = data_mean - cut_off, data_mean + cut_off
182
+
183
+ data_clean = X[:].clone()
184
+ data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
185
+ data_mean, data_std = torch_nanmean(data_clean, axis=0), torch_nanstd(data_clean, axis=0)
186
+ cut_off = data_std * n_sigma
187
+ lower, upper = data_mean - cut_off, data_mean + cut_off
188
+
189
+ X = torch.maximum(-torch.log(1+torch.abs(X)) + lower, X)
190
+ X = torch.minimum(torch.log(1+torch.abs(X)) + upper, X)
191
+ # print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
192
+ return X
193
+
194
+ def bool_mask_to_att_mask(mask):
195
+ return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
196
+
197
+ def print_on_master_only(is_master):
198
+ import builtins as __builtin__
199
+
200
+ builtin_print = __builtin__.print
201
+
202
+ def print(*args, **kwargs):
203
+ force = kwargs.pop("force", False)
204
+ if is_master or force:
205
+ builtin_print(*args, **kwargs)
206
+
207
+ __builtin__.print = print
208
+
209
+ def init_dist(device):
210
+ if 'SLURM_PROCID' in os.environ and torch.cuda.device_count() > 1:
211
+ assert device != 'cpu:0'
212
+ rank = int(os.environ['SLURM_PROCID'])
213
+ os.environ['MASTER_ADDR'] = 'localhost'
214
+ os.environ['MASTER_PORT'] = '12355'
215
+ torch.cuda.set_device(rank)
216
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
217
+ torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
218
+ world_size=torch.cuda.device_count(), rank=rank)
219
+ torch.distributed.barrier()
220
+ print_on_master_only(rank == 0)
221
+ print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
222
+ "only I can print, but when using print(..., force=True) it will print on all ranks.")
223
+
224
+ return True, rank, f'cuda:{rank}'
225
+ else:
226
+ print('Not using distributed')
227
+ # will not change any of the behavior of print, but allows putting the force=True in the print calls
228
+ print_on_master_only(True)
229
+ return False, 0, device
230
+
231
+ # NOP function for python with statements (x = NOP(); with x:)
232
+ class NOP():
233
+ def __enter__(self):
234
+ pass
235
+ def __exit__(self, type, value, traceback):
236
+ pass
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ tabpfn_path = 'TabPFN'
3
+ sys.path.insert(0, tabpfn_path) # our submodule of the TabPFN repo (at 045c8400203ebd062346970b4f2c0ccda5a40618)
4
+ from TabPFN.scripts.transformer_prediction_interface import TabPFNClassifier
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import gradio as gr
10
+ import openml
11
+
12
+
13
+ def compute(table: np.array):
14
+ vfunc = np.vectorize(lambda s: len(s))
15
+ non_empty_row_mask = (vfunc(table).sum(1) != 0)
16
+ table = table[non_empty_row_mask]
17
+ empty_mask = table == ''
18
+ empty_inds = np.where(empty_mask)
19
+ if not len(empty_inds[0]):
20
+ return "**Please leave at least one field blank for prediction.**", None
21
+ if not np.all(empty_inds[1][0] == empty_inds[1]):
22
+ return "**Please only leave fields of one column blank for prediction.**", None
23
+ y_column = empty_inds[1][0]
24
+ eval_lines = empty_inds[0]
25
+
26
+ train_table = np.delete(table, eval_lines, axis=0)
27
+ eval_table = table[eval_lines]
28
+
29
+ try:
30
+ x_train = torch.tensor(np.delete(train_table, y_column, axis=1).astype(np.float32))
31
+ x_eval = torch.tensor(np.delete(eval_table, y_column, axis=1).astype(np.float32))
32
+
33
+ y_train = train_table[:, y_column]
34
+ except ValueError:
35
+ return "**Please only add numbers (to the inputs) or leave fields empty.**", None
36
+
37
+ classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu')
38
+ classifier.fit(x_train, y_train)
39
+ y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True)
40
+
41
+ # print(file, type(file))
42
+ out_table = table.copy().astype(str)
43
+ out_table[eval_lines, y_column] = [f"{y_e} (p={p_e:.2f})" for y_e, p_e in zip(y_eval, p_eval)]
44
+ return None, out_table
45
+
46
+
47
+ def upload_file(file):
48
+ if file.name.endswith('.arff'):
49
+ dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name)
50
+ X_, _, categorical_indicator_, attribute_names_ = dataset.get_data(
51
+ dataset_format="array"
52
+ )
53
+ df = pd.DataFrame(X_, columns=attribute_names_)
54
+ return df
55
+ elif file.name.endswith('.csv') or file.name.endswith('.data'):
56
+ df = pd.read_csv(file.name, header=None)
57
+ df.columns = np.arange(len(df.columns))
58
+ print(df)
59
+ return df
60
+
61
+
62
+ example = \
63
+ [
64
+ [1, 2, 1],
65
+ [2, 1, 1],
66
+ [1, 1, 1],
67
+ [2, 2, 2],
68
+ [3, 4, 2],
69
+ [3, 2, 2],
70
+ [2, 3, '']
71
+ ]
72
+
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("""This demo allows you to play with the **TabPFN**.
75
+ You can either change the table manually (we have filled it with a toy benchmark, sum up to 3 has label 1 and over that label 2).
76
+ The network predicts fields you leave empty. Only one column can have empty entries that are predicted.
77
+ Please, provide everything but the label column as numeric values. It is ok to encode classes as integers.
78
+ """)
79
+ inp_table = gr.DataFrame(type='numpy', value=example, headers=[''] * 3)
80
+ inp_file = gr.File(
81
+ label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')
82
+ examples = gr.Examples(examples=['iris.csv', 'balance-scale.arff'],
83
+ inputs=[inp_file],
84
+ outputs=[inp_table],
85
+ fn=upload_file,
86
+ cache_examples=True)
87
+ btn = gr.Button("Predict Empty Table Cells")
88
+
89
+ inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)
90
+
91
+ out_text = gr.Markdown()
92
+ out_table = gr.DataFrame()
93
+
94
+ btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table])
95
+
96
+ demo.launch()
balance-scale.arff ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %1. Title: Balance Scale Weight & Distance Database
2
+ %
3
+ %2. Source Information:
4
+ % (a) Source: Generated to model psychological experiments reported
5
+ % by Siegler, R. S. (1976). Three Aspects of Cognitive
6
+ % Development. Cognitive Psychology, 8, 481-520.
7
+ % (b) Donor: Tim Hume ([email protected])
8
+ % (c) Date: 22 April 1994
9
+ %
10
+ %3. Past Usage: (possibly different formats of this data)
11
+ % - Publications
12
+ % 1. Klahr, D., & Siegler, R.S. (1978). The Representation of
13
+ % Children's Knowledge. In H. W. Reese & L. P. Lipsitt (Eds.),
14
+ % Advances in Child Development and Behavior, pp. 61-116. New
15
+ % York: Academic Press
16
+ % 2. Langley,P. (1987). A General Theory of Discrimination
17
+ % Learning. In D. Klahr, P. Langley, & R. Neches (Eds.),
18
+ % Production System Models of Learning and Development, pp.
19
+ % 99-161. Cambridge, MA: MIT Press
20
+ % 3. Newell, A. (1990). Unified Theories of Cognition.
21
+ % Cambridge, MA: Harvard University Press
22
+ % 4. McClelland, J.L. (1988). Parallel Distibuted Processing:
23
+ % Implications for Cognition and Development. Technical
24
+ % Report AIP-47, Department of Psychology, Carnegie-Mellon
25
+ % University
26
+ % 5. Shultz, T., Mareschal, D., & Schmidt, W. (1994). Modeling
27
+ % Cognitive Development on Balance Scale Phenomena. Machine
28
+ % Learning, Vol. 16, pp. 59-88.
29
+ %
30
+ %4. Relevant Information:
31
+ % This data set was generated to model psychological
32
+ % experimental results. Each example is classified as having the
33
+ % balance scale tip to the right, tip to the left, or be
34
+ % balanced. The attributes are the left weight, the left
35
+ % distance, the right weight, and the right distance. The
36
+ % correct way to find the class is the greater of
37
+ % (left-distance * left-weight) and (right-distance *
38
+ % right-weight). If they are equal, it is balanced.
39
+ %
40
+ %5. Number of Instances: 625 (49 balanced, 288 left, 288 right)
41
+ %
42
+ %6. Number of Attributes: 4 (numeric) + class name = 5
43
+ %
44
+ %7. Attribute Information:
45
+ % 1. Class Name: 3 (L, B, R)
46
+ % 2. Left-Weight: 5 (1, 2, 3, 4, 5)
47
+ % 3. Left-Distance: 5 (1, 2, 3, 4, 5)
48
+ % 4. Right-Weight: 5 (1, 2, 3, 4, 5)
49
+ % 5. Right-Distance: 5 (1, 2, 3, 4, 5)
50
+ %
51
+ %8. Missing Attribute Values:
52
+ % none
53
+ %
54
+ %9. Class Distribution:
55
+ % 1. 46.08 percent are L
56
+ % 2. 07.84 percent are B
57
+ % 3. 46.08 percent are R
58
+ %
59
+
60
+ @relation balance-scale
61
+ @attribute 'left-weight' real
62
+ @attribute 'left-distance' real
63
+ @attribute 'right-weight' real
64
+ @attribute 'right-distance' real
65
+ @attribute 'class' { L, B, R}
66
+ @data
67
+ 1,1,1,1,B
68
+ 1,1,1,2,R
69
+ 1,1,1,3,R
70
+ 1,1,1,4,R
71
+ 1,1,1,5,R
72
+ 1,1,2,1,R
73
+ 1,1,2,2,R
74
+ 1,1,2,3,R
75
+ 1,1,2,4,R
76
+ 1,1,2,5,R
77
+ 1,1,3,1,R
78
+ 1,1,3,2,R
79
+ 1,1,3,3,R
80
+ 1,1,3,4,R
81
+ 1,1,3,5,R
82
+ 1,1,4,1,R
83
+ 1,1,4,2,R
84
+ 1,1,4,3,R
85
+ 1,1,4,4,R
86
+ 1,1,4,5,R
87
+ 1,1,5,1,R
88
+ 1,1,5,2,R
89
+ 1,1,5,3,R
90
+ 1,1,5,4,R
91
+ 1,1,5,5,R
92
+ 1,2,1,1,L
93
+ 1,2,1,2,B
94
+ 1,2,1,3,R
95
+ 1,2,1,4,R
96
+ 1,2,1,5,R
97
+ 1,2,2,1,B
98
+ 1,2,2,2,R
99
+ 1,2,2,3,R
100
+ 1,2,2,4,R
101
+ 1,2,2,5,R
102
+ 1,2,3,1,R
103
+ 1,2,3,2,R
104
+ 1,2,3,3,R
105
+ 1,2,3,4,R
106
+ 1,2,3,5,R
107
+ 1,2,4,1,R
108
+ 1,2,4,2,R
109
+ 1,2,4,3,R
110
+ 1,2,4,4,R
111
+ 1,2,4,5,R
112
+ 1,2,5,1,R
113
+ 1,2,5,2,R
114
+ 1,2,5,3,R
115
+ 1,2,5,4,R
116
+ 1,2,5,5,R
117
+ 1,3,1,1,L
118
+ 1,3,1,2,L
119
+ 1,3,1,3,B
120
+ 1,3,1,4,R
121
+ 1,3,1,5,R
122
+ 1,3,2,1,L
123
+ 1,3,2,2,R
124
+ 1,3,2,3,R
125
+ 1,3,2,4,R
126
+ 1,3,2,5,R
127
+ 1,3,3,1,B
128
+ 1,3,3,2,R
129
+ 1,3,3,3,R
130
+ 1,3,3,4,R
131
+ 1,3,3,5,R
132
+ 1,3,4,1,R
133
+ 1,3,4,2,R
134
+ 1,3,4,3,R
135
+ 1,3,4,4,R
136
+ 1,3,4,5,R
137
+ 1,3,5,1,R
138
+ 1,3,5,2,R
139
+ 1,3,5,3,R
140
+ 1,3,5,4,R
141
+ 1,3,5,5,R
142
+ 1,4,1,1,L
143
+ 1,4,1,2,L
144
+ 1,4,1,3,L
145
+ 1,4,1,4,B
146
+ 1,4,1,5,R
147
+ 1,4,2,1,L
148
+ 1,4,2,2,B
149
+ 1,4,2,3,R
150
+ 1,4,2,4,R
151
+ 1,4,2,5,R
152
+ 1,4,3,1,L
153
+ 1,4,3,2,R
154
+ 1,4,3,3,R
155
+ 1,4,3,4,R
156
+ 1,4,3,5,R
157
+ 1,4,4,1,B
158
+ 1,4,4,2,R
159
+ 1,4,4,3,R
160
+ 1,4,4,4,R
161
+ 1,4,4,5,R
162
+ 1,4,5,1,R
163
+ 1,4,5,2,R
164
+ 1,4,5,3,R
165
+ 1,4,5,4,R
166
+ 1,4,5,5,R
167
+ 1,5,1,1,L
168
+ 1,5,1,2,L
169
+ 1,5,1,3,L
170
+ 1,5,1,4,L
171
+ 1,5,1,5,B
172
+ 1,5,2,1,L
173
+ 1,5,2,2,L
174
+ 1,5,2,3,R
175
+ 1,5,2,4,R
176
+ 1,5,2,5,R
177
+ 1,5,3,1,L
178
+ 1,5,3,2,R
179
+ 1,5,3,3,R
180
+ 1,5,3,4,R
181
+ 1,5,3,5,R
182
+ 1,5,4,1,L
183
+ 1,5,4,2,R
184
+ 1,5,4,3,R
185
+ 1,5,4,4,R
186
+ 1,5,4,5,R
187
+ 1,5,5,1,B
188
+ 1,5,5,2,R
189
+ 1,5,5,3,R
190
+ 1,5,5,4,R
191
+ 1,5,5,5,R
192
+ 2,1,1,1,L
193
+ 2,1,1,2,B
194
+ 2,1,1,3,R
195
+ 2,1,1,4,R
196
+ 2,1,1,5,R
197
+ 2,1,2,1,B
198
+ 2,1,2,2,R
199
+ 2,1,2,3,R
200
+ 2,1,2,4,R
201
+ 2,1,2,5,R
202
+ 2,1,3,1,R
203
+ 2,1,3,2,R
204
+ 2,1,3,3,R
205
+ 2,1,3,4,R
206
+ 2,1,3,5,R
207
+ 2,1,4,1,R
208
+ 2,1,4,2,R
209
+ 2,1,4,3,R
210
+ 2,1,4,4,R
211
+ 2,1,4,5,R
212
+ 2,1,5,1,R
213
+ 2,1,5,2,R
214
+ 2,1,5,3,R
215
+ 2,1,5,4,R
216
+ 2,1,5,5,R
217
+ 2,2,1,1,L
218
+ 2,2,1,2,L
219
+ 2,2,1,3,L
220
+ 2,2,1,4,B
221
+ 2,2,1,5,R
222
+ 2,2,2,1,L
223
+ 2,2,2,2,B
224
+ 2,2,2,3,R
225
+ 2,2,2,4,R
226
+ 2,2,2,5,R
227
+ 2,2,3,1,L
228
+ 2,2,3,2,R
229
+ 2,2,3,3,R
230
+ 2,2,3,4,R
231
+ 2,2,3,5,R
232
+ 2,2,4,1,B
233
+ 2,2,4,2,R
234
+ 2,2,4,3,R
235
+ 2,2,4,4,R
236
+ 2,2,4,5,R
237
+ 2,2,5,1,R
238
+ 2,2,5,2,R
239
+ 2,2,5,3,R
240
+ 2,2,5,4,R
241
+ 2,2,5,5,R
242
+ 2,3,1,1,L
243
+ 2,3,1,2,L
244
+ 2,3,1,3,L
245
+ 2,3,1,4,L
246
+ 2,3,1,5,L
247
+ 2,3,2,1,L
248
+ 2,3,2,2,L
249
+ 2,3,2,3,B
250
+ 2,3,2,4,R
251
+ 2,3,2,5,R
252
+ 2,3,3,1,L
253
+ 2,3,3,2,B
254
+ 2,3,3,3,R
255
+ 2,3,3,4,R
256
+ 2,3,3,5,R
257
+ 2,3,4,1,L
258
+ 2,3,4,2,R
259
+ 2,3,4,3,R
260
+ 2,3,4,4,R
261
+ 2,3,4,5,R
262
+ 2,3,5,1,L
263
+ 2,3,5,2,R
264
+ 2,3,5,3,R
265
+ 2,3,5,4,R
266
+ 2,3,5,5,R
267
+ 2,4,1,1,L
268
+ 2,4,1,2,L
269
+ 2,4,1,3,L
270
+ 2,4,1,4,L
271
+ 2,4,1,5,L
272
+ 2,4,2,1,L
273
+ 2,4,2,2,L
274
+ 2,4,2,3,L
275
+ 2,4,2,4,B
276
+ 2,4,2,5,R
277
+ 2,4,3,1,L
278
+ 2,4,3,2,L
279
+ 2,4,3,3,R
280
+ 2,4,3,4,R
281
+ 2,4,3,5,R
282
+ 2,4,4,1,L
283
+ 2,4,4,2,B
284
+ 2,4,4,3,R
285
+ 2,4,4,4,R
286
+ 2,4,4,5,R
287
+ 2,4,5,1,L
288
+ 2,4,5,2,R
289
+ 2,4,5,3,R
290
+ 2,4,5,4,R
291
+ 2,4,5,5,R
292
+ 2,5,1,1,L
293
+ 2,5,1,2,L
294
+ 2,5,1,3,L
295
+ 2,5,1,4,L
296
+ 2,5,1,5,L
297
+ 2,5,2,1,L
298
+ 2,5,2,2,L
299
+ 2,5,2,3,L
300
+ 2,5,2,4,L
301
+ 2,5,2,5,B
302
+ 2,5,3,1,L
303
+ 2,5,3,2,L
304
+ 2,5,3,3,L
305
+ 2,5,3,4,R
306
+ 2,5,3,5,R
307
+ 2,5,4,1,L
308
+ 2,5,4,2,L
309
+ 2,5,4,3,R
310
+ 2,5,4,4,R
311
+ 2,5,4,5,R
312
+ 2,5,5,1,L
313
+ 2,5,5,2,B
314
+ 2,5,5,3,R
315
+ 2,5,5,4,R
316
+ 2,5,5,5,R
317
+ 3,1,1,1,L
318
+ 3,1,1,2,L
319
+ 3,1,1,3,B
320
+ 3,1,1,4,R
321
+ 3,1,1,5,R
322
+ 3,1,2,1,L
323
+ 3,1,2,2,R
324
+ 3,1,2,3,R
325
+ 3,1,2,4,R
326
+ 3,1,2,5,R
327
+ 3,1,3,1,B
328
+ 3,1,3,2,R
329
+ 3,1,3,3,R
330
+ 3,1,3,4,R
331
+ 3,1,3,5,R
332
+ 3,1,4,1,R
333
+ 3,1,4,2,R
334
+ 3,1,4,3,R
335
+ 3,1,4,4,R
336
+ 3,1,4,5,R
337
+ 3,1,5,1,R
338
+ 3,1,5,2,R
339
+ 3,1,5,3,R
340
+ 3,1,5,4,R
341
+ 3,1,5,5,R
342
+ 3,2,1,1,L
343
+ 3,2,1,2,L
344
+ 3,2,1,3,L
345
+ 3,2,1,4,L
346
+ 3,2,1,5,L
347
+ 3,2,2,1,L
348
+ 3,2,2,2,L
349
+ 3,2,2,3,B
350
+ 3,2,2,4,R
351
+ 3,2,2,5,R
352
+ 3,2,3,1,L
353
+ 3,2,3,2,B
354
+ 3,2,3,3,R
355
+ 3,2,3,4,R
356
+ 3,2,3,5,R
357
+ 3,2,4,1,L
358
+ 3,2,4,2,R
359
+ 3,2,4,3,R
360
+ 3,2,4,4,R
361
+ 3,2,4,5,R
362
+ 3,2,5,1,L
363
+ 3,2,5,2,R
364
+ 3,2,5,3,R
365
+ 3,2,5,4,R
366
+ 3,2,5,5,R
367
+ 3,3,1,1,L
368
+ 3,3,1,2,L
369
+ 3,3,1,3,L
370
+ 3,3,1,4,L
371
+ 3,3,1,5,L
372
+ 3,3,2,1,L
373
+ 3,3,2,2,L
374
+ 3,3,2,3,L
375
+ 3,3,2,4,L
376
+ 3,3,2,5,R
377
+ 3,3,3,1,L
378
+ 3,3,3,2,L
379
+ 3,3,3,3,B
380
+ 3,3,3,4,R
381
+ 3,3,3,5,R
382
+ 3,3,4,1,L
383
+ 3,3,4,2,L
384
+ 3,3,4,3,R
385
+ 3,3,4,4,R
386
+ 3,3,4,5,R
387
+ 3,3,5,1,L
388
+ 3,3,5,2,R
389
+ 3,3,5,3,R
390
+ 3,3,5,4,R
391
+ 3,3,5,5,R
392
+ 3,4,1,1,L
393
+ 3,4,1,2,L
394
+ 3,4,1,3,L
395
+ 3,4,1,4,L
396
+ 3,4,1,5,L
397
+ 3,4,2,1,L
398
+ 3,4,2,2,L
399
+ 3,4,2,3,L
400
+ 3,4,2,4,L
401
+ 3,4,2,5,L
402
+ 3,4,3,1,L
403
+ 3,4,3,2,L
404
+ 3,4,3,3,L
405
+ 3,4,3,4,B
406
+ 3,4,3,5,R
407
+ 3,4,4,1,L
408
+ 3,4,4,2,L
409
+ 3,4,4,3,B
410
+ 3,4,4,4,R
411
+ 3,4,4,5,R
412
+ 3,4,5,1,L
413
+ 3,4,5,2,L
414
+ 3,4,5,3,R
415
+ 3,4,5,4,R
416
+ 3,4,5,5,R
417
+ 3,5,1,1,L
418
+ 3,5,1,2,L
419
+ 3,5,1,3,L
420
+ 3,5,1,4,L
421
+ 3,5,1,5,L
422
+ 3,5,2,1,L
423
+ 3,5,2,2,L
424
+ 3,5,2,3,L
425
+ 3,5,2,4,L
426
+ 3,5,2,5,L
427
+ 3,5,3,1,L
428
+ 3,5,3,2,L
429
+ 3,5,3,3,L
430
+ 3,5,3,4,L
431
+ 3,5,3,5,B
432
+ 3,5,4,1,L
433
+ 3,5,4,2,L
434
+ 3,5,4,3,L
435
+ 3,5,4,4,R
436
+ 3,5,4,5,R
437
+ 3,5,5,1,L
438
+ 3,5,5,2,L
439
+ 3,5,5,3,B
440
+ 3,5,5,4,R
441
+ 3,5,5,5,R
442
+ 4,1,1,1,L
443
+ 4,1,1,2,L
444
+ 4,1,1,3,L
445
+ 4,1,1,4,B
446
+ 4,1,1,5,R
447
+ 4,1,2,1,L
448
+ 4,1,2,2,B
449
+ 4,1,2,3,R
450
+ 4,1,2,4,R
451
+ 4,1,2,5,R
452
+ 4,1,3,1,L
453
+ 4,1,3,2,R
454
+ 4,1,3,3,R
455
+ 4,1,3,4,R
456
+ 4,1,3,5,R
457
+ 4,1,4,1,B
458
+ 4,1,4,2,R
459
+ 4,1,4,3,R
460
+ 4,1,4,4,R
461
+ 4,1,4,5,R
462
+ 4,1,5,1,R
463
+ 4,1,5,2,R
464
+ 4,1,5,3,R
465
+ 4,1,5,4,R
466
+ 4,1,5,5,R
467
+ 4,2,1,1,L
468
+ 4,2,1,2,L
469
+ 4,2,1,3,L
470
+ 4,2,1,4,L
471
+ 4,2,1,5,L
472
+ 4,2,2,1,L
473
+ 4,2,2,2,L
474
+ 4,2,2,3,L
475
+ 4,2,2,4,B
476
+ 4,2,2,5,R
477
+ 4,2,3,1,L
478
+ 4,2,3,2,L
479
+ 4,2,3,3,R
480
+ 4,2,3,4,R
481
+ 4,2,3,5,R
482
+ 4,2,4,1,L
483
+ 4,2,4,2,B
484
+ 4,2,4,3,R
485
+ 4,2,4,4,R
486
+ 4,2,4,5,R
487
+ 4,2,5,1,L
488
+ 4,2,5,2,R
489
+ 4,2,5,3,R
490
+ 4,2,5,4,R
491
+ 4,2,5,5,R
492
+ 4,3,1,1,L
493
+ 4,3,1,2,L
494
+ 4,3,1,3,L
495
+ 4,3,1,4,L
496
+ 4,3,1,5,L
497
+ 4,3,2,1,L
498
+ 4,3,2,2,L
499
+ 4,3,2,3,L
500
+ 4,3,2,4,L
501
+ 4,3,2,5,L
502
+ 4,3,3,1,L
503
+ 4,3,3,2,L
504
+ 4,3,3,3,L
505
+ 4,3,3,4,B
506
+ 4,3,3,5,R
507
+ 4,3,4,1,L
508
+ 4,3,4,2,L
509
+ 4,3,4,3,B
510
+ 4,3,4,4,R
511
+ 4,3,4,5,R
512
+ 4,3,5,1,L
513
+ 4,3,5,2,L
514
+ 4,3,5,3,R
515
+ 4,3,5,4,R
516
+ 4,3,5,5,R
517
+ 4,4,1,1,L
518
+ 4,4,1,2,L
519
+ 4,4,1,3,L
520
+ 4,4,1,4,L
521
+ 4,4,1,5,L
522
+ 4,4,2,1,L
523
+ 4,4,2,2,L
524
+ 4,4,2,3,L
525
+ 4,4,2,4,L
526
+ 4,4,2,5,L
527
+ 4,4,3,1,L
528
+ 4,4,3,2,L
529
+ 4,4,3,3,L
530
+ 4,4,3,4,L
531
+ 4,4,3,5,L
532
+ 4,4,4,1,L
533
+ 4,4,4,2,L
534
+ 4,4,4,3,L
535
+ 4,4,4,4,B
536
+ 4,4,4,5,R
537
+ 4,4,5,1,L
538
+ 4,4,5,2,L
539
+ 4,4,5,3,L
540
+ 4,4,5,4,R
541
+ 4,4,5,5,R
542
+ 4,5,1,1,L
543
+ 4,5,1,2,L
544
+ 4,5,1,3,L
545
+ 4,5,1,4,L
546
+ 4,5,1,5,L
547
+ 4,5,2,1,L
548
+ 4,5,2,2,L
549
+ 4,5,2,3,L
550
+ 4,5,2,4,L
551
+ 4,5,2,5,L
552
+ 4,5,3,1,L
553
+ 4,5,3,2,L
554
+ 4,5,3,3,L
555
+ 4,5,3,4,L
556
+ 4,5,3,5,L
557
+ 4,5,4,1,L
558
+ 4,5,4,2,L
559
+ 4,5,4,3,L
560
+ 4,5,4,4,L
561
+ 4,5,4,5,B
562
+ 4,5,5,1,L
563
+ 4,5,5,2,L
564
+ 4,5,5,3,L
565
+ 4,5,5,4,B
566
+ 4,5,5,5,R
567
+ 5,1,1,1,L
568
+ 5,1,1,2,L
569
+ 5,1,1,3,L
570
+ 5,1,1,4,L
571
+ 5,1,1,5,B
572
+ 5,1,2,1,L
573
+ 5,1,2,2,L
574
+ 5,1,2,3,R
575
+ 5,1,2,4,R
576
+ 5,1,2,5,R
577
+ 5,1,3,1,L
578
+ 5,1,3,2,R
579
+ 5,1,3,3,R
580
+ 5,1,3,4,R
581
+ 5,1,3,5,R
582
+ 5,1,4,1,L
583
+ 5,1,4,2,R
584
+ 5,1,4,3,R
585
+ 5,1,4,4,R
586
+ 5,1,4,5,R
587
+ 5,1,5,1,B
588
+ 5,1,5,2,R
589
+ 5,1,5,3,R
590
+ 5,1,5,4,R
591
+ 5,1,5,5,R
592
+ 5,2,1,1,L
593
+ 5,2,1,2,L
594
+ 5,2,1,3,L
595
+ 5,2,1,4,L
596
+ 5,2,1,5,L
597
+ 5,2,2,1,L
598
+ 5,2,2,2,L
599
+ 5,2,2,3,L
600
+ 5,2,2,4,L
601
+ 5,2,2,5,B
602
+ 5,2,3,1,L
603
+ 5,2,3,2,L
604
+ 5,2,3,3,L
605
+ 5,2,3,4,R
606
+ 5,2,3,5,R
607
+ 5,2,4,1,L
608
+ 5,2,4,2,L
609
+ 5,2,4,3,R
610
+ 5,2,4,4,R
611
+ 5,2,4,5,R
612
+ 5,2,5,1,L
613
+ 5,2,5,2,B
614
+ 5,2,5,3,R
615
+ 5,2,5,4,R
616
+ 5,2,5,5,R
617
+ 5,3,1,1,L
618
+ 5,3,1,2,L
619
+ 5,3,1,3,L
620
+ 5,3,1,4,L
621
+ 5,3,1,5,L
622
+ 5,3,2,1,L
623
+ 5,3,2,2,L
624
+ 5,3,2,3,L
625
+ 5,3,2,4,L
626
+ 5,3,2,5,L
627
+ 5,3,3,1,L
628
+ 5,3,3,2,L
629
+ 5,3,3,3,L
630
+ 5,3,3,4,L
631
+ 5,3,3,5,B
632
+ 5,3,4,1,L
633
+ 5,3,4,2,L
634
+ 5,3,4,3,L
635
+ 5,3,4,4,R
636
+ 5,3,4,5,R
637
+ 5,3,5,1,L
638
+ 5,3,5,2,L
639
+ 5,3,5,3,B
640
+ 5,3,5,4,R
641
+ 5,3,5,5,R
642
+ 5,4,1,1,L
643
+ 5,4,1,2,L
644
+ 5,4,1,3,L
645
+ 5,4,1,4,L
646
+ 5,4,1,5,L
647
+ 5,4,2,1,L
648
+ 5,4,2,2,L
649
+ 5,4,2,3,L
650
+ 5,4,2,4,L
651
+ 5,4,2,5,L
652
+ 5,4,3,1,L
653
+ 5,4,3,2,L
654
+ 5,4,3,3,L
655
+ 5,4,3,4,L
656
+ 5,4,3,5,L
657
+ 5,4,4,1,L
658
+ 5,4,4,2,L
659
+ 5,4,4,3,L
660
+ 5,4,4,4,L
661
+ 5,4,4,5,B
662
+ 5,4,5,1,L
663
+ 5,4,5,2,L
664
+ 5,4,5,3,L
665
+ 5,4,5,4,B
666
+ 5,4,5,5,R
667
+ 5,5,1,1,L
668
+ 5,5,1,2,L
669
+ 5,5,1,3,L
670
+ 5,5,1,4,L
671
+ 5,5,1,5,L
672
+ 5,5,2,1,L
673
+ 5,5,2,2,L
674
+ 5,5,2,3,L
675
+ 5,5,2,4,L
676
+ 5,5,2,5,L
677
+ 5,5,3,1,L
678
+ 5,5,3,2,L
679
+ 5,5,3,3,L
680
+ 5,5,3,4,L
681
+ 5,5,3,5,L
682
+ 5,5,4,1,L
683
+ 5,5,4,2,L
684
+ 5,5,4,3,L
685
+ 5,5,4,4,L
686
+ 5,5,4,5,L
687
+ 5,5,5,1,L
688
+ 5,5,5,2,L
689
+ 5,5,5,3,L
690
+ 5,5,5,4,L
691
+ 5,5,5,5,B
692
+ %
693
+ %
694
+ %
iris.csv ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 5.1,3.5,1.4,0.2,Iris-setosa
2
+ 4.9,3.0,1.4,0.2,Iris-setosa
3
+ 4.7,3.2,1.3,0.2,Iris-setosa
4
+ 4.6,3.1,1.5,0.2,Iris-setosa
5
+ 5.0,3.6,1.4,0.2,Iris-setosa
6
+ 5.4,3.9,1.7,0.4,Iris-setosa
7
+ 4.6,3.4,1.4,0.3,Iris-setosa
8
+ 5.0,3.4,1.5,0.2,Iris-setosa
9
+ 4.4,2.9,1.4,0.2,Iris-setosa
10
+ 4.9,3.1,1.5,0.1,Iris-setosa
11
+ 5.4,3.7,1.5,0.2,Iris-setosa
12
+ 4.8,3.4,1.6,0.2,Iris-setosa
13
+ 4.8,3.0,1.4,0.1,Iris-setosa
14
+ 4.3,3.0,1.1,0.1,Iris-setosa
15
+ 5.8,4.0,1.2,0.2,Iris-setosa
16
+ 5.7,4.4,1.5,0.4,Iris-setosa
17
+ 5.4,3.9,1.3,0.4,Iris-setosa
18
+ 5.1,3.5,1.4,0.3,Iris-setosa
19
+ 5.7,3.8,1.7,0.3,Iris-setosa
20
+ 5.1,3.8,1.5,0.3,Iris-setosa
21
+ 5.4,3.4,1.7,0.2,Iris-setosa
22
+ 5.1,3.7,1.5,0.4,Iris-setosa
23
+ 4.6,3.6,1.0,0.2,Iris-setosa
24
+ 5.1,3.3,1.7,0.5,Iris-setosa
25
+ 4.8,3.4,1.9,0.2,Iris-setosa
26
+ 5.0,3.0,1.6,0.2,Iris-setosa
27
+ 5.0,3.4,1.6,0.4,Iris-setosa
28
+ 5.2,3.5,1.5,0.2,Iris-setosa
29
+ 5.2,3.4,1.4,0.2,Iris-setosa
30
+ 4.7,3.2,1.6,0.2,Iris-setosa
31
+ 4.8,3.1,1.6,0.2,Iris-setosa
32
+ 5.4,3.4,1.5,0.4,Iris-setosa
33
+ 5.2,4.1,1.5,0.1,Iris-setosa
34
+ 5.5,4.2,1.4,0.2,Iris-setosa
35
+ 4.9,3.1,1.5,0.1,Iris-setosa
36
+ 5.0,3.2,1.2,0.2,Iris-setosa
37
+ 5.5,3.5,1.3,0.2,Iris-setosa
38
+ 4.9,3.1,1.5,0.1,Iris-setosa
39
+ 4.4,3.0,1.3,0.2,Iris-setosa
40
+ 5.1,3.4,1.5,0.2,Iris-setosa
41
+ 5.0,3.5,1.3,0.3,Iris-setosa
42
+ 4.5,2.3,1.3,0.3,Iris-setosa
43
+ 4.4,3.2,1.3,0.2,Iris-setosa
44
+ 5.0,3.5,1.6,0.6,Iris-setosa
45
+ 5.1,3.8,1.9,0.4,Iris-setosa
46
+ 4.8,3.0,1.4,0.3,Iris-setosa
47
+ 5.1,3.8,1.6,0.2,Iris-setosa
48
+ 4.6,3.2,1.4,0.2,Iris-setosa
49
+ 5.3,3.7,1.5,0.2,Iris-setosa
50
+ 5.0,3.3,1.4,0.2,Iris-setosa
51
+ 7.0,3.2,4.7,1.4,Iris-versicolor
52
+ 6.4,3.2,4.5,1.5,Iris-versicolor
53
+ 6.9,3.1,4.9,1.5,Iris-versicolor
54
+ 5.5,2.3,4.0,1.3,Iris-versicolor
55
+ 6.5,2.8,4.6,1.5,Iris-versicolor
56
+ 5.7,2.8,4.5,1.3,Iris-versicolor
57
+ 6.3,3.3,4.7,1.6,Iris-versicolor
58
+ 4.9,2.4,3.3,1.0,Iris-versicolor
59
+ 6.6,2.9,4.6,1.3,Iris-versicolor
60
+ 5.2,2.7,3.9,1.4,Iris-versicolor
61
+ 5.0,2.0,3.5,1.0,Iris-versicolor
62
+ 5.9,3.0,4.2,1.5,Iris-versicolor
63
+ 6.0,2.2,4.0,1.0,Iris-versicolor
64
+ 6.1,2.9,4.7,1.4,Iris-versicolor
65
+ 5.6,2.9,3.6,1.3,Iris-versicolor
66
+ 6.7,3.1,4.4,1.4,Iris-versicolor
67
+ 5.6,3.0,4.5,1.5,Iris-versicolor
68
+ 5.8,2.7,4.1,1.0,Iris-versicolor
69
+ 6.2,2.2,4.5,1.5,Iris-versicolor
70
+ 5.6,2.5,3.9,1.1,Iris-versicolor
71
+ 5.9,3.2,4.8,1.8,Iris-versicolor
72
+ 6.1,2.8,4.0,1.3,Iris-versicolor
73
+ 6.3,2.5,4.9,1.5,Iris-versicolor
74
+ 6.1,2.8,4.7,1.2,Iris-versicolor
75
+ 6.4,2.9,4.3,1.3,Iris-versicolor
76
+ 6.6,3.0,4.4,1.4,Iris-versicolor
77
+ 6.8,2.8,4.8,1.4,Iris-versicolor
78
+ 6.7,3.0,5.0,1.7,Iris-versicolor
79
+ 6.0,2.9,4.5,1.5,Iris-versicolor
80
+ 5.7,2.6,3.5,1.0,Iris-versicolor
81
+ 5.5,2.4,3.8,1.1,Iris-versicolor
82
+ 5.5,2.4,3.7,1.0,Iris-versicolor
83
+ 5.8,2.7,3.9,1.2,Iris-versicolor
84
+ 6.0,2.7,5.1,1.6,Iris-versicolor
85
+ 5.4,3.0,4.5,1.5,Iris-versicolor
86
+ 6.0,3.4,4.5,1.6,Iris-versicolor
87
+ 6.7,3.1,4.7,1.5,Iris-versicolor
88
+ 6.3,2.3,4.4,1.3,Iris-versicolor
89
+ 5.6,3.0,4.1,1.3,Iris-versicolor
90
+ 5.5,2.5,4.0,1.3,Iris-versicolor
91
+ 5.5,2.6,4.4,1.2,Iris-versicolor
92
+ 6.1,3.0,4.6,1.4,Iris-versicolor
93
+ 5.8,2.6,4.0,1.2,Iris-versicolor
94
+ 5.0,2.3,3.3,1.0,Iris-versicolor
95
+ 5.6,2.7,4.2,1.3,Iris-versicolor
96
+ 5.7,3.0,4.2,1.2,Iris-versicolor
97
+ 5.7,2.9,4.2,1.3,Iris-versicolor
98
+ 6.2,2.9,4.3,1.3,Iris-versicolor
99
+ 5.1,2.5,3.0,1.1,Iris-versicolor
100
+ 5.7,2.8,4.1,1.3,Iris-versicolor
101
+ 6.3,3.3,6.0,2.5,Iris-virginica
102
+ 5.8,2.7,5.1,1.9,Iris-virginica
103
+ 7.1,3.0,5.9,2.1,Iris-virginica
104
+ 6.3,2.9,5.6,1.8,Iris-virginica
105
+ 6.5,3.0,5.8,2.2,Iris-virginica
106
+ 7.6,3.0,6.6,2.1,Iris-virginica
107
+ 4.9,2.5,4.5,1.7,Iris-virginica
108
+ 7.3,2.9,6.3,1.8,Iris-virginica
109
+ 6.7,2.5,5.8,1.8,Iris-virginica
110
+ 7.2,3.6,6.1,2.5,Iris-virginica
111
+ 6.5,3.2,5.1,2.0,Iris-virginica
112
+ 6.4,2.7,5.3,1.9,Iris-virginica
113
+ 6.8,3.0,5.5,2.1,Iris-virginica
114
+ 5.7,2.5,5.0,2.0,Iris-virginica
115
+ 5.8,2.8,5.1,2.4,Iris-virginica
116
+ 6.4,3.2,5.3,2.3,Iris-virginica
117
+ 6.5,3.0,5.5,1.8,Iris-virginica
118
+ 7.7,3.8,6.7,2.2,Iris-virginica
119
+ 7.7,2.6,6.9,2.3,Iris-virginica
120
+ 6.0,2.2,5.0,1.5,Iris-virginica
121
+ 6.9,3.2,5.7,2.3,Iris-virginica
122
+ 5.6,2.8,4.9,2.0,Iris-virginica
123
+ 7.7,2.8,6.7,2.0,Iris-virginica
124
+ 6.3,2.7,4.9,1.8,Iris-virginica
125
+ 6.7,3.3,5.7,2.1,Iris-virginica
126
+ 7.2,3.2,6.0,1.8,Iris-virginica
127
+ 6.2,2.8,4.8,1.8,Iris-virginica
128
+ 6.1,3.0,4.9,1.8,Iris-virginica
129
+ 6.4,2.8,5.6,2.1,Iris-virginica
130
+ 7.2,3.0,5.8,1.6,Iris-virginica
131
+ 7.4,2.8,6.1,1.9,Iris-virginica
132
+ 7.9,3.8,6.4,2.0,Iris-virginica
133
+ 6.4,2.8,5.6,2.2,Iris-virginica
134
+ 6.3,2.8,5.1,1.5,Iris-virginica
135
+ 6.1,2.6,5.6,1.4,Iris-virginica
136
+ 7.7,3.0,6.1,2.3,Iris-virginica
137
+ 6.3,3.4,5.6,2.4,Iris-virginica
138
+ 6.4,3.1,5.5,1.8,Iris-virginica
139
+ 6.0,3.0,4.8,1.8,Iris-virginica
140
+ 6.9,3.1,5.4,2.1,Iris-virginica
141
+ 6.7,3.1,5.6,2.4,Iris-virginica
142
+ 6.9,3.1,5.1,2.3,Iris-virginica
143
+ 5.8,2.7,5.1,1.9,Iris-virginica
144
+ 6.8,3.2,5.9,2.3,Iris-virginica
145
+ 6.7,3.3,5.7,2.5,Iris-virginica
146
+ 6.7,3.0,5.2,2.3,Iris-virginica
147
+ 6.3,2.5,5.0,1.9,Iris-virginica
148
+ 6.5,3.0,5.2,2.0,Iris-virginica
149
+ 6.2,3.4,5.4,2.3,Iris-virginica
150
+ 5.9,3.0,5.1,1.8,Iris-virginica
151
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Please use python V 3.7 to be compatible with all packages
2
+ gpytorch==1.5.0
3
+ torch==1.9.0
4
+ scikit-learn==0.24.2
5
+ pyyaml==5.4.1
6
+ seaborn==0.11.2
7
+ xgboost==1.4.0
8
+ tqdm==4.62.1
9
+ numpy==1.21.2
10
+ openml==0.12.2
11
+ catboost==0.26.1
12
+ auto-sklearn==0.14.5
13
+ hyperopt==0.2.5
14
+ configspace==0.4.21
15
+ # autogluon==0.4.0
16
+ gradio==3.1.1