import math
import multiprocessing
import os
from datetime import timedelta
from functools import partial
from itertools import chain

import torch
# import bitsandbytes as bnb

from torch.distributed.fsdp import (
    FullyShardedDataParallel,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
)
from accelerate import Accelerator
from accelerate.utils import (DummyOptim, InitProcessGroupKwargs)
from accelerate.logging import get_logger


from datasets import load_dataset
from lion_pytorch import Lion
from torch.nn import LayerNorm


from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy
)


from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AutoTokenizer, default_data_collator,
                          get_cosine_schedule_with_warmup,
                          get_linear_schedule_with_warmup, set_seed)


from Andromeda.utils.stable_adamw import StableAdamWUnfused
from Andromeda.core.transformer import Transformer, AndromedaEmbedding
# from Andromeda.model import Andromeda
from Andromeda.model import AndromedaEmbedding #, Andromeda
from Andromeda.configs import Andromeda1Billion

########### SETUP CONFIG
import torch.distributed as dist


from accelerate.state import AcceleratorState

# state = AcceleratorState()


logger = get_logger(__name__, log_level="INFO")

class CFG:
    BATCH_SIZE = 1
    GRADIENT_ACCUMULATE_EVERY: int = 1
    SEED: int = 42
    LEARNING_RATE: float = 1e-4 #3e-4 # 1e-4 for lion
    WEIGHT_DECAY: float = 0.1
    SEQ_LEN: int = 8192
    NUM_CPU: int = multiprocessing.cpu_count()
    USE_DEEPSPEED: bool = True
    USE_FSDP: bool = True
    USE_PRETOKENIZED: bool = True
    USE_ACTIVATION_CHECKPOINTING: bool = True
    RESUME_FROM_CHECKPOINT: str = False
    CHECKPOINTING_STEPS: int = 1000
    OUTPUT_DIR: str = 'checkpoints/' # Folder
    ENTITY_NAME: str = "Andromeda"
    LOGGING_STEPS: int = 100


# helpers


def print_num_params(model, accelerator: Accelerator):
    # n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    accelerator.print(f"Number of parameters in model: {n_params}")


# activation checkpointing


def activation_checkpointing(
    model: torch.nn.Module,
    offload_to_cpu: bool = False,
    accelerator: Accelerator = None,
):
    """
    Apply activation checkpointing to a model.

    Args:
        model (Module): The model to which to apply activation checkpointing.
        offload_to_cpu (bool, optional): Whether to offload the activations to CPU. Defaults to False.
        accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.
    """
    if accelerator is not None:
        accelerator.print("Using activation checkpointing")
    def check_fn(submodule):
        return isinstance(submodule, Transformer)
    non_reentrant_wrapper = partial(
        checkpoint_wrapper,
        offload_to_cpu=offload_to_cpu,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    )
    apply_activation_checkpointing(
        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
    )


# FSDP


def fsdp(
    model: torch.nn.Module,
    auto_wrap: bool = False,
    mp: str = "fp32",
    shard_strat: str = "NO_SHARD",
):
    """
    This function wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper to enable efficient data parallelism and model sharding.

    Args:
        model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP.
        auto_wrap (bool, optional): If True, it enables automatic wrapping of the model's layers according to the transformer_auto_wrap_policy. Default is False.
        mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16 or 'fp32' for Float32 precision. Default is 'fp32'.
        shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'.

    Raises:
        ValueError: If the provided mp (mixed precision mode) is not 'bf16', 'fp16' or 'fp32'.
        ValueError: If the provided shard_strat (sharding strategy) is not 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD'.

    Returns:
        torch.nn.Module: The input model wrapped with FSDP.
    """
    if auto_wrap:
        Andromeda_auto_wrap_policy = partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                Transformer,
            },
        )
    else:
        Andromeda_auto_wrap_policy = None

    if mp == "bf16":
        mp_fsdp = MixedPrecision(
            param_dtype=torch.bfloat16,
            # Gradient communication precision.
            reduce_dtype=torch.bfloat16,
            # Buffer precision.
            buffer_dtype=torch.bfloat16,
        )
    elif mp == "fp16":
        mp_fsdp = MixedPrecision(
            param_dtype=torch.float16,
            # Gradient communication precision.
            reduce_dtype=torch.float16,
            # Buffer precision.
            buffer_dtype=torch.float16,
        )
    elif mp == "fp32":
        mp_fsdp = MixedPrecision(
            param_dtype=torch.float32,
            # Gradient communication precision.
            reduce_dtype=torch.float32,
            # Buffer precision.
            buffer_dtype=torch.float32,
        )
    else:
        raise ValueError(
            "Invalid scheduler_type. Expected 'bf16', 'fp16' or 'fp32', got: {}".format(
                mp
            )
        )

    if shard_strat == "SHARD_GRAD":
        sharding_strat_fsdp = ShardingStrategy.SHARD_GRAD_OP 
    elif shard_strat == "FULL_SHARD":
        sharding_strat_fsdp = ShardingStrategy.FULL_SHARD
    elif shard_strat == "NO_SHARD":
        sharding_strat_fsdp = ShardingStrategy.NO_SHARD
    else:
        raise ValueError(
            "Invalid scheduler_type. Expected 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD', got: {}".format(
                shard_strat
            )
        )

    model = FullyShardedDataParallel(
        model,
        auto_wrap_policy=Andromeda_auto_wrap_policy,
        mixed_precision=mp_fsdp,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        sharding_strategy=sharding_strat_fsdp,
        forward_prefetch=True,
        use_orig_params=True,
    )

    return model


# learning rate scheduler


def get_lr_scheduler_with_warmup(
    optimizer: torch.optim.Optimizer,
    scheduler_type: str,
    num_warmup_steps: int,
    max_train_steps: int,
    grad_accumulate_every: int = 1,
    accelerator: Accelerator = None,
):
    """
    Get a learning rate scheduler with warmup.

    Args:
        optimizer (Optimizer): The optimizer for which to create the learning rate scheduler.
        scheduler_type (str): The type of learning rate scheduler to create, either "linear" or "cosine".
        num_warmup_steps (int): The number of warmup steps for the learning rate scheduler.
        max_train_steps (int): The maximum number of training steps.
        grad_accumulate_every (int, optional): The gradient accumulation factor. Defaults to 1.
        accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.

    Returns:
        The learning rate scheduler with warmup.

    Raises:
        ValueError: If scheduler_type is not "linear" or "cosine".
    """
    NUM_WARMUP_STEPS = num_warmup_steps
    GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
    if accelerator is not None:
        accelerator.print(f"Using {scheduler_type} lr scheduler")
    if scheduler_type == "linear":
        return get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
            num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
        )
    elif scheduler_type == "cosine":
        return get_cosine_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
            num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
        )
    else:
        raise ValueError(
            "Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
                scheduler_type
            )
        )


# optimizers


def decoupled_optimizer(
    model: torch.nn.Module,
    learning_rate: float,
    weight_decay: float,
    beta_1: float,
    beta_2: float,
    optimizer_type: str,
    use_fsdp: bool = True,
    accelerator: Accelerator = None,
):
    """
    Decouples the optimizer from the training process.

    This function sets up the optimizer for the model by creating two groups of parameters:
    one for weight decay and one without weight decay. Then, it initializes the optimizer
    with these two groups of parameters.

    Args:
        model (Module): The model whose parameters are optimized.
        learning_rate (float): The learning rate for the optimizer.
        weight_decay (float): The weight decay for the optimizer.
        beta_1 (float): The exponential decay rate for the 1st moment estimates.
        beta_2 (float): The exponential decay rate for the 2nd moment estimates.
        optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'.
        use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True.
        accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None.

    Returns:
        Optimizer: The initialized optimizer.

    Raises:
        ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'.
    """
    accelerator.print(f"Using {optimizer_type} optimizer")
    # Create an empty dictionary called param_dict to store the model's named parameters.
    param_dict = {}
    # Iterate over the model's named parameters and populate the param_dict with key-value pairs.
    for param_name, param in model.named_parameters():
        param_dict[param_name] = param

    # Separate the model's named modules into two groups: decay and no_decay.

    # Create an empty list to store the names of the LayerNorm and Embedding layer weights with no weight decay.
    no_decay = []

    if use_fsdp:
        exclude_module = "_fsdp_wrapped_module.token_emb"
    else:
        exclude_module = "token_emb"

    # Iterate through the named modules of the model.
    for module_name, module in model.named_modules():
        # Check if the current module is an instance of any of the desired types (LayerNorm or torch.nn.Embedding).
        for ndim in [LayerNorm, torch.nn.Embedding]:
            if isinstance(module, ndim):
                # If torch.nn.Embedding, append its name with a ".weight" suffix to the no_decay list.
                if module_name == exclude_module:
                    no_decay.append(f"{module_name}.weight")
                else:
                    # If the module is an instance of LayerNorm
                    no_decay.append(f"{module_name}.gamma")
                # Exit the inner loop since the desired module has been found.
                break

    # Create an empty list to store the names of the Linear layer weights with weight decay.
    decay = []

    # Iterate through the named modules of the model.
    for module_name, module in model.named_modules():
        # Check if the current module is an instance of the desired type (torch.nn.Linear).
        for ndim in [torch.nn.Linear]:
            if isinstance(module, ndim):
                # If the module is an instance of torch.nn.Linear, append its name with a ".weight" suffix to the decay list.
                decay.append(f"{module_name}.weight")
                # Exit the inner loop since the desired module has been found.
                break

    # Create two separate lists of model parameters: decay_param and no_decay_param.
    # The decay_param list contains the parameters that should have weight decay applied.
    # The no_decay_param list contains the parameters that should not have weight decay applied, excluding the 'to_logits.weight' parameter.

    # Create an empty list called decay_param to store the parameters with weight decay.
    decay_param = []

    if use_fsdp:
        exclude_param = "_fsdp_wrapped_module.to_logits.weight"
    else:
        exclude_param = "to_logits.weight"

    # Iterate over the decay list, which contains the names of the parameters with weight decay.
    for param in decay:
        # Check if the current parameter is not 'to_logits.weight'.
        # Append the corresponding parameter from param_dict to the decay_param list.

        if param != exclude_param:
            decay_param.append(param_dict[param])

    # Create an empty list called no_decay_param to store the parameters without weight decay.
    no_decay_param = []

    # Iterate over the no_decay list, which contains the names of the parameters without weight decay.
    for param in no_decay:
        try:
                
            # Append the corresponding parameter from param_dict to the no_decay_param list.
            no_decay_param.append(param_dict[param])
        except KeyError:
            # print(f"Parameter {param_name} does not exist in the model")
            pass

    # Create a list called grouped_params that contains two dictionaries.
    # The first dictionary has the decay_param list and the corresponding weight_decay value.
    # The second dictionary has the no_decay_param list and a weight_decay value of 0.0.
    grouped_params = [
        {"params": decay_param, "weight_decay": weight_decay},
        {"params": no_decay_param, "weight_decay": 0.0},
    ]

    # Create a variable called optimizer that stores an instance of the optimizer.
    if optimizer_type == "lion":
        optimizer = Lion(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
    elif optimizer_type == "adamw":
        optimizer = AdamW(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
    elif optimizer_type == "deepspeed":
        optimizer = DummyOptim(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
    elif optimizer_type == "stable_adamw":
        optimizer = StableAdamWUnfused(
            grouped_params, lr=learning_rate, betas=(beta_1, beta_2),
        )
    # elif optimizer_type=="Adam8bit":
    #     optimizer = bnb.optim.Adam8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
    # elif optimizer_type=="Lion8Bit":
    #     optimizer = bnb.optim.Lion8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
    else:
        raise ValueError(
            "Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format(
                optimizer_type
            )
        )

    # Return the optimizer.
    return optimizer


# dataloaders


def build_dataloaders():
    """
    Build data loaders for training.

    This function performs the following steps:
    1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model.
    2. Load the "openwebtext" dataset.
    3. Tokenize the dataset, adding the end-of-sentence token to each text.
    4. Process the tokenized dataset into chunks of a specified block size.

    Returns:
        Dataset: The processed dataset ready for training.
    """
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    dataset = load_dataset("openwebtext", split="train")

    tokenized_dataset = dataset.map(
        lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]),
        batched=True,
        num_proc=CFG.NUM_CPU,
        remove_columns=["text"],
    )

    block_size = CFG.SEQ_LEN

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        return result

    train_dataset = tokenized_dataset.map(
        group_texts, batched=True, num_proc=CFG.NUM_CPU,
    )

    return train_dataset

#switch to falconwebdataset
def build_pre_tokenized():
    d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train[:10]")
    # d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
    # d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
    # d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
    # d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
    # train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
    return d0



def Train():
    # accelerator

    timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))

    accelerator = Accelerator(
        gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
        mixed_precision="fp16",
        log_with="wandb",
        kwargs_handlers=[timeout],
    )

    state = AcceleratorState()
    
    state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = CFG.BATCH_SIZE #??????

    accelerator.init_trackers(
        project_name="Andromeda",
        config={
            "batch_size": CFG.BATCH_SIZE,
            "gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
            "learning_rate": CFG.LEARNING_RATE,
            "seq_len": CFG.SEQ_LEN,
        },
        # init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}},
    )

    accelerator.print(f"Total GPUS: {accelerator.num_processes}")

    # set seed

    set_seed(CFG.SEED)

    # model = Andromeda(
    #     num_tokens=50432,
    #     max_seq_len=8192,
    #     dim=3072,
    #     depth=24,
    #     dim_head=128,
    #     heads=12,
    #     use_abs_pos_emb=False, 
    #     alibi_pos_bias=True, 
    #     alibi_num_heads=6, 
    #     rotary_xpos=True,
    #     attn_flash=True, 
    #     shift_tokens=1, 
    #     attn_one_kv_head=True, 
    #     qk_norm=True, 
    #     attn_qk_norm=True, 
    #     attn_qk_norm_dim_scale=True, 
    #     embedding_provider=AndromedaEmbedding()
    # )
    model = Andromeda1Billion()

    print_num_params(model, accelerator)

    if CFG.USE_FSDP:
        model = fsdp(
            model,
            mp="fp16",
            shard_strat="SHARD_GRAD"
        )

    if CFG.USE_ACTIVATION_CHECKPOINTING:
        activation_checkpointing(model, accelerator)

    model = accelerator.prepare(model)

    # dataloaders

    if CFG.USE_PRETOKENIZED:
        train_dataset = build_pre_tokenized()
    else:
        train_dataset = build_dataloaders()

    train_loader = DataLoader(
        train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
    )


    # optimizer
    optim = decoupled_optimizer(
        model=model,
        learning_rate=CFG.LEARNING_RATE, 
        weight_decay=CFG.WEIGHT_DECAY, 
        beta_1=0.90, 
        beta_2=0.95, 
        optimizer_type='lion',  
        use_fsdp=True,
        accelerator=accelerator
    )

    # Determine number of training steps

    max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
    accelerator.print(f"Max train steps: {max_train_steps}")

    # lr scheduler

    NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
    accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")

    # if False: # if CFG.USE_DEEPSPEED:
    #     lr_scheduler = DummyScheduler(
    #         optim, 
    #         total_num_steps=max_train_steps * accelerator.num_processes, 
    #         warmup_num_steps=NUM_WARMUP_STEPS
    #     )
    # else:
    lr_scheduler = get_lr_scheduler_with_warmup(
        optimizer=optim,
        scheduler_type="cosine",
        num_warmup_steps=NUM_WARMUP_STEPS,
        max_train_steps=max_train_steps,
        grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY,
    )

    # prepare

    optim, train_loader, lr_scheduler = accelerator.prepare(
        optim, train_loader, lr_scheduler
    )

    # checkpoint scheduler

    accelerator.register_for_checkpointing(lr_scheduler)

    # I do not know why Huggingface recommends recalculation of max_train_steps

    max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
    accelerator.print(f"Max train steps recalculated: {max_train_steps}")

    # Total batch size for logging

    total_batch_size = (
        CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
    )
    accelerator.print(f"Total batch size: {total_batch_size}")

    # resume training

    progress_bar = tqdm(
        range(max_train_steps), disable=not accelerator.is_local_main_process
    )
    completed_steps = 0

    if CFG.RESUME_FROM_CHECKPOINT:
        if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
            accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
            accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
            path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
        training_difference = os.path.splitext(path)[0]

        # need to multiply `gradient_accumulation_steps` to reflect real steps
        resume_step = (
            int(training_difference.replace("step_", ""))
            * CFG.GRADIENT_ACCUMULATE_EVERY
        )

    if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
        train_loader = accelerator.skip_first_batches(train_loader, resume_step)
        completed_steps += resume_step
        progress_bar.update(resume_step)

    # training

    model.train()
    for step, batch in enumerate(train_loader):
        with accelerator.accumulate(model):
            inputs = batch["input_ids"].to(accelerator.device)
            loss = model(inputs, return_loss=True)
            accelerator.backward(loss)

            accelerator.log({"loss": loss.item()}, step=step)

            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)

            optim.step()
            lr_scheduler.step()
            optim.zero_grad()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            completed_steps += 1

        if isinstance(CFG.CHECKPOINTING_STEPS, int):
            if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
                output_dir = f"step_{completed_steps }"
                if CFG.OUTPUT_DIR is not None:
                    output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
                accelerator.save_state(output_dir)

        if completed_steps >= max_train_steps:
            break

        #logging every CFG.LOGGING STEPS
        if CFG.LOGGING_STEPS > 0 and step % CFG.LOGGING_STEPS == 0:
            logger.info(
                f"Step: {completed_steps}/{max_train_steps}, Loss: {loss.item():.5f}"
            )

    # end training

    # accelerator.print(f"Training Finished")
    accelerator.end_training()

    # save final model

    # accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
    if CFG.OUTPUT_DIR is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        with accelerator.main_process_first():
            accelerator.save(
                unwrapped_model.state_dict(), f"{CFG.OUTPUT_DIR}/final/final_model.pt"
            )


def main():
    os.environ['MASTER_ADDR'] #'localhost'
    os.environ['MASTER_PORT'] #= '9994'
    
    # # [CRITICAL] Pay attention to this when scaling to multiple GPUs and clusters
    
    # # Pay attention to this, use "accelerate config"

    os.environ['RANK']       #= str(0) # Number of nodes (servers)
    os.environ['WORLD_SIZE'] # = str(torch.cuda.device_count())

    dist.init_process_group(backend='nccl') #init_method="env://")
    
    Train()

if __name__ == '__main__':
    main()