File size: 1,039 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import dg.domainbed.datasets.datasets as D
import torch
import numpy as np
import os
def load_datasets_of_all_domains(dataset_name, data_dir):
datasets = vars(D)[dataset_name](data_dir)
return datasets
def load_dataset_of_a_domain(dataset_name, domain_index, data_dir):
datasets = vars(D)[dataset_name](data_dir)
return datasets[domain_index]
def load_online_data(dataset_name, data_config, data_dir):
datasets = load_datasets_of_all_domains(dataset_name, data_dir)
res_x, res_y = None, None
dataset_anchors = [0] * len(datasets.ENVIRONMENTS)
for domain_index, n_samples in data_config:
dataset = datasets[domain_index]
x, y = dataset[dataset_anchors[domain_index]: dataset_anchors[domain_index] + n_samples]
dataset_anchors[domain_index] += n_samples
if res_x is None:
res_x = x
res_y = y
else:
res_x = torch.cat([res_x, x])
res_y = torch.cat([res_y, y])
return res_x, res_y
|