|
import dg.domainbed.datasets.datasets as D |
|
import torch |
|
import numpy as np |
|
import os |
|
|
|
|
|
def load_datasets_of_all_domains(dataset_name, data_dir): |
|
datasets = vars(D)[dataset_name](data_dir) |
|
return datasets |
|
|
|
|
|
def load_dataset_of_a_domain(dataset_name, domain_index, data_dir): |
|
datasets = vars(D)[dataset_name](data_dir) |
|
return datasets[domain_index] |
|
|
|
|
|
def load_online_data(dataset_name, data_config, data_dir): |
|
datasets = load_datasets_of_all_domains(dataset_name, data_dir) |
|
|
|
res_x, res_y = None, None |
|
dataset_anchors = [0] * len(datasets.ENVIRONMENTS) |
|
|
|
for domain_index, n_samples in data_config: |
|
dataset = datasets[domain_index] |
|
|
|
x, y = dataset[dataset_anchors[domain_index]: dataset_anchors[domain_index] + n_samples] |
|
dataset_anchors[domain_index] += n_samples |
|
|
|
if res_x is None: |
|
res_x = x |
|
res_y = y |
|
else: |
|
res_x = torch.cat([res_x, x]) |
|
res_y = torch.cat([res_y, y]) |
|
|
|
return res_x, res_y |
|
|