Spaces:
Build error
Build error
Samuel Mueller
commited on
Commit
•
e487255
0
Parent(s):
init
Browse files- .gitattributes +33 -0
- .gitmodules +0 -0
- README.md +12 -0
- TabPFN/PrepareDatasets.ipynb +373 -0
- TabPFN/README.md +23 -0
- TabPFN/SyntheticGPAblation.ipynb +392 -0
- TabPFN/TabPFNPredictionOnly.ipynb +253 -0
- TabPFN/TabularEvaluationVisualization.ipynb +0 -0
- TabPFN/TrainingTuningAndPrediction.ipynb +0 -0
- TabPFN/datasets/__init__.py +149 -0
- TabPFN/datasets/utils.py +8 -0
- TabPFN/decoders.py +30 -0
- TabPFN/differentiable_pfn_evaluation.py +345 -0
- TabPFN/encoders.py +225 -0
- TabPFN/initializers.py +9 -0
- TabPFN/layer.py +125 -0
- TabPFN/losses.py +41 -0
- TabPFN/model_builder.py +273 -0
- TabPFN/models_diff/gp_ablation_model.cpkt +3 -0
- TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt +3 -0
- TabPFN/notebook_utils.py +32 -0
- TabPFN/positional_encodings.py +70 -0
- TabPFN/prior_tuning_result.pkl +3 -0
- TabPFN/priors/__init__.py +4 -0
- TabPFN/priors/differentiable_prior.py +293 -0
- TabPFN/priors/fast_gp.py +144 -0
- TabPFN/priors/flexible_categorical.py +240 -0
- TabPFN/priors/mlp.py +173 -0
- TabPFN/priors/prior.py +12 -0
- TabPFN/priors/prior_bag.py +32 -0
- TabPFN/priors/utils.py +163 -0
- TabPFN/requirements.txt +15 -0
- TabPFN/scripts/baseline_prediction_interface.py +38 -0
- TabPFN/scripts/differentiable_pfn_evaluation.py +391 -0
- TabPFN/scripts/model_configs.py +210 -0
- TabPFN/scripts/tabular_baselines.py +421 -0
- TabPFN/scripts/tabular_evaluation.py +284 -0
- TabPFN/scripts/tabular_metrics.py +181 -0
- TabPFN/scripts/transformer_prediction_interface.py +357 -0
- TabPFN/tabular_evaluation.py +283 -0
- TabPFN/train.py +386 -0
- TabPFN/transformer.py +226 -0
- TabPFN/utils.py +236 -0
- app.py +96 -0
- balance-scale.arff +694 -0
- iris.csv +151 -0
- requirements.txt +16 -0
.gitattributes
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.cpkt filter=lfs diff=lfs merge=lfs -text
|
.gitmodules
ADDED
File without changes
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: TabPFN
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.1.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
TabPFN/PrepareDatasets.ipynb
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import numpy as np\n",
|
10 |
+
"\n",
|
11 |
+
"import openml\n",
|
12 |
+
"import pandas as pd"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 2,
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"from tqdm import tqdm\n",
|
22 |
+
"\n",
|
23 |
+
"from datasets import load_openml_list, test_dids_classification, valid_large_classification, open_cc_dids, open_cc_valid_dids\n"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": 6,
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [
|
31 |
+
{
|
32 |
+
"name": "stdout",
|
33 |
+
"output_type": "stream",
|
34 |
+
"text": [
|
35 |
+
"The autoreload extension is already loaded. To reload it, use:\n",
|
36 |
+
" %reload_ext autoreload\n"
|
37 |
+
]
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"source": [
|
41 |
+
"%load_ext autoreload\n",
|
42 |
+
"\n",
|
43 |
+
"%autoreload 2"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "markdown",
|
48 |
+
"metadata": {
|
49 |
+
"tags": []
|
50 |
+
},
|
51 |
+
"source": [
|
52 |
+
"### Prepare test datasets"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": 7,
|
58 |
+
"metadata": {},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"renamer = {'name': 'Name', 'NumberOfFeatures': '# Features', 'NumberOfSymbolicFeatures': '# Categorical Features', 'NumberOfInstances': '# Instances', 'NumberOfMissingValues': '# NaNs', 'NumberOfClasses': '# Classes', 'MinorityClassSize': 'Minority Class Size'}\n"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 8,
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [
|
69 |
+
{
|
70 |
+
"data": {
|
71 |
+
"text/plain": [
|
72 |
+
"OrderedDict([(99,\n",
|
73 |
+
" {'id': 99,\n",
|
74 |
+
" 'alias': 'OpenML-CC18',\n",
|
75 |
+
" 'main_entity_type': 'task',\n",
|
76 |
+
" 'name': 'OpenML-CC18 Curated Classification benchmark',\n",
|
77 |
+
" 'status': 'active',\n",
|
78 |
+
" 'creation_date': '2019-02-21 18:47:13',\n",
|
79 |
+
" 'creator': 1}),\n",
|
80 |
+
" (225,\n",
|
81 |
+
" {'id': 225,\n",
|
82 |
+
" 'alias': 'OpenML-friendly',\n",
|
83 |
+
" 'main_entity_type': 'task',\n",
|
84 |
+
" 'name': 'OpenML100-friendly',\n",
|
85 |
+
" 'status': 'active',\n",
|
86 |
+
" 'creation_date': '2019-09-16 19:41:46',\n",
|
87 |
+
" 'creator': 1})])"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"execution_count": 8,
|
91 |
+
"metadata": {},
|
92 |
+
"output_type": "execute_result"
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"source": [
|
96 |
+
"openml.study.list_suites()"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 9,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"suite = openml.study.get_suite(suite_id=99)\n",
|
106 |
+
"tasks = openml.tasks.list_tasks(output_format=\"dataframe\")"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"execution_count": 10,
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [],
|
114 |
+
"source": [
|
115 |
+
"# Using ``@`` in `pd.DataFrame.query <\n",
|
116 |
+
"# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html>`_\n",
|
117 |
+
"# accesses variables outside of the current dataframe.\n",
|
118 |
+
"tasks = tasks.query(\"tid in @suite.tasks\")"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 11,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"tids = list(tasks[np.logical_and(np.logical_and((tasks.NumberOfInstances <= 2000), (tasks.NumberOfFeatures <= 100))\n",
|
128 |
+
" , (tasks.NumberOfClasses <= 10))].tid)"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 12,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [
|
136 |
+
{
|
137 |
+
"data": {
|
138 |
+
"text/plain": [
|
139 |
+
"30"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
"execution_count": 12,
|
143 |
+
"metadata": {},
|
144 |
+
"output_type": "execute_result"
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"len(tids)"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"execution_count": 13,
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [],
|
156 |
+
"source": [
|
157 |
+
"tids = list(tasks[tasks.NumberOfInstances <= 2000].tid)"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"cell_type": "code",
|
162 |
+
"execution_count": 14,
|
163 |
+
"metadata": {},
|
164 |
+
"outputs": [],
|
165 |
+
"source": [
|
166 |
+
"open_cc_dids = [openml.tasks.get_task(task_id).get_dataset().id for task_id in tids]"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": null,
|
172 |
+
"outputs": [],
|
173 |
+
"source": [
|
174 |
+
"open_ml_datasets, open_ml_datasets_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 100000, num_feats=100, return_capped=True)\n"
|
175 |
+
],
|
176 |
+
"metadata": {
|
177 |
+
"collapsed": false,
|
178 |
+
"pycharm": {
|
179 |
+
"name": "#%%\n"
|
180 |
+
}
|
181 |
+
}
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 16,
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"open_ml_datasets_df = open_ml_datasets_df[open_ml_datasets_df.NumberOfInstances > 10000]"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"execution_count": 17,
|
195 |
+
"metadata": {},
|
196 |
+
"outputs": [
|
197 |
+
{
|
198 |
+
"name": "stdout",
|
199 |
+
"output_type": "stream",
|
200 |
+
"text": [
|
201 |
+
"\\begin{tabular}{lrrrrrrr}\n",
|
202 |
+
"\\toprule\n",
|
203 |
+
" Name & \\# Features & \\# Categorical Features & \\# Instances & \\# Classes & \\# NaNs & Minority Class Size & id \\\\\n",
|
204 |
+
"\\midrule\n",
|
205 |
+
" KDDCup09\\_appetency & 231 & 39 & 50000 & 2 & 8024152 & 890 & 1111 \\\\\n",
|
206 |
+
" airlines & 8 & 5 & 539383 & 2 & 0 & 240264 & 1169 \\\\\n",
|
207 |
+
" bank-marketing & 17 & 10 & 45211 & 2 & 0 & 5289 & 1461 \\\\\n",
|
208 |
+
" nomao & 119 & 30 & 34465 & 2 & 0 & 9844 & 1486 \\\\\n",
|
209 |
+
" adult & 15 & 9 & 48842 & 2 & 6465 & 11687 & 1590 \\\\\n",
|
210 |
+
" covertype & 55 & 45 & 581012 & 7 & 0 & 2747 & 1596 \\\\\n",
|
211 |
+
" numerai28.6 & 22 & 1 & 96320 & 2 & 0 & 47662 & 23517 \\\\\n",
|
212 |
+
" connect-4 & 43 & 43 & 67557 & 3 & 0 & 6449 & 40668 \\\\\n",
|
213 |
+
"jungle\\_chess\\_2pcs\\_raw\\_endgame\\_complete & 7 & 1 & 44819 & 3 & 0 & 4335 & 41027 \\\\\n",
|
214 |
+
" APSFailure & 171 & 1 & 76000 & 2 & 1078695 & 1375 & 41138 \\\\\n",
|
215 |
+
" albert & 79 & 53 & 425240 & 2 & 2734000 & 212620 & 41147 \\\\\n",
|
216 |
+
" MiniBooNE & 51 & 1 & 130064 & 2 & 0 & 36499 & 41150 \\\\\n",
|
217 |
+
" guillermo & 4297 & 1 & 20000 & 2 & 0 & 8003 & 41159 \\\\\n",
|
218 |
+
" riccardo & 4297 & 1 & 20000 & 2 & 0 & 5000 & 41161 \\\\\n",
|
219 |
+
" volkert & 181 & 1 & 58310 & 10 & 0 & 1361 & 41166 \\\\\n",
|
220 |
+
" dionis & 61 & 1 & 416188 & 355 & 0 & 878 & 41167 \\\\\n",
|
221 |
+
" jannis & 55 & 1 & 83733 & 4 & 0 & 1687 & 41168 \\\\\n",
|
222 |
+
" helena & 28 & 1 & 65196 & 100 & 0 & 111 & 41169 \\\\\n",
|
223 |
+
"\\bottomrule\n",
|
224 |
+
"\\end{tabular}\n",
|
225 |
+
"\n"
|
226 |
+
]
|
227 |
+
}
|
228 |
+
],
|
229 |
+
"source": [
|
230 |
+
"print_table = open_ml_datasets_df\n",
|
231 |
+
"print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
|
232 |
+
"print_table['id'] = print_table.index\n",
|
233 |
+
"print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']] = print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].astype(int)\n",
|
234 |
+
"print_table = print_table.rename(columns=renamer)\n",
|
235 |
+
"print(print_table.to_latex(index=False))"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "markdown",
|
240 |
+
"metadata": {
|
241 |
+
"tags": []
|
242 |
+
},
|
243 |
+
"source": [
|
244 |
+
"### Prepare Validation datasets"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": null,
|
250 |
+
"outputs": [],
|
251 |
+
"source": [
|
252 |
+
"open_cc_datasets, open_cc_datasets_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 2000, num_feats=100, return_capped=True)\n",
|
253 |
+
"\n",
|
254 |
+
"def extend_datasets(datasets, filtering = False):\n",
|
255 |
+
" extended_datasets = {}\n",
|
256 |
+
" i = 0\n",
|
257 |
+
" for d in tqdm(datasets):\n",
|
258 |
+
" if ((not 'NumberOfFeatures' in datasets[d])\n",
|
259 |
+
" or (not 'NumberOfClasses' in datasets[d])\n",
|
260 |
+
" or (not 'NumberOfInstances' in datasets[d])\n",
|
261 |
+
" # or datasets[d]['NumberOfFeatures'] >= num_feats\n",
|
262 |
+
" or datasets[d]['NumberOfClasses'] <= 0):\n",
|
263 |
+
" print(datasets[d])\n",
|
264 |
+
" continue\n",
|
265 |
+
" ds = openml.datasets.get_dataset(d, download_data=False)\n",
|
266 |
+
" if filtering and (datasets[d]['NumberOfInstances'] < 150\n",
|
267 |
+
" or datasets[d]['NumberOfInstances'] > 2000\n",
|
268 |
+
" or datasets[d]['NumberOfFeatures'] > 100\n",
|
269 |
+
" or datasets[d]['NumberOfClasses'] > 10):\n",
|
270 |
+
" continue\n",
|
271 |
+
" extended_datasets[d] = datasets[d]\n",
|
272 |
+
" extended_datasets[d].update(ds.qualities)\n",
|
273 |
+
" \n",
|
274 |
+
" return extended_datasets\n",
|
275 |
+
"\n",
|
276 |
+
"# All datasets\n",
|
277 |
+
"openml_list = openml.datasets.list_datasets()\n",
|
278 |
+
"openml_list = pd.DataFrame.from_dict(openml_list, orient=\"index\")\n",
|
279 |
+
"\n",
|
280 |
+
"# Select only classification\n",
|
281 |
+
"openml_list = openml_list[~openml_list['MajorityClassSize'].isna()]\n",
|
282 |
+
"\n",
|
283 |
+
"# Remove duplicated datasets\n",
|
284 |
+
"duplicated = openml_list.duplicated(subset=['MajorityClassSize', 'MaxNominalAttDistinctValues', 'MinorityClassSize',\n",
|
285 |
+
" 'NumberOfClasses', 'NumberOfFeatures', 'NumberOfInstances',\n",
|
286 |
+
" 'NumberOfInstancesWithMissingValues', 'NumberOfMissingValues',\n",
|
287 |
+
" 'NumberOfNumericFeatures', 'NumberOfSymbolicFeatures'], keep='first')\n",
|
288 |
+
"openml_list = openml_list[~duplicated]\n",
|
289 |
+
"\n",
|
290 |
+
"duplicated = openml_list.duplicated(subset=['name'], keep='first')\n",
|
291 |
+
"openml_list = openml_list[~duplicated]\n",
|
292 |
+
"\n",
|
293 |
+
"# Filter out datasets that don't have meta information or Don't fulfill other criteria\n",
|
294 |
+
"openml_list = openml_list.to_dict(orient='index')\n",
|
295 |
+
"openml_list = pd.DataFrame.from_dict(extend_datasets(openml_list, filtering=True), orient=\"index\")\n",
|
296 |
+
"\n",
|
297 |
+
"# Filter out datasets in Open CC\n",
|
298 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: x in test_datasets_multiclass_df.name.values)]\n",
|
299 |
+
"openml_list['CFI'] = openml_list.apply(lambda x: str(x.NumberOfClasses) + '_' + str(x.NumberOfFeatures) + '_' + str(x.NumberOfInstances), axis = 1)\n",
|
300 |
+
"test_datasets_multiclass_df['CFI'] = test_datasets_multiclass_df.apply(lambda x: str(x.NumberOfClasses) + '_' + str(x.NumberOfFeatures) + '_' + str(x.NumberOfInstances), axis = 1)\n",
|
301 |
+
"openml_list = openml_list[~openml_list.CFI.apply(lambda x: x in test_datasets_multiclass_df.CFI.values)]\n",
|
302 |
+
"\n",
|
303 |
+
"# Remove time series and artificial data\n",
|
304 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'autoUniv' in x)]\n",
|
305 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'fri_' in x)]\n",
|
306 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'FOREX' in x)]\n",
|
307 |
+
"\n",
|
308 |
+
"# Remove datasets that overlapped with Open CC closely by name\n",
|
309 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'ilpd' in x)]\n",
|
310 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'car' in x)]\n",
|
311 |
+
"openml_list = openml_list[~openml_list.name.apply(lambda x: 'pc1' in x)]\n",
|
312 |
+
"\n",
|
313 |
+
"# Remove datasets that didn't load\n",
|
314 |
+
"openml_list = openml_list[~openml_list.did.apply(lambda x: x in {1065, 40589, 41496, 770, 43097, 43148, 43255, 43595, 43786, 41701})]\n",
|
315 |
+
"\n",
|
316 |
+
"# Remove class skew\n",
|
317 |
+
"openml_list = openml_list[(openml_list.MinorityClassSize / openml_list.MajorityClassSize) > 0.05]\n",
|
318 |
+
"openml_list = openml_list[openml_list.AutoCorrelation != 1]\n",
|
319 |
+
"\n",
|
320 |
+
"# Remove too easy\n",
|
321 |
+
"openml_list = openml_list[openml_list.CfsSubsetEval_DecisionStumpAUC != 1]"
|
322 |
+
],
|
323 |
+
"metadata": {
|
324 |
+
"collapsed": false,
|
325 |
+
"pycharm": {
|
326 |
+
"name": "#%%\n"
|
327 |
+
}
|
328 |
+
}
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": null,
|
333 |
+
"metadata": {},
|
334 |
+
"outputs": [],
|
335 |
+
"source": [
|
336 |
+
"print_table = openml_list\n",
|
337 |
+
"print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
|
338 |
+
"print_table['id'] = print_table.index\n",
|
339 |
+
"print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']] = print_table[['NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].astype(int)\n",
|
340 |
+
"print_table = print_table.rename(columns=renamer)\n",
|
341 |
+
"print(print_table.to_latex(index=False))"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"cell_type": "code",
|
346 |
+
"execution_count": null,
|
347 |
+
"metadata": {},
|
348 |
+
"outputs": [],
|
349 |
+
"source": []
|
350 |
+
}
|
351 |
+
],
|
352 |
+
"metadata": {
|
353 |
+
"kernelspec": {
|
354 |
+
"display_name": "Python 3 (ipykernel)",
|
355 |
+
"language": "python",
|
356 |
+
"name": "python3"
|
357 |
+
},
|
358 |
+
"language_info": {
|
359 |
+
"codemirror_mode": {
|
360 |
+
"name": "ipython",
|
361 |
+
"version": 3
|
362 |
+
},
|
363 |
+
"file_extension": ".py",
|
364 |
+
"mimetype": "text/x-python",
|
365 |
+
"name": "python",
|
366 |
+
"nbconvert_exporter": "python",
|
367 |
+
"pygments_lexer": "ipython3",
|
368 |
+
"version": "3.7.13"
|
369 |
+
}
|
370 |
+
},
|
371 |
+
"nbformat": 4,
|
372 |
+
"nbformat_minor": 4
|
373 |
+
}
|
TabPFN/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TabPFN
|
2 |
+
|
3 |
+
## Installation
|
4 |
+
```
|
5 |
+
git clone [email protected]:automl/TabPFN.git
|
6 |
+
cd TabPFN
|
7 |
+
conda create -n TabPFN python=3.7
|
8 |
+
conda activate TabPFN
|
9 |
+
pip install -r requirements.txt
|
10 |
+
```
|
11 |
+
|
12 |
+
To run the autogluon baseline please create a separate environment and install autogluon==0.4.0, installation in the same environment as our other baselines is not possible.
|
13 |
+
|
14 |
+
## Usage
|
15 |
+
TrainingTuningAndPrediction: Train a TabPFN, Prior Tune and predict using a pretrained model.
|
16 |
+
|
17 |
+
TabularEvaluationVisualization: Run Baselines and load Baseline and TabPFN Results for comparison and plotting.
|
18 |
+
|
19 |
+
PrepareDatasets: Notebook used to inspect Datasets (Not needed to run baselines / TabPFN).
|
20 |
+
|
21 |
+
SytheticGPAblation: Ablation experiments for Gaussian Process fitting with differentiable Hyper Parameters.
|
22 |
+
|
23 |
+
|
TabPFN/SyntheticGPAblation.ipynb
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"%load_ext autoreload\n",
|
10 |
+
"\n",
|
11 |
+
"%autoreload 2"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 2,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import os\n",
|
21 |
+
"import time\n",
|
22 |
+
"\n",
|
23 |
+
"import torch\n",
|
24 |
+
"\n",
|
25 |
+
"import numpy as np\n",
|
26 |
+
"\n",
|
27 |
+
"import matplotlib.pyplot as plt\n",
|
28 |
+
"\n",
|
29 |
+
"from model_builder import get_model, get_default_spec, save_model, load_model\n",
|
30 |
+
"\n",
|
31 |
+
"from scripts.model_configs import *"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"tags": []
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"# Setting params"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 6,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"device = 'cuda'\n",
|
50 |
+
"base_path = os.path.join('.')"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 7,
|
56 |
+
"metadata": {},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"def train_function(config_sample, i, add_name=''):\n",
|
60 |
+
" start_time = time.time()\n",
|
61 |
+
" N_epochs_to_save = 50\n",
|
62 |
+
" \n",
|
63 |
+
" def save_callback(model, epoch):\n",
|
64 |
+
" if not hasattr(model, 'last_saved_epoch'):\n",
|
65 |
+
" model.last_saved_epoch = 0\n",
|
66 |
+
" if ((time.time() - start_time) / (maximum_runtime * 60 / N_epochs_to_save)) > model.last_saved_epoch:\n",
|
67 |
+
" print('Saving model..')\n",
|
68 |
+
" config_sample['epoch_in_training'] = epoch\n",
|
69 |
+
" save_model(model, base_path, f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{model.last_saved_epoch}.cpkt',\n",
|
70 |
+
" config_sample)\n",
|
71 |
+
" model.last_saved_epoch = model.last_saved_epoch + 1 # TODO: Rename to checkpoint\n",
|
72 |
+
" \n",
|
73 |
+
" model = get_model(config_sample\n",
|
74 |
+
" , device\n",
|
75 |
+
" , should_train=True\n",
|
76 |
+
" , verbose=1\n",
|
77 |
+
" , epoch_callback = save_callback)\n",
|
78 |
+
" \n",
|
79 |
+
" return"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "markdown",
|
84 |
+
"metadata": {
|
85 |
+
"heading_collapsed": true,
|
86 |
+
"tags": []
|
87 |
+
},
|
88 |
+
"source": [
|
89 |
+
"# Check synthetic data fitting"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "markdown",
|
94 |
+
"metadata": {
|
95 |
+
"tags": []
|
96 |
+
},
|
97 |
+
"source": [
|
98 |
+
"#### Workflow functions"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": 8,
|
104 |
+
"metadata": {
|
105 |
+
"hidden": true,
|
106 |
+
"tags": []
|
107 |
+
},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"def generate_test_data(test_gp_params):\n",
|
111 |
+
" # Generate test data\n",
|
112 |
+
" config = {**test_gp_params}\n",
|
113 |
+
"\n",
|
114 |
+
" config['verbose'] = False\n",
|
115 |
+
" config['differentiable'] = False\n",
|
116 |
+
" #config['bptt'] = config['bptt_in_training']\n",
|
117 |
+
"\n",
|
118 |
+
" model_test_data = get_model(config, device, should_train=False, verbose=True)\n",
|
119 |
+
" (hp_embedding, data, targets_), targets = next(iter(model_test_data[3]))\n",
|
120 |
+
" (hp_embedding, data, targets_), targets = (hp_embedding, data.to(device), targets_.to(device)), targets.to(device)\n",
|
121 |
+
" \n",
|
122 |
+
" return (hp_embedding, data, targets_), targets\n",
|
123 |
+
"\n",
|
124 |
+
"def evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size):\n",
|
125 |
+
" losses, hparams = [], []\n",
|
126 |
+
" for l in np.arange(-1.74, 1.74, plot_step_size):\n",
|
127 |
+
" hparam = [*hparam_true]\n",
|
128 |
+
" hparam[vary_hparam_ind] = l\n",
|
129 |
+
" hp_embedding_used = torch.tensor(hparam).to(device).float()\n",
|
130 |
+
" with torch.inference_mode():\n",
|
131 |
+
" outputs = torch.sigmoid(model[2]((hp_embedding_used.repeat(data.shape[1], 1), data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
|
132 |
+
" \n",
|
133 |
+
" loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten()).detach().cpu()\n",
|
134 |
+
" losses += [loss]\n",
|
135 |
+
" hparam_real = [diff_hparams_f[i][1](hp) for i, hp in enumerate(hparam)]\n",
|
136 |
+
" hparams += [hparam_real]\n",
|
137 |
+
" \n",
|
138 |
+
" print(loss, hparam_real, hparam, outputs.shape)\n",
|
139 |
+
" return np.array(losses), np.array(hparams)"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 9,
|
145 |
+
"metadata": {},
|
146 |
+
"outputs": [],
|
147 |
+
"source": [
|
148 |
+
"def differentiable_hparam_tuning_workflow(config_sample, hparam_label, batch_size=4, N_grad_steps=50, plot_step_size=0.1):\n",
|
149 |
+
" test_gp_params = {\n",
|
150 |
+
" \"lengthscale\": 1.0,\n",
|
151 |
+
" #\"lengthscale_mean\": true_lengthscale,\n",
|
152 |
+
" #\"lengthscale_std\": 0.5,\n",
|
153 |
+
" \"noise\": 0.2,\n",
|
154 |
+
" \"outputscale\": 1.0,\n",
|
155 |
+
" 'batch_size': batch_size\n",
|
156 |
+
" }\n",
|
157 |
+
" config_sample.update(test_gp_params)\n",
|
158 |
+
" (hp_embedding, data, targets_), targets = generate_test_data(config_sample)\n",
|
159 |
+
" hparam_true = [diff_hparams_f[i][0](test_gp_params[hp]) for i, hp in enumerate(diff_hparams_keys)]\n",
|
160 |
+
" #hparam_true = [test_gp_params[hp] for i, hp in enumerate(diff_hparams_keys)]\n",
|
161 |
+
"\n",
|
162 |
+
" for vary_hparam_ind, vary_hparam_name in hparam_label:\n",
|
163 |
+
" print(vary_hparam_name)\n",
|
164 |
+
"\n",
|
165 |
+
" losses, hparams = evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size=plot_step_size)\n",
|
166 |
+
"\n",
|
167 |
+
" # TODO: Make only one parameter diffable\n",
|
168 |
+
" hparam = torch.tensor([*hparam_true]).to(device).float()\n",
|
169 |
+
" hparam[vary_hparam_ind] = hparam[vary_hparam_ind] + 0.1 #random.random() * 2 - 1\n",
|
170 |
+
" hparam = torch.nn.Parameter(hparam, requires_grad=True)\n",
|
171 |
+
" hparam_grad_mask = torch.zeros_like(hparam)\n",
|
172 |
+
" hparam_grad_mask[vary_hparam_ind] = 1\n",
|
173 |
+
"\n",
|
174 |
+
" optimizer = torch.optim.Adam([hparam], lr=0.1)\n",
|
175 |
+
" \n",
|
176 |
+
" for t in range(N_grad_steps):\n",
|
177 |
+
" style = hparam.repeat(data.shape[1], 1)\n",
|
178 |
+
" outputs = torch.sigmoid(model[2]((style, data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
|
179 |
+
" loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten())\n",
|
180 |
+
" optimizer.zero_grad()\n",
|
181 |
+
" loss.backward()\n",
|
182 |
+
" with torch.no_grad():\n",
|
183 |
+
" hparam.grad *= hparam_grad_mask\n",
|
184 |
+
" optimizer.step()\n",
|
185 |
+
" print('loss:', loss, 'hparams', diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind]), 'true', diff_hparams_f[vary_hparam_ind][1](hparam_true[vary_hparam_ind]))\n",
|
186 |
+
" inferred_param = diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind].cpu().detach().numpy())\n",
|
187 |
+
" return hparams, losses, inferred_param, vary_hparam_ind, hparam_true\n",
|
188 |
+
" "
|
189 |
+
]
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "markdown",
|
193 |
+
"metadata": {
|
194 |
+
"tags": []
|
195 |
+
},
|
196 |
+
"source": [
|
197 |
+
"#### Fitting a PFN with HP-Diffable GP Prior"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 10,
|
203 |
+
"metadata": {
|
204 |
+
"hidden": true,
|
205 |
+
"tags": []
|
206 |
+
},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"num_features = 5\n",
|
210 |
+
"bptt = 200\n",
|
211 |
+
"eval_positions = [100]\n",
|
212 |
+
"\n",
|
213 |
+
"config_general = get_general_config(num_features, bptt, eval_positions)\n",
|
214 |
+
"config_flexible_categorical = get_flexible_categorical_config(num_features)\n",
|
215 |
+
"\n",
|
216 |
+
"config_gp = {'noise': 0.2, \"lengthscale\": 1.0, \"outputscale\": 1.0}\n",
|
217 |
+
"config_diff_gp = {'differentiable_hyperparameters': {\n",
|
218 |
+
" 'outputscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
|
219 |
+
" 'lengthscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
|
220 |
+
" 'noise': {'distribution': 'uniform', 'min': 0.0000001, 'max': 0.5},\n",
|
221 |
+
" }\n",
|
222 |
+
"}\n",
|
223 |
+
"\n",
|
224 |
+
"config = {**config_general, **config_flexible_categorical, **config_diff_gp, **config_gp}\n",
|
225 |
+
"\n",
|
226 |
+
"config['prior_type'], config['differentiable'], config['flexible'] = 'gp', True, True\n",
|
227 |
+
"config['num_features'], config['num_features_used'] = num_features, num_features\n",
|
228 |
+
"config['epochs'], config['num_steps'], config['verbose'] = 500, 100, False\n",
|
229 |
+
"config[\"lr\"] = 0.00001\n",
|
230 |
+
"config[\"dropout\"] = 0\n",
|
231 |
+
"config[\"emsize\"] = 512\n",
|
232 |
+
"config[\"batch_size\"] = 128\n",
|
233 |
+
"config[\"aggregate_k_gradients\"] = 1\n",
|
234 |
+
"config['set_value_to_nan'] = 0.0\n",
|
235 |
+
"config['output_multiclass_ordered_p'] = 1.0\n",
|
236 |
+
"config['categorical_feature_p'] = 0.0\n",
|
237 |
+
"config['nan_prob_a_reason'] = 0.0\n",
|
238 |
+
"config['nan_prob_no_reason'] = 0.0\n",
|
239 |
+
"config['nan_prob_unknown_reason'] = 0.0\n",
|
240 |
+
"config[\"nlayers\"] = 8\n",
|
241 |
+
"\n",
|
242 |
+
"# TODO: This should not be sampled, but be one config\n",
|
243 |
+
"# TODO: This uses old hyperparam sampler throws error\n",
|
244 |
+
"config_sample = evaluate_hypers(config)"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": 11,
|
250 |
+
"metadata": {
|
251 |
+
"hidden": true,
|
252 |
+
"tags": []
|
253 |
+
},
|
254 |
+
"outputs": [
|
255 |
+
{
|
256 |
+
"name": "stdout",
|
257 |
+
"output_type": "stream",
|
258 |
+
"text": [
|
259 |
+
"Using style prior: True\n",
|
260 |
+
"Using cpu:0 device\n",
|
261 |
+
"Not using distributed\n",
|
262 |
+
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 128, 'seq_len': 200, 'seq_len_maximum': 200, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 128, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 200, 'eval_positions': None, 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': 5, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.2, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'output_multiclass_ordered_p': 1.0, 'recompute_attn': False}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad8dcf80>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
|
263 |
+
"Using a Transformer with 17.35 M parameters\n"
|
264 |
+
]
|
265 |
+
}
|
266 |
+
],
|
267 |
+
"source": [
|
268 |
+
"device = 'cuda'\n",
|
269 |
+
"train_function(config_sample, 0, add_name='gp_experiments_diff_with_noise_no_meta_new')"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "markdown",
|
274 |
+
"metadata": {
|
275 |
+
"tags": []
|
276 |
+
},
|
277 |
+
"source": [
|
278 |
+
"#### Evaluating a PFN (with pretrained model)"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 13,
|
284 |
+
"metadata": {
|
285 |
+
"hidden": true,
|
286 |
+
"tags": []
|
287 |
+
},
|
288 |
+
"outputs": [
|
289 |
+
{
|
290 |
+
"name": "stdout",
|
291 |
+
"output_type": "stream",
|
292 |
+
"text": [
|
293 |
+
"Using style prior: True\n",
|
294 |
+
"Using cpu:0 device\n",
|
295 |
+
"Not using distributed\n",
|
296 |
+
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 1, 'seq_len': 10, 'seq_len_maximum': 10, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 1, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 10, 'eval_positions': [190], 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'output_multiclass_ordered_p': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'multiclass_type': 'rank', 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': <function load_model.<locals>.<lambda> at 0x7f39ad8534d0>, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.03, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'recompute_attn': False, 'bptt_extra_samples': None, 'epoch_in_training': 0.998, 'categorical_features_sampler': <function load_model.<locals>.<lambda> at 0x7f39ad853680>, 'num_features_used_in_training': 5, 'num_classes_in_training': 2, 'batch_size_in_training': 128, 'bptt_in_training': 200, 'bptt_extra_samples_in_training': None}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad81ab90>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
|
297 |
+
"Using a Transformer with 17.35 M parameters\n"
|
298 |
+
]
|
299 |
+
}
|
300 |
+
],
|
301 |
+
"source": [
|
302 |
+
"device = 'cpu'\n",
|
303 |
+
"model, c = load_model(base_path, f'models_diff/gp_ablation_model.cpkt', device, eval_positions, verbose=False)"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": 14,
|
309 |
+
"metadata": {},
|
310 |
+
"outputs": [],
|
311 |
+
"source": [
|
312 |
+
"from priors.differentiable_prior import DifferentiableHyperparameterList\n",
|
313 |
+
"diff_list = DifferentiableHyperparameterList(c['differentiable_hyperparameters'], 512, device)\n",
|
314 |
+
"diff_hparams_keys, diff_hparams_f = diff_list.get_hyperparameter_info()"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "code",
|
319 |
+
"execution_count": null,
|
320 |
+
"metadata": {
|
321 |
+
"tags": []
|
322 |
+
},
|
323 |
+
"outputs": [],
|
324 |
+
"source": [
|
325 |
+
"model[2].eval()\n",
|
326 |
+
"eval_pos = 100\n",
|
327 |
+
"\n",
|
328 |
+
"hparam_label = [(1, 'outputscale')]\n",
|
329 |
+
"hparam_label = [(0, 'lengthscale')]\n",
|
330 |
+
"hparam_label = [(2, 'noise')]\n",
|
331 |
+
"hparam_labels = [[(1, 'outputscale')], [(2, 'noise')], [(0, 'lengthscale')]]\n",
|
332 |
+
"#hparam_labels = [[(2, 'noise')]]\n",
|
333 |
+
"\n",
|
334 |
+
"hparams, losses, inferred_param, vary_hparam_ind, hparam_true = {}, {}, {}, {}, {}\n",
|
335 |
+
"\n",
|
336 |
+
"for hparam_label in hparam_labels:\n",
|
337 |
+
" (hparams[hparam_label[0][1]], losses[hparam_label[0][1]], inferred_param[hparam_label[0][1]], vary_hparam_ind[hparam_label[0][1]], \n",
|
338 |
+
" hparam_true[hparam_label[0][1]]) = differentiable_hparam_tuning_workflow(config_sample, \n",
|
339 |
+
" hparam_label=hparam_label, \n",
|
340 |
+
" batch_size=256, \n",
|
341 |
+
" N_grad_steps=50,\n",
|
342 |
+
" plot_step_size=0.05)\n"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "code",
|
347 |
+
"execution_count": null,
|
348 |
+
"metadata": {},
|
349 |
+
"outputs": [],
|
350 |
+
"source": [
|
351 |
+
"label = 'lengthscale'\n",
|
352 |
+
"\n",
|
353 |
+
"#import tikzplotlib\n",
|
354 |
+
"\n",
|
355 |
+
"inferred = losses[label]\n",
|
356 |
+
"\n",
|
357 |
+
"plt.plot(hparams[label][:, vary_hparam_ind[label]], losses[label])\n",
|
358 |
+
"true = diff_hparams_f[vary_hparam_ind[label]][1](hparam_true[label][vary_hparam_ind[label]])\n",
|
359 |
+
"plt.axvline(x=inferred_param[label], linestyle='solid', color='red')\n",
|
360 |
+
"plt.axvline(x=true, linestyle='dashed')\n",
|
361 |
+
"\n",
|
362 |
+
"plt.ylabel('Cross entropy Loss')\n",
|
363 |
+
"plt.xlabel(label)\n",
|
364 |
+
"\n",
|
365 |
+
"#tikzplotlib.save(f'diff_inferred_params_{label}.tex', axis_height='5.2cm', axis_width='5.2cm', strict=True)\n",
|
366 |
+
"\n",
|
367 |
+
"plt.show()"
|
368 |
+
]
|
369 |
+
}
|
370 |
+
],
|
371 |
+
"metadata": {
|
372 |
+
"kernelspec": {
|
373 |
+
"display_name": "Python 3 (ipykernel)",
|
374 |
+
"language": "python",
|
375 |
+
"name": "python3"
|
376 |
+
},
|
377 |
+
"language_info": {
|
378 |
+
"codemirror_mode": {
|
379 |
+
"name": "ipython",
|
380 |
+
"version": 3
|
381 |
+
},
|
382 |
+
"file_extension": ".py",
|
383 |
+
"mimetype": "text/x-python",
|
384 |
+
"name": "python",
|
385 |
+
"nbconvert_exporter": "python",
|
386 |
+
"pygments_lexer": "ipython3",
|
387 |
+
"version": "3.7.13"
|
388 |
+
}
|
389 |
+
},
|
390 |
+
"nbformat": 4,
|
391 |
+
"nbformat_minor": 4
|
392 |
+
}
|
TabPFN/TabPFNPredictionOnly.ipynb
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"This notebook shows how to use TabPFN for tabular prediction with a scikit learn wrapper.\n",
|
8 |
+
"\n",
|
9 |
+
"classifier = TabPFNClassifier(device='cpu')\n",
|
10 |
+
"classifier.fit(train_xs, train_ys)\n",
|
11 |
+
"prediction_ = classifier.predict(test_xs)\n",
|
12 |
+
"\n",
|
13 |
+
"The fit function does not perform any computations, but only saves the training data. Computations are only done at inference time, when calling predict.\n",
|
14 |
+
"Note that the presaved models were trained for up to 100 features, 10 classes and 1000 samples. While the model does not have a hard bound on the number of samples, the features and classes are restricted and larger sizes lead to an error."
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"metadata": {
|
20 |
+
"tags": []
|
21 |
+
},
|
22 |
+
"source": [
|
23 |
+
"### Setup"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"%load_ext autoreload\n",
|
33 |
+
"\n",
|
34 |
+
"%autoreload 2"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"import time\n",
|
44 |
+
"import torch\n",
|
45 |
+
"import numpy as np\n",
|
46 |
+
"import os\n",
|
47 |
+
"import random\n",
|
48 |
+
"\n",
|
49 |
+
"from model_builder import get_model, get_default_spec, save_model, load_model\n",
|
50 |
+
"from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier\n",
|
51 |
+
"\n",
|
52 |
+
"from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids\n",
|
53 |
+
"\n",
|
54 |
+
"from scripts import tabular_metrics"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"base_path = '.'"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "markdown",
|
68 |
+
"metadata": {
|
69 |
+
"tags": []
|
70 |
+
},
|
71 |
+
"source": [
|
72 |
+
"### Load datasets"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {
|
79 |
+
"jupyter": {
|
80 |
+
"outputs_hidden": true
|
81 |
+
},
|
82 |
+
"tags": []
|
83 |
+
},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"max_samples = 10000\n",
|
87 |
+
"bptt = 10000\n",
|
88 |
+
"\n",
|
89 |
+
"cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n",
|
90 |
+
"cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n",
|
91 |
+
"\n",
|
92 |
+
"# Loading longer OpenML Datasets for generalization experiments (optional)\n",
|
93 |
+
"# test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)\n",
|
94 |
+
"\n",
|
95 |
+
"random.seed(0)\n",
|
96 |
+
"random.shuffle(cc_valid_datasets_multiclass)"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"from datasets import get_openml_classification"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": null,
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [],
|
113 |
+
"source": [
|
114 |
+
"dataset = openml.datasets.get_dataset(31)\n",
|
115 |
+
"X, y, categorical_indicator, attribute_names = dataset.get_data(\n",
|
116 |
+
" dataset_format=\"array\", target=dataset.default_target_attribute\n",
|
117 |
+
" )"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"def get_datasets(selector, task_type, suite='cc'):\n",
|
127 |
+
" if task_type == 'binary':\n",
|
128 |
+
" ds = valid_datasets_binary if selector == 'valid' else test_datasets_binary\n",
|
129 |
+
" else:\n",
|
130 |
+
" if suite == 'openml':\n",
|
131 |
+
" ds = valid_datasets_multiclass if selector == 'valid' else test_datasets_multiclass\n",
|
132 |
+
" elif suite == 'cc':\n",
|
133 |
+
" ds = cc_valid_datasets_multiclass if selector == 'valid' else cc_test_datasets_multiclass\n",
|
134 |
+
" else:\n",
|
135 |
+
" raise Exception(\"Unknown suite\")\n",
|
136 |
+
" return ds"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": null,
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [],
|
144 |
+
"source": [
|
145 |
+
"model_string, longer, task_type = '', 1, 'multiclass'\n",
|
146 |
+
"eval_positions = [1000]\n",
|
147 |
+
"bptt = 2000\n",
|
148 |
+
" \n",
|
149 |
+
"test_datasets, valid_datasets = get_datasets('test', task_type, suite='cc'), get_datasets('valid', task_type, suite='cc')"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "markdown",
|
154 |
+
"metadata": {
|
155 |
+
"jp-MarkdownHeadingCollapsed": true,
|
156 |
+
"tags": []
|
157 |
+
},
|
158 |
+
"source": [
|
159 |
+
"### Select a dataset for prediction"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"[(i, test_datasets[i][0]) for i in range(len(test_datasets))]"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"evaluation_dataset_index = 4 # Index of the dataset to predict\n",
|
178 |
+
"ds = test_datasets[evaluation_dataset_index]\n",
|
179 |
+
"print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": null,
|
185 |
+
"metadata": {},
|
186 |
+
"outputs": [],
|
187 |
+
"source": [
|
188 |
+
"xs, ys = ds[1].clone(), ds[2].clone()\n",
|
189 |
+
"eval_position = xs.shape[0] // 2\n",
|
190 |
+
"train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]\n",
|
191 |
+
"test_xs, test_ys = xs[eval_position:], ys[eval_position:]"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "markdown",
|
196 |
+
"metadata": {
|
197 |
+
"tags": []
|
198 |
+
},
|
199 |
+
"source": [
|
200 |
+
"### Predict using a Fitted and Tuned Model"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [],
|
208 |
+
"source": [
|
209 |
+
"classifier = TabPFNClassifier(device='cpu')\n",
|
210 |
+
"classifier.fit(train_xs, train_ys)\n",
|
211 |
+
"prediction_ = classifier.predict_proba(test_xs)"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": null,
|
217 |
+
"metadata": {},
|
218 |
+
"outputs": [],
|
219 |
+
"source": [
|
220 |
+
"roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)\n",
|
221 |
+
"'AUC', float(roc), 'Cross Entropy', float(ce)"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"execution_count": null,
|
227 |
+
"metadata": {},
|
228 |
+
"outputs": [],
|
229 |
+
"source": []
|
230 |
+
}
|
231 |
+
],
|
232 |
+
"metadata": {
|
233 |
+
"kernelspec": {
|
234 |
+
"display_name": "Python 3 (ipykernel)",
|
235 |
+
"language": "python",
|
236 |
+
"name": "python3"
|
237 |
+
},
|
238 |
+
"language_info": {
|
239 |
+
"codemirror_mode": {
|
240 |
+
"name": "ipython",
|
241 |
+
"version": 3
|
242 |
+
},
|
243 |
+
"file_extension": ".py",
|
244 |
+
"mimetype": "text/x-python",
|
245 |
+
"name": "python",
|
246 |
+
"nbconvert_exporter": "python",
|
247 |
+
"pygments_lexer": "ipython3",
|
248 |
+
"version": "3.7.13"
|
249 |
+
}
|
250 |
+
},
|
251 |
+
"nbformat": 4,
|
252 |
+
"nbformat_minor": 4
|
253 |
+
}
|
TabPFN/TabularEvaluationVisualization.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
TabPFN/TrainingTuningAndPrediction.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
TabPFN/datasets/__init__.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import openml
|
5 |
+
|
6 |
+
|
7 |
+
def get_openml_classification(did, max_samples, multiclass=True, shuffled=True):
|
8 |
+
dataset = openml.datasets.get_dataset(did)
|
9 |
+
X, y, categorical_indicator, attribute_names = dataset.get_data(
|
10 |
+
dataset_format="array", target=dataset.default_target_attribute
|
11 |
+
)
|
12 |
+
|
13 |
+
if not multiclass:
|
14 |
+
X = X[y < 2]
|
15 |
+
y = y[y < 2]
|
16 |
+
|
17 |
+
if multiclass and not shuffled:
|
18 |
+
raise NotImplementedError("This combination of multiclass and shuffling isn't implemented")
|
19 |
+
|
20 |
+
if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
|
21 |
+
print('Not a NP Array, skipping')
|
22 |
+
return None, None, None, None
|
23 |
+
|
24 |
+
if not shuffled:
|
25 |
+
sort = np.argsort(y) if y.mean() < 0.5 else np.argsort(-y)
|
26 |
+
pos = int(y.sum()) if y.mean() < 0.5 else int((1 - y).sum())
|
27 |
+
X, y = X[sort][-pos * 2:], y[sort][-pos * 2:]
|
28 |
+
y = torch.tensor(y).reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).float()
|
29 |
+
X = torch.tensor(X).reshape(2, -1, X.shape[1]).transpose(0, 1).reshape(-1, X.shape[1]).flip([0]).float()
|
30 |
+
else:
|
31 |
+
order = np.arange(y.shape[0])
|
32 |
+
np.random.seed(13)
|
33 |
+
np.random.shuffle(order)
|
34 |
+
X, y = torch.tensor(X[order]), torch.tensor(y[order])
|
35 |
+
if max_samples:
|
36 |
+
X, y = X[:max_samples], y[:max_samples]
|
37 |
+
|
38 |
+
return X, y, list(np.where(categorical_indicator)[0]), attribute_names
|
39 |
+
|
40 |
+
def load_openml_list(dids, filter_for_nan=False
|
41 |
+
, num_feats=100
|
42 |
+
, min_samples = 100
|
43 |
+
, max_samples=400
|
44 |
+
, multiclass=True
|
45 |
+
, max_num_classes=10
|
46 |
+
, shuffled=True
|
47 |
+
, return_capped = False):
|
48 |
+
datasets = []
|
49 |
+
openml_list = openml.datasets.list_datasets(dids)
|
50 |
+
print(f'Number of datasets: {len(openml_list)}')
|
51 |
+
|
52 |
+
datalist = pd.DataFrame.from_dict(openml_list, orient="index")
|
53 |
+
if filter_for_nan:
|
54 |
+
datalist = datalist[datalist['NumberOfInstancesWithMissingValues'] == 0]
|
55 |
+
print(f'Number of datasets after Nan and feature number filtering: {len(datalist)}')
|
56 |
+
|
57 |
+
for ds in datalist.index:
|
58 |
+
modifications = {'samples_capped': False, 'classes_capped': False, 'feats_capped': False}
|
59 |
+
entry = datalist.loc[ds]
|
60 |
+
|
61 |
+
print('Loading', entry['name'], entry.did, '..')
|
62 |
+
|
63 |
+
if entry['NumberOfClasses'] == 0.0:
|
64 |
+
raise Exception("Regression not supported")
|
65 |
+
#X, y, categorical_feats, attribute_names = get_openml_regression(int(entry.did), max_samples)
|
66 |
+
else:
|
67 |
+
X, y, categorical_feats, attribute_names = get_openml_classification(int(entry.did), max_samples
|
68 |
+
, multiclass=multiclass, shuffled=shuffled)
|
69 |
+
if X is None:
|
70 |
+
continue
|
71 |
+
|
72 |
+
if X.shape[1] > num_feats:
|
73 |
+
if return_capped:
|
74 |
+
X = X[:, 0:num_feats]
|
75 |
+
categorical_feats = [c for c in categorical_feats if c < num_feats]
|
76 |
+
modifications['feats_capped'] = True
|
77 |
+
else:
|
78 |
+
print('Too many features')
|
79 |
+
continue
|
80 |
+
if X.shape[0] == max_samples:
|
81 |
+
modifications['samples_capped'] = True
|
82 |
+
|
83 |
+
if X.shape[0] < min_samples:
|
84 |
+
print(f'Too few samples left')
|
85 |
+
continue
|
86 |
+
|
87 |
+
if len(np.unique(y)) > max_num_classes:
|
88 |
+
if return_capped:
|
89 |
+
X = X[y < np.unique(y)[10]]
|
90 |
+
y = y[y < np.unique(y)[10]]
|
91 |
+
modifications['classes_capped'] = True
|
92 |
+
else:
|
93 |
+
print(f'Too many classes')
|
94 |
+
continue
|
95 |
+
|
96 |
+
datasets += [[entry['name'], X, y, categorical_feats, attribute_names, modifications]]
|
97 |
+
|
98 |
+
return datasets, datalist
|
99 |
+
|
100 |
+
|
101 |
+
# Classification
|
102 |
+
valid_dids_classification = [13, 59, 4, 15, 40710, 43, 1498]
|
103 |
+
test_dids_classification = [973, 1596, 40981, 1468, 40984, 40975, 41163, 41147, 1111, 41164, 1169, 1486, 41143, 1461, 41167, 40668, 41146, 41169, 41027, 23517, 41165, 41161, 41159, 41138, 1590, 41166, 1464, 41168, 41150, 1489, 41142, 3, 12, 31, 54, 1067]
|
104 |
+
valid_large_classification = [ 943, 23512, 49, 838, 1131, 767, 1142, 748, 1112,
|
105 |
+
1541, 384, 912, 1503, 796, 20, 30, 903, 4541,
|
106 |
+
961, 805, 1000, 4135, 1442, 816, 1130, 906, 1511,
|
107 |
+
184, 181, 137, 1452, 1481, 949, 449, 50, 913,
|
108 |
+
1071, 831, 843, 9, 896, 1532, 311, 39, 451,
|
109 |
+
463, 382, 778, 474, 737, 1162, 1538, 820, 188,
|
110 |
+
452, 1156, 37, 957, 911, 1508, 1054, 745, 1220,
|
111 |
+
763, 900, 25, 387, 38, 757, 1507, 396, 4153,
|
112 |
+
806, 779, 746, 1037, 871, 717, 1480, 1010, 1016,
|
113 |
+
981, 1547, 1002, 1126, 1459, 846, 837, 1042, 273,
|
114 |
+
1524, 375, 1018, 1531, 1458, 6332, 1546, 1129, 679,
|
115 |
+
389]
|
116 |
+
|
117 |
+
open_cc_dids = [11,
|
118 |
+
14,
|
119 |
+
15,
|
120 |
+
16,
|
121 |
+
18,
|
122 |
+
22,
|
123 |
+
23,
|
124 |
+
29,
|
125 |
+
31,
|
126 |
+
37,
|
127 |
+
50,
|
128 |
+
54,
|
129 |
+
188,
|
130 |
+
458,
|
131 |
+
469,
|
132 |
+
1049,
|
133 |
+
1050,
|
134 |
+
1063,
|
135 |
+
1068,
|
136 |
+
1510,
|
137 |
+
1494,
|
138 |
+
1480,
|
139 |
+
1462,
|
140 |
+
1464,
|
141 |
+
6332,
|
142 |
+
23381,
|
143 |
+
40966,
|
144 |
+
40982,
|
145 |
+
40994,
|
146 |
+
40975]
|
147 |
+
# Filtered by N_samples < 2000, N feats < 100, N classes < 10
|
148 |
+
|
149 |
+
open_cc_valid_dids = [13,25,35,40,41,43,48,49,51,53,55,56,59,61,187,285,329,333,334,335,336,337,338,377,446,450,451,452,460,463,464,466,470,475,481,679,694,717,721,724,733,738,745,747,748,750,753,756,757,764,765,767,774,778,786,788,795,796,798,801,802,810,811,814,820,825,826,827,831,839,840,841,844,852,853,854,860,880,886,895,900,906,907,908,909,915,925,930,931,934,939,940,941,949,966,968,984,987,996,1048,1054,1071,1073,1100,1115,1412,1442,1443,1444,1446,1447,1448,1451,1453,1488,1490,1495,1498,1499,1506,1508,1511,1512,1520,1523,4153,23499,40496,40646,40663,40669,40680,40682,40686,40690,40693,40705,40706,40710,40711,40981,41430,41538,41919,41976,42172,42261,42544,42585,42638]
|
TabPFN/datasets/utils.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def normalize_data(eval_xs):
|
2 |
+
mean = eval_xs.mean(0)
|
3 |
+
std = eval_xs.std(0) + .000001
|
4 |
+
eval_xs = (eval_xs - mean) / std
|
5 |
+
|
6 |
+
return eval_xs
|
7 |
+
|
8 |
+
|
TabPFN/decoders.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
class ScaledDecoder(nn.Module):
|
7 |
+
def __init__(self, ninp, nhid, nout):
|
8 |
+
super().__init__()
|
9 |
+
self.linear = nn.Linear(ninp, nhid)
|
10 |
+
self.linear1 = nn.Linear(nhid, nout)
|
11 |
+
self.linear2 = nn.Linear(nhid, 10)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
#return torch.cat([self.linear1(x), self.linear2(x)], -1)
|
15 |
+
x = self.linear(x)
|
16 |
+
x = nn.GELU()(x)
|
17 |
+
temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device)
|
18 |
+
if random.random() > .99:
|
19 |
+
print(temps.shape,temps[:,:2])
|
20 |
+
return self.linear1(x) / temps.unsqueeze(-1)
|
21 |
+
|
22 |
+
class FixedScaledDecoder(nn.Module):
|
23 |
+
def __init__(self, ninp, nhid, nout):
|
24 |
+
super().__init__()
|
25 |
+
self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout))
|
26 |
+
self.T = nn.Parameter(torch.ones(10000)/10000)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.mapper(x)/self.T.sum()
|
30 |
+
|
TabPFN/differentiable_pfn_evaluation.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import pickle
|
6 |
+
from scripts import tabular_metrics
|
7 |
+
from scripts.tabular_metrics import calculate_score_per_method
|
8 |
+
from scripts.tabular_evaluation import evaluate
|
9 |
+
from priors.differentiable_prior import draw_random_style
|
10 |
+
from tqdm import tqdm
|
11 |
+
import random
|
12 |
+
from scripts.transformer_prediction_interface import get_params_from_config, load_model_workflow
|
13 |
+
|
14 |
+
"""
|
15 |
+
===============================
|
16 |
+
PUBLIC FUNCTIONS FOR EVALUATION
|
17 |
+
===============================
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
def eval_model_range(i_range, *args, **kwargs):
|
22 |
+
for i in i_range:
|
23 |
+
eval_model(i, *args, **kwargs)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
|
28 |
+
bptt_valid,
|
29 |
+
bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
|
30 |
+
"""
|
31 |
+
Differentiable model evaliation workflow. Evaluates and saves results to disk.
|
32 |
+
|
33 |
+
:param i:
|
34 |
+
:param e:
|
35 |
+
:param valid_datasets:
|
36 |
+
:param test_datasets:
|
37 |
+
:param train_datasets:
|
38 |
+
:param eval_positions_valid:
|
39 |
+
:param eval_positions_test:
|
40 |
+
:param bptt_valid:
|
41 |
+
:param bptt_test:
|
42 |
+
:param add_name:
|
43 |
+
:param base_path:
|
44 |
+
:param device:
|
45 |
+
:param eval_addition:
|
46 |
+
:param extra_tuning_args:
|
47 |
+
:return:
|
48 |
+
"""
|
49 |
+
model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition)
|
50 |
+
params = {'bptt': bptt_valid
|
51 |
+
, 'bptt_final': bptt_test
|
52 |
+
, 'eval_positions': eval_positions_valid
|
53 |
+
, 'eval_positions_test': eval_positions_test
|
54 |
+
, 'valid_datasets': valid_datasets
|
55 |
+
, 'test_datasets': test_datasets
|
56 |
+
, 'train_datasets': train_datasets
|
57 |
+
, 'verbose': True
|
58 |
+
, 'device': device
|
59 |
+
}
|
60 |
+
|
61 |
+
params.update(get_params_from_config(c))
|
62 |
+
|
63 |
+
start = time.time()
|
64 |
+
metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params,
|
65 |
+
**extra_tuning_args)
|
66 |
+
print('Evaluation time: ', time.time() - start)
|
67 |
+
|
68 |
+
print(results_file)
|
69 |
+
r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
|
70 |
+
with open(results_file, 'wb') as output:
|
71 |
+
del r[0]['num_features_used']
|
72 |
+
del r[0]['categorical_features_sampler']
|
73 |
+
pickle.dump(r, output)
|
74 |
+
|
75 |
+
_, _, _, style, temperature, _ = r
|
76 |
+
|
77 |
+
return r, model
|
78 |
+
|
79 |
+
"""
|
80 |
+
===============================
|
81 |
+
INTERNAL HELPER FUNCTIONS
|
82 |
+
===============================
|
83 |
+
"""
|
84 |
+
|
85 |
+
def evaluate_differentiable_model(model
|
86 |
+
, valid_datasets
|
87 |
+
, test_datasets
|
88 |
+
, train_datasets
|
89 |
+
, N_draws=100
|
90 |
+
, N_grad_steps=10
|
91 |
+
, eval_positions=None
|
92 |
+
, eval_positions_test=None
|
93 |
+
, bptt=100
|
94 |
+
, bptt_final=200
|
95 |
+
, style=None
|
96 |
+
, n_parallel_configurations=1
|
97 |
+
, device='cpu'
|
98 |
+
, selection_metric='auc'
|
99 |
+
, final_splits=[1, 2, 3, 4, 5]
|
100 |
+
, N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
|
101 |
+
, **kwargs):
|
102 |
+
"""
|
103 |
+
Evaluation function for diffable model evaluation. Returns a list of results.
|
104 |
+
|
105 |
+
:param model:
|
106 |
+
:param valid_datasets:
|
107 |
+
:param test_datasets:
|
108 |
+
:param train_datasets:
|
109 |
+
:param N_draws:
|
110 |
+
:param N_grad_steps:
|
111 |
+
:param eval_positions:
|
112 |
+
:param eval_positions_test:
|
113 |
+
:param bptt:
|
114 |
+
:param bptt_final:
|
115 |
+
:param style:
|
116 |
+
:param n_parallel_configurations:
|
117 |
+
:param device:
|
118 |
+
:param selection_metric:
|
119 |
+
:param final_splits:
|
120 |
+
:param N_ensemble_configurations_list:
|
121 |
+
:param kwargs:
|
122 |
+
:return:
|
123 |
+
"""
|
124 |
+
torch.manual_seed(0)
|
125 |
+
np.random.seed(0)
|
126 |
+
random.seed(0)
|
127 |
+
|
128 |
+
diffable_metric = tabular_metrics.cross_entropy
|
129 |
+
evaluation_metric = tabular_metrics.auc_metric
|
130 |
+
if selection_metric in ('auc', 'roc'):
|
131 |
+
selection_metric_min_max = 'max'
|
132 |
+
selection_metric = tabular_metrics.auc_metric
|
133 |
+
evaluation_metric = selection_metric
|
134 |
+
elif selection_metric in ('ce', 'selection_metric'):
|
135 |
+
selection_metric_min_max = 'min'
|
136 |
+
selection_metric = tabular_metrics.cross_entropy
|
137 |
+
evaluation_metric = selection_metric
|
138 |
+
|
139 |
+
print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
|
140 |
+
evaluation_metric)
|
141 |
+
print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
|
142 |
+
print('eval_positions', eval_positions)
|
143 |
+
|
144 |
+
def evaluate_valid(style, softmax_temperature, results, results_tracked):
|
145 |
+
result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
|
146 |
+
return_tensor=False, inference_mode=True, selection_metric=selection_metric,
|
147 |
+
evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
|
148 |
+
result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
|
149 |
+
results += [result_valid]
|
150 |
+
results_tracked += [np.nanmean(result_valid)]
|
151 |
+
|
152 |
+
model[2].to(device)
|
153 |
+
model[2].eval()
|
154 |
+
|
155 |
+
results_on_valid, results_on_valid_tracked = [], []
|
156 |
+
best_style, best_softmax_temperature = style, torch.cat(
|
157 |
+
[torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
|
158 |
+
optimization_routes = []
|
159 |
+
|
160 |
+
best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
161 |
+
0)
|
162 |
+
best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
163 |
+
0)
|
164 |
+
|
165 |
+
|
166 |
+
for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
|
167 |
+
style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
168 |
+
0)
|
169 |
+
softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
170 |
+
0)
|
171 |
+
|
172 |
+
evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)
|
173 |
+
|
174 |
+
print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')
|
175 |
+
|
176 |
+
if N_grad_steps > 0:
|
177 |
+
gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
|
178 |
+
, softmax_temperature=softmax_temperature
|
179 |
+
, model=model[2]
|
180 |
+
, train_datasets=train_datasets
|
181 |
+
, valid_datasets=valid_datasets
|
182 |
+
, selection_metric_min_max=selection_metric_min_max
|
183 |
+
, **kwargs)
|
184 |
+
optimization_routes += [gradient_optimize_result['optimization_route']]
|
185 |
+
|
186 |
+
evaluate_valid(gradient_optimize_result['best_style']
|
187 |
+
, gradient_optimize_result['best_temperature']
|
188 |
+
, results_on_valid, results_on_valid_tracked)
|
189 |
+
|
190 |
+
print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')
|
191 |
+
|
192 |
+
if selection_metric_min_max == 'min':
|
193 |
+
is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
|
194 |
+
else:
|
195 |
+
is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))
|
196 |
+
|
197 |
+
if is_best or best_style is None:
|
198 |
+
best_style = gradient_optimize_result['best_style'].clone()
|
199 |
+
best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
|
200 |
+
torch.cuda.empty_cache()
|
201 |
+
|
202 |
+
def final_evaluation():
|
203 |
+
print('Running eval dataset with final params (no gradients)..')
|
204 |
+
print(best_style, best_softmax_temperature)
|
205 |
+
result_test = []
|
206 |
+
for N_ensemble_configurations in N_ensemble_configurations_list:
|
207 |
+
print(f'Running with {N_ensemble_configurations} ensemble_configurations')
|
208 |
+
kwargs['N_ensemble_configurations'] = N_ensemble_configurations
|
209 |
+
splits = []
|
210 |
+
for split in final_splits:
|
211 |
+
splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature
|
212 |
+
, return_tensor=False, eval_positions=eval_positions_test,
|
213 |
+
bptt=bptt_final, inference_mode=True, split_number=split, model=model[2]
|
214 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)]
|
215 |
+
result_test += [splits]
|
216 |
+
|
217 |
+
print('Running valid dataset with final params (no gradients)..')
|
218 |
+
result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature
|
219 |
+
, return_tensor=False, eval_positions=eval_positions_test,
|
220 |
+
bptt=bptt_final, inference_mode=True, model=model[2]
|
221 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)
|
222 |
+
|
223 |
+
return result_test, result_valid
|
224 |
+
|
225 |
+
result_test, result_valid = final_evaluation()
|
226 |
+
|
227 |
+
return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes
|
228 |
+
|
229 |
+
|
230 |
+
def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
|
231 |
+
def step():
|
232 |
+
return evaluate(datasets=ds,
|
233 |
+
method='transformer'
|
234 |
+
, overwrite=True
|
235 |
+
, style=used_style
|
236 |
+
, eval_positions=eval_positions
|
237 |
+
, metric_used=selection_metric
|
238 |
+
, save=False
|
239 |
+
, path_interfix=None
|
240 |
+
, base_path=None
|
241 |
+
, verbose=True
|
242 |
+
, **kwargs)
|
243 |
+
|
244 |
+
if return_tensor:
|
245 |
+
r = step()
|
246 |
+
else:
|
247 |
+
with torch.no_grad():
|
248 |
+
r = step()
|
249 |
+
|
250 |
+
calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean')
|
251 |
+
calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean')
|
252 |
+
|
253 |
+
return r
|
254 |
+
|
255 |
+
|
256 |
+
def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
|
257 |
+
limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
|
258 |
+
"""
|
259 |
+
Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
|
260 |
+
|
261 |
+
:param model:
|
262 |
+
:param init_style:
|
263 |
+
:param steps:
|
264 |
+
:param learning_rate:
|
265 |
+
:param softmax_temperature:
|
266 |
+
:param train_datasets:
|
267 |
+
:param valid_datasets:
|
268 |
+
:param optimize_all:
|
269 |
+
:param limit_style:
|
270 |
+
:param N_datasets_sampled:
|
271 |
+
:param optimize_softmax_temperature:
|
272 |
+
:param selection_metric_min_max:
|
273 |
+
:param kwargs:
|
274 |
+
:return:
|
275 |
+
"""
|
276 |
+
grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True)
|
277 |
+
|
278 |
+
best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None
|
279 |
+
softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature)
|
280 |
+
variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature]
|
281 |
+
optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate)
|
282 |
+
|
283 |
+
optimization_route_selection, optimization_route_diffable = [], []
|
284 |
+
optimization_route_selection_valid, optimization_route_diffable_valid = [], []
|
285 |
+
|
286 |
+
def eval_opt(ds, return_tensor=True, inference_mode=False):
|
287 |
+
result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
|
288 |
+
, inference_mode=inference_mode, model=model[2], **kwargs)
|
289 |
+
|
290 |
+
diffable_metric = result['mean_metric']
|
291 |
+
selection_metric = result['mean_select']
|
292 |
+
|
293 |
+
return diffable_metric, selection_metric
|
294 |
+
|
295 |
+
def eval_all_datasets(datasets, propagate=True):
|
296 |
+
selection_metrics_this_step, diffable_metrics_this_step = [], []
|
297 |
+
for ds in datasets:
|
298 |
+
diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate))
|
299 |
+
if not torch.isnan(diffable_metric_train).any():
|
300 |
+
if propagate and diffable_metric_train.requires_grad == True:
|
301 |
+
diffable_metric_train.backward()
|
302 |
+
selection_metrics_this_step += [selection_metric_train]
|
303 |
+
diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())]
|
304 |
+
diffable_metric_train = np.nanmean(diffable_metrics_this_step)
|
305 |
+
selection_metric_train = np.nanmean(selection_metrics_this_step)
|
306 |
+
|
307 |
+
return diffable_metric_train, selection_metric_train
|
308 |
+
|
309 |
+
for t in tqdm(range(steps), desc='Iterate over Optimization steps'):
|
310 |
+
optimizer.zero_grad()
|
311 |
+
|
312 |
+
# Select subset of datasets
|
313 |
+
random.seed(t)
|
314 |
+
train_datasets_ = random.sample(train_datasets, N_datasets_sampled)
|
315 |
+
|
316 |
+
# Get score on train
|
317 |
+
diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True)
|
318 |
+
optimization_route_selection += [float(selection_metric_train)]
|
319 |
+
optimization_route_diffable += [float(diffable_metric_train)]
|
320 |
+
|
321 |
+
# Get score on valid
|
322 |
+
diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False)
|
323 |
+
optimization_route_selection_valid += [float(selection_metric_valid)]
|
324 |
+
optimization_route_diffable_valid += [float(diffable_metric_valid)]
|
325 |
+
|
326 |
+
is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
|
327 |
+
is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
|
328 |
+
if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best):
|
329 |
+
print('New best', best_selection_metric, selection_metric_valid)
|
330 |
+
best_style = grad_style.detach().clone()
|
331 |
+
best_temperature = softmax_temperature.detach().clone()
|
332 |
+
best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid
|
333 |
+
|
334 |
+
optimizer.step()
|
335 |
+
|
336 |
+
if limit_style:
|
337 |
+
grad_style = grad_style.detach().clamp(-1.74, 1.74)
|
338 |
+
|
339 |
+
print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' +
|
340 |
+
f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}')
|
341 |
+
|
342 |
+
print(f'Return best:{best_style} {best_selection_metric}')
|
343 |
+
return {'best_style': best_style, 'best_temperature': best_temperature
|
344 |
+
, 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable,
|
345 |
+
'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}
|
TabPFN/encoders.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from utils import normalize_data
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
+
|
9 |
+
|
10 |
+
class StyleEncoder(nn.Module):
|
11 |
+
def __init__(self, em_size, hyperparameter_definitions):
|
12 |
+
super().__init__()
|
13 |
+
# self.embeddings = {}
|
14 |
+
self.em_size = em_size
|
15 |
+
# self.hyperparameter_definitions = {}
|
16 |
+
# for hp in hyperparameter_definitions:
|
17 |
+
# self.embeddings[hp] = nn.Linear(1, self.em_size)
|
18 |
+
# self.embeddings = nn.ModuleDict(self.embeddings)
|
19 |
+
self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
|
20 |
+
|
21 |
+
def forward(self, hyperparameters): # T x B x num_features
|
22 |
+
# Make faster by using matrices
|
23 |
+
# sampled_embeddings = [torch.stack([
|
24 |
+
# self.embeddings[hp](torch.tensor([batch[hp]], device=self.embeddings[hp].weight.device, dtype=torch.float))
|
25 |
+
# for hp in batch
|
26 |
+
# ], -1).sum(-1) for batch in hyperparameters]
|
27 |
+
# return torch.stack(sampled_embeddings, 0)
|
28 |
+
return self.embedding(hyperparameters)
|
29 |
+
|
30 |
+
|
31 |
+
class _PositionalEncoding(nn.Module):
|
32 |
+
def __init__(self, d_model, dropout=0.):
|
33 |
+
super().__init__()
|
34 |
+
self.dropout = nn.Dropout(p=dropout)
|
35 |
+
self.d_model = d_model
|
36 |
+
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
|
37 |
+
|
38 |
+
def forward(self, x):# T x B x num_features
|
39 |
+
assert self.d_model % x.shape[-1]*2 == 0
|
40 |
+
d_per_feature = self.d_model // x.shape[-1]
|
41 |
+
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
42 |
+
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
43 |
+
interval_size = 10
|
44 |
+
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
|
45 |
+
#print(div_term/2/math.pi)
|
46 |
+
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
47 |
+
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
48 |
+
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
|
49 |
+
|
50 |
+
|
51 |
+
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
52 |
+
|
53 |
+
class EmbeddingEncoder(nn.Module):
|
54 |
+
def __init__(self, num_features, em_size, num_embs=100):
|
55 |
+
super().__init__()
|
56 |
+
self.num_embs = num_embs
|
57 |
+
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
58 |
+
self.init_weights(.1)
|
59 |
+
self.min_max = (-2,+2)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def width(self):
|
63 |
+
return self.min_max[1] - self.min_max[0]
|
64 |
+
|
65 |
+
def init_weights(self, initrange):
|
66 |
+
self.embeddings.weight.data.uniform_(-initrange, initrange)
|
67 |
+
|
68 |
+
def discretize(self, x):
|
69 |
+
split_size = self.width / self.num_embs
|
70 |
+
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
|
71 |
+
|
72 |
+
def forward(self, x): # T x B x num_features
|
73 |
+
x_idxs = self.discretize(x)
|
74 |
+
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
75 |
+
# print(x_idxs,self.embeddings.weight.shape)
|
76 |
+
return self.embeddings(x_idxs).mean(-2)
|
77 |
+
|
78 |
+
|
79 |
+
class Normalize(nn.Module):
|
80 |
+
def __init__(self, mean, std):
|
81 |
+
super().__init__()
|
82 |
+
self.mean = mean
|
83 |
+
self.std = std
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return (x-self.mean)/self.std
|
87 |
+
|
88 |
+
|
89 |
+
def get_normalized_uniform_encoder(encoder_creator):
|
90 |
+
"""
|
91 |
+
This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
|
92 |
+
For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
|
93 |
+
be initialized with `encoder_creator(feature_dim, in_dim)`.
|
94 |
+
:param encoder:
|
95 |
+
:return:
|
96 |
+
"""
|
97 |
+
return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
|
98 |
+
|
99 |
+
|
100 |
+
Linear = nn.Linear
|
101 |
+
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Linear(emsize*2,emsize))
|
104 |
+
|
105 |
+
class NanHandlingEncoder(nn.Module):
|
106 |
+
def __init__(self, num_features, emsize, keep_nans=True):
|
107 |
+
super().__init__()
|
108 |
+
self.num_features = 2 * num_features if keep_nans else num_features
|
109 |
+
self.emsize = emsize
|
110 |
+
self.keep_nans = keep_nans
|
111 |
+
self.layer = nn.Linear(self.num_features, self.emsize)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
if self.keep_nans:
|
115 |
+
x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
|
116 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
|
117 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
|
118 |
+
)], -1)
|
119 |
+
else:
|
120 |
+
x = torch.nan_to_num(x, nan=0.0)
|
121 |
+
return self.layer(x)
|
122 |
+
|
123 |
+
class Linear(nn.Linear):
|
124 |
+
def __init__(self, num_features, emsize):
|
125 |
+
super().__init__(num_features, emsize)
|
126 |
+
self.num_features = num_features
|
127 |
+
self.emsize = emsize
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
x = torch.nan_to_num(x, nan=0.0)
|
131 |
+
return super().forward(x)
|
132 |
+
|
133 |
+
class SequenceSpanningEncoder(nn.Module):
|
134 |
+
# Regular Encoder transforms Seq_len, B, S -> Seq_len, B, E attending only to last dimension
|
135 |
+
# This Encoder accesses the Seq_Len dimension additionally
|
136 |
+
|
137 |
+
# Why would we want this? We can learn normalization and embedding of features
|
138 |
+
# , this might be more important for e.g. categorical, ordinal feats, nan detection
|
139 |
+
# However maybe this can be easily learned through transformer as well?
|
140 |
+
# A problem is to make this work across any sequence length and be independent of ordering
|
141 |
+
|
142 |
+
# We could use average and maximum pooling and use those with a linear layer
|
143 |
+
|
144 |
+
|
145 |
+
# Another idea !! Similar to this we would like to encode features so that their number is variable
|
146 |
+
# We would like to embed features, also using knowledge of the features in the entire sequence
|
147 |
+
|
148 |
+
# We could use convolution or another transformer
|
149 |
+
# Convolution:
|
150 |
+
|
151 |
+
# Transformer/Conv across sequence dimension that encodes and normalizes features
|
152 |
+
# -> Transformer across feature dimension that encodes features to a constant size
|
153 |
+
|
154 |
+
# Conv with flexible features but no sequence info: S,B,F -(reshape)-> S*B,1,F
|
155 |
+
# -(Conv1d)-> S*B,N,F -(AvgPool,MaxPool)-> S*B,N,1 -> S,B,N
|
156 |
+
# This probably won't work since it's missing a way to recognize which feature is encoded
|
157 |
+
|
158 |
+
# Transformer with flexible features: S,B,F -> F,B*S,1 -> F2,B*S,1 -> S,B,F2
|
159 |
+
|
160 |
+
def __init__(self, num_features, em_size):
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
raise NotImplementedError()
|
164 |
+
# Seq_len, B, S -> Seq_len, B, E
|
165 |
+
#
|
166 |
+
self.convs = torch.nn.ModuleList([nn.Conv1d(64 if i else 1, 64, 3) for i in range(5)])
|
167 |
+
# self.linear = nn.Linear(64, emsize)
|
168 |
+
|
169 |
+
class TransformerBasedFeatureEncoder(nn.Module):
|
170 |
+
def __init__(self, num_features, emsize):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
hidden_emsize = emsize
|
174 |
+
encoder = Linear(1, hidden_emsize)
|
175 |
+
n_out = emsize
|
176 |
+
nhid = 2*emsize
|
177 |
+
dropout =0.0
|
178 |
+
nhead=4
|
179 |
+
nlayers=4
|
180 |
+
model = nn.Transformer(nhead=nhead, num_encoder_layers=4, num_decoder_layers=4, d_model=1)
|
181 |
+
|
182 |
+
def forward(self, *input):
|
183 |
+
# S,B,F -> F,S*B,1 -> F2,S*B,1 -> S,B,F2
|
184 |
+
input = input.transpose()
|
185 |
+
self.model(input)
|
186 |
+
|
187 |
+
class Conv(nn.Module):
|
188 |
+
def __init__(self, input_size, emsize):
|
189 |
+
super().__init__()
|
190 |
+
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
|
191 |
+
self.linear = nn.Linear(64,emsize)
|
192 |
+
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
size = math.isqrt(x.shape[-1])
|
196 |
+
assert size*size == x.shape[-1]
|
197 |
+
x = x.reshape(*x.shape[:-1], 1, size, size)
|
198 |
+
for conv in self.convs:
|
199 |
+
if x.shape[-1] < 4:
|
200 |
+
break
|
201 |
+
x = conv(x)
|
202 |
+
x.relu_()
|
203 |
+
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
|
204 |
+
return self.linear(x)
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
class CanEmb(nn.Embedding):
|
210 |
+
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
|
211 |
+
assert embedding_dim % num_features == 0
|
212 |
+
embedding_dim = embedding_dim // num_features
|
213 |
+
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
lx = x.long()
|
217 |
+
assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
|
218 |
+
x = super().forward(lx)
|
219 |
+
return x.view(*x.shape[:-2], -1)
|
220 |
+
|
221 |
+
def get_Canonical(num_classes):
|
222 |
+
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
|
223 |
+
|
224 |
+
def get_Embedding(num_embs_per_feature=100):
|
225 |
+
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
|
TabPFN/initializers.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
def get_NormalInitializer(std):
|
5 |
+
def initializer(m):
|
6 |
+
if isinstance(m, nn.Linear):
|
7 |
+
nn.init.normal_(m.weight, 0, std)
|
8 |
+
nn.init.normal_(m.bias, 0, std)
|
9 |
+
return initializer
|
TabPFN/layer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn.modules.transformer import *
|
5 |
+
from torch.nn.modules.transformer import _get_activation_fn
|
6 |
+
|
7 |
+
from torch.utils.checkpoint import checkpoint
|
8 |
+
|
9 |
+
|
10 |
+
class TransformerEncoderLayer(Module):
|
11 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
12 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
13 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
14 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
15 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
16 |
+
in a different way during application.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
d_model: the number of expected features in the input (required).
|
20 |
+
nhead: the number of heads in the multiheadattention models (required).
|
21 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
22 |
+
dropout: the dropout value (default=0.1).
|
23 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
24 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
25 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
26 |
+
as (batch, seq, feature). Default: ``False``.
|
27 |
+
|
28 |
+
Examples::
|
29 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
30 |
+
>>> src = torch.rand(10, 32, 512)
|
31 |
+
>>> out = encoder_layer(src)
|
32 |
+
|
33 |
+
Alternatively, when ``batch_first`` is ``True``:
|
34 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
35 |
+
>>> src = torch.rand(32, 10, 512)
|
36 |
+
>>> out = encoder_layer(src)
|
37 |
+
"""
|
38 |
+
__constants__ = ['batch_first']
|
39 |
+
|
40 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
41 |
+
layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
|
42 |
+
device=None, dtype=None, recompute_attn=False) -> None:
|
43 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
44 |
+
super().__init__()
|
45 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
46 |
+
**factory_kwargs)
|
47 |
+
# Implementation of Feedforward model
|
48 |
+
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
49 |
+
self.dropout = Dropout(dropout)
|
50 |
+
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
51 |
+
|
52 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
53 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
54 |
+
self.dropout1 = Dropout(dropout)
|
55 |
+
self.dropout2 = Dropout(dropout)
|
56 |
+
self.pre_norm = pre_norm
|
57 |
+
self.recompute_attn = recompute_attn
|
58 |
+
|
59 |
+
self.activation = _get_activation_fn(activation)
|
60 |
+
|
61 |
+
def __setstate__(self, state):
|
62 |
+
if 'activation' not in state:
|
63 |
+
state['activation'] = F.relu
|
64 |
+
super().__setstate__(state)
|
65 |
+
|
66 |
+
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
67 |
+
r"""Pass the input through the encoder layer.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
src: the sequence to the encoder layer (required).
|
71 |
+
src_mask: the mask for the src sequence (optional).
|
72 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
73 |
+
|
74 |
+
Shape:
|
75 |
+
see the docs in Transformer class.
|
76 |
+
"""
|
77 |
+
if self.pre_norm:
|
78 |
+
src_ = self.norm1(src)
|
79 |
+
else:
|
80 |
+
src_ = src
|
81 |
+
if isinstance(src_mask, tuple):
|
82 |
+
# global attention setup
|
83 |
+
assert not self.self_attn.batch_first
|
84 |
+
assert src_key_padding_mask is None
|
85 |
+
|
86 |
+
global_src_mask, trainset_src_mask, valset_src_mask = src_mask
|
87 |
+
|
88 |
+
num_global_tokens = global_src_mask.shape[0]
|
89 |
+
num_train_tokens = trainset_src_mask.shape[0]
|
90 |
+
|
91 |
+
global_tokens_src = src_[:num_global_tokens]
|
92 |
+
train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
|
93 |
+
global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
|
94 |
+
eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
|
95 |
+
|
96 |
+
|
97 |
+
attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
|
98 |
+
|
99 |
+
global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
|
100 |
+
train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
|
101 |
+
eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
|
102 |
+
None, True, valset_src_mask)[0]
|
103 |
+
|
104 |
+
src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
|
105 |
+
|
106 |
+
else:
|
107 |
+
if self.recompute_attn:
|
108 |
+
src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
|
109 |
+
else:
|
110 |
+
src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
|
111 |
+
key_padding_mask=src_key_padding_mask)[0]
|
112 |
+
src = src + self.dropout1(src2)
|
113 |
+
if not self.pre_norm:
|
114 |
+
src = self.norm1(src)
|
115 |
+
|
116 |
+
if self.pre_norm:
|
117 |
+
src_ = self.norm2(src)
|
118 |
+
else:
|
119 |
+
src_ = src
|
120 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
|
121 |
+
src = src + self.dropout2(src2)
|
122 |
+
|
123 |
+
if not self.pre_norm:
|
124 |
+
src = self.norm2(src)
|
125 |
+
return src
|
TabPFN/losses.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class CrossEntropyForMulticlassLoss(torch.nn.CrossEntropyLoss):
|
5 |
+
# This loss applies cross entropy after reducing the number of prediction
|
6 |
+
# dimensions to the number of classes in the target
|
7 |
+
|
8 |
+
# TODO: loss.item() doesn't work so the displayed losses are Nans
|
9 |
+
def __init__(self, num_classes, weight=None, size_average=None, ignore_index: int = -100,
|
10 |
+
reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None:
|
11 |
+
super().__init__(size_average=size_average, reduce=reduce, reduction=reduction, ignore_index=ignore_index)
|
12 |
+
self.num_classes = num_classes
|
13 |
+
|
14 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
15 |
+
loss = torch.zeros_like(input[:, :, 0])
|
16 |
+
for b in range(target.shape[1]):
|
17 |
+
l = super().forward(input[:, b, 0:len(torch.unique(target[:, b]))], target[:, b])
|
18 |
+
loss[:, b] += l
|
19 |
+
return loss.flatten()
|
20 |
+
|
21 |
+
def JointBCELossWithLogits(output, target):
|
22 |
+
# output shape: (S, B, NS) with NS = Number of sequences
|
23 |
+
# target shape: (S, B, SL)
|
24 |
+
# Loss = -log(mean_NS(prod_SL(p(target_SL, output_NS))))
|
25 |
+
# Here at the moment NS = SL
|
26 |
+
output = output.unsqueeze(-1).repeat(1, 1, 1, target.shape[-1]) # (S, B, NS, SL)
|
27 |
+
output = output.permute(2, 0, 1, 3) # (NS, S, B, SL)
|
28 |
+
print(target.shape, output.shape)
|
29 |
+
loss = (target * torch.sigmoid(output)) + ((1-target) * (1-torch.sigmoid(output)))
|
30 |
+
loss = loss.prod(-1)
|
31 |
+
loss = loss.mean(0)
|
32 |
+
loss = -torch.log(loss)
|
33 |
+
loss = loss.mean()
|
34 |
+
return loss
|
35 |
+
|
36 |
+
class ScaledSoftmaxCE(nn.Module):
|
37 |
+
def forward(self, x, label):
|
38 |
+
logits = x[..., :-10]
|
39 |
+
temp_scales = x[..., -10:]
|
40 |
+
|
41 |
+
logprobs = logits.softmax(-1)
|
TabPFN/model_builder.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from train import train, Losses
|
2 |
+
import priors
|
3 |
+
import encoders
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from priors.utils import trunc_norm_sampler_f, gamma_sampler_f
|
8 |
+
from utils import get_uniform_single_eval_pos_sampler
|
9 |
+
import torch
|
10 |
+
import math
|
11 |
+
|
12 |
+
def save_model(model, path, filename, config_sample):
|
13 |
+
config_sample = {**config_sample}
|
14 |
+
|
15 |
+
def make_serializable(config_sample):
|
16 |
+
if isinstance(config_sample, dict):
|
17 |
+
config_sample = {k: make_serializable(config_sample[k]) for k in config_sample}
|
18 |
+
if isinstance(config_sample, list):
|
19 |
+
config_sample = [make_serializable(v) for v in config_sample]
|
20 |
+
if callable(config_sample):
|
21 |
+
config_sample = str(config_sample)
|
22 |
+
return config_sample
|
23 |
+
|
24 |
+
#if 'num_features_used' in config_sample:
|
25 |
+
# del config_sample['num_features_used']
|
26 |
+
|
27 |
+
#config_sample['num_classes_as_str'] = str(config_sample['num_classes'])
|
28 |
+
#del config_sample['num_classes']
|
29 |
+
|
30 |
+
config_sample = make_serializable(config_sample)
|
31 |
+
|
32 |
+
torch.save((model.state_dict(), None, config_sample), os.path.join(path, filename))
|
33 |
+
|
34 |
+
|
35 |
+
import subprocess as sp
|
36 |
+
import os
|
37 |
+
|
38 |
+
def get_gpu_memory():
|
39 |
+
command = "nvidia-smi"
|
40 |
+
memory_free_info = sp.check_output(command.split()).decode('ascii')
|
41 |
+
return memory_free_info
|
42 |
+
|
43 |
+
|
44 |
+
def load_model(path, filename, device, eval_positions, verbose):
|
45 |
+
# TODO: This function only restores evaluation functionality but training canät be continued. It is also not flexible.
|
46 |
+
|
47 |
+
model_state, optimizer_state, config_sample = torch.load(
|
48 |
+
os.path.join(path, filename), map_location='cpu')
|
49 |
+
if ('differentiable_hyperparameters' in config_sample
|
50 |
+
and 'prior_mlp_activations' in config_sample['differentiable_hyperparameters']):
|
51 |
+
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values_used'] = config_sample[
|
52 |
+
'differentiable_hyperparameters'][
|
53 |
+
'prior_mlp_activations'][
|
54 |
+
'choice_values']
|
55 |
+
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values'] = [
|
56 |
+
torch.nn.Tanh for k in config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values']]
|
57 |
+
|
58 |
+
config_sample['categorical_features_sampler'] = lambda: lambda x: ([], [], [])
|
59 |
+
config_sample['num_features_used_in_training'] = config_sample['num_features_used']
|
60 |
+
config_sample['num_features_used'] = lambda: config_sample['num_features']
|
61 |
+
config_sample['num_classes_in_training'] = config_sample['num_classes']
|
62 |
+
config_sample['num_classes'] = 2
|
63 |
+
config_sample['batch_size_in_training'] = config_sample['batch_size']
|
64 |
+
config_sample['batch_size'] = 1
|
65 |
+
config_sample['bptt_in_training'] = config_sample['bptt']
|
66 |
+
config_sample['bptt'] = 10
|
67 |
+
config_sample['bptt_extra_samples_in_training'] = config_sample['bptt_extra_samples']
|
68 |
+
config_sample['bptt_extra_samples'] = None
|
69 |
+
|
70 |
+
#print('Memory', str(get_gpu_memory()))
|
71 |
+
|
72 |
+
model = get_model(config_sample, device=device, should_train=False, verbose=verbose)
|
73 |
+
module_prefix = 'module.'
|
74 |
+
model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
|
75 |
+
model[2].load_state_dict(model_state)
|
76 |
+
model[2].to(device)
|
77 |
+
|
78 |
+
return model, config_sample
|
79 |
+
|
80 |
+
def fix_loaded_config_sample(loaded_config_sample, config):
|
81 |
+
def copy_to_sample(*k):
|
82 |
+
t,s = loaded_config_sample, config
|
83 |
+
for k_ in k[:-1]:
|
84 |
+
t = t[k_]
|
85 |
+
s = s[k_]
|
86 |
+
t[k[-1]] = s[k[-1]]
|
87 |
+
copy_to_sample('num_features_used')
|
88 |
+
copy_to_sample('num_classes')
|
89 |
+
copy_to_sample('differentiable_hyperparameters','prior_mlp_activations','choice_values')
|
90 |
+
|
91 |
+
def load_config_sample(path, template_config):
|
92 |
+
model_state, optimizer_state, loaded_config_sample = torch.load(path, map_location='cpu')
|
93 |
+
fix_loaded_config_sample(loaded_config_sample, template_config)
|
94 |
+
return loaded_config_sample
|
95 |
+
|
96 |
+
def get_default_spec(test_datasets, valid_datasets):
|
97 |
+
bptt = 10000
|
98 |
+
eval_positions = [1000, 2000, 3000, 4000, 5000] # list(2 ** np.array([4, 5, 6, 7, 8, 9, 10, 11, 12]))
|
99 |
+
max_features = max([X.shape[1] for (_, X, _, _, _, _) in test_datasets] + [X.shape[1] for (_, X, _, _, _, _) in valid_datasets])
|
100 |
+
max_splits = 5
|
101 |
+
|
102 |
+
return bptt, eval_positions, max_features, max_splits
|
103 |
+
|
104 |
+
def get_mlp_prior_hyperparameters(config):
|
105 |
+
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
106 |
+
|
107 |
+
if "prior_sigma_gamma_k" in config:
|
108 |
+
sigma_sampler = gamma_sampler_f(config["prior_sigma_gamma_k"], config["prior_sigma_gamma_theta"])
|
109 |
+
config['init_std'] = sigma_sampler
|
110 |
+
if "prior_noise_std_gamma_k" in config:
|
111 |
+
noise_std_sampler = gamma_sampler_f(config["prior_noise_std_gamma_k"], config["prior_noise_std_gamma_theta"])
|
112 |
+
config['noise_std'] = noise_std_sampler
|
113 |
+
|
114 |
+
return config
|
115 |
+
|
116 |
+
|
117 |
+
def get_gp_mix_prior_hyperparameters(config):
|
118 |
+
return {'lengthscale_concentration': config["prior_lengthscale_concentration"],
|
119 |
+
'nu': config["prior_nu"],
|
120 |
+
'outputscale_concentration': config["prior_outputscale_concentration"],
|
121 |
+
'categorical_data': config["prior_y_minmax_norm"],
|
122 |
+
'y_minmax_norm': config["prior_lengthscale_concentration"],
|
123 |
+
'noise_concentration': config["prior_noise_concentration"],
|
124 |
+
'noise_rate': config["prior_noise_rate"]}
|
125 |
+
|
126 |
+
def get_gp_prior_hyperparameters(config):
|
127 |
+
return {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
128 |
+
|
129 |
+
|
130 |
+
def get_meta_gp_prior_hyperparameters(config):
|
131 |
+
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
132 |
+
|
133 |
+
if "outputscale_mean" in config:
|
134 |
+
outputscale_sampler = trunc_norm_sampler_f(config["outputscale_mean"]
|
135 |
+
, config["outputscale_mean"] * config["outputscale_std_f"])
|
136 |
+
config['outputscale'] = outputscale_sampler
|
137 |
+
if "lengthscale_mean" in config:
|
138 |
+
lengthscale_sampler = trunc_norm_sampler_f(config["lengthscale_mean"],
|
139 |
+
config["lengthscale_mean"] * config["lengthscale_std_f"])
|
140 |
+
config['lengthscale'] = lengthscale_sampler
|
141 |
+
|
142 |
+
return config
|
143 |
+
|
144 |
+
|
145 |
+
def get_model(config, device, should_train=True, verbose=False, state_dict=None, epoch_callback=None):
|
146 |
+
extra_kwargs = {}
|
147 |
+
verbose_train, verbose_prior = verbose >= 1, verbose >= 2
|
148 |
+
config['verbose'] = verbose_prior
|
149 |
+
|
150 |
+
if 'aggregate_k_gradients' not in config or config['aggregate_k_gradients'] is None:
|
151 |
+
config['aggregate_k_gradients'] = math.ceil(config['batch_size'] * ((config['nlayers'] * config['emsize'] * config['bptt'] * config['bptt']) / 10824640000))
|
152 |
+
|
153 |
+
config['num_steps'] = math.ceil(config['num_steps'] * config['aggregate_k_gradients'])
|
154 |
+
config['batch_size'] = math.ceil(config['batch_size'] / config['aggregate_k_gradients'])
|
155 |
+
config['recompute_attn'] = config['recompute_attn'] if 'recompute_attn' in config else False
|
156 |
+
|
157 |
+
def make_get_batch(model_proto, **extra_kwargs):
|
158 |
+
extra_kwargs = defaultdict(lambda: None, **extra_kwargs)
|
159 |
+
return (lambda batch_size, seq_len, num_features, hyperparameters
|
160 |
+
, device, model_proto=model_proto, get_batch=extra_kwargs['get_batch']
|
161 |
+
, prior_bag_priors=extra_kwargs['prior_bag_priors']: model_proto.get_batch(
|
162 |
+
batch_size=batch_size
|
163 |
+
, seq_len=seq_len
|
164 |
+
, device=device
|
165 |
+
, get_batch=get_batch
|
166 |
+
, hyperparameters=hyperparameters
|
167 |
+
, num_features=num_features))
|
168 |
+
|
169 |
+
if config['prior_type'] == 'prior_bag':
|
170 |
+
# Prior bag combines priors
|
171 |
+
get_batch_gp = make_get_batch(priors.fast_gp)
|
172 |
+
get_batch_mlp = make_get_batch(priors.mlp)
|
173 |
+
if 'flexible' in config and config['flexible']:
|
174 |
+
get_batch_gp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_gp})
|
175 |
+
get_batch_mlp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_mlp})
|
176 |
+
prior_bag_hyperparameters = {'prior_bag_get_batch': (get_batch_gp, get_batch_mlp)
|
177 |
+
, 'prior_bag_exp_weights_1': 2.0}
|
178 |
+
prior_hyperparameters = {**get_mlp_prior_hyperparameters(config), **get_gp_prior_hyperparameters(config)
|
179 |
+
, **prior_bag_hyperparameters}
|
180 |
+
model_proto = priors.prior_bag
|
181 |
+
else:
|
182 |
+
if config['prior_type'] == 'mlp':
|
183 |
+
prior_hyperparameters = get_mlp_prior_hyperparameters(config)
|
184 |
+
model_proto = priors.mlp
|
185 |
+
elif config['prior_type'] == 'gp':
|
186 |
+
prior_hyperparameters = get_gp_prior_hyperparameters(config)
|
187 |
+
model_proto = priors.fast_gp
|
188 |
+
elif config['prior_type'] == 'gp_mix':
|
189 |
+
prior_hyperparameters = get_gp_mix_prior_hyperparameters(config)
|
190 |
+
model_proto = priors.fast_gp_mix
|
191 |
+
else:
|
192 |
+
raise Exception()
|
193 |
+
|
194 |
+
if 'flexible' in config and config['flexible']:
|
195 |
+
get_batch_base = make_get_batch(model_proto)
|
196 |
+
extra_kwargs['get_batch'] = get_batch_base
|
197 |
+
model_proto = priors.flexible_categorical
|
198 |
+
|
199 |
+
use_style = False
|
200 |
+
|
201 |
+
if 'differentiable' in config and config['differentiable']:
|
202 |
+
get_batch_base = make_get_batch(model_proto, **extra_kwargs)
|
203 |
+
extra_kwargs = {'get_batch': get_batch_base, 'differentiable_hyperparameters': config['differentiable_hyperparameters']}
|
204 |
+
model_proto = priors.differentiable_prior
|
205 |
+
use_style = True
|
206 |
+
print(f"Using style prior: {use_style}")
|
207 |
+
|
208 |
+
if (('nan_prob_no_reason' in config and config['nan_prob_no_reason'] > 0.0) or
|
209 |
+
('nan_prob_a_reason' in config and config['nan_prob_a_reason'] > 0.0) or
|
210 |
+
('nan_prob_unknown_reason' in config and config['nan_prob_unknown_reason'] > 0.0)):
|
211 |
+
encoder = encoders.NanHandlingEncoder
|
212 |
+
else:
|
213 |
+
encoder = encoders.Linear
|
214 |
+
|
215 |
+
num_outputs = config['num_outputs'] if 'num_outputs' in config else 1
|
216 |
+
if config['max_num_classes'] == 2:
|
217 |
+
if 'joint_loss' in config and config['joint_loss']:
|
218 |
+
loss = JointBCELossWithLogits
|
219 |
+
else:
|
220 |
+
loss = Losses.bce
|
221 |
+
elif config['max_num_classes'] > 2:
|
222 |
+
loss = Losses.ce(torch.ones((config['max_num_classes'])))
|
223 |
+
else:
|
224 |
+
loss = BarDistribution(borders=get_bucket_limits(500, full_range=(-10, 10)))
|
225 |
+
|
226 |
+
aggregate_k_gradients = 1 if 'aggregate_k_gradients' not in config else config['aggregate_k_gradients']
|
227 |
+
check_is_compatible = False if 'multiclass_loss_type' not in config else (config['multiclass_loss_type'] == 'compatible')
|
228 |
+
config['multiclass_type'] = config['multiclass_type'] if 'multiclass_type' in config else 'rank'
|
229 |
+
config['mix_activations'] = config['mix_activations'] if 'mix_activations' in config else False
|
230 |
+
|
231 |
+
config['bptt_extra_samples'] = config['bptt_extra_samples'] if 'bptt_extra_samples' in config else None
|
232 |
+
config['eval_positions'] = [int(config['bptt'] * 0.95)] if config['bptt_extra_samples'] is None else [int(config['bptt'])]
|
233 |
+
|
234 |
+
epochs = 0 if not should_train else config['epochs']
|
235 |
+
model = train(model_proto.DataLoader
|
236 |
+
, loss
|
237 |
+
, encoder
|
238 |
+
, style_encoder_generator = encoders.StyleEncoder if use_style else None
|
239 |
+
, emsize=config['emsize']
|
240 |
+
, nhead=config['nhead']
|
241 |
+
, y_encoder_generator= encoders.get_Canonical(config['max_num_classes']) if config.get('canonical_y_encoder', False) else encoders.Linear
|
242 |
+
, pos_encoder_generator=None
|
243 |
+
, batch_size=config['batch_size']
|
244 |
+
, nlayers=config['nlayers']
|
245 |
+
, nhid=config['emsize'] * config['nhid_factor']
|
246 |
+
, epochs=epochs
|
247 |
+
, total_available_time_in_s=config.get('total_available_time_in_s', None)
|
248 |
+
, warmup_epochs=20
|
249 |
+
, bptt=config['bptt']
|
250 |
+
, gpu_device=device
|
251 |
+
, dropout=config['dropout']
|
252 |
+
, steps_per_epoch=config['num_steps']
|
253 |
+
, single_eval_pos_gen=get_uniform_single_eval_pos_sampler(config['bptt'])
|
254 |
+
, load_weights_from_this_state_dict=state_dict
|
255 |
+
, aggregate_k_gradients=aggregate_k_gradients
|
256 |
+
, check_is_compatible=check_is_compatible
|
257 |
+
, recompute_attn=config['recompute_attn']
|
258 |
+
, epoch_callback=epoch_callback
|
259 |
+
, bptt_extra_samples = config['bptt_extra_samples']
|
260 |
+
, extra_prior_kwargs_dict={
|
261 |
+
'num_features': config['num_features']
|
262 |
+
, 'fuse_x_y': False
|
263 |
+
, 'hyperparameters': prior_hyperparameters
|
264 |
+
, 'num_outputs':num_outputs
|
265 |
+
, 'dynamic_batch_size': 1 if ('num_global_att_tokens' in config and config['num_global_att_tokens']) else 2
|
266 |
+
, **extra_kwargs
|
267 |
+
}
|
268 |
+
, lr=config['lr']
|
269 |
+
, verbose=verbose_train,
|
270 |
+
weight_decay=config.get('weight_decay', 0.0),
|
271 |
+
normalize_labels=True)
|
272 |
+
|
273 |
+
return model
|
TabPFN/models_diff/gp_ablation_model.cpkt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7b0c8febc553cca3fdee265b5a1cd7567dbf83da855969940be4707a9218ffb
|
3 |
+
size 69460013
|
TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dae97f45bd53d719fc2b23fac4ec55eab16d63892196d939b1bb1c3b408be242
|
3 |
+
size 103616779
|
TabPFN/notebook_utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import io
|
5 |
+
import torch
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
def print_models(base_path, model_string):
|
9 |
+
print(model_string)
|
10 |
+
|
11 |
+
for i in range(80):
|
12 |
+
for e in range(50):
|
13 |
+
exists = Path(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt')).is_file()
|
14 |
+
if exists:
|
15 |
+
print(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt'))
|
16 |
+
print()
|
17 |
+
|
18 |
+
class CustomUnpickler(pickle.Unpickler):
|
19 |
+
def find_class(self, module, name):
|
20 |
+
if name == 'Manager':
|
21 |
+
from settings import Manager
|
22 |
+
return Manager
|
23 |
+
try:
|
24 |
+
return self.find_class_cpu(module, name)
|
25 |
+
except:
|
26 |
+
return None
|
27 |
+
|
28 |
+
def find_class_cpu(self, module, name):
|
29 |
+
if module == 'torch.storage' and name == '_load_from_bytes':
|
30 |
+
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
|
31 |
+
else:
|
32 |
+
return super().find_class(module, name)
|
TabPFN/positional_encodings.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
# Protocol for positonal encodings.
|
8 |
+
# __init__(d_model, max_len=..[, more optionals])
|
9 |
+
# forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
|
10 |
+
|
11 |
+
|
12 |
+
class NoPositionalEncoding(nn.Module):
|
13 |
+
def __init__(self, d_model, max_len=None):
|
14 |
+
super(NoPositionalEncoding, self).__init__()
|
15 |
+
pass
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return x #* math.sqrt(x.shape[-1])
|
19 |
+
|
20 |
+
|
21 |
+
class PositionalEncoding(nn.Module):
|
22 |
+
def __init__(self, d_model, max_len=5000):
|
23 |
+
super(PositionalEncoding, self).__init__()
|
24 |
+
pe = torch.zeros(max_len, d_model)
|
25 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
26 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
27 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
28 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
29 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
30 |
+
self.register_buffer('pe', pe)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.pe[:x.size(0), :] + x # * math.sqrt(x.shape[-1])
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class LearnedPositionalEncoding(nn.Module):
|
38 |
+
def __init__(self, d_model, max_len=5000):
|
39 |
+
super(LearnedPositionalEncoding, self).__init__()
|
40 |
+
self.max_seq_len = max_len
|
41 |
+
#self.positional_embeddings = nn.Embedding(max_len, d_model)
|
42 |
+
self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
|
43 |
+
nn.init.normal_(self.positional_embeddings, mean=0, std=d_model ** -0.5)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
seq_len, bs, d_model = x.shape
|
47 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
48 |
+
pos_emb = self.positional_embeddings[:seq_len]
|
49 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
50 |
+
|
51 |
+
|
52 |
+
class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
|
53 |
+
# TODO check whether it is a problem to use the same perm. for full batch
|
54 |
+
def forward(self, x):
|
55 |
+
seq_len, bs, d_model = x.shape
|
56 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
57 |
+
assert len(self.positional_embeddings) % 2 == 0, 'Please specify an even max_len.'
|
58 |
+
|
59 |
+
paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2)
|
60 |
+
pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(*self.positional_embeddings.shape)[:seq_len]
|
61 |
+
|
62 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
TabPFN/prior_tuning_result.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:24d2189bbc836aeea888cf6c540f2c1b45b5351822931189e8bf10a0bc80a0b6
|
3 |
+
size 18668851
|
TabPFN/priors/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import fast_gp, mlp, flexible_categorical, differentiable_prior, prior_bag
|
2 |
+
|
3 |
+
|
4 |
+
|
TabPFN/priors/differentiable_prior.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
from .utils import get_batch_to_dataloader
|
6 |
+
from utils import default_device
|
7 |
+
from .utils import order_by_y, normalize_by_used_features_f
|
8 |
+
|
9 |
+
from .utils import trunc_norm_sampler_f, beta_sampler_f, gamma_sampler_f, uniform_sampler_f, zipf_sampler_f, scaled_beta_sampler_f, uniform_int_sampler_f
|
10 |
+
|
11 |
+
|
12 |
+
def unpack_dict_of_tuples(d):
|
13 |
+
# Returns list of dicts where each dict i contains values of tuple position i
|
14 |
+
# {'a': (1,2), 'b': (3,4)} -> [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}]
|
15 |
+
return [dict(zip(d.keys(), v)) for v in list(zip(*list(d.values())))]
|
16 |
+
|
17 |
+
class DifferentiableHyperparameter(nn.Module):
|
18 |
+
## We can sample this and get a hyperparameter value and a normalized hyperparameter indicator
|
19 |
+
def __init__(self, distribution, embedding_dim, device, **args):
|
20 |
+
super(DifferentiableHyperparameter, self).__init__()
|
21 |
+
|
22 |
+
self.distribution = distribution
|
23 |
+
self.embedding_dim = embedding_dim
|
24 |
+
self.device=device
|
25 |
+
for key in args:
|
26 |
+
setattr(self, key, args[key])
|
27 |
+
|
28 |
+
def get_sampler():
|
29 |
+
#if self.distribution == "beta":
|
30 |
+
# return beta_sampler_f(self.a, self.b), 0, 1
|
31 |
+
#elif self.distribution == "gamma":
|
32 |
+
# return gamma_sampler_f(self.a, self.b), 0, 1
|
33 |
+
#elif self.distribution == "beta_int":
|
34 |
+
# return scaled_beta_sampler_f(self.a, self.b, self.scale, self.min), self.scale + self.min, self.min, self.a / (self.a + self.b)
|
35 |
+
if self.distribution == "uniform":
|
36 |
+
if not hasattr(self, 'sample'):
|
37 |
+
return uniform_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min))
|
38 |
+
else:
|
39 |
+
return lambda: self.sample, self.min, self.max, None, None
|
40 |
+
elif self.distribution == "uniform_int":
|
41 |
+
return uniform_int_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min))
|
42 |
+
|
43 |
+
if self.distribution.startswith("meta"):
|
44 |
+
self.hparams = {}
|
45 |
+
def sample_meta(f):
|
46 |
+
indicators, passed = unpack_dict_of_tuples({hp: self.hparams[hp]() for hp in self.hparams})
|
47 |
+
# sampled_embeddings = list(itertools.chain.from_iterable([sampled_embeddings[k] for k in sampled_embeddings]))
|
48 |
+
meta_passed = f(**passed)
|
49 |
+
return indicators, meta_passed
|
50 |
+
|
51 |
+
args_passed = {'device': device, 'embedding_dim': embedding_dim}
|
52 |
+
if self.distribution == "meta_beta":
|
53 |
+
## Truncated normal where std and mean are drawn randomly logarithmically scaled
|
54 |
+
if hasattr(self, 'b') and hasattr(self, 'k'):
|
55 |
+
self.hparams = {'b': lambda: (None, self.b), 'k': lambda: (None, self.k)}
|
56 |
+
else:
|
57 |
+
self.hparams = {"b": DifferentiableHyperparameter(distribution="uniform", min=self.min
|
58 |
+
, max=self.max, **args_passed)
|
59 |
+
, "k": DifferentiableHyperparameter(distribution="uniform", min=self.min
|
60 |
+
, max=self.max, **args_passed)}
|
61 |
+
def make_beta(b, k):
|
62 |
+
return lambda b=b, k=k: self.scale * beta_sampler_f(b, k)()
|
63 |
+
self.sampler = lambda make_beta=make_beta : sample_meta(make_beta)
|
64 |
+
elif self.distribution == "meta_trunc_norm_log_scaled":
|
65 |
+
# these choices are copied down below, don't change these without changing `replace_differentiable_distributions`
|
66 |
+
self.min_std = self.min_std if hasattr(self, 'min_std') else 0.001
|
67 |
+
self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean
|
68 |
+
## Truncated normal where std and mean are drawn randomly logarithmically scaled
|
69 |
+
if not hasattr(self, 'log_mean'):
|
70 |
+
self.hparams = {"log_mean": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_mean)
|
71 |
+
, max=math.log(self.max_mean), **args_passed)
|
72 |
+
, "log_std": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_std)
|
73 |
+
, max=math.log(self.max_std), **args_passed)}
|
74 |
+
else:
|
75 |
+
self.hparams = {'log_mean': lambda: (None, self.log_mean), 'log_std': lambda: (None, self.log_std)}
|
76 |
+
def make_trunc_norm(log_mean, log_std):
|
77 |
+
return ((lambda : self.lower_bound + round(trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))())) if self.round
|
78 |
+
else (lambda: self.lower_bound + trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))()))
|
79 |
+
|
80 |
+
self.sampler = lambda make_trunc_norm=make_trunc_norm: sample_meta(make_trunc_norm)
|
81 |
+
elif self.distribution == "meta_trunc_norm":
|
82 |
+
self.min_std = self.min_std if hasattr(self, 'min_std') else 0
|
83 |
+
self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean
|
84 |
+
self.hparams = {"mean": DifferentiableHyperparameter(distribution="uniform", min=self.min_mean
|
85 |
+
, max=self.max_mean, **args_passed)
|
86 |
+
, "std": DifferentiableHyperparameter(distribution="uniform", min=self.min_std
|
87 |
+
, max=self.max_std, **args_passed)}
|
88 |
+
def make_trunc_norm(mean, std):
|
89 |
+
return ((lambda: self.lower_bound + round(
|
90 |
+
trunc_norm_sampler_f(math.exp(mean), math.exp(std))())) if self.round
|
91 |
+
else (
|
92 |
+
lambda make_trunc_norm=make_trunc_norm: self.lower_bound + trunc_norm_sampler_f(math.exp(mean), math.exp(std))()))
|
93 |
+
self.sampler = lambda : sample_meta(make_trunc_norm)
|
94 |
+
elif self.distribution == "meta_choice":
|
95 |
+
if hasattr(self, 'choice_1_weight'):
|
96 |
+
self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))}
|
97 |
+
else:
|
98 |
+
self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0
|
99 |
+
, max=6.0, **args_passed) for i in range(1, len(self.choice_values))}
|
100 |
+
def make_choice(**choices):
|
101 |
+
weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights
|
102 |
+
sample = torch.multinomial(weights, 1, replacement=True).numpy()[0]
|
103 |
+
return self.choice_values[sample]
|
104 |
+
|
105 |
+
self.sampler = lambda make_choice=make_choice: sample_meta(make_choice)
|
106 |
+
elif self.distribution == "meta_choice_mixed":
|
107 |
+
if hasattr(self, 'choice_1_weight'):
|
108 |
+
self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))}
|
109 |
+
else:
|
110 |
+
self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0
|
111 |
+
, max=6.0, **args_passed) for i in range(1, len(self.choice_values))}
|
112 |
+
def make_choice(**choices):
|
113 |
+
weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights
|
114 |
+
def sample():
|
115 |
+
s = torch.multinomial(weights, 1, replacement=True).numpy()[0]
|
116 |
+
return self.choice_values[s]()
|
117 |
+
return lambda: sample
|
118 |
+
|
119 |
+
self.sampler = lambda make_choice=make_choice: sample_meta(make_choice)
|
120 |
+
else:
|
121 |
+
def return_two(x, min, max, mean, std):
|
122 |
+
# Returns (a hyperparameter value, and an indicator value passed to the model)
|
123 |
+
if mean is not None:
|
124 |
+
ind = (x-mean)/std#(2 * (x-min) / (max-min) - 1)
|
125 |
+
else:
|
126 |
+
ind = None
|
127 |
+
return ind, x # normalize indicator to [-1, 1]
|
128 |
+
# def sample_standard(sampler_f, embedding):
|
129 |
+
# s = torch.tensor([sampler_f()], device = self.device)
|
130 |
+
# return s, embedding(s)
|
131 |
+
self.sampler_f, self.sampler_min, self.sampler_max, self.sampler_mean, self.sampler_std = get_sampler()
|
132 |
+
self.sampler = lambda : return_two(self.sampler_f(), min=self.sampler_min, max=self.sampler_max
|
133 |
+
, mean=self.sampler_mean, std=self.sampler_std)
|
134 |
+
# self.embedding_layer = nn.Linear(1, self.embedding_dim, device=self.device)
|
135 |
+
# self.embed = lambda x : self.embedding_layer(
|
136 |
+
# (x - self.sampler_min) / (self.sampler_max - self.sampler_min))
|
137 |
+
#self.sampler = lambda : sample_standard(self.sampler_f, self.embedding)
|
138 |
+
|
139 |
+
|
140 |
+
def forward(self):
|
141 |
+
s, s_passed = self.sampler()
|
142 |
+
return s, s_passed
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
class DifferentiableHyperparameterList(nn.Module):
|
147 |
+
def __init__(self, hyperparameters, embedding_dim, device):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
self.device = device
|
151 |
+
hyperparameters = {k: v for (k, v) in hyperparameters.items() if v}
|
152 |
+
self.hyperparameters = nn.ModuleDict({hp: DifferentiableHyperparameter(embedding_dim = embedding_dim
|
153 |
+
, name = hp
|
154 |
+
, device = device, **hyperparameters[hp]) for hp in hyperparameters})
|
155 |
+
def get_hyperparameter_info(self):
|
156 |
+
sampled_hyperparameters_f, sampled_hyperparameters_keys = [], []
|
157 |
+
def append_hp(hp_key, hp_val):
|
158 |
+
sampled_hyperparameters_keys.append(hp_key)
|
159 |
+
# Function remaps hyperparameters from [-1, 1] range to true value
|
160 |
+
s_min, s_max, s_mean, s_std = hp_val.sampler_min, hp_val.sampler_max, hp_val.sampler_mean, hp_val.sampler_std
|
161 |
+
sampled_hyperparameters_f.append((lambda x: (x-s_mean)/s_std, lambda y : (y * s_std)+s_mean))
|
162 |
+
#sampled_hyperparameters_f.append(((lambda x: ((x - s_min) / (s_max - s_min) * (2) - 1)
|
163 |
+
# , (lambda y: ((y + 1) * (1 / 2) * (s_max - s_min) + s_min))))
|
164 |
+
for hp in self.hyperparameters:
|
165 |
+
hp_val = self.hyperparameters[hp]
|
166 |
+
if hasattr(hp_val, 'hparams'):
|
167 |
+
for hp_ in hp_val.hparams:
|
168 |
+
append_hp(f'{hp}_{hp_}', hp_val.hparams[hp_])
|
169 |
+
else:
|
170 |
+
append_hp(hp, hp_val)
|
171 |
+
|
172 |
+
|
173 |
+
return sampled_hyperparameters_keys, sampled_hyperparameters_f
|
174 |
+
|
175 |
+
def sample_parameter_object(self):
|
176 |
+
sampled_hyperparameters, s_passed = {}, {}
|
177 |
+
for hp in self.hyperparameters:
|
178 |
+
sampled_hyperparameters_, s_passed_ = self.hyperparameters[hp]()
|
179 |
+
s_passed[hp] = s_passed_
|
180 |
+
if isinstance(sampled_hyperparameters_, dict):
|
181 |
+
sampled_hyperparameters_ = {hp + '_' + str(key): val for key, val in sampled_hyperparameters_.items()}
|
182 |
+
sampled_hyperparameters.update(sampled_hyperparameters_)
|
183 |
+
else:
|
184 |
+
sampled_hyperparameters[hp] = sampled_hyperparameters_
|
185 |
+
|
186 |
+
# s_passed contains the values passed to the get_batch function
|
187 |
+
# sampled_hyperparameters contains the indicator of the sampled value, i.e. only number that describe the sampled object
|
188 |
+
return s_passed, sampled_hyperparameters#self.pack_parameter_object(sampled_embeddings)
|
189 |
+
|
190 |
+
class DifferentiablePrior(torch.nn.Module):
|
191 |
+
def __init__(self, get_batch, hyperparameters, differentiable_hyperparameters, args):
|
192 |
+
super(DifferentiablePrior, self).__init__()
|
193 |
+
|
194 |
+
self.h = hyperparameters
|
195 |
+
self.args = args
|
196 |
+
self.get_batch = get_batch
|
197 |
+
self.differentiable_hyperparameters = DifferentiableHyperparameterList(differentiable_hyperparameters
|
198 |
+
, embedding_dim=self.h['emsize']
|
199 |
+
, device=self.args['device'])
|
200 |
+
|
201 |
+
def forward(self):
|
202 |
+
# Sample hyperparameters
|
203 |
+
sampled_hyperparameters_passed, sampled_hyperparameters_indicators = self.differentiable_hyperparameters.sample_parameter_object()
|
204 |
+
|
205 |
+
hyperparameters = {**self.h, **sampled_hyperparameters_passed}
|
206 |
+
x, y, y_ = self.get_batch(hyperparameters=hyperparameters, **self.args)
|
207 |
+
|
208 |
+
return x, y, y_, sampled_hyperparameters_indicators
|
209 |
+
|
210 |
+
|
211 |
+
# TODO: Make this a class that keeps objects
|
212 |
+
@torch.no_grad()
|
213 |
+
def get_batch(batch_size, seq_len, num_features, get_batch
|
214 |
+
, device=default_device, differentiable_hyperparameters={}
|
215 |
+
, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
|
216 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size))
|
217 |
+
num_models = batch_size // batch_size_per_gp_sample
|
218 |
+
assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
219 |
+
|
220 |
+
args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
|
221 |
+
|
222 |
+
models = [DifferentiablePrior(get_batch, hyperparameters, differentiable_hyperparameters, args) for _ in range(num_models)]
|
223 |
+
sample = sum([[model()] for model in models], [])
|
224 |
+
|
225 |
+
x, y, y_, hyperparameter_dict = zip(*sample)
|
226 |
+
|
227 |
+
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
228 |
+
print('Hparams', hyperparameter_dict[0].keys())
|
229 |
+
|
230 |
+
hyperparameter_matrix = []
|
231 |
+
for batch in hyperparameter_dict:
|
232 |
+
hyperparameter_matrix.append([batch[hp] for hp in batch])
|
233 |
+
|
234 |
+
transposed_hyperparameter_matrix = list(zip(*hyperparameter_matrix))
|
235 |
+
assert all([all([hp is None for hp in hp_]) or all([hp is not None for hp in hp_]) for hp_ in transposed_hyperparameter_matrix]), 'it should always be the case that when a hyper-parameter is None, once it is always None'
|
236 |
+
# we remove columns that are only None (i.e. not sampled)
|
237 |
+
hyperparameter_matrix = [[hp for hp in hp_ if hp is not None] for hp_ in hyperparameter_matrix]
|
238 |
+
if len(hyperparameter_matrix[0]) > 0:
|
239 |
+
packed_hyperparameters = torch.tensor(hyperparameter_matrix)
|
240 |
+
packed_hyperparameters = torch.repeat_interleave(packed_hyperparameters, repeats=batch_size_per_gp_sample, dim=0).detach()
|
241 |
+
else:
|
242 |
+
packed_hyperparameters = None
|
243 |
+
|
244 |
+
x, y, y_, packed_hyperparameters = (torch.cat(x, 1).detach()
|
245 |
+
, torch.cat(y, 1).detach()
|
246 |
+
, torch.cat(y_, 1).detach()
|
247 |
+
, packed_hyperparameters)#list(itertools.chain.from_iterable(itertools.repeat(x, batch_size_per_gp_sample) for x in packed_hyperparameters)))#torch.repeat_interleave(torch.stack(packed_hyperparameters, 0).detach(), repeats=batch_size_per_gp_sample, dim=0))
|
248 |
+
|
249 |
+
return x, y, y_, packed_hyperparameters
|
250 |
+
|
251 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
252 |
+
DataLoader.num_outputs = 1
|
253 |
+
#DataLoader.validate = lambda : 0
|
254 |
+
|
255 |
+
def draw_random_style(dl, device):
|
256 |
+
(hp_embedding, data, targets_), targets = next(iter(dl))
|
257 |
+
return hp_embedding.to(device)[0:1, :]
|
258 |
+
|
259 |
+
def merge_style_with_info(diff_hparams_keys, diff_hparams_f, style, transform=True):
|
260 |
+
params = dict(zip(diff_hparams_keys, zip(diff_hparams_f, style.detach().cpu().numpy().tolist()[0])))
|
261 |
+
def t(v):
|
262 |
+
if transform:
|
263 |
+
return v[0][1](v[1])
|
264 |
+
else:
|
265 |
+
return v[1]
|
266 |
+
return {k : t(v) for k, v in params.items()}
|
267 |
+
|
268 |
+
|
269 |
+
import ConfigSpace.hyperparameters as CSH
|
270 |
+
|
271 |
+
def replace_differentiable_distributions(config):
|
272 |
+
diff_config = config['differentiable_hyperparameters']
|
273 |
+
for name, diff_hp_dict in diff_config.items():
|
274 |
+
distribution = diff_hp_dict['distribution']
|
275 |
+
if distribution == 'uniform':
|
276 |
+
diff_hp_dict['sample'] = CSH.UniformFloatHyperparameter(name, diff_hp_dict['min'], diff_hp_dict['max'])
|
277 |
+
elif distribution == 'meta_beta':
|
278 |
+
diff_hp_dict['k'] = CSH.UniformFloatHyperparameter(name+'_k', diff_hp_dict['min'], diff_hp_dict['max'])
|
279 |
+
diff_hp_dict['b'] = CSH.UniformFloatHyperparameter(name+'_b', diff_hp_dict['min'], diff_hp_dict['max'])
|
280 |
+
elif distribution == 'meta_choice':
|
281 |
+
for i in range(1, len(diff_hp_dict['choice_values'])):
|
282 |
+
diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0)
|
283 |
+
elif distribution == 'meta_choice_mixed':
|
284 |
+
for i in range(1, len(diff_hp_dict['choice_values'])):
|
285 |
+
diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0)
|
286 |
+
elif distribution == 'meta_trunc_norm_log_scaled':
|
287 |
+
diff_hp_dict['log_mean'] = CSH.UniformFloatHyperparameter(name+'_log_mean', math.log(diff_hp_dict['min_mean']), math.log(diff_hp_dict['max_mean']))
|
288 |
+
min_std = diff_hp_dict['min_std'] if 'min_std' in diff_hp_dict else 0.001
|
289 |
+
max_std = diff_hp_dict['max_std'] if 'max_std' in diff_hp_dict else diff_hp_dict['max_mean']
|
290 |
+
diff_hp_dict['log_std'] = CSH.UniformFloatHyperparameter(name+'_log_std', math.log(min_std), math.log(max_std))
|
291 |
+
else:
|
292 |
+
raise ValueError(f'Unknown distribution {distribution}')
|
293 |
+
|
TabPFN/priors/fast_gp.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import gpytorch
|
6 |
+
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
from utils import default_device
|
9 |
+
|
10 |
+
|
11 |
+
# We will use the simplest form of GP model, exact inference
|
12 |
+
class ExactGPModel(gpytorch.models.ExactGP):
|
13 |
+
def __init__(self, train_x, train_y, likelihood):
|
14 |
+
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
15 |
+
self.mean_module = gpytorch.means.ConstantMean()
|
16 |
+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
mean_x = self.mean_module(x)
|
20 |
+
covar_x = self.covar_module(x)
|
21 |
+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
22 |
+
|
23 |
+
|
24 |
+
def get_model(x, y, hyperparameters):
|
25 |
+
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
|
26 |
+
model = ExactGPModel(x, y, likelihood)
|
27 |
+
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
|
28 |
+
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
|
29 |
+
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
|
30 |
+
hyperparameters["lengthscale"]
|
31 |
+
return model, likelihood
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
|
36 |
+
equidistant_x=False, fix_x=None, **kwargs):
|
37 |
+
if isinstance(hyperparameters, (tuple, list)):
|
38 |
+
hyperparameters = {"noise": hyperparameters[0]
|
39 |
+
, "outputscale": hyperparameters[1]
|
40 |
+
, "lengthscale": hyperparameters[2]
|
41 |
+
, "is_binary_classification": hyperparameters[3]
|
42 |
+
# , "num_features_used": hyperparameters[4]
|
43 |
+
, "normalize_by_used_features": hyperparameters[5]
|
44 |
+
, "order_y": hyperparameters[6]
|
45 |
+
, "sampling": hyperparameters[7]
|
46 |
+
}
|
47 |
+
elif hyperparameters is None:
|
48 |
+
hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
|
49 |
+
|
50 |
+
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
51 |
+
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
52 |
+
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})
|
53 |
+
|
54 |
+
# hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
|
55 |
+
# hyperparameters.keys()}
|
56 |
+
assert not (equidistant_x and (fix_x is not None))
|
57 |
+
|
58 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
|
59 |
+
if equidistant_x:
|
60 |
+
assert num_features == 1
|
61 |
+
x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
|
62 |
+
elif fix_x is not None:
|
63 |
+
assert fix_x.shape == (seq_len, num_features)
|
64 |
+
x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
|
65 |
+
else:
|
66 |
+
if hyperparameters.get('sampling','uniform') == 'uniform':
|
67 |
+
x = torch.rand(batch_size, seq_len, num_features, device=device)
|
68 |
+
else:
|
69 |
+
x = torch.randn(batch_size, seq_len, num_features, device=device)
|
70 |
+
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
71 |
+
model.to(device)
|
72 |
+
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
73 |
+
# trained_model.eval()
|
74 |
+
is_fitted = False
|
75 |
+
while not is_fitted:
|
76 |
+
try:
|
77 |
+
with gpytorch.settings.prior_mode(True):
|
78 |
+
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
79 |
+
model.to(device)
|
80 |
+
|
81 |
+
d = model(x)
|
82 |
+
d = likelihood(d)
|
83 |
+
sample = d.sample().transpose(0, 1)
|
84 |
+
is_fitted = True
|
85 |
+
except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
|
86 |
+
print('GP Fitting unsuccessful, retrying.. ')
|
87 |
+
print(x)
|
88 |
+
print(hyperparameters)
|
89 |
+
|
90 |
+
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
|
91 |
+
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
92 |
+
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})
|
93 |
+
|
94 |
+
# TODO: Multi output
|
95 |
+
return x.transpose(0, 1), sample, sample # x.shape = (T,B,H)
|
96 |
+
|
97 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
98 |
+
DataLoader.num_outputs = 1
|
99 |
+
|
100 |
+
def get_model_on_device(x,y,hyperparameters,device):
|
101 |
+
model, likelihood = get_model(x, y, hyperparameters)
|
102 |
+
model.to(device)
|
103 |
+
return model, likelihood
|
104 |
+
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
|
108 |
+
start_time = time.time()
|
109 |
+
losses_after_t = [.0] if start_pos == 0 else []
|
110 |
+
all_losses_after_t = []
|
111 |
+
|
112 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
113 |
+
for t in range(max(start_pos, 1), len(x), step_size):
|
114 |
+
loss_sum = 0.
|
115 |
+
model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
|
116 |
+
|
117 |
+
|
118 |
+
model.eval()
|
119 |
+
# print([t.shape for t in model.train_inputs])
|
120 |
+
# print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
|
121 |
+
f = model(x[t].unsqueeze(1))
|
122 |
+
l = likelihood(f)
|
123 |
+
means = l.mean.squeeze()
|
124 |
+
varis = l.covariance_matrix.squeeze()
|
125 |
+
# print(l.variance.squeeze(), l.mean.squeeze(), y[t])
|
126 |
+
|
127 |
+
assert len(means.shape) == len(varis.shape) == 1
|
128 |
+
assert len(means) == len(varis) == x.shape[1]
|
129 |
+
|
130 |
+
if use_mse:
|
131 |
+
c = nn.MSELoss(reduction='none')
|
132 |
+
ls = c(means, y[t])
|
133 |
+
else:
|
134 |
+
ls = -l.log_prob(y[t].unsqueeze(1))
|
135 |
+
|
136 |
+
losses_after_t.append(ls.mean())
|
137 |
+
all_losses_after_t.append(ls.flatten())
|
138 |
+
return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
hps = (.1,.1,.1)
|
142 |
+
for redo_idx in range(1):
|
143 |
+
print(
|
144 |
+
evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
|
TabPFN/priors/flexible_categorical.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
from utils import normalize_data, nan_handling_missing_for_unknown_reason_value, nan_handling_missing_for_no_reason_value, nan_handling_missing_for_a_reason_value, to_ranking_low_mem, remove_outliers
|
9 |
+
from .utils import normalize_by_used_features_f, randomize_classes, CategoricalActivation
|
10 |
+
from .utils import uniform_int_sampler_f
|
11 |
+
|
12 |
+
time_it = False
|
13 |
+
|
14 |
+
class BalancedBinarize(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return (x > torch.median(x)).float()
|
20 |
+
|
21 |
+
def class_sampler_f(min_, max_):
|
22 |
+
def s():
|
23 |
+
if random.random() > 0.5:
|
24 |
+
return uniform_int_sampler_f(min_, max_)()
|
25 |
+
return 2
|
26 |
+
return s
|
27 |
+
|
28 |
+
class MulticlassRank(nn.Module):
|
29 |
+
def __init__(self, num_classes, ordered_p=0.5):
|
30 |
+
super().__init__()
|
31 |
+
self.num_classes = class_sampler_f(2, num_classes)()
|
32 |
+
self.ordered_p = ordered_p
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
# x has shape (T,B,H)
|
36 |
+
|
37 |
+
# CAUTION: This samples the same idx in sequence for each class boundary in a batch
|
38 |
+
class_boundaries = torch.randint(0, x.shape[0], (self.num_classes - 1,))
|
39 |
+
class_boundaries = x[class_boundaries].unsqueeze(1)
|
40 |
+
|
41 |
+
d = (x > class_boundaries).sum(axis=0)
|
42 |
+
|
43 |
+
randomized_classes = torch.rand((d.shape[1], )) > self.ordered_p
|
44 |
+
d[:, randomized_classes] = randomize_classes(d[:, randomized_classes], self.num_classes)
|
45 |
+
reverse_classes = torch.rand((d.shape[1],)) > 0.5
|
46 |
+
d[:, reverse_classes] = self.num_classes - 1 - d[:, reverse_classes]
|
47 |
+
return d
|
48 |
+
|
49 |
+
class MulticlassValue(nn.Module):
|
50 |
+
def __init__(self, num_classes, ordered_p=0.5):
|
51 |
+
super().__init__()
|
52 |
+
self.num_classes = class_sampler_f(2, num_classes)()
|
53 |
+
self.classes = nn.Parameter(torch.randn(num_classes-1), requires_grad=False)
|
54 |
+
self.ordered_p = ordered_p
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
# x has shape (T,B,H)
|
58 |
+
d = (x > (self.classes.unsqueeze(-1).unsqueeze(-1))).sum(axis=0)
|
59 |
+
|
60 |
+
randomized_classes = torch.rand((d.shape[1],)) > self.ordered_p
|
61 |
+
d[:, randomized_classes] = randomize_classes(d[:, randomized_classes], self.num_classes)
|
62 |
+
reverse_classes = torch.rand((d.shape[1],)) > 0.5
|
63 |
+
d[:, reverse_classes] = self.num_classes - 1 - d[:, reverse_classes]
|
64 |
+
return d
|
65 |
+
|
66 |
+
class MulticlassMultiNode(nn.Module):
|
67 |
+
def __init__(self, num_classes, ordered_p=0.5):
|
68 |
+
super().__init__()
|
69 |
+
self.num_classes = class_sampler_f(2, num_classes)()
|
70 |
+
self.classes = nn.Parameter(torch.randn(num_classes-1), requires_grad=False)
|
71 |
+
self.alt_multi_class = MulticlassValue(num_classes, ordered_p)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
# x has shape T, B, H
|
75 |
+
if len(x.shape) == 2:
|
76 |
+
return self.alt_multi_class(x)
|
77 |
+
T = 3
|
78 |
+
x[torch.isnan(x)] = 0.00001
|
79 |
+
d = torch.multinomial(torch.pow(0.00001+torch.sigmoid(x[:, :, 0:self.num_classes]).reshape(-1, self.num_classes), T), 1, replacement=True).reshape(x.shape[0], x.shape[1]).float()
|
80 |
+
return d
|
81 |
+
|
82 |
+
|
83 |
+
class FlexibleCategorical(torch.nn.Module):
|
84 |
+
def __init__(self, get_batch, hyperparameters, args):
|
85 |
+
super(FlexibleCategorical, self).__init__()
|
86 |
+
|
87 |
+
self.h = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
|
88 |
+
hyperparameters.keys()}
|
89 |
+
self.args = args
|
90 |
+
self.args_passed = {**self.args}
|
91 |
+
self.args_passed.update({'num_features': self.h['num_features_used']})
|
92 |
+
self.get_batch = get_batch
|
93 |
+
|
94 |
+
if self.h['num_classes'] > 1 and not self.h['balanced']:
|
95 |
+
if self.h['multiclass_type'] == 'rank':
|
96 |
+
self.class_assigner = MulticlassRank(self.h['num_classes']
|
97 |
+
, ordered_p=self.h['output_multiclass_ordered_p']
|
98 |
+
)
|
99 |
+
elif self.h['multiclass_type'] == 'value':
|
100 |
+
self.class_assigner = MulticlassValue(self.h['num_classes']
|
101 |
+
, ordered_p=self.h['output_multiclass_ordered_p']
|
102 |
+
)
|
103 |
+
elif self.h['multiclass_type'] == 'multi_node':
|
104 |
+
self.class_assigner = MulticlassMultiNode(self.h['num_classes'])
|
105 |
+
else:
|
106 |
+
raise ValueError("Unknow Multiclass type")
|
107 |
+
elif self.h['num_classes'] == 2 and self.h['balanced']:
|
108 |
+
self.class_assigner = BalancedBinarize()
|
109 |
+
elif self.h['num_classes'] > 2 and self.h['balanced']:
|
110 |
+
raise NotImplementedError("Balanced multiclass training is not possible")
|
111 |
+
else:
|
112 |
+
self.class_assigner = lambda x:x # Regression
|
113 |
+
|
114 |
+
def drop_for_reason(self, x, v):
|
115 |
+
nan_prob_sampler = CategoricalActivation(ordered_p=0.0
|
116 |
+
, categorical_p=1.0
|
117 |
+
, keep_activation_size=False,
|
118 |
+
num_classes_sampler=lambda: 20)
|
119 |
+
d = nan_prob_sampler(x)
|
120 |
+
# TODO: Make a different ordering for each activation
|
121 |
+
x[d < torch.rand((1,), device=x.device) * 20 * self.h['nan_prob_no_reason'] * random.random()] = v
|
122 |
+
return x
|
123 |
+
|
124 |
+
def drop_for_no_reason(self, x, v):
|
125 |
+
x[torch.rand(x.shape, device=self.args['device']) < self.h['nan_prob_no_reason']] = v
|
126 |
+
return x
|
127 |
+
|
128 |
+
def forward(self, batch_size):
|
129 |
+
start = time.time()
|
130 |
+
x, y, y_ = self.get_batch(hyperparameters=self.h, **self.args_passed)
|
131 |
+
if time_it:
|
132 |
+
print('Flex Forward Block 1', round(time.time() - start, 3))
|
133 |
+
|
134 |
+
start = time.time()
|
135 |
+
|
136 |
+
if self.h['nan_prob_no_reason']+self.h['nan_prob_a_reason']+self.h['nan_prob_unknown_reason'] > 0 and random.random() > 0.5: # Only one out of two datasets should have nans
|
137 |
+
if self.h['nan_prob_no_reason'] > 0 and random.random() > 0.5: # Missing for no reason
|
138 |
+
x = self.drop_for_no_reason(x, nan_handling_missing_for_no_reason_value(self.h['set_value_to_nan']))
|
139 |
+
|
140 |
+
if self.h['nan_prob_a_reason'] > 0 and random.random() > 0.5: # Missing for a reason
|
141 |
+
x = self.drop_for_reason(x, nan_handling_missing_for_a_reason_value(self.h['set_value_to_nan']))
|
142 |
+
|
143 |
+
if self.h['nan_prob_unknown_reason'] > 0: # Missing for unknown reason and random.random() > 0.5
|
144 |
+
if random.random() < self.h['nan_prob_unknown_reason_reason_prior']:
|
145 |
+
x = self.drop_for_no_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan']))
|
146 |
+
else:
|
147 |
+
x = self.drop_for_reason(x, nan_handling_missing_for_unknown_reason_value(self.h['set_value_to_nan']))
|
148 |
+
|
149 |
+
# Categorical features
|
150 |
+
if 'categorical_feature_p' in self.h and random.random() > 1 - self.h['categorical_feature_p']:
|
151 |
+
p = random.random()
|
152 |
+
for col in range(x.shape[2]):
|
153 |
+
m = MulticlassRank(10, ordered_p=0.3)
|
154 |
+
if random.random() > p:
|
155 |
+
x[:, :, col] = m(x[:, :, col])
|
156 |
+
|
157 |
+
if time_it:
|
158 |
+
print('Flex Forward Block 2', round(time.time() - start, 3))
|
159 |
+
start = time.time()
|
160 |
+
|
161 |
+
if self.h['normalize_to_ranking']:
|
162 |
+
x = to_ranking_low_mem(x)
|
163 |
+
else:
|
164 |
+
x = remove_outliers(x)
|
165 |
+
x, y = normalize_data(x), normalize_data(y)
|
166 |
+
|
167 |
+
if time_it:
|
168 |
+
print('Flex Forward Block 3', round(time.time() - start, 3))
|
169 |
+
start = time.time()
|
170 |
+
|
171 |
+
# Cast to classification if enabled
|
172 |
+
y = self.class_assigner(y).float()
|
173 |
+
|
174 |
+
if time_it:
|
175 |
+
print('Flex Forward Block 4', round(time.time() - start, 3))
|
176 |
+
start = time.time()
|
177 |
+
if self.h['normalize_by_used_features']:
|
178 |
+
x = normalize_by_used_features_f(x, self.h['num_features_used'], self.args['num_features'], normalize_with_sqrt=self.h.get('normalize_with_sqrt',False))
|
179 |
+
if time_it:
|
180 |
+
print('Flex Forward Block 5', round(time.time() - start, 3))
|
181 |
+
|
182 |
+
start = time.time()
|
183 |
+
# Append empty features if enabled
|
184 |
+
x = torch.cat(
|
185 |
+
[x, torch.zeros((x.shape[0], x.shape[1], self.args['num_features'] - self.h['num_features_used']),
|
186 |
+
device=self.args['device'])], -1)
|
187 |
+
if time_it:
|
188 |
+
print('Flex Forward Block 6', round(time.time() - start, 3))
|
189 |
+
|
190 |
+
return x, y, y # x.shape = (T,B,H)
|
191 |
+
|
192 |
+
import torch.cuda as cutorch
|
193 |
+
|
194 |
+
@torch.no_grad()
|
195 |
+
def get_batch(batch_size, seq_len, num_features, get_batch, device, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
|
196 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or (min(32, batch_size))
|
197 |
+
num_models = batch_size // batch_size_per_gp_sample
|
198 |
+
assert num_models > 0, f'Batch size ({batch_size}) is too small for batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
199 |
+
assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
200 |
+
|
201 |
+
# Sample one seq_len for entire batch
|
202 |
+
seq_len = hyperparameters['seq_len_used']() if callable(hyperparameters['seq_len_used']) else seq_len
|
203 |
+
|
204 |
+
args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
|
205 |
+
|
206 |
+
models = [FlexibleCategorical(get_batch, hyperparameters, args).to(device) for _ in range(num_models)]
|
207 |
+
|
208 |
+
start = time.time()
|
209 |
+
sample = sum([[model(batch_size=batch_size_per_gp_sample)] for model in models], [])
|
210 |
+
#print('sample', time.time() - start)
|
211 |
+
|
212 |
+
x, y, y_ = zip(*sample)
|
213 |
+
x, y, y_ = torch.cat(x, 1).detach(), torch.cat(y, 1).detach(), torch.cat(y_, 1).detach()
|
214 |
+
|
215 |
+
# # TODO: Reintegrate this code (Doesn't work on batch dim), could be applied to each batch sample individually
|
216 |
+
# if hyperparameters['is_binary_classification'] and hyperparameters['order_y']:
|
217 |
+
# x, y = order_by_y(x, y)
|
218 |
+
|
219 |
+
return x, y, y_
|
220 |
+
|
221 |
+
# num_features_used = num_features_used_sampler()
|
222 |
+
# prior_outputscale = prior_outputscale_sampler()
|
223 |
+
# prior_lengthscale = prior_lengthscale_sampler()
|
224 |
+
#
|
225 |
+
# x, sample = normalize_data(x), normalize_data(sample)
|
226 |
+
#
|
227 |
+
# if is_binary_classification:
|
228 |
+
# sample = (sample > torch.median(sample, dim=0)[0]).float()
|
229 |
+
#
|
230 |
+
# if normalize_by_used_features:
|
231 |
+
# x = normalize_by_used_features_f(x, num_features_used, num_features)
|
232 |
+
#
|
233 |
+
# # # if is_binary_classification and order_y:
|
234 |
+
# # # x, sample = order_by_y(x, sample)
|
235 |
+
# #
|
236 |
+
# # Append empty features if enabled
|
237 |
+
# x = torch.cat([x, torch.zeros((x.shape[0], x.shape[1], num_features - num_features_used), device=device)], -1)
|
238 |
+
|
239 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
240 |
+
DataLoader.num_outputs = 1
|
TabPFN/priors/mlp.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from utils import default_device
|
9 |
+
from .utils import get_batch_to_dataloader
|
10 |
+
|
11 |
+
class GaussianNoise(nn.Module):
|
12 |
+
def __init__(self, std, device):
|
13 |
+
super().__init__()
|
14 |
+
self.std = std
|
15 |
+
self.device=device
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return x + torch.normal(torch.zeros_like(x), self.std)
|
19 |
+
|
20 |
+
|
21 |
+
def causes_sampler_f(num_causes):
|
22 |
+
means = np.random.normal(0, 1, (num_causes))
|
23 |
+
std = np.abs(np.random.normal(0, 1, (num_causes)) * means)
|
24 |
+
return means, std
|
25 |
+
|
26 |
+
def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, sampling='normal', **kwargs):
|
27 |
+
if ('mix_activations' in hyperparameters) and hyperparameters['mix_activations']:
|
28 |
+
s = hyperparameters['prior_mlp_activations']()
|
29 |
+
hyperparameters['prior_mlp_activations'] = lambda : s
|
30 |
+
|
31 |
+
class MLP(torch.nn.Module):
|
32 |
+
def __init__(self, hyperparameters):
|
33 |
+
super(MLP, self).__init__()
|
34 |
+
|
35 |
+
with torch.no_grad():
|
36 |
+
|
37 |
+
for key in hyperparameters:
|
38 |
+
setattr(self, key, hyperparameters[key])
|
39 |
+
|
40 |
+
assert (self.num_layers >= 2)
|
41 |
+
|
42 |
+
if 'verbose' in hyperparameters and self.verbose:
|
43 |
+
print({k : hyperparameters[k] for k in ['is_causal', 'num_causes', 'prior_mlp_hidden_dim'
|
44 |
+
, 'num_layers', 'noise_std', 'y_is_effect', 'pre_sample_weights', 'prior_mlp_dropout_prob'
|
45 |
+
, 'pre_sample_causes']})
|
46 |
+
|
47 |
+
if self.is_causal:
|
48 |
+
self.prior_mlp_hidden_dim = max(self.prior_mlp_hidden_dim, num_outputs + 2 * num_features)
|
49 |
+
else:
|
50 |
+
self.num_causes = num_features
|
51 |
+
|
52 |
+
# This means that the mean and standard deviation of each cause is determined in advance
|
53 |
+
if self.pre_sample_causes:
|
54 |
+
self.causes_mean, self.causes_std = causes_sampler_f(self.num_causes)
|
55 |
+
self.causes_mean = torch.tensor(self.causes_mean, device=device).unsqueeze(0).unsqueeze(0).tile(
|
56 |
+
(seq_len, 1, 1))
|
57 |
+
self.causes_std = torch.tensor(self.causes_std, device=device).unsqueeze(0).unsqueeze(0).tile(
|
58 |
+
(seq_len, 1, 1))
|
59 |
+
|
60 |
+
def generate_module(layer_idx, out_dim):
|
61 |
+
# Determine std of each noise term in initialization, so that is shared in runs
|
62 |
+
# torch.abs(torch.normal(torch.zeros((out_dim)), self.noise_std)) - Change std for each dimension?
|
63 |
+
noise = (GaussianNoise(torch.abs(torch.normal(torch.zeros(size=(1, out_dim), device=device), float(self.noise_std))), device=device)
|
64 |
+
if self.pre_sample_weights else GaussianNoise(float(self.noise_std), device=device))
|
65 |
+
return [
|
66 |
+
nn.Sequential(*[self.prior_mlp_activations()
|
67 |
+
, nn.Linear(self.prior_mlp_hidden_dim, out_dim)
|
68 |
+
, noise])
|
69 |
+
]
|
70 |
+
|
71 |
+
self.layers = [nn.Linear(self.num_causes, self.prior_mlp_hidden_dim, device=device)]
|
72 |
+
self.layers += [module for layer_idx in range(self.num_layers-1) for module in generate_module(layer_idx, self.prior_mlp_hidden_dim)]
|
73 |
+
if not self.is_causal:
|
74 |
+
self.layers += generate_module(-1, num_outputs)
|
75 |
+
self.layers = nn.Sequential(*self.layers)
|
76 |
+
|
77 |
+
# Initialize Model parameters
|
78 |
+
for i, (n, p) in enumerate(self.layers.named_parameters()):
|
79 |
+
if self.block_wise_dropout:
|
80 |
+
if len(p.shape) == 2: # Only apply to weight matrices and not bias
|
81 |
+
nn.init.zeros_(p)
|
82 |
+
# TODO: N blocks should be a setting
|
83 |
+
n_blocks = random.randint(1, math.ceil(math.sqrt(min(p.shape[0], p.shape[1]))))
|
84 |
+
w, h = p.shape[0] // n_blocks, p.shape[1] // n_blocks
|
85 |
+
keep_prob = (n_blocks*w*h) / p.numel()
|
86 |
+
for block in range(0, n_blocks):
|
87 |
+
nn.init.normal_(p[w * block: w * (block+1), h * block: h * (block+1)], std=self.init_std / keep_prob**(1/2))
|
88 |
+
else:
|
89 |
+
if len(p.shape) == 2: # Only apply to weight matrices and not bias
|
90 |
+
dropout_prob = self.prior_mlp_dropout_prob if i > 0 else 0.0 # Don't apply dropout in first layer
|
91 |
+
dropout_prob = min(dropout_prob, 0.99)
|
92 |
+
nn.init.normal_(p, std=self.init_std / (1. - dropout_prob)**(1/2))
|
93 |
+
p *= torch.bernoulli(torch.zeros_like(p) + 1. - dropout_prob)
|
94 |
+
|
95 |
+
def forward(self):
|
96 |
+
def sample_normal():
|
97 |
+
if self.pre_sample_causes:
|
98 |
+
causes = torch.normal(self.causes_mean, self.causes_std.abs()).float()
|
99 |
+
else:
|
100 |
+
causes = torch.normal(0., 1., (seq_len, 1, self.num_causes), device=device).float()
|
101 |
+
return causes
|
102 |
+
|
103 |
+
if self.sampling == 'normal':
|
104 |
+
causes = sample_normal()
|
105 |
+
elif self.sampling == 'mixed':
|
106 |
+
zipf_p, multi_p, normal_p = random.random() * 0.66, random.random() * 0.66, random.random() * 0.66
|
107 |
+
def sample_cause(n):
|
108 |
+
if random.random() > normal_p:
|
109 |
+
if self.pre_sample_causes:
|
110 |
+
return torch.normal(self.causes_mean[:, :, n], self.causes_std[:, :, n].abs()).float()
|
111 |
+
else:
|
112 |
+
return torch.normal(0., 1., (seq_len, 1), device=device).float()
|
113 |
+
elif random.random() > multi_p:
|
114 |
+
x = torch.multinomial(torch.rand((random.randint(2, 10))), seq_len, replacement=True).to(device).unsqueeze(-1).float()
|
115 |
+
x = (x - torch.mean(x)) / torch.std(x)
|
116 |
+
return x
|
117 |
+
else:
|
118 |
+
x = torch.minimum(torch.tensor(np.random.zipf(2.0 + random.random() * 2, size=(seq_len)),
|
119 |
+
device=device).unsqueeze(-1).float(), torch.tensor(10.0, device=device))
|
120 |
+
return x - torch.mean(x)
|
121 |
+
causes = torch.cat([sample_cause(n).unsqueeze(-1) for n in range(self.num_causes)], -1)
|
122 |
+
elif self.sampling == 'uniform':
|
123 |
+
causes = torch.rand((seq_len, 1, self.num_causes), device=device)
|
124 |
+
else:
|
125 |
+
raise ValueError(f'Sampling is set to invalid setting: {sampling}.')
|
126 |
+
|
127 |
+
outputs = [causes]
|
128 |
+
for layer in self.layers:
|
129 |
+
outputs.append(layer(outputs[-1]))
|
130 |
+
outputs = outputs[2:]
|
131 |
+
|
132 |
+
if self.is_causal:
|
133 |
+
## Sample nodes from graph if model is causal
|
134 |
+
outputs_flat = torch.cat(outputs, -1)
|
135 |
+
|
136 |
+
if self.in_clique:
|
137 |
+
random_perm = random.randint(0, outputs_flat.shape[-1] - num_outputs - num_features) + torch.randperm(num_outputs + num_features, device=device)
|
138 |
+
else:
|
139 |
+
random_perm = torch.randperm(outputs_flat.shape[-1]-1, device=device)
|
140 |
+
|
141 |
+
random_idx_y = list(range(-num_outputs, -0)) if self.y_is_effect else random_perm[0:num_outputs]
|
142 |
+
random_idx = random_perm[num_outputs:num_outputs + num_features]
|
143 |
+
|
144 |
+
if self.sort_features:
|
145 |
+
random_idx, _ = torch.sort(random_idx)
|
146 |
+
y = outputs_flat[:, :, random_idx_y]
|
147 |
+
|
148 |
+
x = outputs_flat[:, :, random_idx]
|
149 |
+
else:
|
150 |
+
y = outputs[-1][:, :, :]
|
151 |
+
x = causes
|
152 |
+
|
153 |
+
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()) or bool(torch.any(torch.isnan(y)).detach().cpu().numpy()):
|
154 |
+
x[:] = 0.0
|
155 |
+
y[:] = 1.0
|
156 |
+
|
157 |
+
return x, y
|
158 |
+
|
159 |
+
model = MLP(hyperparameters).to(device)
|
160 |
+
|
161 |
+
sample = sum([[model()] for _ in range(0, batch_size)], [])
|
162 |
+
|
163 |
+
x, y = zip(*sample)
|
164 |
+
y = torch.cat(y, 1).detach().squeeze(2)
|
165 |
+
x = torch.cat(x, 1).detach()
|
166 |
+
x = x[..., torch.randperm(x.shape[-1])]
|
167 |
+
|
168 |
+
return x, y, y
|
169 |
+
|
170 |
+
|
171 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
172 |
+
DataLoader.num_outputs = 1
|
173 |
+
|
TabPFN/priors/prior.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
|
3 |
+
|
4 |
+
class PriorDataLoader(DataLoader):
|
5 |
+
pass
|
6 |
+
# init accepts num_steps as first argument
|
7 |
+
|
8 |
+
# has two attributes set on class or object level:
|
9 |
+
# num_features: int and
|
10 |
+
# num_outputs: int
|
11 |
+
# fuse_x_y: bool
|
12 |
+
# Optional: validate function that accepts a transformer model
|
TabPFN/priors/prior_bag.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .utils import get_batch_to_dataloader
|
4 |
+
from utils import default_device
|
5 |
+
|
6 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device
|
7 |
+
, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs):
|
8 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size))
|
9 |
+
num_models = batch_size // batch_size_per_gp_sample
|
10 |
+
assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})'
|
11 |
+
|
12 |
+
args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample}
|
13 |
+
|
14 |
+
prior_bag_priors_get_batch = hyperparameters['prior_bag_get_batch']
|
15 |
+
prior_bag_priors_p = [1.0] + [hyperparameters[f'prior_bag_exp_weights_{i}'] for i in range(1, len(prior_bag_priors_get_batch))]
|
16 |
+
|
17 |
+
weights = torch.tensor(prior_bag_priors_p, dtype=torch.float) # create a tensor of weights
|
18 |
+
batch_assignments = torch.multinomial(torch.softmax(weights, 0), num_models, replacement=True).numpy()
|
19 |
+
|
20 |
+
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
21 |
+
print('PRIOR_BAG:', weights, batch_assignments)
|
22 |
+
|
23 |
+
sample = sum([[prior_bag_priors_get_batch[int(prior_idx)](hyperparameters=hyperparameters, **args)] for prior_idx in batch_assignments], [])
|
24 |
+
|
25 |
+
x, y, y_ = zip(*sample)
|
26 |
+
x, y, y_ = (torch.cat(x, 1).detach()
|
27 |
+
, torch.cat(y, 1).detach()
|
28 |
+
, torch.cat(y_, 1).detach())
|
29 |
+
return x, y, y_
|
30 |
+
|
31 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
32 |
+
DataLoader.num_outputs = 1
|
TabPFN/priors/utils.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from utils import set_locals_in_self
|
6 |
+
from .prior import PriorDataLoader
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import matplotlib.gridspec as gridspec
|
11 |
+
import scipy.stats as stats
|
12 |
+
import math
|
13 |
+
|
14 |
+
def get_batch_to_dataloader(get_batch_method_):
|
15 |
+
class DL(PriorDataLoader):
|
16 |
+
get_batch_method = get_batch_method_
|
17 |
+
|
18 |
+
# Caution, you might need to set self.num_features manually if it is not part of the args.
|
19 |
+
def __init__(self, num_steps, fuse_x_y=False, **get_batch_kwargs):
|
20 |
+
set_locals_in_self(locals())
|
21 |
+
# The stuff outside the or is set as class attribute before instantiation.
|
22 |
+
self.num_features = get_batch_kwargs.get('num_features') or self.num_features
|
23 |
+
self.num_outputs = get_batch_kwargs.get('num_outputs') or self.num_outputs
|
24 |
+
print('DataLoader.__dict__', self.__dict__)
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def gbm(*args, fuse_x_y=True, **kwargs):
|
28 |
+
dynamic_seq_len = callable(kwargs['seq_len'])
|
29 |
+
kwargs['seq_len'] = kwargs['seq_len']() if dynamic_seq_len else kwargs['seq_len']
|
30 |
+
# Scales the batch size dynamically with the power of 'dynamic_batch_size'.
|
31 |
+
# A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
|
32 |
+
if dynamic_seq_len and 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0:
|
33 |
+
kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
|
34 |
+
batch = get_batch_method_(*args, **kwargs)
|
35 |
+
x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
|
36 |
+
if fuse_x_y:
|
37 |
+
return torch.cat([x, torch.cat([torch.zeros_like(y[:1]), y[:-1]], 0).unsqueeze(-1).float()],
|
38 |
+
-1), target_y
|
39 |
+
else:
|
40 |
+
return (style, x, y), target_y
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return self.num_steps
|
44 |
+
|
45 |
+
def __iter__(self):
|
46 |
+
return iter(self.gbm(**self.get_batch_kwargs, fuse_x_y=self.fuse_x_y) for _ in range(self.num_steps))
|
47 |
+
|
48 |
+
|
49 |
+
return DL
|
50 |
+
|
51 |
+
import seaborn as sns
|
52 |
+
def plot_features(data, targets, fig=None):
|
53 |
+
if torch.is_tensor(data):
|
54 |
+
data = data.detach().cpu().numpy()
|
55 |
+
targets = targets.detach().cpu().numpy()
|
56 |
+
#data = np.concatenate([data, data[:, -1:]], -1)
|
57 |
+
#df = pd.DataFrame(data, columns=list(range(0, data.shape[1])))
|
58 |
+
#g = sns.pairplot(df, hue=data.shape[1]-1, palette="Set2", diag_kind="kde", height=2.5)
|
59 |
+
#plt.legend([], [], frameon=False)
|
60 |
+
#g._legend.remove()
|
61 |
+
#g = sns.PairGrid(df, hue=data.shape[1]-1)
|
62 |
+
#g.map_diag(sns.histplot)
|
63 |
+
#g.map_offdiag(sns.scatterplot)
|
64 |
+
#g._legend.remove()
|
65 |
+
|
66 |
+
fig2 = fig if fig else plt.figure(figsize=(8, 8))
|
67 |
+
spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
|
68 |
+
for d in range(0, data.shape[1]):
|
69 |
+
for d2 in range(0, data.shape[1]):
|
70 |
+
sub_ax = fig2.add_subplot(spec2[d, d2])
|
71 |
+
if d == d2:
|
72 |
+
sns.kdeplot(data[:, d],hue=targets[:],ax=sub_ax,legend=False, palette="deep")
|
73 |
+
sub_ax.set(ylabel=None)
|
74 |
+
else:
|
75 |
+
sns.scatterplot(x=data[:, d], y=data[:, d2],
|
76 |
+
hue=targets[:],legend=False, palette="deep")
|
77 |
+
#plt.scatter(data[:, d], data[:, d2],
|
78 |
+
# c=targets[:])
|
79 |
+
sub_ax.get_xaxis().set_ticks([])
|
80 |
+
sub_ax.get_yaxis().set_ticks([])
|
81 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
82 |
+
fig2.show()
|
83 |
+
|
84 |
+
|
85 |
+
def plot_prior(prior):
|
86 |
+
s = np.array([prior() for _ in range(0, 1000)])
|
87 |
+
count, bins, ignored = plt.hist(s, 50, density=True)
|
88 |
+
print(s.min())
|
89 |
+
plt.show()
|
90 |
+
|
91 |
+
trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
|
92 |
+
beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
|
93 |
+
gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
|
94 |
+
uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
|
95 |
+
uniform_int_sampler_f = lambda a, b : lambda : round(np.random.uniform(a, b))
|
96 |
+
def zipf_sampler_f(a, b, c):
|
97 |
+
x = np.arange(b, c)
|
98 |
+
weights = x ** (-a)
|
99 |
+
weights /= weights.sum()
|
100 |
+
return lambda : stats.rv_discrete(name='bounded_zipf', values=(x, weights)).rvs(1)
|
101 |
+
scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum))
|
102 |
+
|
103 |
+
|
104 |
+
def normalize_by_used_features_f(x, num_features_used, num_features, normalize_with_sqrt=False):
|
105 |
+
if normalize_with_sqrt:
|
106 |
+
return x / (num_features_used / num_features)**(1 / 2)
|
107 |
+
return x / (num_features_used / num_features)
|
108 |
+
|
109 |
+
|
110 |
+
def order_by_y(x, y):
|
111 |
+
order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
|
112 |
+
order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
|
113 |
+
x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
|
114 |
+
y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
|
115 |
+
|
116 |
+
return x, y
|
117 |
+
|
118 |
+
def randomize_classes(x, num_classes):
|
119 |
+
classes = torch.arange(0, num_classes, device=x.device)
|
120 |
+
random_classes = torch.randperm(num_classes, device=x.device).type(x.type())
|
121 |
+
x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class CategoricalActivation(nn.Module):
|
126 |
+
def __init__(self, categorical_p=0.1, ordered_p=0.7
|
127 |
+
, keep_activation_size=False
|
128 |
+
, num_classes_sampler=zipf_sampler_f(0.8, 1, 10)):
|
129 |
+
self.categorical_p = categorical_p
|
130 |
+
self.ordered_p = ordered_p
|
131 |
+
self.keep_activation_size = keep_activation_size
|
132 |
+
self.num_classes_sampler = num_classes_sampler
|
133 |
+
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
# x shape: T, B, H
|
138 |
+
|
139 |
+
x = nn.Softsign()(x)
|
140 |
+
|
141 |
+
num_classes = self.num_classes_sampler()
|
142 |
+
hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None
|
143 |
+
|
144 |
+
categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p
|
145 |
+
class_boundaries = torch.zeros((num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype)
|
146 |
+
# Sample a different index for each hidden dimension, but shared for all batches
|
147 |
+
for b in range(x.shape[1]):
|
148 |
+
for h in range(x.shape[2]):
|
149 |
+
ind = torch.randint(0, x.shape[0], (num_classes - 1,))
|
150 |
+
class_boundaries[:, b, h] = x[ind, b, h]
|
151 |
+
|
152 |
+
for b in range(x.shape[1]):
|
153 |
+
x_rel = x[:, b, categorical_classes[b]]
|
154 |
+
boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1)
|
155 |
+
x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum(dim=0).float() - num_classes / 2
|
156 |
+
|
157 |
+
ordered_classes = torch.rand((x.shape[1],x.shape[2])) < self.ordered_p
|
158 |
+
ordered_classes = torch.logical_and(ordered_classes, categorical_classes)
|
159 |
+
x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes)
|
160 |
+
|
161 |
+
x = x * hid_strength if self.keep_activation_size else x
|
162 |
+
|
163 |
+
return x
|
TabPFN/requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please use python V 3.7 to be compatible with all packages
|
2 |
+
gpytorch==1.5.0
|
3 |
+
torch==1.9.0
|
4 |
+
scikit-learn==0.24.2
|
5 |
+
pyyaml==5.4.1
|
6 |
+
seaborn==0.11.2
|
7 |
+
xgboost==1.4.0
|
8 |
+
tqdm==4.62.1
|
9 |
+
numpy==1.21.2
|
10 |
+
openml==0.12.2
|
11 |
+
catboost==0.26.1
|
12 |
+
auto-sklearn==0.14.5
|
13 |
+
hyperopt==0.2.5
|
14 |
+
configspace==0.4.21
|
15 |
+
# autogluon==0.4.0
|
TabPFN/scripts/baseline_prediction_interface.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def baseline_predict(metric_function, eval_xs, eval_ys, categorical_feats, metric_used=None, eval_pos=2, max_time=300, **kwargs):
|
5 |
+
"""
|
6 |
+
Baseline prediction interface.
|
7 |
+
:param metric_function:
|
8 |
+
:param eval_xs:
|
9 |
+
:param eval_ys:
|
10 |
+
:param categorical_feats:
|
11 |
+
:param metric_used:
|
12 |
+
:param eval_pos:
|
13 |
+
:param max_time: Scheduled maximum time
|
14 |
+
:param kwargs:
|
15 |
+
:return: list [np.array(metrics), np.array(outputs), best_configs] or [None, None, None] if failed
|
16 |
+
"""
|
17 |
+
|
18 |
+
metrics = []
|
19 |
+
outputs = []
|
20 |
+
best_configs = []
|
21 |
+
eval_splits = list(zip(eval_xs.transpose(0, 1), eval_ys.transpose(0, 1)))
|
22 |
+
for eval_x, eval_y in tqdm.tqdm(eval_splits, desc='Calculating splits'+str(metric_function)+' '+str(eval_pos)):
|
23 |
+
try:
|
24 |
+
metric, output, best_config = metric_function(eval_x[:eval_pos],
|
25 |
+
eval_y[:eval_pos],
|
26 |
+
eval_x[eval_pos:],
|
27 |
+
eval_y[eval_pos:],
|
28 |
+
categorical_feats,
|
29 |
+
metric_used=metric_used
|
30 |
+
, max_time=max_time)
|
31 |
+
metrics += [metric]
|
32 |
+
outputs += [output]
|
33 |
+
best_configs += [best_config]
|
34 |
+
return np.array(metrics), np.array(outputs), best_configs
|
35 |
+
except Exception as e:
|
36 |
+
print(f'There was an exception in {metric_function}')
|
37 |
+
print(e)
|
38 |
+
return None, None, None
|
TabPFN/scripts/differentiable_pfn_evaluation.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import pickle
|
6 |
+
from scripts import tabular_metrics
|
7 |
+
from scripts.tabular_metrics import calculate_score_per_method
|
8 |
+
from scripts.tabular_evaluation import evaluate
|
9 |
+
from priors.differentiable_prior import draw_random_style
|
10 |
+
from tqdm import tqdm
|
11 |
+
from pathlib import Path
|
12 |
+
import random
|
13 |
+
from model_builder import load_model
|
14 |
+
from scripts.transformer_prediction_interface import get_params_from_config
|
15 |
+
|
16 |
+
"""
|
17 |
+
===============================
|
18 |
+
PUBLIC FUNCTIONS FOR EVALUATION
|
19 |
+
===============================
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
def eval_model_range(i_range, *args, **kwargs):
|
24 |
+
for i in i_range:
|
25 |
+
eval_model(i, *args, **kwargs)
|
26 |
+
|
27 |
+
|
28 |
+
def load_model_workflow(i, e, add_name, base_path, device='cpu', eval_addition=''):
|
29 |
+
"""
|
30 |
+
Workflow for loading a model and setting appropriate parameters for diffable hparam tuning.
|
31 |
+
|
32 |
+
:param i:
|
33 |
+
:param e:
|
34 |
+
:param eval_positions_valid:
|
35 |
+
:param add_name:
|
36 |
+
:param base_path:
|
37 |
+
:param device:
|
38 |
+
:param eval_addition:
|
39 |
+
:return:
|
40 |
+
"""
|
41 |
+
def check_file(e):
|
42 |
+
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
|
43 |
+
model_path = os.path.join(base_path, model_file)
|
44 |
+
# print('Evaluate ', model_path)
|
45 |
+
results_file = os.path.join(base_path,
|
46 |
+
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
|
47 |
+
if not Path(model_path).is_file(): # or Path(results_file).is_file():
|
48 |
+
return None, None, None
|
49 |
+
return model_file, model_path, results_file
|
50 |
+
|
51 |
+
model_file = None
|
52 |
+
if e == -1:
|
53 |
+
for e_ in range(100, -1, -1):
|
54 |
+
model_file_, model_path_, results_file_ = check_file(e_)
|
55 |
+
if model_file_ is not None:
|
56 |
+
e = e_
|
57 |
+
model_file, model_path, results_file = model_file_, model_path_, results_file_
|
58 |
+
break
|
59 |
+
else:
|
60 |
+
model_file, model_path, results_file = check_file(e)
|
61 |
+
|
62 |
+
if model_file is None:
|
63 |
+
print('No checkpoint found')
|
64 |
+
return None
|
65 |
+
|
66 |
+
print(f'Loading {model_file}')
|
67 |
+
|
68 |
+
model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False)
|
69 |
+
|
70 |
+
return model, c, results_file
|
71 |
+
|
72 |
+
|
73 |
+
def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
|
74 |
+
bptt_valid,
|
75 |
+
bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
|
76 |
+
"""
|
77 |
+
Differentiable model evaliation workflow. Evaluates and saves results to disk.
|
78 |
+
|
79 |
+
:param i:
|
80 |
+
:param e:
|
81 |
+
:param valid_datasets:
|
82 |
+
:param test_datasets:
|
83 |
+
:param train_datasets:
|
84 |
+
:param eval_positions_valid:
|
85 |
+
:param eval_positions_test:
|
86 |
+
:param bptt_valid:
|
87 |
+
:param bptt_test:
|
88 |
+
:param add_name:
|
89 |
+
:param base_path:
|
90 |
+
:param device:
|
91 |
+
:param eval_addition:
|
92 |
+
:param extra_tuning_args:
|
93 |
+
:return:
|
94 |
+
"""
|
95 |
+
model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition)
|
96 |
+
params = {'bptt': bptt_valid
|
97 |
+
, 'bptt_final': bptt_test
|
98 |
+
, 'eval_positions': eval_positions_valid
|
99 |
+
, 'eval_positions_test': eval_positions_test
|
100 |
+
, 'valid_datasets': valid_datasets
|
101 |
+
, 'test_datasets': test_datasets
|
102 |
+
, 'train_datasets': train_datasets
|
103 |
+
, 'verbose': True
|
104 |
+
, 'device': device
|
105 |
+
}
|
106 |
+
|
107 |
+
params.update(get_params_from_config(c))
|
108 |
+
|
109 |
+
start = time.time()
|
110 |
+
metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params,
|
111 |
+
**extra_tuning_args)
|
112 |
+
print('Evaluation time: ', time.time() - start)
|
113 |
+
|
114 |
+
print(results_file)
|
115 |
+
r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
|
116 |
+
with open(results_file, 'wb') as output:
|
117 |
+
del r[0]['num_features_used']
|
118 |
+
del r[0]['categorical_features_sampler']
|
119 |
+
pickle.dump(r, output)
|
120 |
+
|
121 |
+
_, _, _, style, temperature, _ = r
|
122 |
+
|
123 |
+
return r, model
|
124 |
+
|
125 |
+
"""
|
126 |
+
===============================
|
127 |
+
INTERNAL HELPER FUNCTIONS
|
128 |
+
===============================
|
129 |
+
"""
|
130 |
+
|
131 |
+
def evaluate_differentiable_model(model
|
132 |
+
, valid_datasets
|
133 |
+
, test_datasets
|
134 |
+
, train_datasets
|
135 |
+
, N_draws=100
|
136 |
+
, N_grad_steps=10
|
137 |
+
, eval_positions=None
|
138 |
+
, eval_positions_test=None
|
139 |
+
, bptt=100
|
140 |
+
, bptt_final=200
|
141 |
+
, style=None
|
142 |
+
, n_parallel_configurations=1
|
143 |
+
, device='cpu'
|
144 |
+
, selection_metric='auc'
|
145 |
+
, final_splits=[1, 2, 3, 4, 5]
|
146 |
+
, N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
|
147 |
+
, **kwargs):
|
148 |
+
"""
|
149 |
+
Evaluation function for diffable model evaluation. Returns a list of results.
|
150 |
+
|
151 |
+
:param model:
|
152 |
+
:param valid_datasets:
|
153 |
+
:param test_datasets:
|
154 |
+
:param train_datasets:
|
155 |
+
:param N_draws:
|
156 |
+
:param N_grad_steps:
|
157 |
+
:param eval_positions:
|
158 |
+
:param eval_positions_test:
|
159 |
+
:param bptt:
|
160 |
+
:param bptt_final:
|
161 |
+
:param style:
|
162 |
+
:param n_parallel_configurations:
|
163 |
+
:param device:
|
164 |
+
:param selection_metric:
|
165 |
+
:param final_splits:
|
166 |
+
:param N_ensemble_configurations_list:
|
167 |
+
:param kwargs:
|
168 |
+
:return:
|
169 |
+
"""
|
170 |
+
torch.manual_seed(0)
|
171 |
+
np.random.seed(0)
|
172 |
+
random.seed(0)
|
173 |
+
|
174 |
+
diffable_metric = tabular_metrics.cross_entropy
|
175 |
+
evaluation_metric = tabular_metrics.auc_metric
|
176 |
+
if selection_metric in ('auc', 'roc'):
|
177 |
+
selection_metric_min_max = 'max'
|
178 |
+
selection_metric = tabular_metrics.auc_metric
|
179 |
+
evaluation_metric = selection_metric
|
180 |
+
elif selection_metric in ('ce', 'selection_metric'):
|
181 |
+
selection_metric_min_max = 'min'
|
182 |
+
selection_metric = tabular_metrics.cross_entropy
|
183 |
+
evaluation_metric = selection_metric
|
184 |
+
|
185 |
+
print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
|
186 |
+
evaluation_metric)
|
187 |
+
print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
|
188 |
+
print('eval_positions', eval_positions)
|
189 |
+
|
190 |
+
def evaluate_valid(style, softmax_temperature, results, results_tracked):
|
191 |
+
result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
|
192 |
+
return_tensor=False, inference_mode=True, selection_metric=selection_metric,
|
193 |
+
evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
|
194 |
+
result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
|
195 |
+
results += [result_valid]
|
196 |
+
results_tracked += [np.nanmean(result_valid)]
|
197 |
+
|
198 |
+
model[2].to(device)
|
199 |
+
model[2].eval()
|
200 |
+
|
201 |
+
results_on_valid, results_on_valid_tracked = [], []
|
202 |
+
best_style, best_softmax_temperature = style, torch.cat(
|
203 |
+
[torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
|
204 |
+
optimization_routes = []
|
205 |
+
|
206 |
+
best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
207 |
+
0)
|
208 |
+
best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
209 |
+
0)
|
210 |
+
|
211 |
+
|
212 |
+
for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
|
213 |
+
style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
214 |
+
0)
|
215 |
+
softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
216 |
+
0)
|
217 |
+
|
218 |
+
evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)
|
219 |
+
|
220 |
+
print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')
|
221 |
+
|
222 |
+
if N_grad_steps > 0:
|
223 |
+
gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
|
224 |
+
, softmax_temperature=softmax_temperature
|
225 |
+
, model=model[2]
|
226 |
+
, train_datasets=train_datasets
|
227 |
+
, valid_datasets=valid_datasets
|
228 |
+
, selection_metric_min_max=selection_metric_min_max
|
229 |
+
, **kwargs)
|
230 |
+
optimization_routes += [gradient_optimize_result['optimization_route']]
|
231 |
+
|
232 |
+
evaluate_valid(gradient_optimize_result['best_style']
|
233 |
+
, gradient_optimize_result['best_temperature']
|
234 |
+
, results_on_valid, results_on_valid_tracked)
|
235 |
+
|
236 |
+
print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')
|
237 |
+
|
238 |
+
if selection_metric_min_max == 'min':
|
239 |
+
is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
|
240 |
+
else:
|
241 |
+
is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))
|
242 |
+
|
243 |
+
if is_best or best_style is None:
|
244 |
+
best_style = gradient_optimize_result['best_style'].clone()
|
245 |
+
best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
|
246 |
+
torch.cuda.empty_cache()
|
247 |
+
|
248 |
+
def final_evaluation():
|
249 |
+
print('Running eval dataset with final params (no gradients)..')
|
250 |
+
print(best_style, best_softmax_temperature)
|
251 |
+
result_test = []
|
252 |
+
for N_ensemble_configurations in N_ensemble_configurations_list:
|
253 |
+
print(f'Running with {N_ensemble_configurations} ensemble_configurations')
|
254 |
+
kwargs['N_ensemble_configurations'] = N_ensemble_configurations
|
255 |
+
splits = []
|
256 |
+
for split in final_splits:
|
257 |
+
splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature
|
258 |
+
, return_tensor=False, eval_positions=eval_positions_test,
|
259 |
+
bptt=bptt_final, inference_mode=True, split_number=split, model=model[2]
|
260 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)]
|
261 |
+
result_test += [splits]
|
262 |
+
|
263 |
+
print('Running valid dataset with final params (no gradients)..')
|
264 |
+
result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature
|
265 |
+
, return_tensor=False, eval_positions=eval_positions_test,
|
266 |
+
bptt=bptt_final, inference_mode=True, model=model[2]
|
267 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)
|
268 |
+
|
269 |
+
return result_test, result_valid
|
270 |
+
|
271 |
+
result_test, result_valid = final_evaluation()
|
272 |
+
|
273 |
+
return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes
|
274 |
+
|
275 |
+
|
276 |
+
def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
|
277 |
+
def step():
|
278 |
+
return evaluate(datasets=ds,
|
279 |
+
method='transformer'
|
280 |
+
, overwrite=True
|
281 |
+
, style=used_style
|
282 |
+
, eval_positions=eval_positions
|
283 |
+
, metric_used=selection_metric
|
284 |
+
, save=False
|
285 |
+
, path_interfix=None
|
286 |
+
, base_path=None
|
287 |
+
, verbose=True
|
288 |
+
, **kwargs)
|
289 |
+
|
290 |
+
if return_tensor:
|
291 |
+
r = step()
|
292 |
+
else:
|
293 |
+
with torch.no_grad():
|
294 |
+
r = step()
|
295 |
+
|
296 |
+
calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean')
|
297 |
+
calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean')
|
298 |
+
|
299 |
+
return r
|
300 |
+
|
301 |
+
|
302 |
+
def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
|
303 |
+
limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
|
304 |
+
"""
|
305 |
+
Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
|
306 |
+
|
307 |
+
:param model:
|
308 |
+
:param init_style:
|
309 |
+
:param steps:
|
310 |
+
:param learning_rate:
|
311 |
+
:param softmax_temperature:
|
312 |
+
:param train_datasets:
|
313 |
+
:param valid_datasets:
|
314 |
+
:param optimize_all:
|
315 |
+
:param limit_style:
|
316 |
+
:param N_datasets_sampled:
|
317 |
+
:param optimize_softmax_temperature:
|
318 |
+
:param selection_metric_min_max:
|
319 |
+
:param kwargs:
|
320 |
+
:return:
|
321 |
+
"""
|
322 |
+
grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True)
|
323 |
+
|
324 |
+
best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None
|
325 |
+
softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature)
|
326 |
+
variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature]
|
327 |
+
optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate)
|
328 |
+
|
329 |
+
optimization_route_selection, optimization_route_diffable = [], []
|
330 |
+
optimization_route_selection_valid, optimization_route_diffable_valid = [], []
|
331 |
+
|
332 |
+
def eval_opt(ds, return_tensor=True, inference_mode=False):
|
333 |
+
result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
|
334 |
+
, inference_mode=inference_mode, model=model[2], **kwargs)
|
335 |
+
|
336 |
+
diffable_metric = result['mean_metric']
|
337 |
+
selection_metric = result['mean_select']
|
338 |
+
|
339 |
+
return diffable_metric, selection_metric
|
340 |
+
|
341 |
+
def eval_all_datasets(datasets, propagate=True):
|
342 |
+
selection_metrics_this_step, diffable_metrics_this_step = [], []
|
343 |
+
for ds in datasets:
|
344 |
+
diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate))
|
345 |
+
if not torch.isnan(diffable_metric_train).any():
|
346 |
+
if propagate and diffable_metric_train.requires_grad == True:
|
347 |
+
diffable_metric_train.backward()
|
348 |
+
selection_metrics_this_step += [selection_metric_train]
|
349 |
+
diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())]
|
350 |
+
diffable_metric_train = np.nanmean(diffable_metrics_this_step)
|
351 |
+
selection_metric_train = np.nanmean(selection_metrics_this_step)
|
352 |
+
|
353 |
+
return diffable_metric_train, selection_metric_train
|
354 |
+
|
355 |
+
for t in tqdm(range(steps), desc='Iterate over Optimization steps'):
|
356 |
+
optimizer.zero_grad()
|
357 |
+
|
358 |
+
# Select subset of datasets
|
359 |
+
random.seed(t)
|
360 |
+
train_datasets_ = random.sample(train_datasets, N_datasets_sampled)
|
361 |
+
|
362 |
+
# Get score on train
|
363 |
+
diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True)
|
364 |
+
optimization_route_selection += [float(selection_metric_train)]
|
365 |
+
optimization_route_diffable += [float(diffable_metric_train)]
|
366 |
+
|
367 |
+
# Get score on valid
|
368 |
+
diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False)
|
369 |
+
optimization_route_selection_valid += [float(selection_metric_valid)]
|
370 |
+
optimization_route_diffable_valid += [float(diffable_metric_valid)]
|
371 |
+
|
372 |
+
is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
|
373 |
+
is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
|
374 |
+
if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best):
|
375 |
+
print('New best', best_selection_metric, selection_metric_valid)
|
376 |
+
best_style = grad_style.detach().clone()
|
377 |
+
best_temperature = softmax_temperature.detach().clone()
|
378 |
+
best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid
|
379 |
+
|
380 |
+
optimizer.step()
|
381 |
+
|
382 |
+
if limit_style:
|
383 |
+
grad_style = grad_style.detach().clamp(-1.74, 1.74)
|
384 |
+
|
385 |
+
print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' +
|
386 |
+
f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}')
|
387 |
+
|
388 |
+
print(f'Return best:{best_style} {best_selection_metric}')
|
389 |
+
return {'best_style': best_style, 'best_temperature': best_temperature
|
390 |
+
, 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable,
|
391 |
+
'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}
|
TabPFN/scripts/model_configs.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from priors.utils import uniform_int_sampler_f
|
3 |
+
from priors.differentiable_prior import DifferentiableHyperparameter
|
4 |
+
from ConfigSpace import hyperparameters as CSH
|
5 |
+
import torch
|
6 |
+
from priors.differentiable_prior import replace_differentiable_distributions
|
7 |
+
|
8 |
+
import ConfigSpace as CS
|
9 |
+
|
10 |
+
def get_general_config(max_features, bptt, eval_positions=None):
|
11 |
+
""""
|
12 |
+
Returns the general PFN training hyperparameters.
|
13 |
+
"""
|
14 |
+
config_general = {
|
15 |
+
"lr": CSH.UniformFloatHyperparameter('lr', lower=0.00002, upper=0.0002, log=True),
|
16 |
+
"dropout": CSH.CategoricalHyperparameter('dropout', [0.0]),
|
17 |
+
"emsize": CSH.CategoricalHyperparameter('emsize', [2 ** i for i in range(8, 9)]), ## upper bound is -1
|
18 |
+
"batch_size": CSH.CategoricalHyperparameter('batch_size', [2 ** i for i in range(8, 9)]),
|
19 |
+
"nlayers": CSH.CategoricalHyperparameter('nlayers', [12]),
|
20 |
+
"num_features": max_features,
|
21 |
+
"nhead": CSH.CategoricalHyperparameter('nhead', [4]),
|
22 |
+
"nhid_factor": 2,
|
23 |
+
"bptt": bptt,
|
24 |
+
"eval_positions": None,
|
25 |
+
"seq_len_used": bptt,
|
26 |
+
"sampling": 'normal',#hp.choice('sampling', ['mixed', 'normal']), # uniform
|
27 |
+
"epochs": 80,
|
28 |
+
"num_steps": 100,
|
29 |
+
"verbose": False,
|
30 |
+
"pre_sample_causes": True, # This is MLP
|
31 |
+
"mix_activations": False,#hp.choice('mix_activations', [True, False]),
|
32 |
+
}
|
33 |
+
|
34 |
+
return config_general
|
35 |
+
|
36 |
+
def get_flexible_categorical_config(max_features):
|
37 |
+
""""
|
38 |
+
Returns the configuration parameters for the tabular multiclass wrapper.
|
39 |
+
"""
|
40 |
+
config_flexible_categorical = {
|
41 |
+
"nan_prob_unknown_reason_reason_prior": CSH.CategoricalHyperparameter('nan_prob_unknown_reason_reason_prior', [1.0]),
|
42 |
+
"categorical_feature_p": CSH.CategoricalHyperparameter('categorical_feature_p', [0.0]),
|
43 |
+
"nan_prob_no_reason": CSH.CategoricalHyperparameter('nan_prob_no_reason', [0.0, 0.1, 0.2]),
|
44 |
+
"nan_prob_unknown_reason": CSH.CategoricalHyperparameter('nan_prob_unknown_reason', [0.0]),
|
45 |
+
"nan_prob_a_reason": CSH.CategoricalHyperparameter('nan_prob_a_reason', [0.0]),
|
46 |
+
# "num_classes": lambda : random.randint(2, 10), "balanced": False,
|
47 |
+
"max_num_classes": 2,
|
48 |
+
"num_classes": 2,
|
49 |
+
"noise_type": CSH.CategoricalHyperparameter('noise_type', ["Gaussian"]), # NN
|
50 |
+
"balanced": True,
|
51 |
+
"normalize_to_ranking": CSH.CategoricalHyperparameter('normalize_to_ranking', [False]),
|
52 |
+
"set_value_to_nan": CSH.CategoricalHyperparameter('set_value_to_nan', [0.5, 0.2, 0.0]),
|
53 |
+
"normalize_by_used_features": True,
|
54 |
+
"num_features_used":
|
55 |
+
{'uniform_int_sampler_f(3,max_features)': uniform_int_sampler_f(1, max_features)}
|
56 |
+
# hp.choice('conv_activation', [{'distribution': 'uniform', 'min': 2.0, 'max': 8.0}, None]),
|
57 |
+
}
|
58 |
+
return config_flexible_categorical
|
59 |
+
|
60 |
+
def get_diff_flex():
|
61 |
+
""""
|
62 |
+
Returns the configuration parameters for a differentiable wrapper around the tabular multiclass wrapper.
|
63 |
+
"""
|
64 |
+
diff_flex = {
|
65 |
+
# "ordinal_pct": {'distribution': 'uniform', 'min': 0.0, 'max': 0.5},
|
66 |
+
# "num_categorical_features_sampler_a": hp.choice('num_categorical_features_sampler_a',
|
67 |
+
# [{'distribution': 'uniform', 'min': 0.3, 'max': 0.9}, None]),
|
68 |
+
# "num_categorical_features_sampler_b": {'distribution': 'uniform', 'min': 0.3, 'max': 0.9},
|
69 |
+
"output_multiclass_ordered_p": {'distribution': 'uniform', 'min': 0.0, 'max': 0.5}, #CSH.CategoricalHyperparameter('output_multiclass_ordered_p', [0.0, 0.1, 0.2]),
|
70 |
+
"multiclass_type": {'distribution': 'meta_choice', 'choice_values': ['value', 'rank']},
|
71 |
+
}
|
72 |
+
|
73 |
+
return diff_flex
|
74 |
+
|
75 |
+
def get_diff_gp():
|
76 |
+
""""
|
77 |
+
Returns the configuration parameters for a differentiable wrapper around GP.
|
78 |
+
"""
|
79 |
+
diff_gp = {
|
80 |
+
'outputscale': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10., 'min_mean': 0.00001, 'round': False,
|
81 |
+
'lower_bound': 0},
|
82 |
+
'lengthscale': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10., 'min_mean': 0.00001, 'round': False,
|
83 |
+
'lower_bound': 0},
|
84 |
+
'noise': {'distribution': 'meta_choice', 'choice_values': [0.00001, 0.0001, 0.01]}
|
85 |
+
}
|
86 |
+
|
87 |
+
return diff_gp
|
88 |
+
|
89 |
+
def get_diff_causal():
|
90 |
+
""""
|
91 |
+
Returns the configuration parameters for a differentiable wrapper around MLP / Causal mixture.
|
92 |
+
"""
|
93 |
+
diff_causal = {
|
94 |
+
"num_layers": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 6, 'min_mean': 1, 'round': True,
|
95 |
+
'lower_bound': 2},
|
96 |
+
# Better beta?
|
97 |
+
"prior_mlp_hidden_dim": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 130, 'min_mean': 5,
|
98 |
+
'round': True, 'lower_bound': 4},
|
99 |
+
|
100 |
+
"prior_mlp_dropout_prob": {'distribution': 'meta_beta', 'scale': 0.9, 'min': 0.1, 'max': 5.0},
|
101 |
+
# This mustn't be too high since activations get too large otherwise
|
102 |
+
|
103 |
+
"noise_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': .3, 'min_mean': 0.0001, 'round': False,
|
104 |
+
'lower_bound': 0.0},
|
105 |
+
"init_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10.0, 'min_mean': 0.01, 'round': False,
|
106 |
+
'lower_bound': 0.0},
|
107 |
+
"num_causes": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 12, 'min_mean': 1, 'round': True,
|
108 |
+
'lower_bound': 1},
|
109 |
+
"is_causal": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
110 |
+
"pre_sample_weights": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
111 |
+
"y_is_effect": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
112 |
+
"prior_mlp_activations": {'distribution': 'meta_choice_mixed', 'choice_values': [
|
113 |
+
torch.nn.Tanh
|
114 |
+
, torch.nn.ReLU
|
115 |
+
, torch.nn.Identity
|
116 |
+
, lambda : torch.nn.LeakyReLU(negative_slope=0.1)
|
117 |
+
, torch.nn.ELU
|
118 |
+
]},
|
119 |
+
"block_wise_dropout": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
120 |
+
"sort_features": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
121 |
+
"in_clique": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
122 |
+
}
|
123 |
+
|
124 |
+
return diff_causal
|
125 |
+
|
126 |
+
def get_diff_prior_bag():
|
127 |
+
""""
|
128 |
+
Returns the configuration parameters for a GP and MLP / Causal mixture.
|
129 |
+
"""
|
130 |
+
diff_prior_bag = {
|
131 |
+
'prior_bag_exp_weights_1': {'distribution': 'uniform', 'min': 100000., 'max': 100001.},
|
132 |
+
# MLP Weight (Biased, since MLP works better, 1.0 is weight for prior number 0)
|
133 |
+
}
|
134 |
+
|
135 |
+
return diff_prior_bag
|
136 |
+
|
137 |
+
def get_diff_config():
|
138 |
+
""""
|
139 |
+
Returns the configuration parameters for a differentiable wrapper around GP and MLP / Causal mixture priors.
|
140 |
+
"""
|
141 |
+
diff_prior_bag = get_diff_prior_bag()
|
142 |
+
diff_causal = get_diff_causal()
|
143 |
+
diff_gp = get_diff_gp()
|
144 |
+
diff_flex = get_diff_flex()
|
145 |
+
|
146 |
+
config_diff = {'differentiable_hyperparameters': {**diff_prior_bag, **diff_causal, **diff_gp, **diff_flex}}
|
147 |
+
|
148 |
+
return config_diff
|
149 |
+
|
150 |
+
|
151 |
+
def sample_differentiable(config):
|
152 |
+
""""
|
153 |
+
Returns sampled hyperparameters from a differentiable wrapper, that is it makes a non-differentiable out of
|
154 |
+
differentiable.
|
155 |
+
"""
|
156 |
+
# config is a dict of dicts, dicts that have a 'distribution' key are treated as distributions to be sampled
|
157 |
+
result = deepcopy(config)
|
158 |
+
del result['differentiable_hyperparameters']
|
159 |
+
|
160 |
+
for k, v in config['differentiable_hyperparameters'].items():
|
161 |
+
s_indicator, s_hp = DifferentiableHyperparameter(**v, embedding_dim=None,
|
162 |
+
device=None)() # both of these are actually not used to the best of my knowledge
|
163 |
+
result[k] = s_hp
|
164 |
+
|
165 |
+
return result
|
166 |
+
|
167 |
+
def list_all_hps_in_nested(config):
|
168 |
+
""""
|
169 |
+
Returns a list of hyperparameters from a neszed dict of hyperparameters.
|
170 |
+
"""
|
171 |
+
|
172 |
+
if isinstance(config, CSH.Hyperparameter):
|
173 |
+
return [config]
|
174 |
+
elif isinstance(config, dict):
|
175 |
+
result = []
|
176 |
+
for k, v in config.items():
|
177 |
+
result += list_all_hps_in_nested(v)
|
178 |
+
return result
|
179 |
+
else:
|
180 |
+
return []
|
181 |
+
|
182 |
+
def create_configspace_from_hierarchical(config):
|
183 |
+
cs = CS.ConfigurationSpace()
|
184 |
+
for hp in list_all_hps_in_nested(config):
|
185 |
+
cs.add_hyperparameter(hp)
|
186 |
+
return cs
|
187 |
+
|
188 |
+
def fill_in_configsample(config, configsample):
|
189 |
+
# config is our dict that defines config distribution
|
190 |
+
# configsample is a CS.Configuration
|
191 |
+
hierarchical_configsample = deepcopy(config)
|
192 |
+
for k, v in config.items():
|
193 |
+
if isinstance(v, CSH.Hyperparameter):
|
194 |
+
hierarchical_configsample[k] = configsample[v.name]
|
195 |
+
elif isinstance(v, dict):
|
196 |
+
hierarchical_configsample[k] = fill_in_configsample(v, configsample)
|
197 |
+
return hierarchical_configsample
|
198 |
+
|
199 |
+
|
200 |
+
def evaluate_hypers(config, sample_diff_hps=False):
|
201 |
+
""""
|
202 |
+
Samples a hyperparameter configuration from a sampleable configuration (can be used in HP search).
|
203 |
+
"""
|
204 |
+
if sample_diff_hps:
|
205 |
+
# I do a deepcopy here, such that the config stays the same and can still be used with diff. hps
|
206 |
+
config = deepcopy(config)
|
207 |
+
replace_differentiable_distributions(config)
|
208 |
+
cs = create_configspace_from_hierarchical(config)
|
209 |
+
cs_sample = cs.sample_configuration()
|
210 |
+
return fill_in_configsample(config, cs_sample)
|
TabPFN/scripts/tabular_baselines.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from catboost import CatBoostClassifier, Pool
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
from sklearn.impute import SimpleImputer
|
6 |
+
|
7 |
+
import xgboost as xgb
|
8 |
+
from sklearn import neighbors
|
9 |
+
from sklearn.gaussian_process import GaussianProcessClassifier
|
10 |
+
from sklearn.gaussian_process.kernels import RBF
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from scripts import tabular_metrics
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
+
from sklearn.linear_model import LogisticRegression
|
17 |
+
from sklearn.model_selection import cross_val_score
|
18 |
+
import time
|
19 |
+
|
20 |
+
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials , space_eval, rand
|
21 |
+
from sklearn.compose import ColumnTransformer
|
22 |
+
from sklearn.preprocessing import OneHotEncoder
|
23 |
+
from sklearn.preprocessing import MinMaxScaler
|
24 |
+
|
25 |
+
import autosklearn.classification
|
26 |
+
|
27 |
+
CV = 5
|
28 |
+
MULTITHREAD = 1 # Number of threads baselines are able to use at most
|
29 |
+
param_grid, param_grid_hyperopt = {}, {}
|
30 |
+
|
31 |
+
def get_scoring_direction(metric_used):
|
32 |
+
# Not needed
|
33 |
+
if metric_used == tabular_metrics.auc_metric:
|
34 |
+
return -1
|
35 |
+
elif metric_used == tabular_metrics.cross_entropy:
|
36 |
+
return 1
|
37 |
+
else:
|
38 |
+
raise Exception('No scoring string found for metric')
|
39 |
+
|
40 |
+
def get_scoring_string(metric_used, multiclass=True, usage="sklearn_cv"):
|
41 |
+
if metric_used == tabular_metrics.auc_metric:
|
42 |
+
if usage == 'sklearn_cv':
|
43 |
+
return 'roc_auc_ovo'
|
44 |
+
elif usage == 'autogluon':
|
45 |
+
return 'log_loss' # Autogluon crashes when using 'roc_auc' with some datasets usning logloss gives better scores;
|
46 |
+
# We might be able to fix this, but doesn't work out of box.
|
47 |
+
# File bug report? Error happens with dataset robert and fabert
|
48 |
+
if multiclass:
|
49 |
+
return 'roc_auc_ovo_macro'
|
50 |
+
else:
|
51 |
+
return 'roc_auc'
|
52 |
+
elif usage == 'autosklearn':
|
53 |
+
if multiclass:
|
54 |
+
return autosklearn.metrics.log_loss # roc_auc only works for binary, use logloss instead
|
55 |
+
else:
|
56 |
+
return autosklearn.metrics.roc_auc
|
57 |
+
elif usage == 'catboost':
|
58 |
+
return 'MultiClass' # Effectively LogLoss, ROC not available
|
59 |
+
elif usage == 'xgb':
|
60 |
+
return 'logloss'
|
61 |
+
return 'roc_auc'
|
62 |
+
elif metric_used == tabular_metrics.cross_entropy:
|
63 |
+
if usage == 'sklearn_cv':
|
64 |
+
return 'neg_log_loss'
|
65 |
+
elif usage == 'autogluon':
|
66 |
+
return 'log_loss'
|
67 |
+
elif usage == 'autosklearn':
|
68 |
+
return autosklearn.metrics.log_loss
|
69 |
+
elif usage == 'catboost':
|
70 |
+
return 'MultiClass' # Effectively LogLoss
|
71 |
+
return 'logloss'
|
72 |
+
else:
|
73 |
+
raise Exception('No scoring string found for metric')
|
74 |
+
|
75 |
+
def eval_f(params, clf_, x, y, metric_used, start_time, max_time):
|
76 |
+
if time.time() - start_time > max_time:
|
77 |
+
return np.nan
|
78 |
+
scores = cross_val_score(clf_(**params), x, y, cv=CV, scoring=get_scoring_string(metric_used))
|
79 |
+
|
80 |
+
return -np.nanmean(scores)
|
81 |
+
|
82 |
+
def preprocess_impute(x, y, test_x, test_y, impute, one_hot, standardize, cat_features=[]):
|
83 |
+
import warnings
|
84 |
+
def warn(*args, **kwargs):
|
85 |
+
pass
|
86 |
+
|
87 |
+
warnings.warn = warn
|
88 |
+
|
89 |
+
x, y, test_x, test_y = x.cpu().numpy(), y.cpu().long().numpy(), test_x.cpu().numpy(), test_y.cpu().long().numpy()
|
90 |
+
|
91 |
+
if impute:
|
92 |
+
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
|
93 |
+
imp_mean.fit(x)
|
94 |
+
x, test_x = imp_mean.transform(x), imp_mean.transform(test_x)
|
95 |
+
|
96 |
+
if one_hot:
|
97 |
+
def make_pd_from_np(x):
|
98 |
+
data = pd.DataFrame(x)
|
99 |
+
for c in cat_features:
|
100 |
+
data.iloc[:, c] = data.iloc[:, c].astype('int')
|
101 |
+
return data
|
102 |
+
x, test_x = make_pd_from_np(x), make_pd_from_np(test_x)
|
103 |
+
transformer = ColumnTransformer(transformers=[('cat', OneHotEncoder(handle_unknown='ignore', sparse=False), cat_features)], remainder="passthrough")
|
104 |
+
transformer.fit(x)
|
105 |
+
x, test_x = transformer.transform(x), transformer.transform(test_x)
|
106 |
+
|
107 |
+
if standardize:
|
108 |
+
scaler = MinMaxScaler()
|
109 |
+
scaler.fit(x)
|
110 |
+
x, test_x = scaler.transform(x), scaler.transform(test_x)
|
111 |
+
|
112 |
+
return x, y, test_x, test_y
|
113 |
+
|
114 |
+
## Auto Gluon
|
115 |
+
def autogluon_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
116 |
+
from autogluon.tabular import TabularPredictor # Inside function so package can be sued without installation
|
117 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
118 |
+
, one_hot=False
|
119 |
+
, cat_features=cat_features
|
120 |
+
, impute=False
|
121 |
+
, standardize=False)
|
122 |
+
train_data = pd.DataFrame(np.concatenate([x, y[:, np.newaxis]], 1))
|
123 |
+
test_data = pd.DataFrame(np.concatenate([test_x, test_y[:, np.newaxis]], 1))
|
124 |
+
|
125 |
+
# AutoGluon automatically infers datatypes, we don't specify the categorical labels
|
126 |
+
predictor = TabularPredictor(
|
127 |
+
label=train_data.columns[-1],
|
128 |
+
eval_metric=get_scoring_string(metric_used, usage='autogluon', multiclass=(len(np.unique(y)) > 2)),
|
129 |
+
problem_type='multiclass' if len(np.unique(y)) > 2 else 'binary'
|
130 |
+
## seed=int(y[:].sum()) doesn't accept seed
|
131 |
+
).fit(
|
132 |
+
train_data=train_data,
|
133 |
+
time_limit=max_time,
|
134 |
+
presets=['best_quality']
|
135 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
136 |
+
)
|
137 |
+
|
138 |
+
pred = predictor.predict_proba(test_data, as_multiclass=True).values
|
139 |
+
|
140 |
+
metric = metric_used(test_y, pred)
|
141 |
+
|
142 |
+
return metric, pred, predictor.fit_summary()
|
143 |
+
|
144 |
+
## AUTO Sklearn
|
145 |
+
def autosklearn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
146 |
+
return autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=max_time, version=1)
|
147 |
+
|
148 |
+
from autosklearn.experimental.askl2 import AutoSklearn2Classifier
|
149 |
+
from autosklearn.classification import AutoSklearnClassifier
|
150 |
+
def autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300, version=2):
|
151 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
152 |
+
, one_hot=False
|
153 |
+
, cat_features=cat_features
|
154 |
+
, impute=False
|
155 |
+
, standardize=False)
|
156 |
+
|
157 |
+
def make_pd_from_np(x):
|
158 |
+
data = pd.DataFrame(x)
|
159 |
+
for c in cat_features:
|
160 |
+
data.iloc[:, c] = data.iloc[:, c].astype('category')
|
161 |
+
return data
|
162 |
+
|
163 |
+
x = make_pd_from_np(x)
|
164 |
+
test_x = make_pd_from_np(test_x)
|
165 |
+
|
166 |
+
clf_ = AutoSklearn2Classifier if version == 2 else AutoSklearnClassifier
|
167 |
+
clf = clf_(time_left_for_this_task=max_time,
|
168 |
+
memory_limit=4000,
|
169 |
+
n_jobs=MULTITHREAD,
|
170 |
+
seed=int(y[:].sum()),
|
171 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
172 |
+
metric=get_scoring_string(metric_used, usage='autosklearn', multiclass=len(np.unique(y)) > 2))
|
173 |
+
|
174 |
+
# fit model to data
|
175 |
+
clf.fit(x, y)
|
176 |
+
|
177 |
+
pred = clf.predict_proba(test_x)
|
178 |
+
metric = metric_used(test_y, pred)
|
179 |
+
|
180 |
+
return metric, pred, None
|
181 |
+
|
182 |
+
param_grid_hyperopt['logistic'] = {
|
183 |
+
'penalty': hp.choice('penalty', ['l1', 'l2', 'none'])
|
184 |
+
, 'max_iter': hp.randint('max_iter', [50, 500])
|
185 |
+
, 'fit_intercept': hp.choice('fit_intercept', [True, False])
|
186 |
+
, 'C': hp.loguniform('C', -5, math.log(5.0))} # 'normalize': [False],
|
187 |
+
|
188 |
+
def logistic_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
189 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
190 |
+
, one_hot=True, impute=True, standardize=True
|
191 |
+
, cat_features=cat_features)
|
192 |
+
|
193 |
+
def clf_(**params):
|
194 |
+
return LogisticRegression(solver='saga', tol=1e-4, n_jobs=1, **params)
|
195 |
+
|
196 |
+
start_time = time.time()
|
197 |
+
|
198 |
+
def stop(trial):
|
199 |
+
return time.time() - start_time > max_time, []
|
200 |
+
|
201 |
+
best = fmin(
|
202 |
+
fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
203 |
+
space=param_grid_hyperopt['logistic'],
|
204 |
+
algo=rand.suggest,
|
205 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
206 |
+
early_stop_fn=stop,
|
207 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
208 |
+
max_evals=10000)
|
209 |
+
best = space_eval(param_grid_hyperopt['logistic'], best)
|
210 |
+
|
211 |
+
clf = clf_(**best)
|
212 |
+
clf.fit(x, y)
|
213 |
+
|
214 |
+
pred = clf.predict_proba(test_x)
|
215 |
+
metric = metric_used(test_y, pred)
|
216 |
+
|
217 |
+
return metric, pred, best
|
218 |
+
|
219 |
+
## KNN
|
220 |
+
param_grid_hyperopt['knn'] = {'n_neighbors': hp.randint('n_neighbors', 1,16)
|
221 |
+
}
|
222 |
+
def knn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
223 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y,
|
224 |
+
one_hot=True, impute=True, standardize=True,
|
225 |
+
cat_features=cat_features)
|
226 |
+
|
227 |
+
def clf_(**params):
|
228 |
+
return neighbors.KNeighborsClassifier(n_jobs=1, **params)
|
229 |
+
|
230 |
+
start_time = time.time()
|
231 |
+
|
232 |
+
def stop(trial):
|
233 |
+
return time.time() - start_time > max_time, []
|
234 |
+
|
235 |
+
best = fmin(
|
236 |
+
fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
237 |
+
space=param_grid_hyperopt['knn'],
|
238 |
+
algo=rand.suggest,
|
239 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
240 |
+
early_stop_fn=stop,
|
241 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
242 |
+
max_evals=10000)
|
243 |
+
best = space_eval(param_grid_hyperopt['knn'], best)
|
244 |
+
|
245 |
+
clf = clf_(**best)
|
246 |
+
clf.fit(x, y)
|
247 |
+
|
248 |
+
pred = clf.predict_proba(test_x)
|
249 |
+
metric = metric_used(test_y, pred)
|
250 |
+
|
251 |
+
return metric, pred, best
|
252 |
+
|
253 |
+
## GP
|
254 |
+
param_grid_hyperopt['gp'] = {
|
255 |
+
'params_y_scale': hp.loguniform('params_y_scale', math.log(0.05), math.log(5.0)),
|
256 |
+
'params_length_scale': hp.loguniform('params_length_scale', math.log(0.1), math.log(1.0)),
|
257 |
+
'n_jobs': hp.choice('njobs', [1])
|
258 |
+
}
|
259 |
+
def gp_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
260 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y,
|
261 |
+
one_hot=True, impute=True, standardize=True,
|
262 |
+
cat_features=cat_features)
|
263 |
+
|
264 |
+
def clf_(params_y_scale,params_length_scale, **params):
|
265 |
+
return GaussianProcessClassifier(kernel= params_y_scale * RBF(params_length_scale), **params)
|
266 |
+
|
267 |
+
start_time = time.time()
|
268 |
+
def stop(trial):
|
269 |
+
return time.time() - start_time > max_time, []
|
270 |
+
|
271 |
+
|
272 |
+
best = fmin(
|
273 |
+
fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
274 |
+
space=param_grid_hyperopt['gp'],
|
275 |
+
algo=rand.suggest,
|
276 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
277 |
+
early_stop_fn=stop,
|
278 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
279 |
+
max_evals=1000)
|
280 |
+
best = space_eval(param_grid_hyperopt['gp'], best)
|
281 |
+
|
282 |
+
clf = clf_(**best)
|
283 |
+
clf.fit(x, y)
|
284 |
+
|
285 |
+
pred = clf.predict_proba(test_x)
|
286 |
+
metric = metric_used(test_y, pred)
|
287 |
+
|
288 |
+
return metric, pred, best
|
289 |
+
|
290 |
+
|
291 |
+
# Catboost
|
292 |
+
# Hyperparameter space: https://arxiv.org/pdf/2106.03253.pdf
|
293 |
+
|
294 |
+
param_grid_hyperopt['catboost'] = {
|
295 |
+
'learning_rate': hp.loguniform('learning_rate', math.log(math.pow(math.e, -5)), math.log(1)),
|
296 |
+
'random_strength': hp.randint('random_strength', 1, 20),
|
297 |
+
'l2_leaf_reg': hp.loguniform('l2_leaf_reg', math.log(1), math.log(10)),
|
298 |
+
'bagging_temperature': hp.uniform('bagging_temperature', 0., 1),
|
299 |
+
'leaf_estimation_iterations': hp.randint('leaf_estimation_iterations', 1, 20),
|
300 |
+
'iterations': hp.randint('iterations', 100, 4000), # This is smaller than in paper, 4000 leads to ram overusage
|
301 |
+
}
|
302 |
+
|
303 |
+
def catboost_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
304 |
+
print(x)
|
305 |
+
|
306 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
307 |
+
, one_hot=False
|
308 |
+
, cat_features=cat_features
|
309 |
+
, impute=False
|
310 |
+
, standardize=False)
|
311 |
+
|
312 |
+
# Nans in categorical features must be encoded as separate class
|
313 |
+
x[:, cat_features], test_x[:, cat_features] = np.nan_to_num(x[:, cat_features], -1), np.nan_to_num(
|
314 |
+
test_x[:, cat_features], -1)
|
315 |
+
|
316 |
+
def make_pd_from_np(x):
|
317 |
+
data = pd.DataFrame(x)
|
318 |
+
for c in cat_features:
|
319 |
+
data.iloc[:, c] = data.iloc[:, c].astype('int')
|
320 |
+
return data
|
321 |
+
|
322 |
+
x = make_pd_from_np(x)
|
323 |
+
test_x = make_pd_from_np(test_x)
|
324 |
+
|
325 |
+
def clf_(**params):
|
326 |
+
return CatBoostClassifier(
|
327 |
+
loss_function=get_scoring_string(metric_used, usage='catboost'),
|
328 |
+
thread_count = MULTITHREAD,
|
329 |
+
used_ram_limit='4gb',
|
330 |
+
random_seed=int(y[:].sum()),
|
331 |
+
logging_level='Silent',
|
332 |
+
cat_features=cat_features,
|
333 |
+
**params)
|
334 |
+
|
335 |
+
start_time = time.time()
|
336 |
+
def stop(trial):
|
337 |
+
return time.time() - start_time > max_time, []
|
338 |
+
|
339 |
+
best = fmin(
|
340 |
+
fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
341 |
+
space=param_grid_hyperopt['catboost'],
|
342 |
+
algo=rand.suggest,
|
343 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
344 |
+
early_stop_fn=stop,
|
345 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
346 |
+
max_evals=1000)
|
347 |
+
best = space_eval(param_grid_hyperopt['catboost'], best)
|
348 |
+
|
349 |
+
clf = clf_(**best)
|
350 |
+
clf.fit(x, y)
|
351 |
+
|
352 |
+
pred = clf.predict_proba(test_x)
|
353 |
+
metric = metric_used(test_y, pred)
|
354 |
+
|
355 |
+
return metric, pred, best
|
356 |
+
|
357 |
+
|
358 |
+
# XGBoost
|
359 |
+
# Hyperparameter space: https://arxiv.org/pdf/2106.03253.pdf
|
360 |
+
param_grid_hyperopt['xgb'] = {
|
361 |
+
'learning_rate': hp.loguniform('learning_rate', -7, math.log(1)),
|
362 |
+
'max_depth': hp.randint('max_depth', 1, 10),
|
363 |
+
'subsample': hp.uniform('subsample', 0.2, 1),
|
364 |
+
'colsample_bytree': hp.uniform('colsample_bytree', 0.2, 1),
|
365 |
+
'colsample_bylevel': hp.uniform('colsample_bylevel', 0.2, 1),
|
366 |
+
'min_child_weight': hp.loguniform('min_child_weight', -16, 5),
|
367 |
+
'alpha': hp.loguniform('alpha', -16, 2),
|
368 |
+
'lambda': hp.loguniform('lambda', -16, 2),
|
369 |
+
'gamma': hp.loguniform('gamma', -16, 2),
|
370 |
+
'n_estimators': hp.randint('n_estimators', 100, 4000), # This is smaller than in paper
|
371 |
+
}
|
372 |
+
|
373 |
+
def xgb_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
374 |
+
# XGB Documentation:
|
375 |
+
# XGB handles categorical data appropriately without using One Hot Encoding, categorical features are experimetal
|
376 |
+
# XGB handles missing values appropriately without imputation
|
377 |
+
|
378 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
379 |
+
, one_hot=False
|
380 |
+
, cat_features=cat_features
|
381 |
+
, impute=False
|
382 |
+
, standardize=False)
|
383 |
+
|
384 |
+
def clf_(**params):
|
385 |
+
return xgb.XGBClassifier(use_label_encoder=False
|
386 |
+
, nthread=1
|
387 |
+
, **params
|
388 |
+
, eval_metric=get_scoring_string(metric_used, usage='xgb') # AUC not implemented
|
389 |
+
)
|
390 |
+
|
391 |
+
start_time = time.time()
|
392 |
+
def stop(trial):
|
393 |
+
return time.time() - start_time > max_time, []
|
394 |
+
|
395 |
+
best = fmin(
|
396 |
+
fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
397 |
+
space=param_grid_hyperopt['xgb'],
|
398 |
+
algo=rand.suggest,
|
399 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
400 |
+
early_stop_fn=stop,
|
401 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
402 |
+
max_evals=1000)
|
403 |
+
best = space_eval(param_grid_hyperopt['xgb'], best)
|
404 |
+
|
405 |
+
clf = clf_(**best)
|
406 |
+
clf.fit(x, y)
|
407 |
+
|
408 |
+
pred = clf.predict_proba(test_x)
|
409 |
+
metric = metric_used(test_y, pred)
|
410 |
+
|
411 |
+
return metric, pred, best
|
412 |
+
|
413 |
+
|
414 |
+
clf_dict = {'gp': gp_metric
|
415 |
+
, 'knn': knn_metric
|
416 |
+
, 'catboost': catboost_metric
|
417 |
+
, 'xgb': xgb_metric
|
418 |
+
, 'logistic': logistic_metric
|
419 |
+
, 'autosklearn': autosklearn_metric
|
420 |
+
, 'autosklearn2': autosklearn2_metric
|
421 |
+
, 'autogluon': autogluon_metric}
|
TabPFN/scripts/tabular_evaluation.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from utils import torch_nanmean
|
12 |
+
from datasets import *
|
13 |
+
from model_builder import load_model
|
14 |
+
from scripts.tabular_baselines import get_scoring_string
|
15 |
+
from scripts import tabular_metrics
|
16 |
+
from scripts.transformer_prediction_interface import *
|
17 |
+
from scripts.baseline_prediction_interface import *
|
18 |
+
"""
|
19 |
+
===============================
|
20 |
+
PUBLIC FUNCTIONS FOR EVALUATION
|
21 |
+
===============================
|
22 |
+
"""
|
23 |
+
|
24 |
+
|
25 |
+
def eval_model(i, e, valid_datasets, test_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
|
26 |
+
metrics_test, config_sample, model_path = eval_model_on_ds(i, e, test_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
|
27 |
+
metrics_valid, _, _ = eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
|
28 |
+
return {'mean_auc_test': metrics_test['mean_roc_at_1000'], 'mean_auc_valid': metrics_valid['mean_roc_at_1000'], 'mean_ce_test': metrics_test['mean_ce_at_1000'], 'mean_ce_valid': metrics_valid['mean_ce_at_1000'], 'config_sample': config_sample, 'model_path': model_path}
|
29 |
+
|
30 |
+
def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
|
31 |
+
|
32 |
+
# How to use: evaluate_without_fitting(i,0,valid_datasets, [1024], 100000, add_name=model_string, base_path=base_path,)
|
33 |
+
def check_file(e):
|
34 |
+
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
|
35 |
+
model_path = os.path.join(base_path, model_file)
|
36 |
+
# print('Evaluate ', model_path)
|
37 |
+
results_file = os.path.join(base_path,
|
38 |
+
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
|
39 |
+
if not Path(model_path).is_file(): # or Path(results_file).is_file():
|
40 |
+
# print('checkpoint exists: ', Path(model_file).is_file(), ', results are written:', Path(results_file).is_file())
|
41 |
+
return None, None, None
|
42 |
+
return model_file, model_path, results_file
|
43 |
+
|
44 |
+
if e == -1: # use last checkpoint, if e == -1
|
45 |
+
for e_ in range(100, -1, -1):
|
46 |
+
model_file_, model_path_, results_file_ = check_file(e_)
|
47 |
+
if model_file_ is not None:
|
48 |
+
e = e_
|
49 |
+
model_file, model_path, results_file = model_file_, model_path_, results_file_
|
50 |
+
break
|
51 |
+
else:
|
52 |
+
model_file, model_path, results_file = check_file(e)
|
53 |
+
|
54 |
+
model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
|
55 |
+
print(model[2].style_encoder)
|
56 |
+
|
57 |
+
params = {'max_features': config_sample['num_features']
|
58 |
+
, 'rescale_features': config_sample["normalize_by_used_features"]
|
59 |
+
, 'normalize_to_ranking': config_sample["normalize_to_ranking"]
|
60 |
+
, 'normalize_with_sqrt': config_sample.get("normalize_with_sqrt", False)
|
61 |
+
}
|
62 |
+
metrics_valid = evaluate(datasets=valid_datasets, model=model[2], method='transformer', device=device, overwrite=True,
|
63 |
+
extend_features=True
|
64 |
+
# just removed the style keyword but transformer is trained with style, just empty
|
65 |
+
, save=False
|
66 |
+
, metric_used=tabular_metrics.cross_entropy
|
67 |
+
, return_tensor=True
|
68 |
+
, verbose=False
|
69 |
+
, eval_positions=eval_positions
|
70 |
+
, bptt=bptt
|
71 |
+
, base_path=None
|
72 |
+
, inference_mode=True
|
73 |
+
, **params
|
74 |
+
, **kwargs)
|
75 |
+
|
76 |
+
tabular_metrics.calculate_score_per_method(tabular_metrics.auc_metric, 'roc', metrics_valid, valid_datasets, eval_positions)
|
77 |
+
tabular_metrics.calculate_score_per_method(tabular_metrics.cross_entropy, 'ce', metrics_valid, valid_datasets, eval_positions)
|
78 |
+
|
79 |
+
return metrics_valid, config_sample, model_path
|
80 |
+
|
81 |
+
|
82 |
+
def evaluate(datasets, bptt, eval_positions, metric_used, model
|
83 |
+
, verbose=False
|
84 |
+
, return_tensor=False
|
85 |
+
, **kwargs):
|
86 |
+
"""
|
87 |
+
Evaluates a list of datasets for a model function.
|
88 |
+
|
89 |
+
:param datasets: List of datasets
|
90 |
+
:param bptt: maximum sequence length
|
91 |
+
:param eval_positions: List of positions where to evaluate models
|
92 |
+
:param verbose: If True, is verbose.
|
93 |
+
:param metric_used: Which metric is optimized for.
|
94 |
+
:param return_tensor: Wheater to return results as a pytorch.tensor or numpy, this is only relevant for transformer.
|
95 |
+
:param kwargs:
|
96 |
+
:return:
|
97 |
+
"""
|
98 |
+
overall_result = {'metric_used': get_scoring_string(metric_used)
|
99 |
+
, 'bptt': bptt
|
100 |
+
, 'eval_positions': eval_positions}
|
101 |
+
|
102 |
+
aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
|
103 |
+
|
104 |
+
# For each dataset
|
105 |
+
for [ds_name, X, y, categorical_feats, _, _] in tqdm.tqdm(datasets, desc='Iterate over datasets') if verbose else datasets:
|
106 |
+
dataset_bptt = min(len(X), bptt)
|
107 |
+
# if verbose and dataset_bptt < bptt:
|
108 |
+
# print(f'Dataset too small for given sequence length, reducing to {len(X)} ({bptt})')
|
109 |
+
|
110 |
+
aggregated_metric, num = torch.tensor(0.0), 0
|
111 |
+
ds_result = {}
|
112 |
+
|
113 |
+
for eval_position in (eval_positions if verbose else eval_positions):
|
114 |
+
eval_position_real = int(dataset_bptt * 0.5) if 2 * eval_position > dataset_bptt else eval_position
|
115 |
+
eval_position_bptt = int(eval_position_real * 2.0)
|
116 |
+
|
117 |
+
r = evaluate_position(X, y, model=model
|
118 |
+
, num_classes=len(torch.unique(y))
|
119 |
+
, categorical_feats = categorical_feats
|
120 |
+
, bptt = eval_position_bptt
|
121 |
+
, ds_name=ds_name
|
122 |
+
, eval_position = eval_position_real
|
123 |
+
, metric_used = metric_used
|
124 |
+
,**kwargs)
|
125 |
+
|
126 |
+
if r is None:
|
127 |
+
continue
|
128 |
+
|
129 |
+
_, outputs, ys, best_configs, time_used = r
|
130 |
+
|
131 |
+
if torch.is_tensor(outputs):
|
132 |
+
outputs = outputs.to(outputs.device)
|
133 |
+
ys = ys.to(outputs.device)
|
134 |
+
|
135 |
+
ys = ys.T
|
136 |
+
ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
|
137 |
+
ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
|
138 |
+
ds_result[f'{ds_name}_ys_at_{eval_position}'] = ys
|
139 |
+
ds_result[f'{ds_name}_time_at_{eval_position}'] = time_used
|
140 |
+
|
141 |
+
new_metric = torch_nanmean(torch.stack([metric_used(ys[i], outputs[i]) for i in range(ys.shape[0])]))
|
142 |
+
|
143 |
+
if not return_tensor:
|
144 |
+
make_scalar = lambda x: float(x.detach().cpu().numpy()) if (torch.is_tensor(x) and (len(x.shape) == 0)) else x
|
145 |
+
new_metric = make_scalar(new_metric)
|
146 |
+
ds_result = {k: make_scalar(ds_result[k]) for k in ds_result.keys()}
|
147 |
+
|
148 |
+
lib = torch if return_tensor else np
|
149 |
+
if not lib.isnan(new_metric).any():
|
150 |
+
aggregated_metric, num = aggregated_metric + new_metric, num + 1
|
151 |
+
|
152 |
+
overall_result.update(ds_result)
|
153 |
+
if num > 0:
|
154 |
+
aggregated_metric_datasets, num_datasets = (aggregated_metric_datasets + (aggregated_metric / num)), num_datasets + 1
|
155 |
+
|
156 |
+
overall_result['mean_metric'] = aggregated_metric_datasets / num_datasets
|
157 |
+
|
158 |
+
return overall_result
|
159 |
+
|
160 |
+
"""
|
161 |
+
===============================
|
162 |
+
INTERNAL HELPER FUNCTIONS
|
163 |
+
===============================
|
164 |
+
"""
|
165 |
+
|
166 |
+
def check_file_exists(path):
|
167 |
+
"""Checks if a pickle file exists. Returns None if not, else returns the unpickled file."""
|
168 |
+
if (os.path.isfile(path)):
|
169 |
+
print(f'loading results from {path}')
|
170 |
+
with open(path, 'rb') as f:
|
171 |
+
return np.load(f, allow_pickle=True).tolist()
|
172 |
+
return None
|
173 |
+
|
174 |
+
def generate_valid_split(X, y, bptt, eval_position, split_number=1):
|
175 |
+
"""Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
|
176 |
+
the entire datasets. If no such split can be sampled in 7 passes, returns None.
|
177 |
+
|
178 |
+
:param X: torch tensor, feature values
|
179 |
+
:param y: torch tensor, class values
|
180 |
+
:param bptt: Number of samples in train + test
|
181 |
+
:param eval_position: Number of samples in train, i.e. from which index values are in test
|
182 |
+
:param split_number: The split id
|
183 |
+
:return:
|
184 |
+
"""
|
185 |
+
done, seed = False, 13
|
186 |
+
|
187 |
+
torch.manual_seed(split_number)
|
188 |
+
perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
|
189 |
+
X, y = X[perm], y[perm]
|
190 |
+
|
191 |
+
while not done:
|
192 |
+
if seed > 20:
|
193 |
+
return None, None # No split could be generated in 7 passes, return None
|
194 |
+
random.seed(seed)
|
195 |
+
i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
|
196 |
+
y_ = y[i:i + bptt]
|
197 |
+
|
198 |
+
# Checks if all classes from dataset are contained and classes in train and test are equal (contain same
|
199 |
+
# classes) and
|
200 |
+
done = len(torch.unique(y_)) == len(torch.unique(y))
|
201 |
+
done = done and torch.all(torch.unique(y_) == torch.unique(y))
|
202 |
+
done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
|
203 |
+
done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
|
204 |
+
seed = seed + 1
|
205 |
+
|
206 |
+
eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
|
207 |
+
eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
|
208 |
+
|
209 |
+
return eval_xs, eval_ys
|
210 |
+
|
211 |
+
|
212 |
+
def evaluate_position(X, y, categorical_feats, model, bptt
|
213 |
+
, eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
|
214 |
+
, max_time=300, split_number=1
|
215 |
+
, per_step_normalization=False, **kwargs):
|
216 |
+
"""
|
217 |
+
Evaluates a dataset with a 'bptt' number of training samples.
|
218 |
+
|
219 |
+
:param X: Dataset X
|
220 |
+
:param y: Dataset labels
|
221 |
+
:param categorical_feats: Indices of categorical features.
|
222 |
+
:param model: Model function
|
223 |
+
:param bptt: Sequence length.
|
224 |
+
:param eval_position: Number of training samples.
|
225 |
+
:param overwrite: Wheater to ove
|
226 |
+
:param overwrite: If True, results on disk are overwritten.
|
227 |
+
:param save:
|
228 |
+
:param path_interfix: Used for constructing path to write on disk.
|
229 |
+
:param method: Model name.
|
230 |
+
:param ds_name: Datset name.
|
231 |
+
:param fetch_only: Wheater to calculate or only fetch results.
|
232 |
+
:param per_step_normalization:
|
233 |
+
:param kwargs:
|
234 |
+
:return:
|
235 |
+
"""
|
236 |
+
|
237 |
+
if save:
|
238 |
+
path = os.path.join(base_path, f'results/tabular/{path_interfix}/results_{method}_{ds_name}_{eval_position}_{bptt}_{split_number}.npy')
|
239 |
+
#log_path =
|
240 |
+
|
241 |
+
## Load results if on disk
|
242 |
+
if not overwrite:
|
243 |
+
result = check_file_exists(path)
|
244 |
+
if result is not None:
|
245 |
+
if not fetch_only:
|
246 |
+
print(f'Loaded saved result for {path}')
|
247 |
+
return result
|
248 |
+
elif fetch_only:
|
249 |
+
print(f'Could not load saved result for {path}')
|
250 |
+
return None
|
251 |
+
|
252 |
+
## Generate data splits
|
253 |
+
eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position, split_number=split_number)
|
254 |
+
if eval_xs is None:
|
255 |
+
return None
|
256 |
+
print(f"No dataset could be generated {ds_name} {bptt}")
|
257 |
+
|
258 |
+
eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
|
259 |
+
|
260 |
+
start_time = time.time()
|
261 |
+
|
262 |
+
if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
|
263 |
+
outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, categorical_feats=categorical_feats, **kwargs), None
|
264 |
+
else:
|
265 |
+
_, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
|
266 |
+
, eval_pos=eval_position
|
267 |
+
, max_time=max_time, **kwargs)
|
268 |
+
|
269 |
+
eval_ys = eval_ys[eval_position:]
|
270 |
+
if outputs is None:
|
271 |
+
return None
|
272 |
+
|
273 |
+
if torch.is_tensor(outputs): # Transfers data to cpu for saving
|
274 |
+
outputs = outputs.cpu()
|
275 |
+
eval_ys = eval_ys.cpu()
|
276 |
+
|
277 |
+
ds_result = None, outputs, eval_ys, best_configs, time.time() - start_time
|
278 |
+
|
279 |
+
if save:
|
280 |
+
with open(path, 'wb') as f:
|
281 |
+
np.save(f, ds_result)
|
282 |
+
print(f'saved results to {path}')
|
283 |
+
|
284 |
+
return ds_result
|
TabPFN/scripts/tabular_metrics.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
===============================
|
3 |
+
Metrics calculation
|
4 |
+
===============================
|
5 |
+
Includes a few metric as well as functions composing metrics on results files.
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, average_precision_score
|
14 |
+
from scipy.stats import rankdata
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
"""
|
18 |
+
===============================
|
19 |
+
Metrics calculation
|
20 |
+
===============================
|
21 |
+
"""
|
22 |
+
def auc_metric(target, pred, multi_class='ovo', numpy=False):
|
23 |
+
lib = np if numpy else torch
|
24 |
+
try:
|
25 |
+
if not numpy:
|
26 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
27 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
28 |
+
if len(lib.unique(target)) > 2:
|
29 |
+
if not numpy:
|
30 |
+
return torch.tensor(roc_auc_score(target, pred, multi_class=multi_class))
|
31 |
+
return roc_auc_score(target, pred, multi_class=multi_class)
|
32 |
+
else:
|
33 |
+
if len(pred.shape) == 2:
|
34 |
+
pred = pred[:, 1]
|
35 |
+
if not numpy:
|
36 |
+
return torch.tensor(roc_auc_score(target, pred))
|
37 |
+
return roc_auc_score(target, pred)
|
38 |
+
except ValueError as e:
|
39 |
+
print(e)
|
40 |
+
return np.nan
|
41 |
+
|
42 |
+
def accuracy_metric(target, pred):
|
43 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
44 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
45 |
+
if len(torch.unique(target)) > 2:
|
46 |
+
return torch.tensor(accuracy_score(target, torch.argmax(pred, -1)))
|
47 |
+
else:
|
48 |
+
return torch.tensor(accuracy_score(target, pred[:, 1] > 0.5))
|
49 |
+
|
50 |
+
def average_precision_metric(target, pred):
|
51 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
52 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
53 |
+
if len(torch.unique(target)) > 2:
|
54 |
+
return torch.tensor(average_precision_score(target, torch.argmax(pred, -1)))
|
55 |
+
else:
|
56 |
+
return torch.tensor(average_precision_score(target, pred[:, 1] > 0.5))
|
57 |
+
|
58 |
+
def balanced_accuracy_metric(target, pred):
|
59 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
60 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
61 |
+
if len(torch.unique(target)) > 2:
|
62 |
+
return torch.tensor(balanced_accuracy_score(target, torch.argmax(pred, -1)))
|
63 |
+
else:
|
64 |
+
return torch.tensor(balanced_accuracy_score(target, pred[:, 1] > 0.5))
|
65 |
+
|
66 |
+
def cross_entropy(target, pred):
|
67 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
68 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
69 |
+
if len(torch.unique(target)) > 2:
|
70 |
+
ce = torch.nn.CrossEntropyLoss()
|
71 |
+
return ce(pred.float(), target.long())
|
72 |
+
else:
|
73 |
+
bce = torch.nn.BCELoss()
|
74 |
+
return bce(pred[:, 1].float(), target.float())
|
75 |
+
|
76 |
+
def time_metric():
|
77 |
+
"""
|
78 |
+
Dummy function, will just be used as a handler.
|
79 |
+
"""
|
80 |
+
pass
|
81 |
+
|
82 |
+
def count_metric(x, y):
|
83 |
+
"""
|
84 |
+
Dummy function, returns one count per dataset.
|
85 |
+
"""
|
86 |
+
return 1
|
87 |
+
|
88 |
+
"""
|
89 |
+
===============================
|
90 |
+
Metrics composition
|
91 |
+
===============================
|
92 |
+
"""
|
93 |
+
def calculate_score_per_method(metric, name:str, global_results:dict, ds:list, eval_positions:list, aggregator:str='mean'):
|
94 |
+
"""
|
95 |
+
Calculates the metric given by 'metric' and saves it under 'name' in the 'global_results'
|
96 |
+
|
97 |
+
:param metric: Metric function
|
98 |
+
:param name: Name of metric in 'global_results'
|
99 |
+
:param global_results: Dicrtonary containing the results for current method for a collection of datasets
|
100 |
+
:param ds: Dataset to calculate metrics on, a list of dataset properties
|
101 |
+
:param eval_positions: List of positions to calculate metrics on
|
102 |
+
:param aggregator: Specifies way to aggregate results across evaluation positions
|
103 |
+
:return:
|
104 |
+
"""
|
105 |
+
aggregator_f = np.nanmean if aggregator == 'mean' else np.nansum
|
106 |
+
for pos in eval_positions:
|
107 |
+
valid_positions = 0
|
108 |
+
for d in ds:
|
109 |
+
if f'{d[0]}_outputs_at_{pos}' in global_results:
|
110 |
+
preds = global_results[f'{d[0]}_outputs_at_{pos}']
|
111 |
+
y = global_results[f'{d[0]}_ys_at_{pos}']
|
112 |
+
|
113 |
+
preds, y = preds.detach().cpu().numpy() if torch.is_tensor(
|
114 |
+
preds) else preds, y.detach().cpu().numpy() if torch.is_tensor(y) else y
|
115 |
+
|
116 |
+
try:
|
117 |
+
if metric == time_metric:
|
118 |
+
global_results[f'{d[0]}_{name}_at_{pos}'] = global_results[f'{d[0]}_time_at_{pos}']
|
119 |
+
valid_positions = valid_positions + 1
|
120 |
+
else:
|
121 |
+
global_results[f'{d[0]}_{name}_at_{pos}'] = aggregator_f(
|
122 |
+
[metric(y[split], preds[split]) for split in range(y.shape[0])])
|
123 |
+
valid_positions = valid_positions + 1
|
124 |
+
except Exception as err:
|
125 |
+
print(f'Error calculating metric with {err}, {type(err)} at {d[0]} {pos} {name}')
|
126 |
+
global_results[f'{d[0]}_{name}_at_{pos}'] = np.nan
|
127 |
+
else:
|
128 |
+
global_results[f'{d[0]}_{name}_at_{pos}'] = np.nan
|
129 |
+
|
130 |
+
if valid_positions > 0:
|
131 |
+
global_results[f'{aggregator}_{name}_at_{pos}'] = aggregator_f([global_results[f'{d[0]}_{name}_at_{pos}'] for d in ds])
|
132 |
+
else:
|
133 |
+
global_results[f'{aggregator}_{name}_at_{pos}'] = np.nan
|
134 |
+
|
135 |
+
for d in ds:
|
136 |
+
metrics = [global_results[f'{d[0]}_{name}_at_{pos}'] for pos in eval_positions]
|
137 |
+
metrics = [m for m in metrics if not np.isnan(m)]
|
138 |
+
global_results[f'{d[0]}_{aggregator}_{name}'] = aggregator_f(metrics) if len(metrics) > 0 else np.nan
|
139 |
+
|
140 |
+
metrics = [global_results[f'{aggregator}_{name}_at_{pos}'] for pos in eval_positions]
|
141 |
+
metrics = [m for m in metrics if not np.isnan(m)]
|
142 |
+
global_results[f'{aggregator}_{name}'] = aggregator_f(metrics) if len(metrics) > 0 else np.nan
|
143 |
+
|
144 |
+
|
145 |
+
def calculate_score(metric, name, global_results, ds, eval_positions, aggregator='mean', limit_to=''):
|
146 |
+
"""
|
147 |
+
Calls calculate_metrics_by_method with a range of methods. See arguments of that method.
|
148 |
+
:param limit_to: This method will not get metric calculations.
|
149 |
+
"""
|
150 |
+
for m in global_results:
|
151 |
+
if limit_to not in m:
|
152 |
+
continue
|
153 |
+
calculate_score_per_method(metric, name, global_results[m], ds, eval_positions, aggregator=aggregator)
|
154 |
+
|
155 |
+
|
156 |
+
def make_metric_matrix(global_results, methods, pos, name, ds):
|
157 |
+
result = []
|
158 |
+
for m in global_results:
|
159 |
+
result += [[global_results[m][d[0] + '_' + name + '_at_' + str(pos)] for d in ds]]
|
160 |
+
result = np.array(result)
|
161 |
+
result = pd.DataFrame(result.T, index=[d[0] for d in ds], columns=[k[:-8] for k in list(global_results.keys())])
|
162 |
+
|
163 |
+
matrix_means, matrix_stds = [], []
|
164 |
+
|
165 |
+
for method in methods:
|
166 |
+
matrix_means += [result.iloc[:, [(method) in c for c in result.columns]].mean(axis=1)]
|
167 |
+
matrix_stds += [result.iloc[:, [(method) in c for c in result.columns]].std(axis=1)]
|
168 |
+
|
169 |
+
matrix_means = pd.DataFrame(matrix_means, index=methods).T
|
170 |
+
matrix_stds = pd.DataFrame(matrix_stds, index=methods).T
|
171 |
+
|
172 |
+
return matrix_means, matrix_stds
|
173 |
+
|
174 |
+
|
175 |
+
def make_ranks_and_wins_table(matrix):
|
176 |
+
for dss in matrix.T:
|
177 |
+
matrix.loc[dss] = rankdata(-matrix.round(3).loc[dss])
|
178 |
+
ranks_acc = matrix.mean()
|
179 |
+
wins_acc = (matrix == 1).sum()
|
180 |
+
|
181 |
+
return ranks_acc, wins_acc
|
TabPFN/scripts/transformer_prediction_interface.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
|
4 |
+
from torch.utils.checkpoint import checkpoint
|
5 |
+
|
6 |
+
from utils import normalize_data, to_ranking_low_mem, remove_outliers
|
7 |
+
from priors.utils import normalize_by_used_features_f
|
8 |
+
from utils import NOP
|
9 |
+
|
10 |
+
from sklearn.preprocessing import PowerTransformer, QuantileTransformer, RobustScaler
|
11 |
+
|
12 |
+
from notebook_utils import CustomUnpickler
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
from sklearn.base import BaseEstimator, ClassifierMixin
|
16 |
+
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
17 |
+
from sklearn.utils.multiclass import check_classification_targets
|
18 |
+
from sklearn.utils import column_or_1d
|
19 |
+
from pathlib import Path
|
20 |
+
from model_builder import load_model
|
21 |
+
import os
|
22 |
+
|
23 |
+
def load_model_workflow(i, e, add_name, base_path, device='cpu', eval_addition=''):
|
24 |
+
"""
|
25 |
+
Workflow for loading a model and setting appropriate parameters for diffable hparam tuning.
|
26 |
+
|
27 |
+
:param i:
|
28 |
+
:param e:
|
29 |
+
:param eval_positions_valid:
|
30 |
+
:param add_name:
|
31 |
+
:param base_path:
|
32 |
+
:param device:
|
33 |
+
:param eval_addition:
|
34 |
+
:return:
|
35 |
+
"""
|
36 |
+
def check_file(e):
|
37 |
+
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
|
38 |
+
model_path = os.path.join(base_path, model_file)
|
39 |
+
# print('Evaluate ', model_path)
|
40 |
+
results_file = os.path.join(base_path,
|
41 |
+
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
|
42 |
+
if not Path(model_path).is_file(): # or Path(results_file).is_file():
|
43 |
+
return None, None, None
|
44 |
+
return model_file, model_path, results_file
|
45 |
+
|
46 |
+
model_file = None
|
47 |
+
if e == -1:
|
48 |
+
for e_ in range(100, -1, -1):
|
49 |
+
model_file_, model_path_, results_file_ = check_file(e_)
|
50 |
+
if model_file_ is not None:
|
51 |
+
e = e_
|
52 |
+
model_file, model_path, results_file = model_file_, model_path_, results_file_
|
53 |
+
break
|
54 |
+
else:
|
55 |
+
model_file, model_path, results_file = check_file(e)
|
56 |
+
|
57 |
+
if model_file is None:
|
58 |
+
print('No checkpoint found')
|
59 |
+
return None
|
60 |
+
|
61 |
+
print(f'Loading {model_file}')
|
62 |
+
|
63 |
+
model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False)
|
64 |
+
|
65 |
+
return model, c, results_file
|
66 |
+
|
67 |
+
|
68 |
+
class TabPFNClassifier(BaseEstimator, ClassifierMixin):
|
69 |
+
|
70 |
+
def __init__(self, device='cpu', base_path='.'):
|
71 |
+
# Model file specification (Model name, Epoch)
|
72 |
+
model_string = ''
|
73 |
+
i, e = '8x_lr0.0003', -1
|
74 |
+
|
75 |
+
# File which contains result of hyperparameter tuning run: style (i.e. hyperparameters) and a dataframe with results.
|
76 |
+
style_file = 'prior_tuning_result.pkl'
|
77 |
+
|
78 |
+
model, c, results_file = load_model_workflow(i, e, add_name=model_string, base_path=base_path, device=device,
|
79 |
+
eval_addition='')
|
80 |
+
style, temperature = self.load_result_minimal(style_file, i, e, base_path=base_path)
|
81 |
+
|
82 |
+
self.device = device
|
83 |
+
self.base_path = base_path
|
84 |
+
self.model = model
|
85 |
+
self.c = c
|
86 |
+
self.style = style
|
87 |
+
self.temperature = temperature
|
88 |
+
|
89 |
+
self.max_num_features = self.c['num_features']
|
90 |
+
self.max_num_classes = self.c['max_num_classes']
|
91 |
+
|
92 |
+
def load_result_minimal(self, path, i, e, base_path='.'):
|
93 |
+
with open(os.path.join(base_path,path), 'rb') as output:
|
94 |
+
_, _, _, style, temperature, optimization_route = CustomUnpickler(output).load()
|
95 |
+
|
96 |
+
return style, temperature
|
97 |
+
|
98 |
+
def fit(self, X, y):
|
99 |
+
# Check that X and y have correct shape
|
100 |
+
X, y = check_X_y(X, y)
|
101 |
+
y = self._validate_targets(y)
|
102 |
+
|
103 |
+
self.X_ = X
|
104 |
+
self.y_ = y
|
105 |
+
|
106 |
+
if X.shape[1] > self.max_num_features:
|
107 |
+
raise ValueError("The number of features for this classifier is restricted to ", self.max_num_features)
|
108 |
+
if len(np.unique(y)) > self.max_num_classes:
|
109 |
+
raise ValueError("The number of classes for this classifier is restricted to ", self.max_num_classes)
|
110 |
+
|
111 |
+
# Return the classifier
|
112 |
+
return self
|
113 |
+
|
114 |
+
def _validate_targets(self, y):
|
115 |
+
y_ = column_or_1d(y, warn=True)
|
116 |
+
check_classification_targets(y)
|
117 |
+
cls, y = np.unique(y_, return_inverse=True)
|
118 |
+
if len(cls) < 2:
|
119 |
+
raise ValueError(
|
120 |
+
"The number of classes has to be greater than one; got %d class"
|
121 |
+
% len(cls)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.classes_ = cls
|
125 |
+
|
126 |
+
return np.asarray(y, dtype=np.float64, order="C")
|
127 |
+
|
128 |
+
def predict_proba(self, X):
|
129 |
+
# Check is fit had been called
|
130 |
+
check_is_fitted(self)
|
131 |
+
|
132 |
+
# Input validation
|
133 |
+
X = check_array(X)
|
134 |
+
|
135 |
+
X_full = np.concatenate([self.X_, X], axis=0)
|
136 |
+
X_full = torch.tensor(X_full, device=self.device).float().unsqueeze(1)
|
137 |
+
y_full = np.concatenate([self.y_, self.y_[0] + np.zeros_like(X[:, 0])], axis=0)
|
138 |
+
y_full = torch.tensor(y_full, device=self.device).float().unsqueeze(1)
|
139 |
+
|
140 |
+
eval_pos = self.X_.shape[0]
|
141 |
+
|
142 |
+
prediction = transformer_predict(self.model[2], X_full, y_full, eval_pos,
|
143 |
+
device=self.device,
|
144 |
+
style=self.style,
|
145 |
+
inference_mode=True,
|
146 |
+
N_ensemble_configurations=10,
|
147 |
+
softmax_temperature=self.temperature
|
148 |
+
, **get_params_from_config(self.c))
|
149 |
+
prediction_ = prediction.squeeze(0)
|
150 |
+
|
151 |
+
return prediction_.detach().cpu().numpy()
|
152 |
+
|
153 |
+
def predict(self, X, return_winning_probability=False):
|
154 |
+
p = self.predict_proba(X)
|
155 |
+
y = np.argmax(self.predict_proba(X), axis=-1)
|
156 |
+
y = self.classes_.take(np.asarray(y, dtype=np.intp))
|
157 |
+
if return_winning_probability:
|
158 |
+
return y, p.max(axis=-1)
|
159 |
+
return y
|
160 |
+
|
161 |
+
def transformer_predict(model, eval_xs, eval_ys, eval_position,
|
162 |
+
device='cpu',
|
163 |
+
max_features=100,
|
164 |
+
style=None,
|
165 |
+
inference_mode=False,
|
166 |
+
num_classes=2,
|
167 |
+
extend_features=True,
|
168 |
+
normalize_to_ranking=False,
|
169 |
+
softmax_temperature=0.0,
|
170 |
+
multiclass_decoder='permutation',
|
171 |
+
preprocess_transform='mix',
|
172 |
+
categorical_feats=[],
|
173 |
+
feature_shift_decoder=True,
|
174 |
+
N_ensemble_configurations=10,
|
175 |
+
average_logits=True,
|
176 |
+
normalize_with_sqrt=False, **kwargs):
|
177 |
+
"""
|
178 |
+
|
179 |
+
:param model:
|
180 |
+
:param eval_xs:
|
181 |
+
:param eval_ys: should be classes that are 0-indexed and every class until num_classes-1 is present
|
182 |
+
:param eval_position:
|
183 |
+
:param rescale_features:
|
184 |
+
:param device:
|
185 |
+
:param max_features:
|
186 |
+
:param style:
|
187 |
+
:param inference_mode:
|
188 |
+
:param num_classes:
|
189 |
+
:param extend_features:
|
190 |
+
:param normalize_to_ranking:
|
191 |
+
:param softmax_temperature:
|
192 |
+
:param multiclass_decoder:
|
193 |
+
:param preprocess_transform:
|
194 |
+
:param categorical_feats:
|
195 |
+
:param feature_shift_decoder:
|
196 |
+
:param N_ensemble_configurations:
|
197 |
+
:param average_logits:
|
198 |
+
:param normalize_with_sqrt:
|
199 |
+
:param metric_used:
|
200 |
+
:return:
|
201 |
+
"""
|
202 |
+
num_classes = len(torch.unique(eval_ys))
|
203 |
+
|
204 |
+
def predict(eval_xs, eval_ys, used_style, softmax_temperature, return_logits):
|
205 |
+
# Initialize results array size S, B, Classes
|
206 |
+
|
207 |
+
inference_mode_call = torch.inference_mode() if inference_mode else NOP()
|
208 |
+
with inference_mode_call:
|
209 |
+
output = model(
|
210 |
+
(used_style.repeat(eval_xs.shape[1], 1) if used_style is not None else None, eval_xs, eval_ys.float()),
|
211 |
+
single_eval_pos=eval_position)[:, :, 0:num_classes]
|
212 |
+
|
213 |
+
output = output[:, :, 0:num_classes] / torch.exp(softmax_temperature)
|
214 |
+
if not return_logits:
|
215 |
+
output = torch.nn.functional.softmax(output, dim=-1)
|
216 |
+
#else:
|
217 |
+
# output[:, :, 1] = model((style.repeat(eval_xs.shape[1], 1) if style is not None else None, eval_xs, eval_ys.float()),
|
218 |
+
# single_eval_pos=eval_position)
|
219 |
+
|
220 |
+
# output[:, :, 1] = torch.sigmoid(output[:, :, 1]).squeeze(-1)
|
221 |
+
# output[:, :, 0] = 1 - output[:, :, 1]
|
222 |
+
|
223 |
+
#print('RESULTS', eval_ys.shape, torch.unique(eval_ys, return_counts=True), output.mean(axis=0))
|
224 |
+
|
225 |
+
return output
|
226 |
+
|
227 |
+
def preprocess_input(eval_xs, preprocess_transform):
|
228 |
+
import warnings
|
229 |
+
|
230 |
+
if eval_xs.shape[1] > 1:
|
231 |
+
raise Exception("Transforms only allow one batch dim - TODO")
|
232 |
+
if preprocess_transform != 'none':
|
233 |
+
if preprocess_transform == 'power' or preprocess_transform == 'power_all':
|
234 |
+
pt = PowerTransformer(standardize=True)
|
235 |
+
elif preprocess_transform == 'quantile' or preprocess_transform == 'quantile_all':
|
236 |
+
pt = QuantileTransformer(output_distribution='normal')
|
237 |
+
elif preprocess_transform == 'robust' or preprocess_transform == 'robust_all':
|
238 |
+
pt = RobustScaler(unit_variance=True)
|
239 |
+
|
240 |
+
# eval_xs, eval_ys = normalize_data(eval_xs), normalize_data(eval_ys)
|
241 |
+
eval_xs = normalize_data(eval_xs)
|
242 |
+
|
243 |
+
# Removing empty features
|
244 |
+
eval_xs = eval_xs[:, 0, :].cpu().numpy()
|
245 |
+
sel = [len(np.unique(eval_xs[0:eval_ys.shape[0], col])) > 1 for col in range(eval_xs.shape[1])]
|
246 |
+
eval_xs = np.array(eval_xs[:, sel])
|
247 |
+
|
248 |
+
warnings.simplefilter('error')
|
249 |
+
if preprocess_transform != 'none':
|
250 |
+
feats = set(range(eval_xs.shape[1])) if 'all' in preprocess_transform else set(
|
251 |
+
range(eval_xs.shape[1])) - set(categorical_feats)
|
252 |
+
for col in feats:
|
253 |
+
try:
|
254 |
+
pt.fit(eval_xs[0:eval_ys.shape[0], col:col + 1])
|
255 |
+
trans = pt.transform(eval_xs[:, col:col + 1])
|
256 |
+
# print(scipy.stats.spearmanr(trans[~np.isnan(eval_xs[:, col:col+1])], eval_xs[:, col:col+1][~np.isnan(eval_xs[:, col:col+1])]))
|
257 |
+
eval_xs[:, col:col + 1] = trans
|
258 |
+
except:
|
259 |
+
pass
|
260 |
+
warnings.simplefilter('default')
|
261 |
+
|
262 |
+
eval_xs = torch.tensor(eval_xs).float().unsqueeze(1).to(device)
|
263 |
+
|
264 |
+
# eval_xs = normalize_data(eval_xs)
|
265 |
+
|
266 |
+
# TODO: Cautian there is information leakage when to_ranking is used, we should not use it
|
267 |
+
eval_xs = remove_outliers(eval_xs) if not normalize_to_ranking else normalize_data(to_ranking_low_mem(eval_xs))
|
268 |
+
|
269 |
+
# Rescale X
|
270 |
+
eval_xs = normalize_by_used_features_f(eval_xs, eval_xs.shape[-1], max_features,
|
271 |
+
normalize_with_sqrt=normalize_with_sqrt)
|
272 |
+
return eval_xs.detach()
|
273 |
+
|
274 |
+
eval_xs, eval_ys = eval_xs.to(device), eval_ys.to(device)
|
275 |
+
eval_ys = eval_ys[:eval_position]
|
276 |
+
|
277 |
+
model.to(device)
|
278 |
+
style = style.to(device)
|
279 |
+
|
280 |
+
model.eval()
|
281 |
+
|
282 |
+
import itertools
|
283 |
+
style = style.unsqueeze(0) if len(style.shape) == 1 else style
|
284 |
+
num_styles = style.shape[0]
|
285 |
+
styles_configurations = range(0, num_styles)
|
286 |
+
preprocess_transform_configurations = [preprocess_transform if i % 2 == 0 else 'none' for i in range(0, num_styles)]
|
287 |
+
if preprocess_transform == 'mix':
|
288 |
+
def get_preprocess(i):
|
289 |
+
if i == 0:
|
290 |
+
return 'power_all'
|
291 |
+
if i == 1:
|
292 |
+
return 'robust_all'
|
293 |
+
if i == 2:
|
294 |
+
return 'none'
|
295 |
+
preprocess_transform_configurations = [get_preprocess(i) for i in range(0, num_styles)]
|
296 |
+
styles_configurations = zip(styles_configurations, preprocess_transform_configurations)
|
297 |
+
|
298 |
+
feature_shift_configurations = range(0, eval_xs.shape[2]) if feature_shift_decoder else [0]
|
299 |
+
class_shift_configurations = range(0, len(torch.unique(eval_ys))) if multiclass_decoder == 'permutation' else [0]
|
300 |
+
|
301 |
+
ensemble_configurations = list(itertools.product(styles_configurations, feature_shift_configurations, class_shift_configurations))
|
302 |
+
random.shuffle(ensemble_configurations)
|
303 |
+
ensemble_configurations = ensemble_configurations[0:N_ensemble_configurations]
|
304 |
+
|
305 |
+
output = None
|
306 |
+
|
307 |
+
eval_xs_transformed = {}
|
308 |
+
for ensemble_configuration in ensemble_configurations:
|
309 |
+
(styles_configuration, preprocess_transform_configuration), feature_shift_configuration, class_shift_configuration = ensemble_configuration
|
310 |
+
|
311 |
+
style_ = style[styles_configuration:styles_configuration+1, :]
|
312 |
+
softmax_temperature_ = softmax_temperature[styles_configuration]
|
313 |
+
|
314 |
+
eval_xs_, eval_ys_ = eval_xs.clone(), eval_ys.clone()
|
315 |
+
|
316 |
+
if preprocess_transform_configuration in eval_xs_transformed:
|
317 |
+
eval_xs_ = eval_xs_transformed['preprocess_transform_configuration'].clone()
|
318 |
+
else:
|
319 |
+
eval_xs_ = preprocess_input(eval_xs_, preprocess_transform=preprocess_transform_configuration)
|
320 |
+
eval_xs_transformed['preprocess_transform_configuration'] = eval_xs_
|
321 |
+
|
322 |
+
eval_ys_ = ((eval_ys_ + class_shift_configuration) % num_classes).float()
|
323 |
+
|
324 |
+
eval_xs_ = torch.cat([eval_xs_[..., feature_shift_configuration:],eval_xs_[..., :feature_shift_configuration]],dim=-1)
|
325 |
+
|
326 |
+
# Extend X
|
327 |
+
if extend_features:
|
328 |
+
eval_xs_ = torch.cat(
|
329 |
+
[eval_xs_,
|
330 |
+
torch.zeros((eval_xs_.shape[0], eval_xs_.shape[1], max_features - eval_xs_.shape[2])).to(device)], -1)
|
331 |
+
|
332 |
+
#preprocess_transform_ = preprocess_transform if styles_configuration % 2 == 0 else 'none'
|
333 |
+
import warnings
|
334 |
+
with warnings.catch_warnings():
|
335 |
+
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
|
336 |
+
output_ = checkpoint(predict, eval_xs_, eval_ys_, style_, softmax_temperature_, True)
|
337 |
+
output_ = torch.cat([output_[..., class_shift_configuration:],output_[..., :class_shift_configuration]],dim=-1)
|
338 |
+
|
339 |
+
#output_ = predict(eval_xs, eval_ys, style_, preprocess_transform_)
|
340 |
+
if not average_logits:
|
341 |
+
output_ = torch.nn.functional.softmax(output_, dim=-1)
|
342 |
+
output = output_ if output is None else output + output_
|
343 |
+
|
344 |
+
output = output / len(ensemble_configurations)
|
345 |
+
if average_logits:
|
346 |
+
output = torch.nn.functional.softmax(output, dim=-1)
|
347 |
+
|
348 |
+
output = torch.transpose(output, 0, 1)
|
349 |
+
|
350 |
+
return output
|
351 |
+
|
352 |
+
def get_params_from_config(c):
|
353 |
+
return {'max_features': c['num_features']
|
354 |
+
, 'rescale_features': c["normalize_by_used_features"]
|
355 |
+
, 'normalize_to_ranking': c["normalize_to_ranking"]
|
356 |
+
, 'normalize_with_sqrt': c.get("normalize_with_sqrt", False)
|
357 |
+
}
|
TabPFN/tabular_evaluation.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from utils import torch_nanmean
|
12 |
+
from datasets import *
|
13 |
+
from model_builder import load_model
|
14 |
+
from scripts.tabular_baselines import get_scoring_string
|
15 |
+
from scripts import tabular_metrics
|
16 |
+
from scripts.transformer_prediction_interface import *
|
17 |
+
from scripts.baseline_prediction_interface import *
|
18 |
+
"""
|
19 |
+
===============================
|
20 |
+
PUBLIC FUNCTIONS FOR EVALUATION
|
21 |
+
===============================
|
22 |
+
"""
|
23 |
+
|
24 |
+
|
25 |
+
def eval_model(i, e, valid_datasets, test_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
|
26 |
+
metrics_test, config_sample, model_path = eval_model_on_ds(i, e, test_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
|
27 |
+
metrics_valid, _, _ = eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
|
28 |
+
return {'mean_auc_test': metrics_test['mean_roc_at_1000'], 'mean_auc_valid': metrics_valid['mean_roc_at_1000'], 'mean_ce_test': metrics_test['mean_ce_at_1000'], 'mean_ce_valid': metrics_valid['mean_ce_at_1000'], 'config_sample': config_sample, 'model_path': model_path}
|
29 |
+
|
30 |
+
def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
|
31 |
+
|
32 |
+
# How to use: evaluate_without_fitting(i,0,valid_datasets, [1024], 100000, add_name=model_string, base_path=base_path,)
|
33 |
+
def check_file(e):
|
34 |
+
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
|
35 |
+
model_path = os.path.join(base_path, model_file)
|
36 |
+
# print('Evaluate ', model_path)
|
37 |
+
results_file = os.path.join(base_path,
|
38 |
+
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
|
39 |
+
if not Path(model_path).is_file(): # or Path(results_file).is_file():
|
40 |
+
# print('checkpoint exists: ', Path(model_file).is_file(), ', results are written:', Path(results_file).is_file())
|
41 |
+
return None, None, None
|
42 |
+
return model_file, model_path, results_file
|
43 |
+
|
44 |
+
if e == -1: # use last checkpoint, if e == -1
|
45 |
+
for e_ in range(100, -1, -1):
|
46 |
+
model_file_, model_path_, results_file_ = check_file(e_)
|
47 |
+
if model_file_ is not None:
|
48 |
+
e = e_
|
49 |
+
model_file, model_path, results_file = model_file_, model_path_, results_file_
|
50 |
+
break
|
51 |
+
else:
|
52 |
+
model_file, model_path, results_file = check_file(e)
|
53 |
+
|
54 |
+
model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
|
55 |
+
|
56 |
+
params = {'max_features': config_sample['num_features']
|
57 |
+
, 'rescale_features': config_sample["normalize_by_used_features"]
|
58 |
+
, 'normalize_to_ranking': config_sample["normalize_to_ranking"]
|
59 |
+
, 'normalize_with_sqrt': config_sample.get("normalize_with_sqrt", False)
|
60 |
+
}
|
61 |
+
metrics_valid = evaluate(datasets=valid_datasets, model=model[2], method='transformer', device=device, overwrite=True,
|
62 |
+
extend_features=True
|
63 |
+
# just removed the style keyword but transformer is trained with style, just empty
|
64 |
+
, save=False
|
65 |
+
, metric_used=tabular_metrics.cross_entropy
|
66 |
+
, return_tensor=True
|
67 |
+
, verbose=False
|
68 |
+
, eval_positions=eval_positions
|
69 |
+
, bptt=bptt
|
70 |
+
, base_path=None
|
71 |
+
, inference_mode=True
|
72 |
+
, **params
|
73 |
+
, **kwargs)
|
74 |
+
|
75 |
+
tabular_metrics.calculate_score_per_method(tabular_metrics.auc_metric, 'roc', metrics_valid, valid_datasets, eval_positions)
|
76 |
+
tabular_metrics.calculate_score_per_method(tabular_metrics.cross_entropy, 'ce', metrics_valid, valid_datasets, eval_positions)
|
77 |
+
|
78 |
+
return metrics_valid, config_sample, model_path
|
79 |
+
|
80 |
+
|
81 |
+
def evaluate(datasets, bptt, eval_positions, metric_used, model
|
82 |
+
, verbose=False
|
83 |
+
, return_tensor=False
|
84 |
+
, **kwargs):
|
85 |
+
"""
|
86 |
+
Evaluates a list of datasets for a model function.
|
87 |
+
|
88 |
+
:param datasets: List of datasets
|
89 |
+
:param bptt: maximum sequence length
|
90 |
+
:param eval_positions: List of positions where to evaluate models
|
91 |
+
:param verbose: If True, is verbose.
|
92 |
+
:param metric_used: Which metric is optimized for.
|
93 |
+
:param return_tensor: Wheater to return results as a pytorch.tensor or numpy, this is only relevant for transformer.
|
94 |
+
:param kwargs:
|
95 |
+
:return:
|
96 |
+
"""
|
97 |
+
overall_result = {'metric_used': get_scoring_string(metric_used)
|
98 |
+
, 'bptt': bptt
|
99 |
+
, 'eval_positions': eval_positions}
|
100 |
+
|
101 |
+
aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
|
102 |
+
|
103 |
+
# For each dataset
|
104 |
+
for [ds_name, X, y, categorical_feats, _, _] in tqdm.tqdm(datasets, desc='Iterate over datasets') if verbose else datasets:
|
105 |
+
dataset_bptt = min(len(X), bptt)
|
106 |
+
# if verbose and dataset_bptt < bptt:
|
107 |
+
# print(f'Dataset too small for given sequence length, reducing to {len(X)} ({bptt})')
|
108 |
+
|
109 |
+
aggregated_metric, num = torch.tensor(0.0), 0
|
110 |
+
ds_result = {}
|
111 |
+
|
112 |
+
for eval_position in (eval_positions if verbose else eval_positions):
|
113 |
+
eval_position_real = int(dataset_bptt * 0.5) if 2 * eval_position > dataset_bptt else eval_position
|
114 |
+
eval_position_bptt = int(eval_position_real * 2.0)
|
115 |
+
|
116 |
+
r = evaluate_position(X, y, model=model
|
117 |
+
, num_classes=len(torch.unique(y))
|
118 |
+
, categorical_feats = categorical_feats
|
119 |
+
, bptt = eval_position_bptt
|
120 |
+
, ds_name=ds_name
|
121 |
+
, eval_position = eval_position_real
|
122 |
+
, metric_used = metric_used
|
123 |
+
,**kwargs)
|
124 |
+
|
125 |
+
if r is None:
|
126 |
+
continue
|
127 |
+
|
128 |
+
_, outputs, ys, best_configs, time_used = r
|
129 |
+
|
130 |
+
if torch.is_tensor(outputs):
|
131 |
+
outputs = outputs.to(outputs.device)
|
132 |
+
ys = ys.to(outputs.device)
|
133 |
+
|
134 |
+
ys = ys.T
|
135 |
+
ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
|
136 |
+
ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
|
137 |
+
ds_result[f'{ds_name}_ys_at_{eval_position}'] = ys
|
138 |
+
ds_result[f'{ds_name}_time_at_{eval_position}'] = time_used
|
139 |
+
|
140 |
+
new_metric = torch_nanmean(torch.stack([metric_used(ys[i], outputs[i]) for i in range(ys.shape[0])]))
|
141 |
+
|
142 |
+
if not return_tensor:
|
143 |
+
make_scalar = lambda x: float(x.detach().cpu().numpy()) if (torch.is_tensor(x) and (len(x.shape) == 0)) else x
|
144 |
+
new_metric = make_scalar(new_metric)
|
145 |
+
ds_result = {k: make_scalar(ds_result[k]) for k in ds_result.keys()}
|
146 |
+
|
147 |
+
lib = torch if return_tensor else np
|
148 |
+
if not lib.isnan(new_metric).any():
|
149 |
+
aggregated_metric, num = aggregated_metric + new_metric, num + 1
|
150 |
+
|
151 |
+
overall_result.update(ds_result)
|
152 |
+
if num > 0:
|
153 |
+
aggregated_metric_datasets, num_datasets = (aggregated_metric_datasets + (aggregated_metric / num)), num_datasets + 1
|
154 |
+
|
155 |
+
overall_result['mean_metric'] = aggregated_metric_datasets / num_datasets
|
156 |
+
|
157 |
+
return overall_result
|
158 |
+
|
159 |
+
"""
|
160 |
+
===============================
|
161 |
+
INTERNAL HELPER FUNCTIONS
|
162 |
+
===============================
|
163 |
+
"""
|
164 |
+
|
165 |
+
def check_file_exists(path):
|
166 |
+
"""Checks if a pickle file exists. Returns None if not, else returns the unpickled file."""
|
167 |
+
if (os.path.isfile(path)):
|
168 |
+
print(f'loading results from {path}')
|
169 |
+
with open(path, 'rb') as f:
|
170 |
+
return np.load(f, allow_pickle=True).tolist()
|
171 |
+
return None
|
172 |
+
|
173 |
+
def generate_valid_split(X, y, bptt, eval_position, split_number=1):
|
174 |
+
"""Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
|
175 |
+
the entire datasets. If no such split can be sampled in 7 passes, returns None.
|
176 |
+
|
177 |
+
:param X: torch tensor, feature values
|
178 |
+
:param y: torch tensor, class values
|
179 |
+
:param bptt: Number of samples in train + test
|
180 |
+
:param eval_position: Number of samples in train, i.e. from which index values are in test
|
181 |
+
:param split_number: The split id
|
182 |
+
:return:
|
183 |
+
"""
|
184 |
+
done, seed = False, 13
|
185 |
+
|
186 |
+
torch.manual_seed(split_number)
|
187 |
+
perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
|
188 |
+
X, y = X[perm], y[perm]
|
189 |
+
|
190 |
+
while not done:
|
191 |
+
if seed > 20:
|
192 |
+
return None, None # No split could be generated in 7 passes, return None
|
193 |
+
random.seed(seed)
|
194 |
+
i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
|
195 |
+
y_ = y[i:i + bptt]
|
196 |
+
|
197 |
+
# Checks if all classes from dataset are contained and classes in train and test are equal (contain same
|
198 |
+
# classes) and
|
199 |
+
done = len(torch.unique(y_)) == len(torch.unique(y))
|
200 |
+
done = done and torch.all(torch.unique(y_) == torch.unique(y))
|
201 |
+
done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
|
202 |
+
done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
|
203 |
+
seed = seed + 1
|
204 |
+
|
205 |
+
eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
|
206 |
+
eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
|
207 |
+
|
208 |
+
return eval_xs, eval_ys
|
209 |
+
|
210 |
+
|
211 |
+
def evaluate_position(X, y, categorical_feats, model, bptt
|
212 |
+
, eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
|
213 |
+
, max_time=300, split_number=1
|
214 |
+
, per_step_normalization=False, **kwargs):
|
215 |
+
"""
|
216 |
+
Evaluates a dataset with a 'bptt' number of training samples.
|
217 |
+
|
218 |
+
:param X: Dataset X
|
219 |
+
:param y: Dataset labels
|
220 |
+
:param categorical_feats: Indices of categorical features.
|
221 |
+
:param model: Model function
|
222 |
+
:param bptt: Sequence length.
|
223 |
+
:param eval_position: Number of training samples.
|
224 |
+
:param overwrite: Wheater to ove
|
225 |
+
:param overwrite: If True, results on disk are overwritten.
|
226 |
+
:param save:
|
227 |
+
:param path_interfix: Used for constructing path to write on disk.
|
228 |
+
:param method: Model name.
|
229 |
+
:param ds_name: Datset name.
|
230 |
+
:param fetch_only: Wheater to calculate or only fetch results.
|
231 |
+
:param per_step_normalization:
|
232 |
+
:param kwargs:
|
233 |
+
:return:
|
234 |
+
"""
|
235 |
+
|
236 |
+
if save:
|
237 |
+
path = os.path.join(base_path, f'results/tabular/{path_interfix}/results_{method}_{ds_name}_{eval_position}_{bptt}_{split_number}.npy')
|
238 |
+
#log_path =
|
239 |
+
|
240 |
+
## Load results if on disk
|
241 |
+
if not overwrite:
|
242 |
+
result = check_file_exists(path)
|
243 |
+
if result is not None:
|
244 |
+
if not fetch_only:
|
245 |
+
print(f'Loaded saved result for {path}')
|
246 |
+
return result
|
247 |
+
elif fetch_only:
|
248 |
+
print(f'Could not load saved result for {path}')
|
249 |
+
return None
|
250 |
+
|
251 |
+
## Generate data splits
|
252 |
+
eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position, split_number=split_number)
|
253 |
+
if eval_xs is None:
|
254 |
+
return None
|
255 |
+
print(f"No dataset could be generated {ds_name} {bptt}")
|
256 |
+
|
257 |
+
eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
|
258 |
+
|
259 |
+
start_time = time.time()
|
260 |
+
|
261 |
+
if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
|
262 |
+
outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, categorical_feats=categorical_feats, **kwargs), None
|
263 |
+
else:
|
264 |
+
_, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
|
265 |
+
, eval_pos=eval_position
|
266 |
+
, max_time=max_time, **kwargs)
|
267 |
+
|
268 |
+
eval_ys = eval_ys[eval_position:]
|
269 |
+
if outputs is None:
|
270 |
+
return None
|
271 |
+
|
272 |
+
if torch.is_tensor(outputs): # Transfers data to cpu for saving
|
273 |
+
outputs = outputs.cpu()
|
274 |
+
eval_ys = eval_ys.cpu()
|
275 |
+
|
276 |
+
ds_result = None, outputs, eval_ys, best_configs, time.time() - start_time
|
277 |
+
|
278 |
+
if save:
|
279 |
+
with open(path, 'wb') as f:
|
280 |
+
np.save(f, ds_result)
|
281 |
+
print(f'saved results to {path}')
|
282 |
+
|
283 |
+
return ds_result
|
TabPFN/train.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import itertools
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
import datetime
|
6 |
+
import yaml
|
7 |
+
from contextlib import nullcontext
|
8 |
+
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
import utils
|
14 |
+
from transformer import TransformerModel
|
15 |
+
from utils import get_cosine_schedule_with_warmup, get_openai_lr, StoreDictKeyPair, get_weighted_single_eval_pos_sampler, get_uniform_single_eval_pos_sampler
|
16 |
+
import priors
|
17 |
+
import encoders
|
18 |
+
import positional_encodings
|
19 |
+
from utils import init_dist
|
20 |
+
from torch.cuda.amp import autocast
|
21 |
+
|
22 |
+
class Losses():
|
23 |
+
gaussian = nn.GaussianNLLLoss(full=True, reduction='none')
|
24 |
+
mse = nn.MSELoss(reduction='none')
|
25 |
+
ce = lambda weight : nn.CrossEntropyLoss(reduction='none', weight=weight)
|
26 |
+
bce = nn.BCEWithLogitsLoss(reduction='none')
|
27 |
+
|
28 |
+
|
29 |
+
def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.2,
|
30 |
+
epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False,
|
31 |
+
y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup,
|
32 |
+
load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, bptt_extra_samples=None, gpu_device='cuda:0',
|
33 |
+
aggregate_k_gradients=1, verbose=True, style_encoder_generator=None, check_is_compatible=True, epoch_callback=None,
|
34 |
+
initializer=None, initialize_with_model=None, train_mixed_precision=False, total_available_time_in_s=None, normalize_labels=True, **model_extra_args
|
35 |
+
):
|
36 |
+
assert (epochs is None) != (total_available_time_in_s is None)
|
37 |
+
start_of_training = time.time()
|
38 |
+
device = gpu_device if torch.cuda.is_available() else 'cpu:0'
|
39 |
+
print(f'Using {device} device')
|
40 |
+
using_dist, rank, device = init_dist(device)
|
41 |
+
bptt_sampler = (lambda : single_eval_pos_gen() + bptt_extra_samples if callable(single_eval_pos_gen) else single_eval_pos_gen + bptt_extra_samples) if bptt_extra_samples is not None else bptt
|
42 |
+
dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, seq_len=bptt_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict)
|
43 |
+
if dl.fuse_x_y:
|
44 |
+
raise Exception("Illegal parameter")
|
45 |
+
|
46 |
+
encoder = encoder_generator(dl.num_features+1 if dl.fuse_x_y else dl.num_features,emsize)
|
47 |
+
style_def = next(iter(dl))[0][0] # This is (style, x, y), target with x and y with batch size
|
48 |
+
|
49 |
+
style_encoder = style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize) if (style_def is not None) else None
|
50 |
+
n_out = dl.num_outputs
|
51 |
+
if isinstance(criterion, nn.GaussianNLLLoss):
|
52 |
+
n_out *= 2
|
53 |
+
elif isinstance(criterion, nn.CrossEntropyLoss):
|
54 |
+
n_out *= criterion.weight.shape[0]
|
55 |
+
model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout, style_encoder=style_encoder,
|
56 |
+
y_encoder=y_encoder_generator(dl.num_outputs, emsize), input_normalization=input_normalization,
|
57 |
+
pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2),
|
58 |
+
decoder=decoder, init_method=initializer, **model_extra_args
|
59 |
+
)
|
60 |
+
model.criterion = criterion
|
61 |
+
if load_weights_from_this_state_dict is not None:
|
62 |
+
model.load_state_dict(load_weights_from_this_state_dict)
|
63 |
+
if initialize_with_model is not None:
|
64 |
+
model.init_from_small_model(initialize_with_model)
|
65 |
+
|
66 |
+
print(f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters")
|
67 |
+
|
68 |
+
try:
|
69 |
+
for (k, v), (k2, v2) in zip(model.state_dict().items(), initialize_with_model.state_dict().items()):
|
70 |
+
print(k, ((v - v2) / v).abs().mean(), v.shape)
|
71 |
+
except Exception:
|
72 |
+
pass
|
73 |
+
|
74 |
+
model.to(device)
|
75 |
+
if using_dist:
|
76 |
+
print("Distributed training")
|
77 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
|
78 |
+
|
79 |
+
|
80 |
+
# learning rate
|
81 |
+
if lr is None:
|
82 |
+
lr = get_openai_lr(model)
|
83 |
+
print(f"Using OpenAI max lr of {lr}.")
|
84 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
85 |
+
scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps
|
86 |
+
|
87 |
+
def train_step():
|
88 |
+
model.train() # Turn on the train mode
|
89 |
+
total_loss = 0.
|
90 |
+
total_positional_losses = 0.
|
91 |
+
total_positional_losses_recorded = 0
|
92 |
+
before_get_batch = time.time()
|
93 |
+
assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.'
|
94 |
+
valid_batch_steps = 0.0
|
95 |
+
for batch, (data, targets) in enumerate(dl):
|
96 |
+
if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1):
|
97 |
+
cm = model.no_sync()
|
98 |
+
#print(f'p={rank}, no_sync', force=True)
|
99 |
+
else:
|
100 |
+
cm = nullcontext()
|
101 |
+
#print(f'p={rank}, sync', force=True)
|
102 |
+
with cm:
|
103 |
+
time_to_get_batch = time.time() - before_get_batch
|
104 |
+
before_forward = time.time()
|
105 |
+
if bptt_extra_samples is None:
|
106 |
+
single_eval_pos = single_eval_pos_gen() if callable(single_eval_pos_gen) else single_eval_pos_gen
|
107 |
+
else:
|
108 |
+
single_eval_pos = targets.shape[0] - bptt_extra_samples
|
109 |
+
|
110 |
+
is_compatible = torch.ones((targets.shape[1])).bool()
|
111 |
+
if check_is_compatible or normalize_labels:
|
112 |
+
for b in range(targets.shape[1]):
|
113 |
+
targets_in_train = torch.unique(targets[:single_eval_pos, b], sorted=True)
|
114 |
+
targets_in_eval = torch.unique(targets[single_eval_pos:, b], sorted=True)
|
115 |
+
|
116 |
+
if check_is_compatible:
|
117 |
+
is_compatible[b] = len(targets_in_train) == len(targets_in_eval) and (targets_in_train == targets_in_eval).all()
|
118 |
+
is_compatible[b] = is_compatible[b] and len(targets_in_train) > 1
|
119 |
+
|
120 |
+
# Set targets to range starting from 0 (e.g. targets 0, 2, 5, 2 will be converted to 0, 1, 2, 1)
|
121 |
+
if normalize_labels:
|
122 |
+
targets[:, b] = (targets[:, b] > torch.unique(targets[:, b]).unsqueeze(1)).sum(axis=0).unsqueeze(0)
|
123 |
+
valid_batch_steps += is_compatible.float().mean()
|
124 |
+
is_compatible = is_compatible.to(device)
|
125 |
+
#if using_dist and check_is_compatible:
|
126 |
+
# print('step share before reduce',curr_step_share, force=True)
|
127 |
+
# curr_step_share = curr_step_share.to(device)
|
128 |
+
# torch.distributed.all_reduce_multigpu([curr_step_share], op=torch.distributed.ReduceOp.SUM)
|
129 |
+
# curr_step_share = curr_step_share.cpu() / torch.distributed.get_world_size()
|
130 |
+
# print('step share after reduce',curr_step_share, torch.distributed.get_world_size(), force=True)
|
131 |
+
|
132 |
+
# If style is set to None, it should not be transferred to device
|
133 |
+
output = model(tuple(e.to(device) if torch.is_tensor(e) else e for e in data) if isinstance(data, tuple) else data.to(device)
|
134 |
+
, single_eval_pos=single_eval_pos)
|
135 |
+
|
136 |
+
forward_time = time.time() - before_forward
|
137 |
+
|
138 |
+
#output, targets = output[:, is_compatible], targets[:, is_compatible]
|
139 |
+
|
140 |
+
if single_eval_pos is not None:
|
141 |
+
targets = targets[single_eval_pos:]
|
142 |
+
if isinstance(criterion, nn.GaussianNLLLoss):
|
143 |
+
assert output.shape[-1] == 2, \
|
144 |
+
'need to write a little bit of code to handle multiple regression targets at once'
|
145 |
+
|
146 |
+
mean_pred = output[..., 0]
|
147 |
+
var_pred = output[..., 1].abs()
|
148 |
+
losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten())
|
149 |
+
elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
|
150 |
+
losses = criterion(output.flatten(), targets.to(device).flatten())
|
151 |
+
elif isinstance(criterion, (nn.CrossEntropyLoss)):
|
152 |
+
#print(n_out, targets.min(), targets.max(), force=True)
|
153 |
+
losses = criterion(output.reshape(-1, n_out), targets.to(device).long().flatten())
|
154 |
+
else:
|
155 |
+
losses = criterion(output.reshape(-1, n_out), targets.to(device).flatten())
|
156 |
+
losses = losses.view(*output.shape[0:2])
|
157 |
+
loss = losses.mean(0) @ is_compatible.float() / losses.shape[1]
|
158 |
+
#loss = torch_nanmean(losses, axis=[0, 1]) * is_compatible.float().mean()
|
159 |
+
# not sure whether we can go without the nan checks.
|
160 |
+
|
161 |
+
loss.backward()
|
162 |
+
|
163 |
+
if ((batch % aggregate_k_gradients == aggregate_k_gradients - 1) and (not check_is_compatible or using_dist))\
|
164 |
+
or (valid_batch_steps >= aggregate_k_gradients and (check_is_compatible and not using_dist)):
|
165 |
+
with torch.no_grad():
|
166 |
+
for p in model.parameters():
|
167 |
+
if p.grad is not None:
|
168 |
+
p.grad.div_(valid_batch_steps)
|
169 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
|
170 |
+
try:
|
171 |
+
optimizer.step()
|
172 |
+
except:
|
173 |
+
print("Invalid optimization step encountered")
|
174 |
+
optimizer.zero_grad()
|
175 |
+
valid_batch_steps = 0.0
|
176 |
+
|
177 |
+
step_time = time.time() - before_forward
|
178 |
+
|
179 |
+
if not torch.isnan(loss):
|
180 |
+
total_loss += loss.item()
|
181 |
+
total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \
|
182 |
+
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)*loss.cpu().detach()
|
183 |
+
|
184 |
+
total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \
|
185 |
+
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
186 |
+
|
187 |
+
before_get_batch = time.time()
|
188 |
+
return total_loss / steps_per_epoch, (
|
189 |
+
total_positional_losses / total_positional_losses_recorded).tolist(), time_to_get_batch, forward_time, step_time
|
190 |
+
|
191 |
+
best_val_loss = float("inf")
|
192 |
+
best_model = None
|
193 |
+
total_loss = float('inf')
|
194 |
+
total_positional_losses = float('inf')
|
195 |
+
try:
|
196 |
+
for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)):
|
197 |
+
|
198 |
+
epoch_start_time = time.time()
|
199 |
+
if train_mixed_precision:
|
200 |
+
with autocast():
|
201 |
+
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
|
202 |
+
else:
|
203 |
+
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
|
204 |
+
if hasattr(dl, 'validate') and epoch % validation_period == 0:
|
205 |
+
with torch.no_grad():
|
206 |
+
val_score = dl.validate(model)
|
207 |
+
else:
|
208 |
+
val_score = None
|
209 |
+
|
210 |
+
if verbose:
|
211 |
+
print('-' * 89)
|
212 |
+
print(
|
213 |
+
f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | '
|
214 |
+
f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
|
215 |
+
f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}'
|
216 |
+
f' forward time {forward_time:5.2f}' + (f'val score {val_score}' if val_score is not None else ''))
|
217 |
+
print('-' * 89)
|
218 |
+
|
219 |
+
# stepping with wallclock time based scheduler
|
220 |
+
current_time = time.time()
|
221 |
+
if epoch_callback is not None and rank == 0:
|
222 |
+
epoch_callback(model, epoch / epochs if total_available_time_in_s is None else # noqa
|
223 |
+
(current_time - start_of_training) / total_available_time_in_s # noqa
|
224 |
+
)
|
225 |
+
if epochs is None and (current_time - start_of_training) > total_available_time_in_s: # noqa
|
226 |
+
break
|
227 |
+
if epochs is None:
|
228 |
+
scheduler.step((current_time - epoch_start_time) / total_available_time_in_s * 100)
|
229 |
+
else:
|
230 |
+
scheduler.step()
|
231 |
+
except KeyboardInterrupt:
|
232 |
+
pass
|
233 |
+
|
234 |
+
return total_loss, total_positional_losses, model.to('cpu'), dl
|
235 |
+
|
236 |
+
def _parse_args(config_parser, parser):
|
237 |
+
# Do we have a config file to parse?
|
238 |
+
args_config, remaining = config_parser.parse_known_args()
|
239 |
+
if args_config.config:
|
240 |
+
with open(args_config.config, 'r') as f:
|
241 |
+
cfg = yaml.safe_load(f)
|
242 |
+
parser.set_defaults(**cfg)
|
243 |
+
|
244 |
+
# The main arg parser parses the rest of the args, the usual
|
245 |
+
# defaults will have been overridden if config file specified.
|
246 |
+
args = parser.parse_args(remaining)
|
247 |
+
|
248 |
+
# Cache the args as a text string to save them in the output dir later
|
249 |
+
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
250 |
+
return args, args_text
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == '__main__':
|
254 |
+
config_parser = argparse.ArgumentParser(description='Only used as a first parser for the config file path.')
|
255 |
+
config_parser.add_argument('--config')
|
256 |
+
parser = argparse.ArgumentParser()
|
257 |
+
parser.add_argument('prior')
|
258 |
+
parser.add_argument('--loss_function', default='barnll')
|
259 |
+
# Optional Arg's for `--loss_function barnll`
|
260 |
+
parser.add_argument('--min_y', type=float, help='barnll can only model y in strict ranges, this is the minimum y can take.')
|
261 |
+
parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.')
|
262 |
+
parser.add_argument('--num_buckets', default=100, type=int)
|
263 |
+
#parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
|
264 |
+
parser.add_argument("--extra_prior_kwargs_dict", default={'fuse_x_y': False}, dest="extra_prior_kwargs_dict", action=StoreDictKeyPair, nargs="+", metavar="KEY=VAL", help='Specify depending on the prior.')
|
265 |
+
parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.')
|
266 |
+
parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.')
|
267 |
+
parser.add_argument('--pos_encoder', default='sinus', type=str, help='Specify depending on the prior.')
|
268 |
+
parser.add_argument('--bptt', default=10, type=int)
|
269 |
+
parser.add_argument('--epochs', default=200, type=int)
|
270 |
+
parser.add_argument('--warmup_epochs', default=50, type=int)
|
271 |
+
parser.add_argument('--validation_period', default=10, type=int)
|
272 |
+
parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ')
|
273 |
+
parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.")
|
274 |
+
|
275 |
+
# these can likely be mostly left at defaults
|
276 |
+
parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024
|
277 |
+
parser.add_argument('--nlayers', default=6, type=int)
|
278 |
+
parser.add_argument('--nhid', default=None, type=int) # 2*emsize is the default
|
279 |
+
parser.add_argument('--nhead', default=4, type=int) # nhead = emsize / 64 in the original paper
|
280 |
+
parser.add_argument('--dropout', default=.0, type=float)
|
281 |
+
parser.add_argument('--steps_per_epoch', default=10, type=int)
|
282 |
+
parser.add_argument('--batch_size', default=1000, type=int)
|
283 |
+
parser.add_argument('--lr', '--learning_rate', default=.001, type=float) # try also .0003, .0001, go lower with lower batch size
|
284 |
+
|
285 |
+
args, _ = _parse_args(config_parser, parser)
|
286 |
+
|
287 |
+
if args.nhid is None:
|
288 |
+
args.nhid = 2*args.emsize
|
289 |
+
|
290 |
+
prior = args.__dict__.pop('prior')
|
291 |
+
|
292 |
+
if prior == 'gp':
|
293 |
+
prior = priors.fast_gp.DataLoader
|
294 |
+
elif prior == 'ridge':
|
295 |
+
prior = priors.ridge.DataLoader
|
296 |
+
elif prior == 'stroke':
|
297 |
+
prior = priors.stroke.DataLoader
|
298 |
+
elif prior == 'mix_gp':
|
299 |
+
prior = priors.fast_gp_mix.DataLoader
|
300 |
+
else:
|
301 |
+
raise NotImplementedError(f'Prior == {prior}.')
|
302 |
+
|
303 |
+
loss_function = args.__dict__.pop('loss_function')
|
304 |
+
|
305 |
+
criterion = nn.GaussianNLLLoss(reduction='none', full=True)
|
306 |
+
classificiation_criterion = nn.CrossEntropyLoss(reduction='none')
|
307 |
+
num_buckets = args.__dict__.pop('num_buckets')
|
308 |
+
max_y = args.__dict__.pop('max_y')
|
309 |
+
min_y = args.__dict__.pop('min_y')
|
310 |
+
# criterion = nn.MSELoss(reduction='none')
|
311 |
+
|
312 |
+
def get_y_sample():
|
313 |
+
dl = prior(num_steps=1, batch_size=args.batch_size * args.steps_per_epoch, seq_len=args.bptt, device=device,
|
314 |
+
**args.extra_prior_kwargs_dict)
|
315 |
+
y_sample = next(iter(dl))[-1]
|
316 |
+
print(f'Creating Bar distribution with borders from y sample of size {y_sample.numel()}')
|
317 |
+
return y_sample
|
318 |
+
|
319 |
+
if loss_function == 'ce':
|
320 |
+
criterion = nn.CrossEntropyLoss(reduction='none')
|
321 |
+
elif loss_function == 'gaussnll':
|
322 |
+
criterion = nn.GaussianNLLLoss(reduction='none', full=True)
|
323 |
+
elif loss_function == 'mse':
|
324 |
+
criterion = nn.MSELoss(reduction='none')
|
325 |
+
elif loss_function == 'barnll':
|
326 |
+
criterion = BarDistribution(borders=get_bucket_limits(num_buckets, full_range=(min_y,max_y)))
|
327 |
+
elif loss_function == 'adaptivebarnll':
|
328 |
+
borders = get_bucket_limits(num_buckets, ys=get_y_sample(), full_range=(min_y,max_y))
|
329 |
+
criterion = BarDistribution(borders=borders)
|
330 |
+
elif loss_function == 'adaptivefullsupportbarnll':
|
331 |
+
assert min_y is None and max_y is None, "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
|
332 |
+
borders = get_bucket_limits(num_buckets, ys=get_y_sample())
|
333 |
+
criterion = FullSupportBarDistribution(borders=borders)
|
334 |
+
else:
|
335 |
+
raise NotImplementedError(f'loss_function == {loss_function}.')
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
encoder = args.__dict__.pop('encoder')
|
340 |
+
y_encoder = args.__dict__.pop('y_encoder')
|
341 |
+
|
342 |
+
def get_encoder_generator(encoder):
|
343 |
+
if encoder == 'linear':
|
344 |
+
encoder_generator = encoders.Linear
|
345 |
+
elif encoder == 'mlp':
|
346 |
+
encoder_generator = encoders.MLP
|
347 |
+
elif encoder == 'positional':
|
348 |
+
encoder_generator = encoders.Positional
|
349 |
+
else:
|
350 |
+
raise NotImplementedError(f'A {encoder} encoder is not valid.')
|
351 |
+
return encoder_generator
|
352 |
+
|
353 |
+
encoder_generator = get_encoder_generator(encoder)
|
354 |
+
y_encoder_generator = get_encoder_generator(y_encoder)
|
355 |
+
|
356 |
+
pos_encoder = args.__dict__.pop('pos_encoder')
|
357 |
+
|
358 |
+
if pos_encoder == 'none':
|
359 |
+
pos_encoder_generator = None
|
360 |
+
elif pos_encoder == 'sinus':
|
361 |
+
pos_encoder_generator = positional_encodings.PositionalEncoding
|
362 |
+
elif pos_encoder == 'learned':
|
363 |
+
pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
|
364 |
+
elif pos_encoder == 'paired_scrambled_learned':
|
365 |
+
pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings
|
366 |
+
else:
|
367 |
+
raise NotImplementedError(f'pos_encoer == {pos_encoder} is not valid.')
|
368 |
+
|
369 |
+
permutation_invariant_max_eval_pos = args.__dict__.pop('permutation_invariant_max_eval_pos')
|
370 |
+
permutation_invariant_sampling = args.__dict__.pop('permutation_invariant_sampling')
|
371 |
+
if permutation_invariant_max_eval_pos is not None:
|
372 |
+
if permutation_invariant_sampling == 'weighted':
|
373 |
+
get_sampler = get_weighted_single_eval_pos_sampler
|
374 |
+
elif permutation_invariant_sampling == 'uniform':
|
375 |
+
get_sampler = get_uniform_single_eval_pos_sampler
|
376 |
+
else:
|
377 |
+
raise ValueError()
|
378 |
+
args.__dict__['single_eval_pos_gen'] = get_sampler(permutation_invariant_max_eval_pos)
|
379 |
+
|
380 |
+
|
381 |
+
print("ARGS for `train`:", args.__dict__)
|
382 |
+
|
383 |
+
train(prior, criterion, encoder_generator,
|
384 |
+
y_encoder_generator=y_encoder_generator, pos_encoder_generator=pos_encoder_generator,
|
385 |
+
**args.__dict__)
|
386 |
+
|
TabPFN/transformer.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import Module, TransformerEncoder
|
8 |
+
|
9 |
+
from layer import TransformerEncoderLayer, _get_activation_fn
|
10 |
+
from utils import SeqBN, bool_mask_to_att_mask
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class TransformerModel(nn.Module):
|
15 |
+
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
|
16 |
+
pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
|
17 |
+
activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
|
18 |
+
all_layers_same_init=True):
|
19 |
+
super().__init__()
|
20 |
+
self.model_type = 'Transformer'
|
21 |
+
encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
|
22 |
+
pre_norm=pre_norm, recompute_attn=recompute_attn)
|
23 |
+
self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
|
24 |
+
if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
|
25 |
+
self.ninp = ninp
|
26 |
+
self.encoder = encoder
|
27 |
+
self.y_encoder = y_encoder
|
28 |
+
self.pos_encoder = pos_encoder
|
29 |
+
self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
|
30 |
+
self.input_ln = SeqBN(ninp) if input_normalization else None
|
31 |
+
self.style_encoder = style_encoder
|
32 |
+
self.init_method = init_method
|
33 |
+
if num_global_att_tokens is not None:
|
34 |
+
assert not full_attention
|
35 |
+
self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
|
36 |
+
self.full_attention = full_attention
|
37 |
+
|
38 |
+
self.n_out = n_out
|
39 |
+
self.nhid = nhid
|
40 |
+
|
41 |
+
self.init_weights()
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def generate_square_subsequent_mask(sz):
|
45 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
46 |
+
return bool_mask_to_att_mask(mask)
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def generate_D_q_matrix(sz, query_size):
|
50 |
+
train_size = sz-query_size
|
51 |
+
mask = torch.zeros(sz,sz) == 0
|
52 |
+
mask[:,train_size:].zero_()
|
53 |
+
mask |= torch.eye(sz) == 1
|
54 |
+
return bool_mask_to_att_mask(mask)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
58 |
+
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
59 |
+
sz = seq_len + num_global_att_tokens
|
60 |
+
mask = torch.zeros(num_query_tokens, sz) == 0
|
61 |
+
mask[:,train_size:].zero_()
|
62 |
+
mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
|
63 |
+
return bool_mask_to_att_mask(mask)
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
67 |
+
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
68 |
+
trainset_size = seq_len - num_query_tokens
|
69 |
+
mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
|
70 |
+
#mask[:,num_global_att_tokens:].zero_()
|
71 |
+
#mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
|
72 |
+
return bool_mask_to_att_mask(mask)
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
76 |
+
mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
|
77 |
+
return bool_mask_to_att_mask(mask)
|
78 |
+
|
79 |
+
def init_weights(self):
|
80 |
+
initrange = 1.
|
81 |
+
# if isinstance(self.encoder,EmbeddingEncoder):
|
82 |
+
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
83 |
+
# self.decoder.bias.data.zero_()
|
84 |
+
# self.decoder.weight.data.uniform_(-initrange, initrange)
|
85 |
+
if self.init_method is not None:
|
86 |
+
self.apply(self.init_method)
|
87 |
+
for layer in self.transformer_encoder.layers:
|
88 |
+
nn.init.zeros_(layer.linear2.weight)
|
89 |
+
nn.init.zeros_(layer.linear2.bias)
|
90 |
+
attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
|
91 |
+
for attn in attns:
|
92 |
+
nn.init.zeros_(attn.out_proj.weight)
|
93 |
+
nn.init.zeros_(attn.out_proj.bias)
|
94 |
+
|
95 |
+
def forward(self, src, src_mask=None, single_eval_pos=None):
|
96 |
+
assert isinstance(src, tuple), 'fuse_x_y is forbidden, that is inputs have to be given as (x,y) or (style,x,y)'
|
97 |
+
|
98 |
+
if len(src) == 2:
|
99 |
+
src = (None,) + src
|
100 |
+
|
101 |
+
style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
|
102 |
+
if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
|
103 |
+
if src_mask is None:
|
104 |
+
x_src = src[1]
|
105 |
+
if self.global_att_embeddings is None:
|
106 |
+
full_len = len(x_src) + style_src_size
|
107 |
+
if self.full_attention:
|
108 |
+
src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
|
109 |
+
else:
|
110 |
+
src_mask = self.generate_D_q_matrix(len(x_src) + style_src_size, len(x_src) + style_src_size -single_eval_pos).to(x_src.device)
|
111 |
+
else:
|
112 |
+
src_mask_args = (self.global_att_embeddings.num_embeddings,
|
113 |
+
len(x_src) + style_src_size,
|
114 |
+
len(x_src) + style_src_size - single_eval_pos)
|
115 |
+
src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
|
116 |
+
self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
|
117 |
+
self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
|
118 |
+
|
119 |
+
style_src, x_src, y_src = src
|
120 |
+
x_src = self.encoder(x_src)
|
121 |
+
y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
|
122 |
+
style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
|
123 |
+
global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
|
124 |
+
self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
|
125 |
+
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
126 |
+
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
127 |
+
|
128 |
+
if self.input_ln is not None:
|
129 |
+
src = self.input_ln(src)
|
130 |
+
|
131 |
+
if self.pos_encoder is not None:
|
132 |
+
src = self.pos_encoder(src)
|
133 |
+
|
134 |
+
# If we have style input, drop its output
|
135 |
+
output = self.transformer_encoder(src, src_mask)[style_src_size:]
|
136 |
+
output = self.decoder(output)
|
137 |
+
return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def init_from_small_model(self, small_model):
|
141 |
+
assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
|
142 |
+
and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
|
143 |
+
|
144 |
+
def set_encoder_weights(my_encoder, small_model_encoder):
|
145 |
+
my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
|
146 |
+
if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
|
147 |
+
small_in_dim = small_encoder_linear.out_features
|
148 |
+
my_encoder_linear.weight.zero_()
|
149 |
+
my_encoder_linear.bias.zero_()
|
150 |
+
my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
|
151 |
+
my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
|
152 |
+
|
153 |
+
set_encoder_weights(self.encoder, small_model.encoder)
|
154 |
+
set_encoder_weights(self.y_encoder, small_model.y_encoder)
|
155 |
+
|
156 |
+
small_in_dim = small_model.decoder.in_features
|
157 |
+
|
158 |
+
self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
|
159 |
+
self.decoder.bias = small_model.decoder.bias
|
160 |
+
|
161 |
+
for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
|
162 |
+
small_hid_dim = small_layer.linear1.out_features
|
163 |
+
my_in_dim = my_layer.linear1.in_features
|
164 |
+
|
165 |
+
# packed along q,k,v order in first dim
|
166 |
+
my_in_proj_w = my_layer.self_attn.in_proj_weight
|
167 |
+
small_in_proj_w = small_layer.self_attn.in_proj_weight
|
168 |
+
|
169 |
+
my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
|
170 |
+
small_in_dim,
|
171 |
+
small_in_dim)
|
172 |
+
my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
|
173 |
+
:small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
|
174 |
+
|
175 |
+
my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
|
176 |
+
my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
|
177 |
+
|
178 |
+
my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
|
179 |
+
my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
|
180 |
+
|
181 |
+
my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
|
182 |
+
my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
|
183 |
+
|
184 |
+
my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
|
185 |
+
my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
|
186 |
+
|
187 |
+
my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
|
188 |
+
my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
|
189 |
+
|
190 |
+
|
191 |
+
class TransformerEncoderDiffInit(Module):
|
192 |
+
r"""TransformerEncoder is a stack of N encoder layers
|
193 |
+
|
194 |
+
Args:
|
195 |
+
encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
|
196 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
197 |
+
norm: the layer normalization component (optional).
|
198 |
+
"""
|
199 |
+
__constants__ = ['norm']
|
200 |
+
|
201 |
+
def __init__(self, encoder_layer_creator, num_layers, norm=None):
|
202 |
+
super().__init__()
|
203 |
+
self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
|
204 |
+
self.num_layers = num_layers
|
205 |
+
self.norm = norm
|
206 |
+
|
207 |
+
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
208 |
+
r"""Pass the input through the encoder layers in turn.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
src: the sequence to the encoder (required).
|
212 |
+
mask: the mask for the src sequence (optional).
|
213 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
214 |
+
|
215 |
+
Shape:
|
216 |
+
see the docs in Transformer class.
|
217 |
+
"""
|
218 |
+
output = src
|
219 |
+
|
220 |
+
for mod in self.layers:
|
221 |
+
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
|
222 |
+
|
223 |
+
if self.norm is not None:
|
224 |
+
output = self.norm(output)
|
225 |
+
|
226 |
+
return output
|
TabPFN/utils.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import datetime
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
# copied from huggingface
|
13 |
+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
|
14 |
+
""" Create a schedule with a learning rate that decreases following the
|
15 |
+
values of the cosine function between 0 and `pi * cycles` after a warmup
|
16 |
+
period during which it increases linearly between 0 and 1.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def lr_lambda(current_step):
|
20 |
+
if current_step < num_warmup_steps:
|
21 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
22 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
23 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
24 |
+
|
25 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
26 |
+
|
27 |
+
# copied from huggingface
|
28 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
29 |
+
"""
|
30 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
31 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
optimizer (:class:`~torch.optim.Optimizer`):
|
35 |
+
The optimizer for which to schedule the learning rate.
|
36 |
+
num_warmup_steps (:obj:`int`):
|
37 |
+
The number of steps for the warmup phase.
|
38 |
+
num_training_steps (:obj:`int`):
|
39 |
+
The total number of training steps.
|
40 |
+
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
41 |
+
The index of the last epoch when resuming training.
|
42 |
+
|
43 |
+
Return:
|
44 |
+
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def lr_lambda(current_step: int):
|
48 |
+
if current_step < num_warmup_steps:
|
49 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
50 |
+
return max(
|
51 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
52 |
+
)
|
53 |
+
|
54 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
55 |
+
|
56 |
+
|
57 |
+
def get_openai_lr(transformer_model):
|
58 |
+
num_params = sum(p.numel() for p in transformer_model.parameters())
|
59 |
+
return 0.003239 - 0.0001395 * math.log(num_params)
|
60 |
+
|
61 |
+
|
62 |
+
def get_weighted_single_eval_pos_sampler(max_len):
|
63 |
+
"""
|
64 |
+
This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p,
|
65 |
+
where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
|
66 |
+
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
67 |
+
"""
|
68 |
+
return lambda: random.choices(range(max_len), [1 / (max_len - i) for i in range(max_len)])[0]
|
69 |
+
|
70 |
+
|
71 |
+
def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
|
72 |
+
"""
|
73 |
+
Just sample any evaluation position with the same weight
|
74 |
+
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
75 |
+
"""
|
76 |
+
return lambda: random.choices(range(min_len, max_len))[0]
|
77 |
+
|
78 |
+
|
79 |
+
class SeqBN(nn.Module):
|
80 |
+
def __init__(self, d_model):
|
81 |
+
super().__init__()
|
82 |
+
self.bn = nn.BatchNorm1d(d_model)
|
83 |
+
self.d_model = d_model
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
assert self.d_model == x.shape[-1]
|
87 |
+
flat_x = x.view(-1, self.d_model)
|
88 |
+
flat_x = self.bn(flat_x)
|
89 |
+
return flat_x.view(*x.shape)
|
90 |
+
|
91 |
+
|
92 |
+
def set_locals_in_self(locals):
|
93 |
+
self = locals['self']
|
94 |
+
for var_name, val in locals.items():
|
95 |
+
if var_name != 'self': setattr(self, var_name, val)
|
96 |
+
|
97 |
+
|
98 |
+
default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'
|
99 |
+
|
100 |
+
|
101 |
+
# Copied from StackOverflow, but we do an eval on the values additionally
|
102 |
+
class StoreDictKeyPair(argparse.Action):
|
103 |
+
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
104 |
+
self._nargs = nargs
|
105 |
+
super(StoreDictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
|
106 |
+
|
107 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
108 |
+
my_dict = {}
|
109 |
+
for kv in values:
|
110 |
+
k, v = kv.split("=")
|
111 |
+
try:
|
112 |
+
my_dict[k] = eval(v)
|
113 |
+
except NameError:
|
114 |
+
my_dict[k] = v
|
115 |
+
setattr(namespace, self.dest, my_dict)
|
116 |
+
print("dict values: {}".format(my_dict))
|
117 |
+
|
118 |
+
def get_nan_value(v, set_value_to_nan=0.0):
|
119 |
+
if random.random() < set_value_to_nan:
|
120 |
+
return v
|
121 |
+
else:
|
122 |
+
return random.choice([-999, 0, 1, 999])
|
123 |
+
|
124 |
+
def to_ranking(data):
|
125 |
+
x = (data >= data.unsqueeze(-3))
|
126 |
+
x = x.sum(0)
|
127 |
+
return x
|
128 |
+
# TODO: Is there a better way to do this?
|
129 |
+
# 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
|
130 |
+
# 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
|
131 |
+
# 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast?
|
132 |
+
def to_ranking_low_mem(data):
|
133 |
+
x = torch.zeros_like(data)
|
134 |
+
for col in range(data.shape[-1]):
|
135 |
+
x_ = (data[:, :, col] >= data[:, :, col].unsqueeze(-2))
|
136 |
+
x_ = x_.sum(0)
|
137 |
+
x[:, :, col] = x_
|
138 |
+
return x
|
139 |
+
|
140 |
+
def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
|
141 |
+
return get_nan_value(float('nan'), set_value_to_nan)
|
142 |
+
|
143 |
+
def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
|
144 |
+
return get_nan_value(float('-inf'), set_value_to_nan)
|
145 |
+
|
146 |
+
def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
|
147 |
+
return get_nan_value(float('inf'), set_value_to_nan)
|
148 |
+
|
149 |
+
def torch_nanmean(x, axis=0):
|
150 |
+
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
|
151 |
+
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
152 |
+
return value / num
|
153 |
+
|
154 |
+
def torch_nanstd(x, axis=0):
|
155 |
+
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
|
156 |
+
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
157 |
+
mean = value / num
|
158 |
+
mean_broadcast = torch.repeat_interleave(mean.unsqueeze(axis), x.shape[axis], dim=axis)
|
159 |
+
return torch.sqrt(torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1))
|
160 |
+
|
161 |
+
def normalize_data(data, normalize_positions=-1):
|
162 |
+
if normalize_positions > 0:
|
163 |
+
mean = torch_nanmean(data[:normalize_positions], axis=0)
|
164 |
+
std = torch_nanstd(data[:normalize_positions], axis=0) + .000001
|
165 |
+
else:
|
166 |
+
mean = torch_nanmean(data, axis=0)
|
167 |
+
std = torch_nanstd(data, axis=0) + .000001
|
168 |
+
data = (data - mean) / std
|
169 |
+
data = torch.clip(data, min=-100, max=100)
|
170 |
+
|
171 |
+
return data
|
172 |
+
|
173 |
+
def remove_outliers(X, n_sigma=4):
|
174 |
+
# Expects T, B, H
|
175 |
+
assert len(X.shape) == 3, "X must be T,B,H"
|
176 |
+
#for b in range(X.shape[1]):
|
177 |
+
#for col in range(X.shape[2]):
|
178 |
+
data = X
|
179 |
+
data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
|
180 |
+
cut_off = data_std * n_sigma
|
181 |
+
lower, upper = data_mean - cut_off, data_mean + cut_off
|
182 |
+
|
183 |
+
data_clean = X[:].clone()
|
184 |
+
data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
|
185 |
+
data_mean, data_std = torch_nanmean(data_clean, axis=0), torch_nanstd(data_clean, axis=0)
|
186 |
+
cut_off = data_std * n_sigma
|
187 |
+
lower, upper = data_mean - cut_off, data_mean + cut_off
|
188 |
+
|
189 |
+
X = torch.maximum(-torch.log(1+torch.abs(X)) + lower, X)
|
190 |
+
X = torch.minimum(torch.log(1+torch.abs(X)) + upper, X)
|
191 |
+
# print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
|
192 |
+
return X
|
193 |
+
|
194 |
+
def bool_mask_to_att_mask(mask):
|
195 |
+
return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
196 |
+
|
197 |
+
def print_on_master_only(is_master):
|
198 |
+
import builtins as __builtin__
|
199 |
+
|
200 |
+
builtin_print = __builtin__.print
|
201 |
+
|
202 |
+
def print(*args, **kwargs):
|
203 |
+
force = kwargs.pop("force", False)
|
204 |
+
if is_master or force:
|
205 |
+
builtin_print(*args, **kwargs)
|
206 |
+
|
207 |
+
__builtin__.print = print
|
208 |
+
|
209 |
+
def init_dist(device):
|
210 |
+
if 'SLURM_PROCID' in os.environ and torch.cuda.device_count() > 1:
|
211 |
+
assert device != 'cpu:0'
|
212 |
+
rank = int(os.environ['SLURM_PROCID'])
|
213 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
214 |
+
os.environ['MASTER_PORT'] = '12355'
|
215 |
+
torch.cuda.set_device(rank)
|
216 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
217 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
|
218 |
+
world_size=torch.cuda.device_count(), rank=rank)
|
219 |
+
torch.distributed.barrier()
|
220 |
+
print_on_master_only(rank == 0)
|
221 |
+
print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
|
222 |
+
"only I can print, but when using print(..., force=True) it will print on all ranks.")
|
223 |
+
|
224 |
+
return True, rank, f'cuda:{rank}'
|
225 |
+
else:
|
226 |
+
print('Not using distributed')
|
227 |
+
# will not change any of the behavior of print, but allows putting the force=True in the print calls
|
228 |
+
print_on_master_only(True)
|
229 |
+
return False, 0, device
|
230 |
+
|
231 |
+
# NOP function for python with statements (x = NOP(); with x:)
|
232 |
+
class NOP():
|
233 |
+
def __enter__(self):
|
234 |
+
pass
|
235 |
+
def __exit__(self, type, value, traceback):
|
236 |
+
pass
|
app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
tabpfn_path = 'TabPFN'
|
3 |
+
sys.path.insert(0, tabpfn_path) # our submodule of the TabPFN repo (at 045c8400203ebd062346970b4f2c0ccda5a40618)
|
4 |
+
from TabPFN.scripts.transformer_prediction_interface import TabPFNClassifier
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import gradio as gr
|
10 |
+
import openml
|
11 |
+
|
12 |
+
|
13 |
+
def compute(table: np.array):
|
14 |
+
vfunc = np.vectorize(lambda s: len(s))
|
15 |
+
non_empty_row_mask = (vfunc(table).sum(1) != 0)
|
16 |
+
table = table[non_empty_row_mask]
|
17 |
+
empty_mask = table == ''
|
18 |
+
empty_inds = np.where(empty_mask)
|
19 |
+
if not len(empty_inds[0]):
|
20 |
+
return "**Please leave at least one field blank for prediction.**", None
|
21 |
+
if not np.all(empty_inds[1][0] == empty_inds[1]):
|
22 |
+
return "**Please only leave fields of one column blank for prediction.**", None
|
23 |
+
y_column = empty_inds[1][0]
|
24 |
+
eval_lines = empty_inds[0]
|
25 |
+
|
26 |
+
train_table = np.delete(table, eval_lines, axis=0)
|
27 |
+
eval_table = table[eval_lines]
|
28 |
+
|
29 |
+
try:
|
30 |
+
x_train = torch.tensor(np.delete(train_table, y_column, axis=1).astype(np.float32))
|
31 |
+
x_eval = torch.tensor(np.delete(eval_table, y_column, axis=1).astype(np.float32))
|
32 |
+
|
33 |
+
y_train = train_table[:, y_column]
|
34 |
+
except ValueError:
|
35 |
+
return "**Please only add numbers (to the inputs) or leave fields empty.**", None
|
36 |
+
|
37 |
+
classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu')
|
38 |
+
classifier.fit(x_train, y_train)
|
39 |
+
y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True)
|
40 |
+
|
41 |
+
# print(file, type(file))
|
42 |
+
out_table = table.copy().astype(str)
|
43 |
+
out_table[eval_lines, y_column] = [f"{y_e} (p={p_e:.2f})" for y_e, p_e in zip(y_eval, p_eval)]
|
44 |
+
return None, out_table
|
45 |
+
|
46 |
+
|
47 |
+
def upload_file(file):
|
48 |
+
if file.name.endswith('.arff'):
|
49 |
+
dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name)
|
50 |
+
X_, _, categorical_indicator_, attribute_names_ = dataset.get_data(
|
51 |
+
dataset_format="array"
|
52 |
+
)
|
53 |
+
df = pd.DataFrame(X_, columns=attribute_names_)
|
54 |
+
return df
|
55 |
+
elif file.name.endswith('.csv') or file.name.endswith('.data'):
|
56 |
+
df = pd.read_csv(file.name, header=None)
|
57 |
+
df.columns = np.arange(len(df.columns))
|
58 |
+
print(df)
|
59 |
+
return df
|
60 |
+
|
61 |
+
|
62 |
+
example = \
|
63 |
+
[
|
64 |
+
[1, 2, 1],
|
65 |
+
[2, 1, 1],
|
66 |
+
[1, 1, 1],
|
67 |
+
[2, 2, 2],
|
68 |
+
[3, 4, 2],
|
69 |
+
[3, 2, 2],
|
70 |
+
[2, 3, '']
|
71 |
+
]
|
72 |
+
|
73 |
+
with gr.Blocks() as demo:
|
74 |
+
gr.Markdown("""This demo allows you to play with the **TabPFN**.
|
75 |
+
You can either change the table manually (we have filled it with a toy benchmark, sum up to 3 has label 1 and over that label 2).
|
76 |
+
The network predicts fields you leave empty. Only one column can have empty entries that are predicted.
|
77 |
+
Please, provide everything but the label column as numeric values. It is ok to encode classes as integers.
|
78 |
+
""")
|
79 |
+
inp_table = gr.DataFrame(type='numpy', value=example, headers=[''] * 3)
|
80 |
+
inp_file = gr.File(
|
81 |
+
label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')
|
82 |
+
examples = gr.Examples(examples=['iris.csv', 'balance-scale.arff'],
|
83 |
+
inputs=[inp_file],
|
84 |
+
outputs=[inp_table],
|
85 |
+
fn=upload_file,
|
86 |
+
cache_examples=True)
|
87 |
+
btn = gr.Button("Predict Empty Table Cells")
|
88 |
+
|
89 |
+
inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)
|
90 |
+
|
91 |
+
out_text = gr.Markdown()
|
92 |
+
out_table = gr.DataFrame()
|
93 |
+
|
94 |
+
btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table])
|
95 |
+
|
96 |
+
demo.launch()
|
balance-scale.arff
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
%1. Title: Balance Scale Weight & Distance Database
|
2 |
+
%
|
3 |
+
%2. Source Information:
|
4 |
+
% (a) Source: Generated to model psychological experiments reported
|
5 |
+
% by Siegler, R. S. (1976). Three Aspects of Cognitive
|
6 |
+
% Development. Cognitive Psychology, 8, 481-520.
|
7 |
+
% (b) Donor: Tim Hume ([email protected])
|
8 |
+
% (c) Date: 22 April 1994
|
9 |
+
%
|
10 |
+
%3. Past Usage: (possibly different formats of this data)
|
11 |
+
% - Publications
|
12 |
+
% 1. Klahr, D., & Siegler, R.S. (1978). The Representation of
|
13 |
+
% Children's Knowledge. In H. W. Reese & L. P. Lipsitt (Eds.),
|
14 |
+
% Advances in Child Development and Behavior, pp. 61-116. New
|
15 |
+
% York: Academic Press
|
16 |
+
% 2. Langley,P. (1987). A General Theory of Discrimination
|
17 |
+
% Learning. In D. Klahr, P. Langley, & R. Neches (Eds.),
|
18 |
+
% Production System Models of Learning and Development, pp.
|
19 |
+
% 99-161. Cambridge, MA: MIT Press
|
20 |
+
% 3. Newell, A. (1990). Unified Theories of Cognition.
|
21 |
+
% Cambridge, MA: Harvard University Press
|
22 |
+
% 4. McClelland, J.L. (1988). Parallel Distibuted Processing:
|
23 |
+
% Implications for Cognition and Development. Technical
|
24 |
+
% Report AIP-47, Department of Psychology, Carnegie-Mellon
|
25 |
+
% University
|
26 |
+
% 5. Shultz, T., Mareschal, D., & Schmidt, W. (1994). Modeling
|
27 |
+
% Cognitive Development on Balance Scale Phenomena. Machine
|
28 |
+
% Learning, Vol. 16, pp. 59-88.
|
29 |
+
%
|
30 |
+
%4. Relevant Information:
|
31 |
+
% This data set was generated to model psychological
|
32 |
+
% experimental results. Each example is classified as having the
|
33 |
+
% balance scale tip to the right, tip to the left, or be
|
34 |
+
% balanced. The attributes are the left weight, the left
|
35 |
+
% distance, the right weight, and the right distance. The
|
36 |
+
% correct way to find the class is the greater of
|
37 |
+
% (left-distance * left-weight) and (right-distance *
|
38 |
+
% right-weight). If they are equal, it is balanced.
|
39 |
+
%
|
40 |
+
%5. Number of Instances: 625 (49 balanced, 288 left, 288 right)
|
41 |
+
%
|
42 |
+
%6. Number of Attributes: 4 (numeric) + class name = 5
|
43 |
+
%
|
44 |
+
%7. Attribute Information:
|
45 |
+
% 1. Class Name: 3 (L, B, R)
|
46 |
+
% 2. Left-Weight: 5 (1, 2, 3, 4, 5)
|
47 |
+
% 3. Left-Distance: 5 (1, 2, 3, 4, 5)
|
48 |
+
% 4. Right-Weight: 5 (1, 2, 3, 4, 5)
|
49 |
+
% 5. Right-Distance: 5 (1, 2, 3, 4, 5)
|
50 |
+
%
|
51 |
+
%8. Missing Attribute Values:
|
52 |
+
% none
|
53 |
+
%
|
54 |
+
%9. Class Distribution:
|
55 |
+
% 1. 46.08 percent are L
|
56 |
+
% 2. 07.84 percent are B
|
57 |
+
% 3. 46.08 percent are R
|
58 |
+
%
|
59 |
+
|
60 |
+
@relation balance-scale
|
61 |
+
@attribute 'left-weight' real
|
62 |
+
@attribute 'left-distance' real
|
63 |
+
@attribute 'right-weight' real
|
64 |
+
@attribute 'right-distance' real
|
65 |
+
@attribute 'class' { L, B, R}
|
66 |
+
@data
|
67 |
+
1,1,1,1,B
|
68 |
+
1,1,1,2,R
|
69 |
+
1,1,1,3,R
|
70 |
+
1,1,1,4,R
|
71 |
+
1,1,1,5,R
|
72 |
+
1,1,2,1,R
|
73 |
+
1,1,2,2,R
|
74 |
+
1,1,2,3,R
|
75 |
+
1,1,2,4,R
|
76 |
+
1,1,2,5,R
|
77 |
+
1,1,3,1,R
|
78 |
+
1,1,3,2,R
|
79 |
+
1,1,3,3,R
|
80 |
+
1,1,3,4,R
|
81 |
+
1,1,3,5,R
|
82 |
+
1,1,4,1,R
|
83 |
+
1,1,4,2,R
|
84 |
+
1,1,4,3,R
|
85 |
+
1,1,4,4,R
|
86 |
+
1,1,4,5,R
|
87 |
+
1,1,5,1,R
|
88 |
+
1,1,5,2,R
|
89 |
+
1,1,5,3,R
|
90 |
+
1,1,5,4,R
|
91 |
+
1,1,5,5,R
|
92 |
+
1,2,1,1,L
|
93 |
+
1,2,1,2,B
|
94 |
+
1,2,1,3,R
|
95 |
+
1,2,1,4,R
|
96 |
+
1,2,1,5,R
|
97 |
+
1,2,2,1,B
|
98 |
+
1,2,2,2,R
|
99 |
+
1,2,2,3,R
|
100 |
+
1,2,2,4,R
|
101 |
+
1,2,2,5,R
|
102 |
+
1,2,3,1,R
|
103 |
+
1,2,3,2,R
|
104 |
+
1,2,3,3,R
|
105 |
+
1,2,3,4,R
|
106 |
+
1,2,3,5,R
|
107 |
+
1,2,4,1,R
|
108 |
+
1,2,4,2,R
|
109 |
+
1,2,4,3,R
|
110 |
+
1,2,4,4,R
|
111 |
+
1,2,4,5,R
|
112 |
+
1,2,5,1,R
|
113 |
+
1,2,5,2,R
|
114 |
+
1,2,5,3,R
|
115 |
+
1,2,5,4,R
|
116 |
+
1,2,5,5,R
|
117 |
+
1,3,1,1,L
|
118 |
+
1,3,1,2,L
|
119 |
+
1,3,1,3,B
|
120 |
+
1,3,1,4,R
|
121 |
+
1,3,1,5,R
|
122 |
+
1,3,2,1,L
|
123 |
+
1,3,2,2,R
|
124 |
+
1,3,2,3,R
|
125 |
+
1,3,2,4,R
|
126 |
+
1,3,2,5,R
|
127 |
+
1,3,3,1,B
|
128 |
+
1,3,3,2,R
|
129 |
+
1,3,3,3,R
|
130 |
+
1,3,3,4,R
|
131 |
+
1,3,3,5,R
|
132 |
+
1,3,4,1,R
|
133 |
+
1,3,4,2,R
|
134 |
+
1,3,4,3,R
|
135 |
+
1,3,4,4,R
|
136 |
+
1,3,4,5,R
|
137 |
+
1,3,5,1,R
|
138 |
+
1,3,5,2,R
|
139 |
+
1,3,5,3,R
|
140 |
+
1,3,5,4,R
|
141 |
+
1,3,5,5,R
|
142 |
+
1,4,1,1,L
|
143 |
+
1,4,1,2,L
|
144 |
+
1,4,1,3,L
|
145 |
+
1,4,1,4,B
|
146 |
+
1,4,1,5,R
|
147 |
+
1,4,2,1,L
|
148 |
+
1,4,2,2,B
|
149 |
+
1,4,2,3,R
|
150 |
+
1,4,2,4,R
|
151 |
+
1,4,2,5,R
|
152 |
+
1,4,3,1,L
|
153 |
+
1,4,3,2,R
|
154 |
+
1,4,3,3,R
|
155 |
+
1,4,3,4,R
|
156 |
+
1,4,3,5,R
|
157 |
+
1,4,4,1,B
|
158 |
+
1,4,4,2,R
|
159 |
+
1,4,4,3,R
|
160 |
+
1,4,4,4,R
|
161 |
+
1,4,4,5,R
|
162 |
+
1,4,5,1,R
|
163 |
+
1,4,5,2,R
|
164 |
+
1,4,5,3,R
|
165 |
+
1,4,5,4,R
|
166 |
+
1,4,5,5,R
|
167 |
+
1,5,1,1,L
|
168 |
+
1,5,1,2,L
|
169 |
+
1,5,1,3,L
|
170 |
+
1,5,1,4,L
|
171 |
+
1,5,1,5,B
|
172 |
+
1,5,2,1,L
|
173 |
+
1,5,2,2,L
|
174 |
+
1,5,2,3,R
|
175 |
+
1,5,2,4,R
|
176 |
+
1,5,2,5,R
|
177 |
+
1,5,3,1,L
|
178 |
+
1,5,3,2,R
|
179 |
+
1,5,3,3,R
|
180 |
+
1,5,3,4,R
|
181 |
+
1,5,3,5,R
|
182 |
+
1,5,4,1,L
|
183 |
+
1,5,4,2,R
|
184 |
+
1,5,4,3,R
|
185 |
+
1,5,4,4,R
|
186 |
+
1,5,4,5,R
|
187 |
+
1,5,5,1,B
|
188 |
+
1,5,5,2,R
|
189 |
+
1,5,5,3,R
|
190 |
+
1,5,5,4,R
|
191 |
+
1,5,5,5,R
|
192 |
+
2,1,1,1,L
|
193 |
+
2,1,1,2,B
|
194 |
+
2,1,1,3,R
|
195 |
+
2,1,1,4,R
|
196 |
+
2,1,1,5,R
|
197 |
+
2,1,2,1,B
|
198 |
+
2,1,2,2,R
|
199 |
+
2,1,2,3,R
|
200 |
+
2,1,2,4,R
|
201 |
+
2,1,2,5,R
|
202 |
+
2,1,3,1,R
|
203 |
+
2,1,3,2,R
|
204 |
+
2,1,3,3,R
|
205 |
+
2,1,3,4,R
|
206 |
+
2,1,3,5,R
|
207 |
+
2,1,4,1,R
|
208 |
+
2,1,4,2,R
|
209 |
+
2,1,4,3,R
|
210 |
+
2,1,4,4,R
|
211 |
+
2,1,4,5,R
|
212 |
+
2,1,5,1,R
|
213 |
+
2,1,5,2,R
|
214 |
+
2,1,5,3,R
|
215 |
+
2,1,5,4,R
|
216 |
+
2,1,5,5,R
|
217 |
+
2,2,1,1,L
|
218 |
+
2,2,1,2,L
|
219 |
+
2,2,1,3,L
|
220 |
+
2,2,1,4,B
|
221 |
+
2,2,1,5,R
|
222 |
+
2,2,2,1,L
|
223 |
+
2,2,2,2,B
|
224 |
+
2,2,2,3,R
|
225 |
+
2,2,2,4,R
|
226 |
+
2,2,2,5,R
|
227 |
+
2,2,3,1,L
|
228 |
+
2,2,3,2,R
|
229 |
+
2,2,3,3,R
|
230 |
+
2,2,3,4,R
|
231 |
+
2,2,3,5,R
|
232 |
+
2,2,4,1,B
|
233 |
+
2,2,4,2,R
|
234 |
+
2,2,4,3,R
|
235 |
+
2,2,4,4,R
|
236 |
+
2,2,4,5,R
|
237 |
+
2,2,5,1,R
|
238 |
+
2,2,5,2,R
|
239 |
+
2,2,5,3,R
|
240 |
+
2,2,5,4,R
|
241 |
+
2,2,5,5,R
|
242 |
+
2,3,1,1,L
|
243 |
+
2,3,1,2,L
|
244 |
+
2,3,1,3,L
|
245 |
+
2,3,1,4,L
|
246 |
+
2,3,1,5,L
|
247 |
+
2,3,2,1,L
|
248 |
+
2,3,2,2,L
|
249 |
+
2,3,2,3,B
|
250 |
+
2,3,2,4,R
|
251 |
+
2,3,2,5,R
|
252 |
+
2,3,3,1,L
|
253 |
+
2,3,3,2,B
|
254 |
+
2,3,3,3,R
|
255 |
+
2,3,3,4,R
|
256 |
+
2,3,3,5,R
|
257 |
+
2,3,4,1,L
|
258 |
+
2,3,4,2,R
|
259 |
+
2,3,4,3,R
|
260 |
+
2,3,4,4,R
|
261 |
+
2,3,4,5,R
|
262 |
+
2,3,5,1,L
|
263 |
+
2,3,5,2,R
|
264 |
+
2,3,5,3,R
|
265 |
+
2,3,5,4,R
|
266 |
+
2,3,5,5,R
|
267 |
+
2,4,1,1,L
|
268 |
+
2,4,1,2,L
|
269 |
+
2,4,1,3,L
|
270 |
+
2,4,1,4,L
|
271 |
+
2,4,1,5,L
|
272 |
+
2,4,2,1,L
|
273 |
+
2,4,2,2,L
|
274 |
+
2,4,2,3,L
|
275 |
+
2,4,2,4,B
|
276 |
+
2,4,2,5,R
|
277 |
+
2,4,3,1,L
|
278 |
+
2,4,3,2,L
|
279 |
+
2,4,3,3,R
|
280 |
+
2,4,3,4,R
|
281 |
+
2,4,3,5,R
|
282 |
+
2,4,4,1,L
|
283 |
+
2,4,4,2,B
|
284 |
+
2,4,4,3,R
|
285 |
+
2,4,4,4,R
|
286 |
+
2,4,4,5,R
|
287 |
+
2,4,5,1,L
|
288 |
+
2,4,5,2,R
|
289 |
+
2,4,5,3,R
|
290 |
+
2,4,5,4,R
|
291 |
+
2,4,5,5,R
|
292 |
+
2,5,1,1,L
|
293 |
+
2,5,1,2,L
|
294 |
+
2,5,1,3,L
|
295 |
+
2,5,1,4,L
|
296 |
+
2,5,1,5,L
|
297 |
+
2,5,2,1,L
|
298 |
+
2,5,2,2,L
|
299 |
+
2,5,2,3,L
|
300 |
+
2,5,2,4,L
|
301 |
+
2,5,2,5,B
|
302 |
+
2,5,3,1,L
|
303 |
+
2,5,3,2,L
|
304 |
+
2,5,3,3,L
|
305 |
+
2,5,3,4,R
|
306 |
+
2,5,3,5,R
|
307 |
+
2,5,4,1,L
|
308 |
+
2,5,4,2,L
|
309 |
+
2,5,4,3,R
|
310 |
+
2,5,4,4,R
|
311 |
+
2,5,4,5,R
|
312 |
+
2,5,5,1,L
|
313 |
+
2,5,5,2,B
|
314 |
+
2,5,5,3,R
|
315 |
+
2,5,5,4,R
|
316 |
+
2,5,5,5,R
|
317 |
+
3,1,1,1,L
|
318 |
+
3,1,1,2,L
|
319 |
+
3,1,1,3,B
|
320 |
+
3,1,1,4,R
|
321 |
+
3,1,1,5,R
|
322 |
+
3,1,2,1,L
|
323 |
+
3,1,2,2,R
|
324 |
+
3,1,2,3,R
|
325 |
+
3,1,2,4,R
|
326 |
+
3,1,2,5,R
|
327 |
+
3,1,3,1,B
|
328 |
+
3,1,3,2,R
|
329 |
+
3,1,3,3,R
|
330 |
+
3,1,3,4,R
|
331 |
+
3,1,3,5,R
|
332 |
+
3,1,4,1,R
|
333 |
+
3,1,4,2,R
|
334 |
+
3,1,4,3,R
|
335 |
+
3,1,4,4,R
|
336 |
+
3,1,4,5,R
|
337 |
+
3,1,5,1,R
|
338 |
+
3,1,5,2,R
|
339 |
+
3,1,5,3,R
|
340 |
+
3,1,5,4,R
|
341 |
+
3,1,5,5,R
|
342 |
+
3,2,1,1,L
|
343 |
+
3,2,1,2,L
|
344 |
+
3,2,1,3,L
|
345 |
+
3,2,1,4,L
|
346 |
+
3,2,1,5,L
|
347 |
+
3,2,2,1,L
|
348 |
+
3,2,2,2,L
|
349 |
+
3,2,2,3,B
|
350 |
+
3,2,2,4,R
|
351 |
+
3,2,2,5,R
|
352 |
+
3,2,3,1,L
|
353 |
+
3,2,3,2,B
|
354 |
+
3,2,3,3,R
|
355 |
+
3,2,3,4,R
|
356 |
+
3,2,3,5,R
|
357 |
+
3,2,4,1,L
|
358 |
+
3,2,4,2,R
|
359 |
+
3,2,4,3,R
|
360 |
+
3,2,4,4,R
|
361 |
+
3,2,4,5,R
|
362 |
+
3,2,5,1,L
|
363 |
+
3,2,5,2,R
|
364 |
+
3,2,5,3,R
|
365 |
+
3,2,5,4,R
|
366 |
+
3,2,5,5,R
|
367 |
+
3,3,1,1,L
|
368 |
+
3,3,1,2,L
|
369 |
+
3,3,1,3,L
|
370 |
+
3,3,1,4,L
|
371 |
+
3,3,1,5,L
|
372 |
+
3,3,2,1,L
|
373 |
+
3,3,2,2,L
|
374 |
+
3,3,2,3,L
|
375 |
+
3,3,2,4,L
|
376 |
+
3,3,2,5,R
|
377 |
+
3,3,3,1,L
|
378 |
+
3,3,3,2,L
|
379 |
+
3,3,3,3,B
|
380 |
+
3,3,3,4,R
|
381 |
+
3,3,3,5,R
|
382 |
+
3,3,4,1,L
|
383 |
+
3,3,4,2,L
|
384 |
+
3,3,4,3,R
|
385 |
+
3,3,4,4,R
|
386 |
+
3,3,4,5,R
|
387 |
+
3,3,5,1,L
|
388 |
+
3,3,5,2,R
|
389 |
+
3,3,5,3,R
|
390 |
+
3,3,5,4,R
|
391 |
+
3,3,5,5,R
|
392 |
+
3,4,1,1,L
|
393 |
+
3,4,1,2,L
|
394 |
+
3,4,1,3,L
|
395 |
+
3,4,1,4,L
|
396 |
+
3,4,1,5,L
|
397 |
+
3,4,2,1,L
|
398 |
+
3,4,2,2,L
|
399 |
+
3,4,2,3,L
|
400 |
+
3,4,2,4,L
|
401 |
+
3,4,2,5,L
|
402 |
+
3,4,3,1,L
|
403 |
+
3,4,3,2,L
|
404 |
+
3,4,3,3,L
|
405 |
+
3,4,3,4,B
|
406 |
+
3,4,3,5,R
|
407 |
+
3,4,4,1,L
|
408 |
+
3,4,4,2,L
|
409 |
+
3,4,4,3,B
|
410 |
+
3,4,4,4,R
|
411 |
+
3,4,4,5,R
|
412 |
+
3,4,5,1,L
|
413 |
+
3,4,5,2,L
|
414 |
+
3,4,5,3,R
|
415 |
+
3,4,5,4,R
|
416 |
+
3,4,5,5,R
|
417 |
+
3,5,1,1,L
|
418 |
+
3,5,1,2,L
|
419 |
+
3,5,1,3,L
|
420 |
+
3,5,1,4,L
|
421 |
+
3,5,1,5,L
|
422 |
+
3,5,2,1,L
|
423 |
+
3,5,2,2,L
|
424 |
+
3,5,2,3,L
|
425 |
+
3,5,2,4,L
|
426 |
+
3,5,2,5,L
|
427 |
+
3,5,3,1,L
|
428 |
+
3,5,3,2,L
|
429 |
+
3,5,3,3,L
|
430 |
+
3,5,3,4,L
|
431 |
+
3,5,3,5,B
|
432 |
+
3,5,4,1,L
|
433 |
+
3,5,4,2,L
|
434 |
+
3,5,4,3,L
|
435 |
+
3,5,4,4,R
|
436 |
+
3,5,4,5,R
|
437 |
+
3,5,5,1,L
|
438 |
+
3,5,5,2,L
|
439 |
+
3,5,5,3,B
|
440 |
+
3,5,5,4,R
|
441 |
+
3,5,5,5,R
|
442 |
+
4,1,1,1,L
|
443 |
+
4,1,1,2,L
|
444 |
+
4,1,1,3,L
|
445 |
+
4,1,1,4,B
|
446 |
+
4,1,1,5,R
|
447 |
+
4,1,2,1,L
|
448 |
+
4,1,2,2,B
|
449 |
+
4,1,2,3,R
|
450 |
+
4,1,2,4,R
|
451 |
+
4,1,2,5,R
|
452 |
+
4,1,3,1,L
|
453 |
+
4,1,3,2,R
|
454 |
+
4,1,3,3,R
|
455 |
+
4,1,3,4,R
|
456 |
+
4,1,3,5,R
|
457 |
+
4,1,4,1,B
|
458 |
+
4,1,4,2,R
|
459 |
+
4,1,4,3,R
|
460 |
+
4,1,4,4,R
|
461 |
+
4,1,4,5,R
|
462 |
+
4,1,5,1,R
|
463 |
+
4,1,5,2,R
|
464 |
+
4,1,5,3,R
|
465 |
+
4,1,5,4,R
|
466 |
+
4,1,5,5,R
|
467 |
+
4,2,1,1,L
|
468 |
+
4,2,1,2,L
|
469 |
+
4,2,1,3,L
|
470 |
+
4,2,1,4,L
|
471 |
+
4,2,1,5,L
|
472 |
+
4,2,2,1,L
|
473 |
+
4,2,2,2,L
|
474 |
+
4,2,2,3,L
|
475 |
+
4,2,2,4,B
|
476 |
+
4,2,2,5,R
|
477 |
+
4,2,3,1,L
|
478 |
+
4,2,3,2,L
|
479 |
+
4,2,3,3,R
|
480 |
+
4,2,3,4,R
|
481 |
+
4,2,3,5,R
|
482 |
+
4,2,4,1,L
|
483 |
+
4,2,4,2,B
|
484 |
+
4,2,4,3,R
|
485 |
+
4,2,4,4,R
|
486 |
+
4,2,4,5,R
|
487 |
+
4,2,5,1,L
|
488 |
+
4,2,5,2,R
|
489 |
+
4,2,5,3,R
|
490 |
+
4,2,5,4,R
|
491 |
+
4,2,5,5,R
|
492 |
+
4,3,1,1,L
|
493 |
+
4,3,1,2,L
|
494 |
+
4,3,1,3,L
|
495 |
+
4,3,1,4,L
|
496 |
+
4,3,1,5,L
|
497 |
+
4,3,2,1,L
|
498 |
+
4,3,2,2,L
|
499 |
+
4,3,2,3,L
|
500 |
+
4,3,2,4,L
|
501 |
+
4,3,2,5,L
|
502 |
+
4,3,3,1,L
|
503 |
+
4,3,3,2,L
|
504 |
+
4,3,3,3,L
|
505 |
+
4,3,3,4,B
|
506 |
+
4,3,3,5,R
|
507 |
+
4,3,4,1,L
|
508 |
+
4,3,4,2,L
|
509 |
+
4,3,4,3,B
|
510 |
+
4,3,4,4,R
|
511 |
+
4,3,4,5,R
|
512 |
+
4,3,5,1,L
|
513 |
+
4,3,5,2,L
|
514 |
+
4,3,5,3,R
|
515 |
+
4,3,5,4,R
|
516 |
+
4,3,5,5,R
|
517 |
+
4,4,1,1,L
|
518 |
+
4,4,1,2,L
|
519 |
+
4,4,1,3,L
|
520 |
+
4,4,1,4,L
|
521 |
+
4,4,1,5,L
|
522 |
+
4,4,2,1,L
|
523 |
+
4,4,2,2,L
|
524 |
+
4,4,2,3,L
|
525 |
+
4,4,2,4,L
|
526 |
+
4,4,2,5,L
|
527 |
+
4,4,3,1,L
|
528 |
+
4,4,3,2,L
|
529 |
+
4,4,3,3,L
|
530 |
+
4,4,3,4,L
|
531 |
+
4,4,3,5,L
|
532 |
+
4,4,4,1,L
|
533 |
+
4,4,4,2,L
|
534 |
+
4,4,4,3,L
|
535 |
+
4,4,4,4,B
|
536 |
+
4,4,4,5,R
|
537 |
+
4,4,5,1,L
|
538 |
+
4,4,5,2,L
|
539 |
+
4,4,5,3,L
|
540 |
+
4,4,5,4,R
|
541 |
+
4,4,5,5,R
|
542 |
+
4,5,1,1,L
|
543 |
+
4,5,1,2,L
|
544 |
+
4,5,1,3,L
|
545 |
+
4,5,1,4,L
|
546 |
+
4,5,1,5,L
|
547 |
+
4,5,2,1,L
|
548 |
+
4,5,2,2,L
|
549 |
+
4,5,2,3,L
|
550 |
+
4,5,2,4,L
|
551 |
+
4,5,2,5,L
|
552 |
+
4,5,3,1,L
|
553 |
+
4,5,3,2,L
|
554 |
+
4,5,3,3,L
|
555 |
+
4,5,3,4,L
|
556 |
+
4,5,3,5,L
|
557 |
+
4,5,4,1,L
|
558 |
+
4,5,4,2,L
|
559 |
+
4,5,4,3,L
|
560 |
+
4,5,4,4,L
|
561 |
+
4,5,4,5,B
|
562 |
+
4,5,5,1,L
|
563 |
+
4,5,5,2,L
|
564 |
+
4,5,5,3,L
|
565 |
+
4,5,5,4,B
|
566 |
+
4,5,5,5,R
|
567 |
+
5,1,1,1,L
|
568 |
+
5,1,1,2,L
|
569 |
+
5,1,1,3,L
|
570 |
+
5,1,1,4,L
|
571 |
+
5,1,1,5,B
|
572 |
+
5,1,2,1,L
|
573 |
+
5,1,2,2,L
|
574 |
+
5,1,2,3,R
|
575 |
+
5,1,2,4,R
|
576 |
+
5,1,2,5,R
|
577 |
+
5,1,3,1,L
|
578 |
+
5,1,3,2,R
|
579 |
+
5,1,3,3,R
|
580 |
+
5,1,3,4,R
|
581 |
+
5,1,3,5,R
|
582 |
+
5,1,4,1,L
|
583 |
+
5,1,4,2,R
|
584 |
+
5,1,4,3,R
|
585 |
+
5,1,4,4,R
|
586 |
+
5,1,4,5,R
|
587 |
+
5,1,5,1,B
|
588 |
+
5,1,5,2,R
|
589 |
+
5,1,5,3,R
|
590 |
+
5,1,5,4,R
|
591 |
+
5,1,5,5,R
|
592 |
+
5,2,1,1,L
|
593 |
+
5,2,1,2,L
|
594 |
+
5,2,1,3,L
|
595 |
+
5,2,1,4,L
|
596 |
+
5,2,1,5,L
|
597 |
+
5,2,2,1,L
|
598 |
+
5,2,2,2,L
|
599 |
+
5,2,2,3,L
|
600 |
+
5,2,2,4,L
|
601 |
+
5,2,2,5,B
|
602 |
+
5,2,3,1,L
|
603 |
+
5,2,3,2,L
|
604 |
+
5,2,3,3,L
|
605 |
+
5,2,3,4,R
|
606 |
+
5,2,3,5,R
|
607 |
+
5,2,4,1,L
|
608 |
+
5,2,4,2,L
|
609 |
+
5,2,4,3,R
|
610 |
+
5,2,4,4,R
|
611 |
+
5,2,4,5,R
|
612 |
+
5,2,5,1,L
|
613 |
+
5,2,5,2,B
|
614 |
+
5,2,5,3,R
|
615 |
+
5,2,5,4,R
|
616 |
+
5,2,5,5,R
|
617 |
+
5,3,1,1,L
|
618 |
+
5,3,1,2,L
|
619 |
+
5,3,1,3,L
|
620 |
+
5,3,1,4,L
|
621 |
+
5,3,1,5,L
|
622 |
+
5,3,2,1,L
|
623 |
+
5,3,2,2,L
|
624 |
+
5,3,2,3,L
|
625 |
+
5,3,2,4,L
|
626 |
+
5,3,2,5,L
|
627 |
+
5,3,3,1,L
|
628 |
+
5,3,3,2,L
|
629 |
+
5,3,3,3,L
|
630 |
+
5,3,3,4,L
|
631 |
+
5,3,3,5,B
|
632 |
+
5,3,4,1,L
|
633 |
+
5,3,4,2,L
|
634 |
+
5,3,4,3,L
|
635 |
+
5,3,4,4,R
|
636 |
+
5,3,4,5,R
|
637 |
+
5,3,5,1,L
|
638 |
+
5,3,5,2,L
|
639 |
+
5,3,5,3,B
|
640 |
+
5,3,5,4,R
|
641 |
+
5,3,5,5,R
|
642 |
+
5,4,1,1,L
|
643 |
+
5,4,1,2,L
|
644 |
+
5,4,1,3,L
|
645 |
+
5,4,1,4,L
|
646 |
+
5,4,1,5,L
|
647 |
+
5,4,2,1,L
|
648 |
+
5,4,2,2,L
|
649 |
+
5,4,2,3,L
|
650 |
+
5,4,2,4,L
|
651 |
+
5,4,2,5,L
|
652 |
+
5,4,3,1,L
|
653 |
+
5,4,3,2,L
|
654 |
+
5,4,3,3,L
|
655 |
+
5,4,3,4,L
|
656 |
+
5,4,3,5,L
|
657 |
+
5,4,4,1,L
|
658 |
+
5,4,4,2,L
|
659 |
+
5,4,4,3,L
|
660 |
+
5,4,4,4,L
|
661 |
+
5,4,4,5,B
|
662 |
+
5,4,5,1,L
|
663 |
+
5,4,5,2,L
|
664 |
+
5,4,5,3,L
|
665 |
+
5,4,5,4,B
|
666 |
+
5,4,5,5,R
|
667 |
+
5,5,1,1,L
|
668 |
+
5,5,1,2,L
|
669 |
+
5,5,1,3,L
|
670 |
+
5,5,1,4,L
|
671 |
+
5,5,1,5,L
|
672 |
+
5,5,2,1,L
|
673 |
+
5,5,2,2,L
|
674 |
+
5,5,2,3,L
|
675 |
+
5,5,2,4,L
|
676 |
+
5,5,2,5,L
|
677 |
+
5,5,3,1,L
|
678 |
+
5,5,3,2,L
|
679 |
+
5,5,3,3,L
|
680 |
+
5,5,3,4,L
|
681 |
+
5,5,3,5,L
|
682 |
+
5,5,4,1,L
|
683 |
+
5,5,4,2,L
|
684 |
+
5,5,4,3,L
|
685 |
+
5,5,4,4,L
|
686 |
+
5,5,4,5,L
|
687 |
+
5,5,5,1,L
|
688 |
+
5,5,5,2,L
|
689 |
+
5,5,5,3,L
|
690 |
+
5,5,5,4,L
|
691 |
+
5,5,5,5,B
|
692 |
+
%
|
693 |
+
%
|
694 |
+
%
|
iris.csv
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
5.1,3.5,1.4,0.2,Iris-setosa
|
2 |
+
4.9,3.0,1.4,0.2,Iris-setosa
|
3 |
+
4.7,3.2,1.3,0.2,Iris-setosa
|
4 |
+
4.6,3.1,1.5,0.2,Iris-setosa
|
5 |
+
5.0,3.6,1.4,0.2,Iris-setosa
|
6 |
+
5.4,3.9,1.7,0.4,Iris-setosa
|
7 |
+
4.6,3.4,1.4,0.3,Iris-setosa
|
8 |
+
5.0,3.4,1.5,0.2,Iris-setosa
|
9 |
+
4.4,2.9,1.4,0.2,Iris-setosa
|
10 |
+
4.9,3.1,1.5,0.1,Iris-setosa
|
11 |
+
5.4,3.7,1.5,0.2,Iris-setosa
|
12 |
+
4.8,3.4,1.6,0.2,Iris-setosa
|
13 |
+
4.8,3.0,1.4,0.1,Iris-setosa
|
14 |
+
4.3,3.0,1.1,0.1,Iris-setosa
|
15 |
+
5.8,4.0,1.2,0.2,Iris-setosa
|
16 |
+
5.7,4.4,1.5,0.4,Iris-setosa
|
17 |
+
5.4,3.9,1.3,0.4,Iris-setosa
|
18 |
+
5.1,3.5,1.4,0.3,Iris-setosa
|
19 |
+
5.7,3.8,1.7,0.3,Iris-setosa
|
20 |
+
5.1,3.8,1.5,0.3,Iris-setosa
|
21 |
+
5.4,3.4,1.7,0.2,Iris-setosa
|
22 |
+
5.1,3.7,1.5,0.4,Iris-setosa
|
23 |
+
4.6,3.6,1.0,0.2,Iris-setosa
|
24 |
+
5.1,3.3,1.7,0.5,Iris-setosa
|
25 |
+
4.8,3.4,1.9,0.2,Iris-setosa
|
26 |
+
5.0,3.0,1.6,0.2,Iris-setosa
|
27 |
+
5.0,3.4,1.6,0.4,Iris-setosa
|
28 |
+
5.2,3.5,1.5,0.2,Iris-setosa
|
29 |
+
5.2,3.4,1.4,0.2,Iris-setosa
|
30 |
+
4.7,3.2,1.6,0.2,Iris-setosa
|
31 |
+
4.8,3.1,1.6,0.2,Iris-setosa
|
32 |
+
5.4,3.4,1.5,0.4,Iris-setosa
|
33 |
+
5.2,4.1,1.5,0.1,Iris-setosa
|
34 |
+
5.5,4.2,1.4,0.2,Iris-setosa
|
35 |
+
4.9,3.1,1.5,0.1,Iris-setosa
|
36 |
+
5.0,3.2,1.2,0.2,Iris-setosa
|
37 |
+
5.5,3.5,1.3,0.2,Iris-setosa
|
38 |
+
4.9,3.1,1.5,0.1,Iris-setosa
|
39 |
+
4.4,3.0,1.3,0.2,Iris-setosa
|
40 |
+
5.1,3.4,1.5,0.2,Iris-setosa
|
41 |
+
5.0,3.5,1.3,0.3,Iris-setosa
|
42 |
+
4.5,2.3,1.3,0.3,Iris-setosa
|
43 |
+
4.4,3.2,1.3,0.2,Iris-setosa
|
44 |
+
5.0,3.5,1.6,0.6,Iris-setosa
|
45 |
+
5.1,3.8,1.9,0.4,Iris-setosa
|
46 |
+
4.8,3.0,1.4,0.3,Iris-setosa
|
47 |
+
5.1,3.8,1.6,0.2,Iris-setosa
|
48 |
+
4.6,3.2,1.4,0.2,Iris-setosa
|
49 |
+
5.3,3.7,1.5,0.2,Iris-setosa
|
50 |
+
5.0,3.3,1.4,0.2,Iris-setosa
|
51 |
+
7.0,3.2,4.7,1.4,Iris-versicolor
|
52 |
+
6.4,3.2,4.5,1.5,Iris-versicolor
|
53 |
+
6.9,3.1,4.9,1.5,Iris-versicolor
|
54 |
+
5.5,2.3,4.0,1.3,Iris-versicolor
|
55 |
+
6.5,2.8,4.6,1.5,Iris-versicolor
|
56 |
+
5.7,2.8,4.5,1.3,Iris-versicolor
|
57 |
+
6.3,3.3,4.7,1.6,Iris-versicolor
|
58 |
+
4.9,2.4,3.3,1.0,Iris-versicolor
|
59 |
+
6.6,2.9,4.6,1.3,Iris-versicolor
|
60 |
+
5.2,2.7,3.9,1.4,Iris-versicolor
|
61 |
+
5.0,2.0,3.5,1.0,Iris-versicolor
|
62 |
+
5.9,3.0,4.2,1.5,Iris-versicolor
|
63 |
+
6.0,2.2,4.0,1.0,Iris-versicolor
|
64 |
+
6.1,2.9,4.7,1.4,Iris-versicolor
|
65 |
+
5.6,2.9,3.6,1.3,Iris-versicolor
|
66 |
+
6.7,3.1,4.4,1.4,Iris-versicolor
|
67 |
+
5.6,3.0,4.5,1.5,Iris-versicolor
|
68 |
+
5.8,2.7,4.1,1.0,Iris-versicolor
|
69 |
+
6.2,2.2,4.5,1.5,Iris-versicolor
|
70 |
+
5.6,2.5,3.9,1.1,Iris-versicolor
|
71 |
+
5.9,3.2,4.8,1.8,Iris-versicolor
|
72 |
+
6.1,2.8,4.0,1.3,Iris-versicolor
|
73 |
+
6.3,2.5,4.9,1.5,Iris-versicolor
|
74 |
+
6.1,2.8,4.7,1.2,Iris-versicolor
|
75 |
+
6.4,2.9,4.3,1.3,Iris-versicolor
|
76 |
+
6.6,3.0,4.4,1.4,Iris-versicolor
|
77 |
+
6.8,2.8,4.8,1.4,Iris-versicolor
|
78 |
+
6.7,3.0,5.0,1.7,Iris-versicolor
|
79 |
+
6.0,2.9,4.5,1.5,Iris-versicolor
|
80 |
+
5.7,2.6,3.5,1.0,Iris-versicolor
|
81 |
+
5.5,2.4,3.8,1.1,Iris-versicolor
|
82 |
+
5.5,2.4,3.7,1.0,Iris-versicolor
|
83 |
+
5.8,2.7,3.9,1.2,Iris-versicolor
|
84 |
+
6.0,2.7,5.1,1.6,Iris-versicolor
|
85 |
+
5.4,3.0,4.5,1.5,Iris-versicolor
|
86 |
+
6.0,3.4,4.5,1.6,Iris-versicolor
|
87 |
+
6.7,3.1,4.7,1.5,Iris-versicolor
|
88 |
+
6.3,2.3,4.4,1.3,Iris-versicolor
|
89 |
+
5.6,3.0,4.1,1.3,Iris-versicolor
|
90 |
+
5.5,2.5,4.0,1.3,Iris-versicolor
|
91 |
+
5.5,2.6,4.4,1.2,Iris-versicolor
|
92 |
+
6.1,3.0,4.6,1.4,Iris-versicolor
|
93 |
+
5.8,2.6,4.0,1.2,Iris-versicolor
|
94 |
+
5.0,2.3,3.3,1.0,Iris-versicolor
|
95 |
+
5.6,2.7,4.2,1.3,Iris-versicolor
|
96 |
+
5.7,3.0,4.2,1.2,Iris-versicolor
|
97 |
+
5.7,2.9,4.2,1.3,Iris-versicolor
|
98 |
+
6.2,2.9,4.3,1.3,Iris-versicolor
|
99 |
+
5.1,2.5,3.0,1.1,Iris-versicolor
|
100 |
+
5.7,2.8,4.1,1.3,Iris-versicolor
|
101 |
+
6.3,3.3,6.0,2.5,Iris-virginica
|
102 |
+
5.8,2.7,5.1,1.9,Iris-virginica
|
103 |
+
7.1,3.0,5.9,2.1,Iris-virginica
|
104 |
+
6.3,2.9,5.6,1.8,Iris-virginica
|
105 |
+
6.5,3.0,5.8,2.2,Iris-virginica
|
106 |
+
7.6,3.0,6.6,2.1,Iris-virginica
|
107 |
+
4.9,2.5,4.5,1.7,Iris-virginica
|
108 |
+
7.3,2.9,6.3,1.8,Iris-virginica
|
109 |
+
6.7,2.5,5.8,1.8,Iris-virginica
|
110 |
+
7.2,3.6,6.1,2.5,Iris-virginica
|
111 |
+
6.5,3.2,5.1,2.0,Iris-virginica
|
112 |
+
6.4,2.7,5.3,1.9,Iris-virginica
|
113 |
+
6.8,3.0,5.5,2.1,Iris-virginica
|
114 |
+
5.7,2.5,5.0,2.0,Iris-virginica
|
115 |
+
5.8,2.8,5.1,2.4,Iris-virginica
|
116 |
+
6.4,3.2,5.3,2.3,Iris-virginica
|
117 |
+
6.5,3.0,5.5,1.8,Iris-virginica
|
118 |
+
7.7,3.8,6.7,2.2,Iris-virginica
|
119 |
+
7.7,2.6,6.9,2.3,Iris-virginica
|
120 |
+
6.0,2.2,5.0,1.5,Iris-virginica
|
121 |
+
6.9,3.2,5.7,2.3,Iris-virginica
|
122 |
+
5.6,2.8,4.9,2.0,Iris-virginica
|
123 |
+
7.7,2.8,6.7,2.0,Iris-virginica
|
124 |
+
6.3,2.7,4.9,1.8,Iris-virginica
|
125 |
+
6.7,3.3,5.7,2.1,Iris-virginica
|
126 |
+
7.2,3.2,6.0,1.8,Iris-virginica
|
127 |
+
6.2,2.8,4.8,1.8,Iris-virginica
|
128 |
+
6.1,3.0,4.9,1.8,Iris-virginica
|
129 |
+
6.4,2.8,5.6,2.1,Iris-virginica
|
130 |
+
7.2,3.0,5.8,1.6,Iris-virginica
|
131 |
+
7.4,2.8,6.1,1.9,Iris-virginica
|
132 |
+
7.9,3.8,6.4,2.0,Iris-virginica
|
133 |
+
6.4,2.8,5.6,2.2,Iris-virginica
|
134 |
+
6.3,2.8,5.1,1.5,Iris-virginica
|
135 |
+
6.1,2.6,5.6,1.4,Iris-virginica
|
136 |
+
7.7,3.0,6.1,2.3,Iris-virginica
|
137 |
+
6.3,3.4,5.6,2.4,Iris-virginica
|
138 |
+
6.4,3.1,5.5,1.8,Iris-virginica
|
139 |
+
6.0,3.0,4.8,1.8,Iris-virginica
|
140 |
+
6.9,3.1,5.4,2.1,Iris-virginica
|
141 |
+
6.7,3.1,5.6,2.4,Iris-virginica
|
142 |
+
6.9,3.1,5.1,2.3,Iris-virginica
|
143 |
+
5.8,2.7,5.1,1.9,Iris-virginica
|
144 |
+
6.8,3.2,5.9,2.3,Iris-virginica
|
145 |
+
6.7,3.3,5.7,2.5,Iris-virginica
|
146 |
+
6.7,3.0,5.2,2.3,Iris-virginica
|
147 |
+
6.3,2.5,5.0,1.9,Iris-virginica
|
148 |
+
6.5,3.0,5.2,2.0,Iris-virginica
|
149 |
+
6.2,3.4,5.4,2.3,Iris-virginica
|
150 |
+
5.9,3.0,5.1,1.8,Iris-virginica
|
151 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please use python V 3.7 to be compatible with all packages
|
2 |
+
gpytorch==1.5.0
|
3 |
+
torch==1.9.0
|
4 |
+
scikit-learn==0.24.2
|
5 |
+
pyyaml==5.4.1
|
6 |
+
seaborn==0.11.2
|
7 |
+
xgboost==1.4.0
|
8 |
+
tqdm==4.62.1
|
9 |
+
numpy==1.21.2
|
10 |
+
openml==0.12.2
|
11 |
+
catboost==0.26.1
|
12 |
+
auto-sklearn==0.14.5
|
13 |
+
hyperopt==0.2.5
|
14 |
+
configspace==0.4.21
|
15 |
+
# autogluon==0.4.0
|
16 |
+
gradio==3.1.1
|