EdgeTA / utils /datasets.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
import dg.domainbed.datasets.datasets as D
import torch
import numpy as np
import os
def load_datasets_of_all_domains(dataset_name, data_dir):
datasets = vars(D)[dataset_name](data_dir)
return datasets
def load_dataset_of_a_domain(dataset_name, domain_index, data_dir):
datasets = vars(D)[dataset_name](data_dir)
return datasets[domain_index]
def load_online_data(dataset_name, data_config, data_dir):
datasets = load_datasets_of_all_domains(dataset_name, data_dir)
res_x, res_y = None, None
dataset_anchors = [0] * len(datasets.ENVIRONMENTS)
for domain_index, n_samples in data_config:
dataset = datasets[domain_index]
x, y = dataset[dataset_anchors[domain_index]: dataset_anchors[domain_index] + n_samples]
dataset_anchors[domain_index] += n_samples
if res_x is None:
res_x = x
res_y = y
else:
res_x = torch.cat([res_x, x])
res_y = torch.cat([res_y, y])
return res_x, res_y