import argparse
import logging
import os
import random

import torch
from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase
from fastai.distributed import *
from fastai.vision import *
from torch.backends import cudnn

from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy
from dataset import ImageDataset, TextDataset
from losses import MultiLosses
from utils import Config, Logger, MyDataParallel, MyConcatDataset


def _set_random_seed(seed):
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        cudnn.deterministic = True
        logging.warning('You have chosen to seed training. '
                        'This will slow down your training!')

def _get_training_phases(config, n):
    lr = np.array(config.optimizer_lr)
    periods = config.optimizer_scheduler_periods
    sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))]
    phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i])
                for i in range(len(periods))]
    return phases

def _get_dataset(ds_type, paths, is_training, config, **kwargs):
    kwargs.update({
        'img_h': config.dataset_image_height,
        'img_w': config.dataset_image_width,
        'max_length': config.dataset_max_length,
        'case_sensitive': config.dataset_case_sensitive,
        'charset_path': config.dataset_charset_path,
        'data_aug': config.dataset_data_aug,
        'deteriorate_ratio': config.dataset_deteriorate_ratio,
        'is_training': is_training,
        'multiscales': config.dataset_multiscales,
        'one_hot_y': config.dataset_one_hot_y,
    })
    datasets = [ds_type(p, **kwargs) for p in paths]
    if len(datasets) > 1: return MyConcatDataset(datasets)
    else: return datasets[0]


def _get_language_databaunch(config):
    kwargs = {
        'max_length': config.dataset_max_length,
        'case_sensitive': config.dataset_case_sensitive,
        'charset_path': config.dataset_charset_path,
        'smooth_label': config.dataset_smooth_label,
        'smooth_factor': config.dataset_smooth_factor,
        'one_hot_y': config.dataset_one_hot_y,
        'use_sm': config.dataset_use_sm,
    }
    train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs)
    valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs)
    data = DataBunch.create(
        path=train_ds.path,
        train_ds=train_ds,
        valid_ds=valid_ds,
        bs=config.dataset_train_batch_size,
        val_bs=config.dataset_test_batch_size,
        num_workers=config.dataset_num_workers,
        pin_memory=config.dataset_pin_memory)
    logging.info(f'{len(data.train_ds)} training items found.')
    if not data.empty_val:
        logging.info(f'{len(data.valid_ds)} valid items found.')
    return data

def _get_databaunch(config):
    # An awkward way to reduce loadding data time during test
    if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots
    train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config)
    valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config)
    data = ImageDataBunch.create(
        train_ds=train_ds,
        valid_ds=valid_ds,
        bs=config.dataset_train_batch_size,
        val_bs=config.dataset_test_batch_size,
        num_workers=config.dataset_num_workers,
        pin_memory=config.dataset_pin_memory).normalize(imagenet_stats)
    ar_tfm = lambda x: ((x[0], x[1]), x[1])  # auto-regression only for dtd
    data.add_tfm(ar_tfm)

    logging.info(f'{len(data.train_ds)} training items found.')
    if not data.empty_val:
        logging.info(f'{len(data.valid_ds)} valid items found.')
    
    return data

def _get_model(config):
    import importlib
    names = config.model_name.split('.')
    module_name, class_name = '.'.join(names[:-1]), names[-1]
    cls = getattr(importlib.import_module(module_name), class_name)
    model = cls(config)
    logging.info(model)
    return model


def _get_learner(config, data, model, local_rank=None):
    strict = ifnone(config.model_strict, True)
    if config.global_stage == 'pretrain-language':
        metrics = [TopKTextAccuracy(
            k=ifnone(config.model_k, 5),
            charset_path=config.dataset_charset_path,
            max_length=config.dataset_max_length + 1,
            case_sensitive=config.dataset_eval_case_sensisitves,
            model_eval=config.model_eval)] 
    else:
        metrics = [TextAccuracy(
            charset_path=config.dataset_charset_path,
            max_length=config.dataset_max_length + 1,
            case_sensitive=config.dataset_eval_case_sensisitves,
            model_eval=config.model_eval)]
    opt_type = getattr(torch.optim, config.optimizer_type)
    learner = Learner(data, model, silent=True, model_dir='.',
        true_wd=config.optimizer_true_wd, 
        wd=config.optimizer_wd,
        bn_wd=config.optimizer_bn_wd,
        path=config.global_workdir,
        metrics=metrics,
        opt_func=partial(opt_type, **config.optimizer_args or dict()), 
        loss_func=MultiLosses(one_hot=config.dataset_one_hot_y))
    learner.split(lambda m: children(m))

    if config.global_phase == 'train':
        num_replicas = 1 if local_rank is None else torch.distributed.get_world_size()
        phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas)
        learner.callback_fns += [
            partial(GeneralScheduler, phases=phases),
            partial(GradientClipping, clip=config.optimizer_clip_grad),
            partial(IterationCallback, name=config.global_name,
                    show_iters=config.training_show_iters,
                    eval_iters=config.training_eval_iters,
                    save_iters=config.training_save_iters,
                    start_iters=config.training_start_iters,
                    stats_iters=config.training_stats_iters)]
    else:
        learner.callbacks += [
            DumpPrediction(learn=learner,
                    dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path,
                    model_eval=config.model_eval,
                    debug=config.global_debug,
                    image_only=config.global_image_only)]

    learner.rank = local_rank
    if local_rank is not None:
        logging.info(f'Set model to distributed with rank {local_rank}.')
        learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model)
        learner.model.to(local_rank)
        learner = learner.to_distributed(local_rank)

    if torch.cuda.device_count() > 1 and local_rank is None:
        logging.info(f'Use {torch.cuda.device_count()} GPUs.')
        learner.model = MyDataParallel(learner.model)

    if config.model_checkpoint:
        if Path(config.model_checkpoint).exists():
            with open(config.model_checkpoint, 'rb') as f:
                buffer = io.BytesIO(f.read())
            learner.load(buffer, strict=strict)
        else:
            from distutils.dir_util import copy_tree
            src = Path('/data/fangsc/model')/config.global_name
            trg = Path('/output')/config.global_name
            if src.exists(): copy_tree(str(src), str(trg))
            learner.load(config.model_checkpoint, strict=strict)
        logging.info(f'Read model from {config.model_checkpoint}')
    elif config.global_phase == 'test':
        learner.load(f'best-{config.global_name}', strict=strict)
        logging.info(f'Read model from best-{config.global_name}')

    if learner.opt_func.func.__name__ == 'Adadelta':    # fastai bug, fix after 1.0.60
        learner.fit(epochs=0, lr=config.optimizer_lr)
        learner.opt.mom = 0.

    return learner

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True,
                        help='path to config file')
    parser.add_argument('--phase', type=str, default=None, choices=['train', 'test'])
    parser.add_argument('--name', type=str, default=None)
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--test_root', type=str, default=None)
    parser.add_argument("--local_rank", type=int, default=None)
    parser.add_argument('--debug', action='store_true', default=None)
    parser.add_argument('--image_only', action='store_true', default=None)
    parser.add_argument('--model_strict', action='store_false', default=None)
    parser.add_argument('--model_eval', type=str, default=None, 
                        choices=['alignment', 'vision', 'language'])
    args = parser.parse_args()
    config = Config(args.config)
    if args.name is not None: config.global_name = args.name
    if args.phase is not None: config.global_phase = args.phase
    if args.test_root is not None: config.dataset_test_roots = [args.test_root]
    if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
    if args.debug is not None: config.global_debug = args.debug
    if args.image_only is not None: config.global_image_only = args.image_only
    if args.model_eval is not None: config.model_eval = args.model_eval
    if args.model_strict is not None: config.model_strict = args.model_strict

    Logger.init(config.global_workdir, config.global_name, config.global_phase)
    Logger.enable_file()
    _set_random_seed(config.global_seed)
    logging.info(config)

    if args.local_rank is not None:
        logging.info(f'Init distribution training at device {args.local_rank}.')
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logging.info('Construct dataset.')
    if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config)
    else: data = _get_databaunch(config)

    logging.info('Construct model.')
    model = _get_model(config)

    logging.info('Construct learner.')
    learner = _get_learner(config, data, model, args.local_rank)

    if config.global_phase == 'train':
        logging.info('Start training.')
        learner.fit(epochs=config.training_epochs,
                    lr=config.optimizer_lr)
    else:
        logging.info('Start validate')
        last_metrics = learner.validate()
        log_str = f'eval loss = {last_metrics[0]:6.3f},  ' \
                  f'ccr = {last_metrics[1]:6.3f},  cwr = {last_metrics[2]:6.3f},  ' \
                  f'ted = {last_metrics[3]:6.3f},  ned = {last_metrics[4]:6.0f},  ' \
                  f'ted/w = {last_metrics[5]:6.3f}, '
        logging.info(log_str)

if __name__ == '__main__':
    main()