Spaces:
Build error
Build error
import os | |
import itertools | |
import argparse | |
import time | |
import datetime | |
import yaml | |
from contextlib import nullcontext | |
import torch | |
from torch import nn | |
import utils | |
from transformer import TransformerModel | |
from utils import get_cosine_schedule_with_warmup, get_openai_lr, StoreDictKeyPair, get_weighted_single_eval_pos_sampler, get_uniform_single_eval_pos_sampler | |
import priors | |
import encoders | |
import positional_encodings | |
from utils import init_dist | |
from torch.cuda.amp import autocast | |
class Losses(): | |
gaussian = nn.GaussianNLLLoss(full=True, reduction='none') | |
mse = nn.MSELoss(reduction='none') | |
ce = lambda weight : nn.CrossEntropyLoss(reduction='none', weight=weight) | |
bce = nn.BCEWithLogitsLoss(reduction='none') | |
def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.2, | |
epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False, | |
y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup, | |
load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, bptt_extra_samples=None, gpu_device='cuda:0', | |
aggregate_k_gradients=1, verbose=True, style_encoder_generator=None, check_is_compatible=True, epoch_callback=None, | |
initializer=None, initialize_with_model=None, train_mixed_precision=False, total_available_time_in_s=None, normalize_labels=True, **model_extra_args | |
): | |
assert (epochs is None) != (total_available_time_in_s is None) | |
start_of_training = time.time() | |
device = gpu_device if torch.cuda.is_available() else 'cpu:0' | |
print(f'Using {device} device') | |
using_dist, rank, device = init_dist(device) | |
bptt_sampler = (lambda : single_eval_pos_gen() + bptt_extra_samples if callable(single_eval_pos_gen) else single_eval_pos_gen + bptt_extra_samples) if bptt_extra_samples is not None else bptt | |
dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, seq_len=bptt_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict) | |
if dl.fuse_x_y: | |
raise Exception("Illegal parameter") | |
encoder = encoder_generator(dl.num_features+1 if dl.fuse_x_y else dl.num_features,emsize) | |
style_def = next(iter(dl))[0][0] # This is (style, x, y), target with x and y with batch size | |
style_encoder = style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize) if (style_def is not None) else None | |
n_out = dl.num_outputs | |
if isinstance(criterion, nn.GaussianNLLLoss): | |
n_out *= 2 | |
elif isinstance(criterion, nn.CrossEntropyLoss): | |
n_out *= criterion.weight.shape[0] | |
model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout, style_encoder=style_encoder, | |
y_encoder=y_encoder_generator(dl.num_outputs, emsize), input_normalization=input_normalization, | |
pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2), | |
decoder=decoder, init_method=initializer, **model_extra_args | |
) | |
model.criterion = criterion | |
if load_weights_from_this_state_dict is not None: | |
model.load_state_dict(load_weights_from_this_state_dict) | |
if initialize_with_model is not None: | |
model.init_from_small_model(initialize_with_model) | |
print(f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters") | |
try: | |
for (k, v), (k2, v2) in zip(model.state_dict().items(), initialize_with_model.state_dict().items()): | |
print(k, ((v - v2) / v).abs().mean(), v.shape) | |
except Exception: | |
pass | |
model.to(device) | |
if using_dist: | |
print("Distributed training") | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False) | |
# learning rate | |
if lr is None: | |
lr = get_openai_lr(model) | |
print(f"Using OpenAI max lr of {lr}.") | |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | |
scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps | |
def train_step(): | |
model.train() # Turn on the train mode | |
total_loss = 0. | |
total_positional_losses = 0. | |
total_positional_losses_recorded = 0 | |
before_get_batch = time.time() | |
assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.' | |
valid_batch_steps = 0.0 | |
for batch, (data, targets) in enumerate(dl): | |
if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1): | |
cm = model.no_sync() | |
#print(f'p={rank}, no_sync', force=True) | |
else: | |
cm = nullcontext() | |
#print(f'p={rank}, sync', force=True) | |
with cm: | |
time_to_get_batch = time.time() - before_get_batch | |
before_forward = time.time() | |
if bptt_extra_samples is None: | |
single_eval_pos = single_eval_pos_gen() if callable(single_eval_pos_gen) else single_eval_pos_gen | |
else: | |
single_eval_pos = targets.shape[0] - bptt_extra_samples | |
is_compatible = torch.ones((targets.shape[1])).bool() | |
if check_is_compatible or normalize_labels: | |
for b in range(targets.shape[1]): | |
targets_in_train = torch.unique(targets[:single_eval_pos, b], sorted=True) | |
targets_in_eval = torch.unique(targets[single_eval_pos:, b], sorted=True) | |
if check_is_compatible: | |
is_compatible[b] = len(targets_in_train) == len(targets_in_eval) and (targets_in_train == targets_in_eval).all() | |
is_compatible[b] = is_compatible[b] and len(targets_in_train) > 1 | |
# Set targets to range starting from 0 (e.g. targets 0, 2, 5, 2 will be converted to 0, 1, 2, 1) | |
if normalize_labels: | |
targets[:, b] = (targets[:, b] > torch.unique(targets[:, b]).unsqueeze(1)).sum(axis=0).unsqueeze(0) | |
valid_batch_steps += is_compatible.float().mean() | |
is_compatible = is_compatible.to(device) | |
#if using_dist and check_is_compatible: | |
# print('step share before reduce',curr_step_share, force=True) | |
# curr_step_share = curr_step_share.to(device) | |
# torch.distributed.all_reduce_multigpu([curr_step_share], op=torch.distributed.ReduceOp.SUM) | |
# curr_step_share = curr_step_share.cpu() / torch.distributed.get_world_size() | |
# print('step share after reduce',curr_step_share, torch.distributed.get_world_size(), force=True) | |
# If style is set to None, it should not be transferred to device | |
output = model(tuple(e.to(device) if torch.is_tensor(e) else e for e in data) if isinstance(data, tuple) else data.to(device) | |
, single_eval_pos=single_eval_pos) | |
forward_time = time.time() - before_forward | |
#output, targets = output[:, is_compatible], targets[:, is_compatible] | |
if single_eval_pos is not None: | |
targets = targets[single_eval_pos:] | |
if isinstance(criterion, nn.GaussianNLLLoss): | |
assert output.shape[-1] == 2, \ | |
'need to write a little bit of code to handle multiple regression targets at once' | |
mean_pred = output[..., 0] | |
var_pred = output[..., 1].abs() | |
losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten()) | |
elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)): | |
losses = criterion(output.flatten(), targets.to(device).flatten()) | |
elif isinstance(criterion, (nn.CrossEntropyLoss)): | |
#print(n_out, targets.min(), targets.max(), force=True) | |
losses = criterion(output.reshape(-1, n_out), targets.to(device).long().flatten()) | |
else: | |
losses = criterion(output.reshape(-1, n_out), targets.to(device).flatten()) | |
losses = losses.view(*output.shape[0:2]) | |
loss = losses.mean(0) @ is_compatible.float() / losses.shape[1] | |
#loss = torch_nanmean(losses, axis=[0, 1]) * is_compatible.float().mean() | |
# not sure whether we can go without the nan checks. | |
loss.backward() | |
if ((batch % aggregate_k_gradients == aggregate_k_gradients - 1) and (not check_is_compatible or using_dist))\ | |
or (valid_batch_steps >= aggregate_k_gradients and (check_is_compatible and not using_dist)): | |
with torch.no_grad(): | |
for p in model.parameters(): | |
if p.grad is not None: | |
p.grad.div_(valid_batch_steps) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) | |
try: | |
optimizer.step() | |
except: | |
print("Invalid optimization step encountered") | |
optimizer.zero_grad() | |
valid_batch_steps = 0.0 | |
step_time = time.time() - before_forward | |
if not torch.isnan(loss): | |
total_loss += loss.item() | |
total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \ | |
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)*loss.cpu().detach() | |
total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \ | |
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt) | |
before_get_batch = time.time() | |
return total_loss / steps_per_epoch, ( | |
total_positional_losses / total_positional_losses_recorded).tolist(), time_to_get_batch, forward_time, step_time | |
best_val_loss = float("inf") | |
best_model = None | |
total_loss = float('inf') | |
total_positional_losses = float('inf') | |
try: | |
for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)): | |
epoch_start_time = time.time() | |
if train_mixed_precision: | |
with autocast(): | |
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step() | |
else: | |
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step() | |
if hasattr(dl, 'validate') and epoch % validation_period == 0: | |
with torch.no_grad(): | |
val_score = dl.validate(model) | |
else: | |
val_score = None | |
if verbose: | |
print('-' * 89) | |
print( | |
f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | ' | |
f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}" | |
f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}' | |
f' forward time {forward_time:5.2f}' + (f'val score {val_score}' if val_score is not None else '')) | |
print('-' * 89) | |
# stepping with wallclock time based scheduler | |
current_time = time.time() | |
if epoch_callback is not None and rank == 0: | |
epoch_callback(model, epoch / epochs if total_available_time_in_s is None else # noqa | |
(current_time - start_of_training) / total_available_time_in_s # noqa | |
) | |
if epochs is None and (current_time - start_of_training) > total_available_time_in_s: # noqa | |
break | |
if epochs is None: | |
scheduler.step((current_time - epoch_start_time) / total_available_time_in_s * 100) | |
else: | |
scheduler.step() | |
except KeyboardInterrupt: | |
pass | |
return total_loss, total_positional_losses, model.to('cpu'), dl | |
def _parse_args(config_parser, parser): | |
# Do we have a config file to parse? | |
args_config, remaining = config_parser.parse_known_args() | |
if args_config.config: | |
with open(args_config.config, 'r') as f: | |
cfg = yaml.safe_load(f) | |
parser.set_defaults(**cfg) | |
# The main arg parser parses the rest of the args, the usual | |
# defaults will have been overridden if config file specified. | |
args = parser.parse_args(remaining) | |
# Cache the args as a text string to save them in the output dir later | |
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) | |
return args, args_text | |
if __name__ == '__main__': | |
config_parser = argparse.ArgumentParser(description='Only used as a first parser for the config file path.') | |
config_parser.add_argument('--config') | |
parser = argparse.ArgumentParser() | |
parser.add_argument('prior') | |
parser.add_argument('--loss_function', default='barnll') | |
# Optional Arg's for `--loss_function barnll` | |
parser.add_argument('--min_y', type=float, help='barnll can only model y in strict ranges, this is the minimum y can take.') | |
parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.') | |
parser.add_argument('--num_buckets', default=100, type=int) | |
#parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.') | |
parser.add_argument("--extra_prior_kwargs_dict", default={'fuse_x_y': False}, dest="extra_prior_kwargs_dict", action=StoreDictKeyPair, nargs="+", metavar="KEY=VAL", help='Specify depending on the prior.') | |
parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.') | |
parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.') | |
parser.add_argument('--pos_encoder', default='sinus', type=str, help='Specify depending on the prior.') | |
parser.add_argument('--bptt', default=10, type=int) | |
parser.add_argument('--epochs', default=200, type=int) | |
parser.add_argument('--warmup_epochs', default=50, type=int) | |
parser.add_argument('--validation_period', default=10, type=int) | |
parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ') | |
parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.") | |
# these can likely be mostly left at defaults | |
parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024 | |
parser.add_argument('--nlayers', default=6, type=int) | |
parser.add_argument('--nhid', default=None, type=int) # 2*emsize is the default | |
parser.add_argument('--nhead', default=4, type=int) # nhead = emsize / 64 in the original paper | |
parser.add_argument('--dropout', default=.0, type=float) | |
parser.add_argument('--steps_per_epoch', default=10, type=int) | |
parser.add_argument('--batch_size', default=1000, type=int) | |
parser.add_argument('--lr', '--learning_rate', default=.001, type=float) # try also .0003, .0001, go lower with lower batch size | |
args, _ = _parse_args(config_parser, parser) | |
if args.nhid is None: | |
args.nhid = 2*args.emsize | |
prior = args.__dict__.pop('prior') | |
if prior == 'gp': | |
prior = priors.fast_gp.DataLoader | |
elif prior == 'ridge': | |
prior = priors.ridge.DataLoader | |
elif prior == 'stroke': | |
prior = priors.stroke.DataLoader | |
elif prior == 'mix_gp': | |
prior = priors.fast_gp_mix.DataLoader | |
else: | |
raise NotImplementedError(f'Prior == {prior}.') | |
loss_function = args.__dict__.pop('loss_function') | |
criterion = nn.GaussianNLLLoss(reduction='none', full=True) | |
classificiation_criterion = nn.CrossEntropyLoss(reduction='none') | |
num_buckets = args.__dict__.pop('num_buckets') | |
max_y = args.__dict__.pop('max_y') | |
min_y = args.__dict__.pop('min_y') | |
# criterion = nn.MSELoss(reduction='none') | |
def get_y_sample(): | |
dl = prior(num_steps=1, batch_size=args.batch_size * args.steps_per_epoch, seq_len=args.bptt, device=device, | |
**args.extra_prior_kwargs_dict) | |
y_sample = next(iter(dl))[-1] | |
print(f'Creating Bar distribution with borders from y sample of size {y_sample.numel()}') | |
return y_sample | |
if loss_function == 'ce': | |
criterion = nn.CrossEntropyLoss(reduction='none') | |
elif loss_function == 'gaussnll': | |
criterion = nn.GaussianNLLLoss(reduction='none', full=True) | |
elif loss_function == 'mse': | |
criterion = nn.MSELoss(reduction='none') | |
elif loss_function == 'barnll': | |
criterion = BarDistribution(borders=get_bucket_limits(num_buckets, full_range=(min_y,max_y))) | |
elif loss_function == 'adaptivebarnll': | |
borders = get_bucket_limits(num_buckets, ys=get_y_sample(), full_range=(min_y,max_y)) | |
criterion = BarDistribution(borders=borders) | |
elif loss_function == 'adaptivefullsupportbarnll': | |
assert min_y is None and max_y is None, "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`." | |
borders = get_bucket_limits(num_buckets, ys=get_y_sample()) | |
criterion = FullSupportBarDistribution(borders=borders) | |
else: | |
raise NotImplementedError(f'loss_function == {loss_function}.') | |
encoder = args.__dict__.pop('encoder') | |
y_encoder = args.__dict__.pop('y_encoder') | |
def get_encoder_generator(encoder): | |
if encoder == 'linear': | |
encoder_generator = encoders.Linear | |
elif encoder == 'mlp': | |
encoder_generator = encoders.MLP | |
elif encoder == 'positional': | |
encoder_generator = encoders.Positional | |
else: | |
raise NotImplementedError(f'A {encoder} encoder is not valid.') | |
return encoder_generator | |
encoder_generator = get_encoder_generator(encoder) | |
y_encoder_generator = get_encoder_generator(y_encoder) | |
pos_encoder = args.__dict__.pop('pos_encoder') | |
if pos_encoder == 'none': | |
pos_encoder_generator = None | |
elif pos_encoder == 'sinus': | |
pos_encoder_generator = positional_encodings.PositionalEncoding | |
elif pos_encoder == 'learned': | |
pos_encoder_generator = positional_encodings.LearnedPositionalEncoding | |
elif pos_encoder == 'paired_scrambled_learned': | |
pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings | |
else: | |
raise NotImplementedError(f'pos_encoer == {pos_encoder} is not valid.') | |
permutation_invariant_max_eval_pos = args.__dict__.pop('permutation_invariant_max_eval_pos') | |
permutation_invariant_sampling = args.__dict__.pop('permutation_invariant_sampling') | |
if permutation_invariant_max_eval_pos is not None: | |
if permutation_invariant_sampling == 'weighted': | |
get_sampler = get_weighted_single_eval_pos_sampler | |
elif permutation_invariant_sampling == 'uniform': | |
get_sampler = get_uniform_single_eval_pos_sampler | |
else: | |
raise ValueError() | |
args.__dict__['single_eval_pos_gen'] = get_sampler(permutation_invariant_max_eval_pos) | |
print("ARGS for `train`:", args.__dict__) | |
train(prior, criterion, encoder_generator, | |
y_encoder_generator=y_encoder_generator, pos_encoder_generator=pos_encoder_generator, | |
**args.__dict__) | |