import itertools from typing import List import torch from .utils import compute_time_delta class PriorsDataset: def __init__(self, dataset, history, time_delta_map): self.dataset = dataset self.history = history self.study_id_to_index = dict(zip(dataset['study_id'], range(len(dataset)))) self.time_delta_map = time_delta_map self.inf_time_delta_value = time_delta_map(float('inf')) def __getitem__(self, idx): batch = self.dataset[idx] if self.history: raise NotImplementedError("Priors were made not available in the public release.") return batch def __len__(self): return len(self.dataset) def __getattr__(self, name): return getattr(self.dataset, name) def __getitems__(self, keys: List): batch = self.__getitem__(keys) n_examples = len(batch[next(iter(batch))]) return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]