import logging
import os
from pathlib import Path
from time import sleep, time

import hydra
import yaml
from addict import Dict
from comet_ml import ExistingExperiment, Experiment
from omegaconf import OmegaConf

from climategan.trainer import Trainer
from climategan.utils import (
    comet_kwargs,
    copy_run_files,
    env_to_path,
    find_existing_training,
    flatten_opts,
    get_existing_comet_id,
    get_git_branch,
    get_git_revision_hash,
    get_increased_path,
    kill_job,
    load_opts,
    pprint,
)

logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)

hydra_config_path = Path(__file__).resolve().parent / "shared/trainer/config.yaml"


# requires hydra-core==0.11.3 and omegaconf==1.4.1
@hydra.main(config_path=hydra_config_path, strict=False)
def main(opts):
    """
    Opts prevalence:
        1. Load file specified in args.default (or shared/trainer/defaults.yaml
           if none is provided)
        2. Update with file specified in args.config (or no update if none is provided)
        3. Update with parsed command-line arguments

        e.g.
        `python train.py args.config=config/large-lr.yaml data.loaders.batch_size=10`
        loads defaults, overrides with values in large-lr.yaml and sets batch_size to 10
    """

    # -----------------------------
    # -----  Parse arguments  -----
    # -----------------------------

    hydra_opts = Dict(OmegaConf.to_container(opts))
    args = hydra_opts.pop("args", None)
    auto_resumed = {}

    config_path = args.config

    if hydra_opts.train.resume:
        out_ = str(env_to_path(hydra_opts.output_path))
        config_path = Path(out_) / "opts.yaml"
        if not config_path.exists():
            config_path = None
            print("WARNING: could not reuse the opts in {}".format(out_))

    default = args.default or Path(__file__).parent / "shared/trainer/defaults.yaml"

    # -----------------------
    # -----  Load opts  -----
    # -----------------------

    opts = load_opts(config_path, default=default, commandline_opts=hydra_opts)
    if args.resume:
        opts.train.resume = True

    opts.jobID = os.environ.get("SLURM_JOBID")
    opts.slurm_partition = os.environ.get("SLURM_JOB_PARTITION")
    opts.output_path = str(env_to_path(opts.output_path))
    print("Config output_path:", opts.output_path)

    exp = comet_previous_id = None

    # -------------------------------
    # -----  Check output_path  -----
    # -------------------------------

    # Auto-continue if same slurm job ID (=job was requeued)
    if not opts.train.resume and opts.train.auto_resume:
        print("\n\nTrying to auto-resume...")
        existing_path = find_existing_training(opts)
        if existing_path is not None and existing_path.exists():
            auto_resumed["original output_path"] = str(opts.output_path)
            auto_resumed["existing_path"] = str(existing_path)
            opts.train.resume = True
            opts.output_path = str(existing_path)

    # Still not resuming: creating new output path
    if not opts.train.resume:
        opts.output_path = str(get_increased_path(opts.output_path))
        Path(opts.output_path).mkdir(parents=True, exist_ok=True)

    # Copy the opts's sbatch_file to output_path
    copy_run_files(opts)
    # store git hash
    opts.git_hash = get_git_revision_hash()
    opts.git_branch = get_git_branch()

    if not args.no_comet:
        # ----------------------------------
        # -----  Set Comet Experiment  -----
        # ----------------------------------

        if opts.train.resume:
            # Is resuming: get existing comet exp id
            assert Path(opts.output_path).exists(), "Output_path does not exist"

            comet_previous_id = get_existing_comet_id(opts.output_path)
            # Continue existing experiment
            if comet_previous_id is None:
                print("WARNING could not retreive previous comet id")
                print(f"from {opts.output_path}")
            else:
                print("Continuing previous experiment", comet_previous_id)
                auto_resumed["continuing exp id"] = comet_previous_id
                exp = ExistingExperiment(
                    previous_experiment=comet_previous_id, **comet_kwargs
                )
                print("Comet Experiment resumed")

        if exp is None:
            # Create new experiment
            print("Starting new experiment")
            exp = Experiment(project_name="climategan", **comet_kwargs)
            exp.log_asset_folder(
                str(Path(__file__).parent / "climategan"),
                recursive=True,
                log_file_name=True,
            )
            exp.log_asset(str(Path(__file__)))

        # Log note
        if args.note:
            exp.log_parameter("note", args.note)

        # Merge and log tags
        if args.comet_tags or opts.comet.tags:
            tags = set([f"branch:{opts.git_branch}"])
            if args.comet_tags:
                tags.update(args.comet_tags)
            if opts.comet.tags:
                tags.update(opts.comet.tags)
            opts.comet.tags = list(tags)
            print("Logging to comet.ml with tags", opts.comet.tags)
            exp.add_tags(opts.comet.tags)

        # Log all opts
        exp.log_parameters(flatten_opts(opts))
        if auto_resumed:
            exp.log_text("\n".join(f"{k:20}: {v}" for k, v in auto_resumed.items()))

        # allow some time for comet to get its url
        sleep(1)

        # Save comet exp url
        url_path = get_increased_path(Path(opts.output_path) / "comet_url.txt")
        with open(url_path, "w") as f:
            f.write(exp.url)

        # Save config file
        opts_path = get_increased_path(Path(opts.output_path) / "opts.yaml")
        with (opts_path).open("w") as f:
            yaml.safe_dump(opts.to_dict(), f)

    pprint("Running model in", opts.output_path)

    # -------------------
    # -----  Train  -----
    # -------------------

    trainer = Trainer(opts, comet_exp=exp, verbose=1)
    trainer.logger.time.start_time = time()
    trainer.setup()
    trainer.train()

    # -----------------------------
    # -----  End of training  -----
    # -----------------------------

    pprint("Done training")
    kill_job(opts.jobID)


if __name__ == "__main__":

    main()