import dg.domainbed.datasets.datasets as D import torch import numpy as np import os def load_datasets_of_all_domains(dataset_name, data_dir): datasets = vars(D)[dataset_name](data_dir) return datasets def load_dataset_of_a_domain(dataset_name, domain_index, data_dir): datasets = vars(D)[dataset_name](data_dir) return datasets[domain_index] def load_online_data(dataset_name, data_config, data_dir): datasets = load_datasets_of_all_domains(dataset_name, data_dir) res_x, res_y = None, None dataset_anchors = [0] * len(datasets.ENVIRONMENTS) for domain_index, n_samples in data_config: dataset = datasets[domain_index] x, y = dataset[dataset_anchors[domain_index]: dataset_anchors[domain_index] + n_samples] dataset_anchors[domain_index] += n_samples if res_x is None: res_x = x res_y = y else: res_x = torch.cat([res_x, x]) res_y = torch.cat([res_y, y]) return res_x, res_y