nicolas-dufour commited on
Commit
c4c7cee
·
1 Parent(s): 70a055c

squash: merge all unpushed commits

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. DATASET.md +34 -0
  2. LICENSE +21 -0
  3. __init__.py +0 -0
  4. callbacks/__init__.py +3 -0
  5. callbacks/__pycache__/__init__.cpython-310.pyc +0 -0
  6. callbacks/__pycache__/data.cpython-310.pyc +0 -0
  7. callbacks/__pycache__/ema.cpython-310.pyc +0 -0
  8. callbacks/__pycache__/fix_nans.cpython-310.pyc +0 -0
  9. callbacks/data.py +11 -0
  10. callbacks/ema.py +102 -0
  11. callbacks/fix_nans.py +55 -0
  12. configs/computer/a100.yaml +8 -0
  13. configs/computer/cluster-node-a100.yaml +8 -0
  14. configs/computer/cluster-node-v100.yaml +8 -0
  15. configs/computer/cpu.yaml +8 -0
  16. configs/computer/h100.yaml +8 -0
  17. configs/computer/v100.yaml +8 -0
  18. configs/config.yaml +90 -0
  19. configs/dataset/baselines/im2gps.yaml +16 -0
  20. configs/dataset/baselines/im2gps3k.yaml +16 -0
  21. configs/dataset/baselines/yfcc4k.yaml +16 -0
  22. configs/dataset/combined_emb.yaml +38 -0
  23. configs/dataset/inaturalist_emb.yaml +38 -0
  24. configs/dataset/osv5m.yaml +43 -0
  25. configs/dataset/osv5m_emb.yaml +38 -0
  26. configs/dataset/test_transform/center_crop.yaml +12 -0
  27. configs/dataset/test_transform/clip.yaml +2 -0
  28. configs/dataset/test_transform/empty.yaml +2 -0
  29. configs/dataset/test_transform/fast_clip.yaml +12 -0
  30. configs/dataset/test_transform/fast_resnet.yaml +12 -0
  31. configs/dataset/test_transform/none.yaml +6 -0
  32. configs/dataset/train_transform/augmentation.yaml +85 -0
  33. configs/dataset/train_transform/center_crop.yaml +14 -0
  34. configs/dataset/train_transform/clip.yaml +2 -0
  35. configs/dataset/train_transform/empty.yaml +2 -0
  36. configs/dataset/train_transform/fast_clip.yaml +12 -0
  37. configs/dataset/train_transform/fast_resnet.yaml +12 -0
  38. configs/dataset/train_transform/none.yaml +7 -0
  39. configs/dataset/yfcc_emb.yaml +38 -0
  40. configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +35 -0
  41. configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml +32 -0
  42. configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +36 -0
  43. configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml +38 -0
  44. configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
  45. configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml +26 -0
  46. configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml +26 -0
  47. configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml +40 -0
  48. configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml +36 -0
  49. configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml +37 -0
  50. configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml +39 -0
DATASET.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Dataset
2
+ To download the datataset, run:
3
+ ```python
4
+ # download the full dataset
5
+ from huggingface_hub import snapshot_download
6
+ snapshot_download(repo_id="osv5m/osv5m", local_dir="datasets/osv5m", repo_type='dataset')
7
+ ```
8
+
9
+ and finally extract:
10
+ ```python
11
+ import os
12
+ import zipfile
13
+ for root, dirs, files in os.walk("datasets/osv5m"):
14
+ for file in files:
15
+ if file.endswith(".zip"):
16
+ with zipfile.ZipFile(os.path.join(root, file), 'r') as zip_ref:
17
+ zip_ref.extractall(root)
18
+ os.remove(os.path.join(root, file))
19
+ ```
20
+
21
+ You can also directly load the dataset using `load_dataset`:
22
+ ```python
23
+ from datasets import load_dataset
24
+ dataset = load_dataset('osv5m/osv5m', full=False)
25
+ ```
26
+ where with `full` you can specify whether you want to load the complete metadata (default: `False`).
27
+
28
+ If you only want to download the test set, you can run the script below:
29
+ ```python
30
+ from huggingface_hub import hf_hub_download
31
+ for i in range(5):
32
+ hf_hub_download(repo_id="osv5m/osv5m", filename=str(i).zfill(2)+'.zip', subfolder="images/test", repo_type='dataset', local_dir="datasets/osv5m")
33
+ hf_hub_download(repo_id="osv5m/osv5m", filename="README.md", repo_type='dataset', local_dir="datasets/osv5m")
34
+ ```
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Nicolas Dufour
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
__init__.py ADDED
File without changes
callbacks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ema import EMACallback
2
+ from .fix_nans import FixNANinGrad
3
+ from .data import IncreaseDataEpoch
callbacks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (278 Bytes). View file
 
callbacks/__pycache__/data.cpython-310.pyc ADDED
Binary file (851 Bytes). View file
 
callbacks/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.22 kB). View file
 
callbacks/__pycache__/fix_nans.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
callbacks/data.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+
3
+
4
+ class IncreaseDataEpoch(Callback):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def on_train_epoch_start(self, trainer, pl_module):
9
+ epoch = pl_module.current_epoch
10
+ if hasattr(trainer.datamodule.train_dataset, "shared_epoch"):
11
+ trainer.datamodule.train_dataset.shared_epoch.set_value(epoch)
callbacks/ema.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import Callback
2
+ import copy
3
+ import itertools
4
+ import torch
5
+ import contextlib
6
+ from torch.distributed.fsdp import FullyShardedDataParallel
7
+
8
+
9
+ class EMACallback(Callback):
10
+ def __init__(
11
+ self,
12
+ module_attr_name,
13
+ ema_module_attr_name,
14
+ decay=0.999,
15
+ start_ema_step=0,
16
+ init_ema_random=True,
17
+ ):
18
+ super().__init__()
19
+ self.decay = decay
20
+ self.module_attr_name = module_attr_name
21
+ self.ema_module_attr_name = ema_module_attr_name
22
+ self.start_ema_step = start_ema_step
23
+ self.init_ema_random = init_ema_random
24
+
25
+ def on_train_start(self, trainer, pl_module):
26
+ if pl_module.global_step == 0:
27
+ if not hasattr(pl_module, self.module_attr_name):
28
+ raise ValueError(
29
+ f"Module {pl_module} does not have attribute {self.module_attr_name}"
30
+ )
31
+ if not hasattr(pl_module, self.ema_module_attr_name):
32
+ pl_module.add_module(
33
+ self.ema_module_attr_name,
34
+ copy.deepcopy(getattr(pl_module, self.module_attr_name))
35
+ .eval()
36
+ .requires_grad_(False),
37
+ )
38
+ self.reset_ema(pl_module)
39
+
40
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
41
+ if pl_module.global_step == self.start_ema_step:
42
+ self.reset_ema(pl_module)
43
+ elif (
44
+ pl_module.global_step < self.start_ema_step
45
+ and pl_module.global_step % 100 == 0
46
+ ):
47
+ ## slow ema updates for visualisation
48
+ self.update_ema(pl_module, decay=0.9)
49
+ elif pl_module.global_step > self.start_ema_step:
50
+ self.update_ema(pl_module, decay=self.decay)
51
+
52
+ def update_ema(self, pl_module, decay=0.999):
53
+ ema_module = getattr(pl_module, self.ema_module_attr_name)
54
+ module = getattr(pl_module, self.module_attr_name)
55
+ context_manager = self.get_model_context_manager(module)
56
+ with context_manager:
57
+ with torch.no_grad():
58
+ ema_params = ema_module.state_dict()
59
+ for name, param in itertools.chain(
60
+ module.named_parameters(), module.named_buffers()
61
+ ):
62
+ if name in ema_params:
63
+ if param.requires_grad:
64
+ ema_params[name].copy_(
65
+ ema_params[name].detach().lerp(param.detach(), decay)
66
+ )
67
+
68
+ def get_model_context_manager(self, module):
69
+ fsdp_enabled = is_model_fsdp(module)
70
+ model_context_manager = contextlib.nullcontext()
71
+ if fsdp_enabled:
72
+ model_context_manager = module.summon_full_params(module)
73
+ return model_context_manager
74
+
75
+ def reset_ema(self, pl_module):
76
+ ema_module = getattr(pl_module, self.ema_module_attr_name)
77
+ if self.init_ema_random:
78
+ ema_module.init_weights()
79
+ else:
80
+ module = getattr(pl_module, self.module_attr_name)
81
+ context_manager = self.get_model_context_manager(module)
82
+ with context_manager:
83
+ ema_params = ema_module.state_dict()
84
+ for name, param in itertools.chain(
85
+ module.named_parameters(), module.named_buffers()
86
+ ):
87
+ if name in ema_params:
88
+ ema_params[name].copy_(param.detach())
89
+
90
+
91
+ def is_model_fsdp(model: torch.nn.Module) -> bool:
92
+ try:
93
+ if isinstance(model, FullyShardedDataParallel):
94
+ return True
95
+
96
+ # Check if model is wrapped with FSDP
97
+ for _, obj in model.named_children():
98
+ if isinstance(obj, FullyShardedDataParallel):
99
+ return True
100
+ return False
101
+ except ImportError:
102
+ return False
callbacks/fix_nans.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pytorch_lightning.callbacks import Callback
3
+ import torch
4
+
5
+ log = logging.getLogger(__name__)
6
+
7
+
8
+ class FixNANinGrad(Callback):
9
+ def __init__(self, monitor):
10
+ super().__init__()
11
+ self.monitor = monitor
12
+ self.continuous_nan_batchs = 0
13
+
14
+ def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
15
+ has_nan = []
16
+ is_inf = []
17
+ for name, param in pl_module.named_parameters():
18
+ if param.grad is not None:
19
+ if torch.isnan(param.grad).any():
20
+ has_nan.append(name)
21
+ if torch.isinf(param.grad).any():
22
+ is_inf.append(name)
23
+ torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
24
+ if len(has_nan) > 0:
25
+ print(f"Found NaN in {has_nan}")
26
+ if len(is_inf) > 0:
27
+ print(f"Found Inf in {is_inf}")
28
+
29
+ def on_train_batch_end(
30
+ self,
31
+ trainer,
32
+ pl_module,
33
+ outputs,
34
+ batch,
35
+ batch_idx,
36
+ ) -> None:
37
+ logs = trainer.callback_metrics
38
+ i = 0
39
+ found_metric = False
40
+ while i < len(self.monitor) and not found_metric:
41
+ if self.monitor[i] in logs.keys():
42
+ current = logs[self.monitor[i]].squeeze()
43
+ found_metric = True
44
+ else:
45
+ i += 1
46
+ if not found_metric:
47
+ raise ValueError("Asked metric not in logs")
48
+
49
+ if not torch.isfinite(current):
50
+ self.continuous_nan_batchs += 1
51
+ if self.continuous_nan_batchs >= 5:
52
+ trainer.should_stop = True
53
+ log.info("Training interrupted because of NaN in {self.monitor}")
54
+ else:
55
+ self.continuous_nan_batchs = 0
configs/computer/a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ progress_bar_refresh_rate: 2
3
+ num_workers: 8
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
configs/computer/cluster-node-a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 8
2
+ num_workers: 8
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
configs/computer/cluster-node-v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 4
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
configs/computer/cpu.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: null
2
+ num_workers: 0
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: cpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: null
configs/computer/h100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ progress_bar_refresh_rate: 2
3
+ num_workers: 24
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
configs/computer/v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
configs/config.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: default
3
+ - computer: v100
4
+ - dataset: osv5m_emb
5
+ - stage: null
6
+ - _self_
7
+ - exp: ???
8
+
9
+ model:
10
+ val_metrics:
11
+ _target_: metrics.distance_based.HaversineMetrics
12
+ acc_radiuses:
13
+ - 1
14
+ - 25
15
+ - 200
16
+ - 750
17
+ - 2500
18
+ acc_area: []
19
+ test_metrics:
20
+ _target_: metrics.distance_based.HaversineMetrics
21
+ acc_radiuses:
22
+ - 1
23
+ - 25
24
+ - 200
25
+ - 750
26
+ - 2500
27
+ acc_area: ${areas}
28
+
29
+ datamodule:
30
+ _target_: data.datamodule.ImageDataModule
31
+ train_dataset: ${dataset.train_dataset}
32
+ val_dataset: ${dataset.val_dataset}
33
+ test_dataset: ${dataset.test_dataset}
34
+ full_batch_size: ${dataset.full_batch_size}
35
+ eval_batch_size: ${dataset.eval_batch_size}
36
+ num_workers: ${computer.num_workers}
37
+ num_nodes: ${computer.num_nodes}
38
+ num_devices: ${computer.devices}
39
+ val_proportion: 0.02
40
+
41
+ trainer:
42
+ _target_: pytorch_lightning.Trainer
43
+ devices: ${computer.devices}
44
+ accelerator: ${computer.accelerator}
45
+ strategy: ${computer.strategy}
46
+ num_nodes: ${computer.num_nodes}
47
+ precision: ${computer.precision}
48
+ max_steps: 1000000
49
+ val_check_interval: 25000
50
+ check_val_every_n_epoch: null
51
+
52
+ logger:
53
+ _target_: pytorch_lightning.loggers.WandbLogger
54
+ save_dir: ${root_dir}
55
+ name: ${experiment_name}${logger_suffix}
56
+ project: diff_plonk
57
+ log_model: False
58
+ offline: False
59
+
60
+ checkpoints:
61
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
62
+ dirpath: ${root_dir}/checkpoints/${experiment_name}
63
+ filename: 'epoch_{epoch}'
64
+ monitor: val/loss
65
+ save_last: True
66
+ save_top_k: 0
67
+ every_n_epochs: 1
68
+ enable_version_counter: False
69
+
70
+ progress_bar:
71
+ _target_: pytorch_lightning.callbacks.TQDMProgressBar
72
+ refresh_rate: ${computer.progress_bar_refresh_rate}
73
+
74
+ data_dir: ${root_dir}/datasets
75
+ root_dir: ${hydra:runtime.cwd}
76
+ experiment_name: ${dataset.name}_${model.name}_${experiment_name_suffix}
77
+ experiment_name_suffix: base
78
+ logger_suffix: ""
79
+ mode: train # change that to eval to do the testing
80
+ areas: ['country', 'region', 'sub-region', 'city']
81
+ class_name: null
82
+ streetclip: False
83
+ blur: False
84
+ text_tuning: False
85
+
86
+ hydra:
87
+ run:
88
+ dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name}
89
+ job:
90
+ chdir: true
configs/dataset/baselines/im2gps.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: im2gps
3
+ full_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/im2gps
8
+ which: 'im2gps'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ full_batch_size: ${dataset.full_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/baselines/im2gps3k.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: im2gps3k
3
+ full_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/im2gps3k
8
+ which: 'im2gps3k'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ full_batch_size: ${dataset.full_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/baselines/yfcc4k.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: yfcc4k
3
+ full_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/yfcc4k
8
+ which: 'yfcc4k'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ full_batch_size: ${dataset.full_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/combined_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: iNaturalist_OSV5M_YFCC100M_${dataset.embedding_name}
7
+ full_batch_size: 2048
8
+ cond_dim: 1024
9
+ eval_batch_size: 4096
10
+ output_type: emb
11
+ embedding_name: dinov2_vitl14_registers
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/YFCC100M/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/ ${data_dir}/osv5m/train/ ${data_dir}/inaturalist/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/YFCC100M/yfcc4k/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: []
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/YFCC100M/yfcc4k/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: []
configs/dataset/inaturalist_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: iNaturalist_${dataset.embedding_name}
7
+ full_batch_size: 512
8
+ cond_dim: 1024
9
+ eval_batch_size: 4096
10
+ output_type: emb
11
+ embedding_name: dinov2_vitl14_registers
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/inaturalist/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/inaturalist/val/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: []
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/inaturalist/test/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: []
configs/dataset/osv5m.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ full_batch_size: 2048
8
+ eval_batch_size: 4096
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.OSV5M
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ is_baseline: ${is_baseline}
17
+ areas: ${areas}
18
+ streetclip: ${streetclip}
19
+ blur: ${blur}
20
+
21
+ val_dataset:
22
+ _partial_: true
23
+ _target_: data.data.OSV5M
24
+ path: ${data_dir}/osv5m/
25
+ split: val
26
+ class_name: ${class_name}
27
+ transforms: ${dataset.test_transform}
28
+ is_baseline: ${is_baseline}
29
+ areas: ${areas}
30
+ streetclip: ${streetclip}
31
+ blur: ${blur}
32
+
33
+ test_dataset:
34
+ _partial_: true
35
+ _target_: data.data.OSV5M
36
+ path: ${data_dir}/osv5m/
37
+ split: test
38
+ class_name: ${class_name}
39
+ transforms: ${dataset.test_transform}
40
+ is_baseline: ${is_baseline}
41
+ areas: ${areas}
42
+ streetclip: ${streetclip}
43
+ blur: ${blur}
configs/dataset/osv5m_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: osv5m_${dataset.embedding_name}
7
+ full_batch_size: 1024
8
+ eval_batch_size: 4096
9
+ cond_dim: 1024
10
+ output_type: emb
11
+ embedding_name: street_clip
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/osv5m/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/osv5m/val/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"]
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/osv5m/test/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: ["unique_country", "unique_region", "unique_sub-region", "unique_city"]
configs/dataset/test_transform/center_crop.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: 0.5
12
+ std: 0.5
configs/dataset/test_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.transforms.ClipTransform
2
+ split: val
configs/dataset/test_transform/empty.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.data.null_transform
2
+ _partial_: true
configs/dataset/test_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
configs/dataset/test_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
configs/dataset/test_transform/none.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: torchvision.transforms.Normalize
5
+ mean: 0.5
6
+ std: 0.5
configs/dataset/train_transform/augmentation.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: data.augmentation.ImageAugmentation
2
+ names: "standard_augmentation,geometric_augmentation,clip_transform"
3
+
4
+ # always apply clip_transform at the end
5
+ clip_transform:
6
+ _target_: torchvision.transforms.Compose
7
+ transforms:
8
+ - _target_: torchvision.transforms.Resize
9
+ size: 224
10
+ interpolation: 3
11
+ antialias: true
12
+ - _target_: torchvision.transforms.CenterCrop
13
+ size: 224
14
+ - _target_: torchvision.transforms.ToTensor
15
+ - _target_: torchvision.transforms.Normalize
16
+ mean: [0.48145466, 0.4578275, 0.40821073]
17
+ std: [0.26862954, 0.26130258, 0.27577711]
18
+
19
+ standard_augmentation:
20
+ _target_: data.augmentation.StandardAugmentation
21
+ # by default, we all augmentation methods
22
+ names: "brightness,contrast,sharpness,color,blur,gaussian_noise"
23
+
24
+ # random PIL brigtness
25
+ brightness:
26
+ _target_: data.augmentation.PillowBrightness
27
+ p: 0.2
28
+ factor_interval: [0.5, 1.5]
29
+
30
+ # random PIL contrast
31
+ contrast:
32
+ _target_: data.augmentation.PillowContrast
33
+ p: 0.2
34
+ factor_interval: [0.3, 3]
35
+
36
+ # random PIL sharpness
37
+ sharpness:
38
+ _target_: data.augmentation.PillowSharpness
39
+ p: 0.2
40
+ factor_interval: [0.5, 30.0]
41
+
42
+ # random PIL color
43
+ color:
44
+ _target_: data.augmentation.PillowColor
45
+ p: 0.2
46
+ factor_interval: [0.0, 2.0]
47
+
48
+ # random PIL blur
49
+ blur:
50
+ _target_: data.augmentation.PillowBlur
51
+ p: 0.2
52
+ factor_interval: [1, 2]
53
+
54
+ # random numpy gaussian noise
55
+ gaussian_noise:
56
+ _target_: data.augmentation.NumpyGaussianNoise
57
+ p: 0.2
58
+ factor_interval: [0.1, 0.04]
59
+
60
+ geometric_augmentation:
61
+ _target_: data.augmentation.GeometricAugmentation
62
+ # by default, we all augmentation methods
63
+ names: "random_rotation,random_resized_crop,random_horizontal_flip"
64
+
65
+ # random rotation
66
+ random_rotation:
67
+ _target_: torchvision.transforms.RandomRotation
68
+ degrees: [-15, 15]
69
+
70
+ # random crop
71
+ random_resized_crop:
72
+ _target_: torchvision.transforms.RandomResizedCrop
73
+ scale: [0.5, 1.0]
74
+ ratio: [0.9, 1.1]
75
+ size: 224
76
+
77
+ # random horizontal flip
78
+ random_horizontal_flip:
79
+ _target_: torchvision.transforms.RandomHorizontalFlip
80
+ p: 0.5
81
+
82
+ # random vertical flip
83
+ random_vertical_flip:
84
+ _target_: torchvision.transforms.RandomVerticalFlip
85
+ p: 0.5
configs/dataset/train_transform/center_crop.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.RandomHorizontalFlip
11
+ p: 0.5
12
+ - _target_: torchvision.transforms.Normalize
13
+ mean: 0.5
14
+ std: 0.5
configs/dataset/train_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.transforms.ClipTransform
2
+ split: val
configs/dataset/train_transform/empty.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.data.null_transform
2
+ _partial_: true
configs/dataset/train_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
configs/dataset/train_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
configs/dataset/train_transform/none.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.ToTensor
configs/dataset/yfcc_emb.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: empty
3
+ - test_transform: empty
4
+ - _self_
5
+
6
+ name: iNaturalist_${dataset.embedding_name}
7
+ full_batch_size: 2048
8
+ cond_dim: 1024
9
+ eval_batch_size: 4096
10
+ output_type: emb
11
+ embedding_name: dinov2_vitl14_registers
12
+
13
+ train_dataset:
14
+ _partial_: true
15
+ _target_: data.webdataset.GPSWebdataset
16
+ root: ${data_dir}/YFCC100M/train/
17
+ train: true
18
+ embedding_name: ${dataset.embedding_name}
19
+ return_image: false
20
+ metadata_attributes: []
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: data.webdataset.GPSWebdataset
25
+ root: ${data_dir}/YFCC100M/yfcc4k/
26
+ train: false
27
+ embedding_name: ${dataset.embedding_name}
28
+ return_image: false
29
+ metadata_attributes: []
30
+
31
+ test_dataset:
32
+ _partial_: true
33
+ _target_: data.webdataset.GPSWebdataset
34
+ root: ${data_dir}/YFCC100M/yfcc4k/
35
+ train: false
36
+ embedding_name: ${dataset.embedding_name}
37
+ return_image: false
38
+ metadata_attributes: []
configs/exp/YFCC100M_geoadalnmlp_r2_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.05
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+ dataset:
32
+ full_batch_size: 1024
33
+
34
+ experiment_name_suffix: small_sigmoid
35
+ areas: []
configs/exp/YFCC100M_geoadalnmlp_r3_small_linear_flow_rieman.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: linear
8
+ - override /model/inference_noise_scheduler: linear
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 512
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.05
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ interpolant: flow_matching
26
+
27
+ dataset:
28
+ full_batch_size: 1024
29
+
30
+ areas: []
31
+
32
+ experiment_name_suffix: small_sigmoid
configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.05
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+
32
+ dataset:
33
+ full_batch_size: 1024
34
+
35
+ experiment_name_suffix: small_sigmoid
36
+ areas: []
configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: flow_matching
10
+ - override /model/val_sampler: flow_matching
11
+ - override /model/test_sampler: flow_matching
12
+ - _self_
13
+
14
+ model:
15
+ network:
16
+ depth: 12
17
+ dim: 512
18
+ optimizer:
19
+ optim:
20
+ lr: 8e-4
21
+ weight_decay: 0.05
22
+ loss:
23
+ cond_drop_rate: 0.1
24
+ train_noise_scheduler:
25
+ start: -7
26
+ end: 3
27
+ tau: 1.0
28
+ inference_noise_scheduler:
29
+ start: -7
30
+ end: 3
31
+ tau: 1.0
32
+ interpolant: flow_matching
33
+
34
+ dataset:
35
+ full_batch_size: 1024
36
+
37
+ experiment_name_suffix: small_sigmoid
38
+ areas: []
configs/exp/YFCC100M_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 512
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.05
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ train_noise_scheduler:
26
+ start: -7
27
+ end: 3
28
+ tau: 1.0
29
+ inference_noise_scheduler:
30
+ start: -7
31
+ end: 3
32
+ tau: 1.0
33
+ interpolant: flow_matching
34
+
35
+ dataset:
36
+ full_batch_size: 1024
37
+
38
+ areas: []
39
+
40
+ experiment_name_suffix: small_sigmoid
configs/exp/YFCC100M_geoadalnmlp_von_fisher.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: von_fisher
6
+ - override /model/network: geo_adaln_mlp_von_fisher
7
+ - override /model/loss: von_fisher
8
+ - override /model/val_sampler: von_fisher
9
+ - override /model/test_sampler: von_fisher
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 11 # To compensate the increase in params
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 1e-4
19
+ weight_decay: 0.05
20
+ dataset:
21
+ full_batch_size: 1024
22
+ trainer:
23
+ gradient_clip_val: 0.05
24
+ gradient_clip_algorithm: norm
25
+ areas: []
26
+ experiment_name_suffix: von_fisher
configs/exp/YFCC100M_geoadalnmlp_von_fisher_mixture.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: yfcc_emb
5
+ - override /model: von_fisher_mixture
6
+ - override /model/network: geo_adaln_mlp_von_fisher_mixture
7
+ - override /model/loss: von_fisher_mixture
8
+ - override /model/val_sampler: von_fisher_mixture
9
+ - override /model/test_sampler: von_fisher_mixture
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 11 # To compensate the increase in params
15
+ dim: 512
16
+ optimizer:
17
+ optim:
18
+ lr: 1e-5
19
+ weight_decay: 0.05
20
+ dataset:
21
+ full_batch_size: 1024
22
+ trainer:
23
+ gradient_clip_val: 0.01
24
+ gradient_clip_algorithm: norm
25
+ experiment_name_suffix: von_fisher_mixture
26
+ areas: []
configs/exp/combined_geoadalnmlp_r3_small_sigmoid_flow_riemann.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: combined_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: riemannian_flow_matching
10
+ - override /model/manifold: sphere
11
+ - override /model/val_sampler: riemannian_flow_matching
12
+ - override /model/test_sampler: riemannian_flow_matching
13
+ - _self_
14
+
15
+ model:
16
+ network:
17
+ depth: 12
18
+ dim: 512
19
+ optimizer:
20
+ optim:
21
+ lr: 8e-4
22
+ weight_decay: 0.05
23
+ loss:
24
+ cond_drop_rate: 0.1
25
+ train_noise_scheduler:
26
+ start: -7
27
+ end: 3
28
+ tau: 1.0
29
+ inference_noise_scheduler:
30
+ start: -7
31
+ end: 3
32
+ tau: 1.0
33
+ interpolant: flow_matching
34
+
35
+ dataset:
36
+ full_batch_size: 1024
37
+
38
+ areas: []
39
+
40
+ experiment_name_suffix: small_sigmoid
configs/exp/iNaturalist_geoadalnmlp_r2_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: emb_cond
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 256
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.1
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+ dataset:
32
+ full_batch_size: 512
33
+
34
+ areas: []
35
+
36
+ experiment_name_suffix: small_sigmoid
configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_diffusion.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: ddpm
10
+ - _self_
11
+
12
+ model:
13
+ network:
14
+ depth: 12
15
+ dim: 256
16
+ optimizer:
17
+ optim:
18
+ lr: 8e-4
19
+ weight_decay: 0.1
20
+ loss:
21
+ cond_drop_rate: 0.1
22
+ train_noise_scheduler:
23
+ start: -7
24
+ end: 3
25
+ tau: 1.0
26
+ inference_noise_scheduler:
27
+ start: -7
28
+ end: 3
29
+ tau: 1.0
30
+ interpolant: diffusion
31
+
32
+ dataset:
33
+ full_batch_size: 512
34
+
35
+ areas: []
36
+
37
+ experiment_name_suffix: small_sigmoid
configs/exp/iNaturalist_geoadalnmlp_r3_small_sigmoid_flow.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: inaturalist_emb
5
+ - override /model: emb_cond_cartesian
6
+ - override /model/network: geo_adaln_mlp
7
+ - override /model/train_noise_scheduler: sigmoid
8
+ - override /model/inference_noise_scheduler: sigmoid
9
+ - override /model/loss: flow_matching
10
+ - override /model/val_sampler: flow_matching
11
+ - override /model/test_sampler: flow_matching
12
+ - _self_
13
+
14
+ model:
15
+ network:
16
+ depth: 12
17
+ dim: 256
18
+ optimizer:
19
+ optim:
20
+ lr: 8e-4
21
+ weight_decay: 0.1
22
+ loss:
23
+ cond_drop_rate: 0.1
24
+ train_noise_scheduler:
25
+ start: -7
26
+ end: 3
27
+ tau: 1.0
28
+ inference_noise_scheduler:
29
+ start: -7
30
+ end: 3
31
+ tau: 1.0
32
+ interpolant: flow_matching
33
+
34
+ dataset:
35
+ full_batch_size: 512
36
+
37
+ areas: []
38
+
39
+ experiment_name_suffix: small_sigmoid