from joblib import load
import pandas as pd
import random
from pydantic import BaseModel, ValidationInfo, field_validator

PARAM_CONSTRAINTS = {
    "N": {"type": "range", "bounds": [1, 10]},
    "alpha": {"type": "range", "bounds": [0.0, 1.0]},
    "d_model": {"type": "range", "bounds": [100, 1024]},
    "dim_feedforward": {"type": "range", "bounds": [1024, 4096]},
    "dropout": {"type": "range", "bounds": [0.0, 1.0]},
    "emb_scaler": {"type": "range", "bounds": [0.0, 1.0]},
    "eps": {"type": "range", "bounds": [1e-7, 1e-4]},
    "epochs_step": {"type": "range", "bounds": [5, 20]},
    "fudge": {"type": "range", "bounds": [0.0, 0.1]},
    "heads": {"type": "range", "bounds": [1, 10]},
    "k": {"type": "range", "bounds": [2, 10]},
    "lr": {"type": "range", "bounds": [1e-4, 6e-3]},
    "pe_resolution": {"type": "range", "bounds": [2500, 10000]},
    "ple_resolution": {"type": "range", "bounds": [2500, 10000]},
    "pos_scaler": {"type": "range", "bounds": [0.0, 1.0]},
    "weight_decay": {"type": "range", "bounds": [0.0, 1.0]},
    "batch_size": {"type": "range", "bounds": [32, 256]},
    "out_hidden4": {"type": "range", "bounds": [32, 512]},
    "betas1": {"type": "range", "bounds": [0.5, 0.9999]},
    "betas2": {"type": "range", "bounds": [0.5, 0.9999]},
    "bias": {"type": "choice", "values": [False, True]},
    "criterion": {"type": "choice", "values": ["RobustL1", "RobustL2"]},
    "elem_prop": {"type": "choice", "values": ["mat2vec", "magpie", "onehot"]},
    "train_frac": {"type": "range", "bounds": [0.01, 1.0]},
}


class Parameterization(BaseModel):
    N: int
    alpha: float
    d_model: int
    dim_feedforward: int
    dropout: float
    emb_scaler: float
    epochs_step: int
    eps: float
    fudge: float
    heads: int
    k: int
    lr: float
    pe_resolution: int
    ple_resolution: int
    pos_scaler: float
    weight_decay: int
    batch_size: int
    out_hidden4: int
    betas1: float
    betas2: float
    losscurve: bool
    learningcurve: bool
    bias: bool
    criterion: str
    elem_prop: str
    train_frac: float

    @field_validator("*")
    def check_constraints(cls, v: int, info: ValidationInfo) -> int:
        param = PARAM_CONSTRAINTS.get(info.field_name)
        if param is None:
            return v

        if param["type"] == "range":
            min_val, max_val = param["bounds"]
            if not min_val <= v <= max_val:
                raise ValueError(
                    f"{info.field_name} must be between {min_val} and {max_val}"
                )
        elif param["type"] == "choice":
            if v not in param["values"]:
                raise ValueError(f"{info.field_name} must be one of {param['values']}")

        if (
            info.field_name in ("betas1", "betas2")
            and "betas1" in field.owner
            and "betas2" in field.owner
        ):
            if field.owner["betas1"] > field.owner["betas2"]:
                raise ValueError("betas1 must be less than or equal to betas2")
        if (
            info.field_name in ("emb_scaler", "pos_scaler")
            and "emb_scaler" in field.owner
            and "pos_scaler" in field.owner
        ):
            if field.owner["emb_scaler"] + field.owner["pos_scaler"] > 1.0:
                raise ValueError(
                    "The sum of emb_scaler and pos_scaler must be less than or equal to 1.0"
                )

        return v


class CrabNetSurrogateModel(object):
    def __init__(self, fpath="surrogate_models.pkl"):
        self.models = load(fpath)
        pass

    def prepare_params_for_eval(self, raw_params: Parameterization):
        raw_params["bias"] = int(raw_params["bias"])
        raw_params["use_RobustL1"] = raw_params["criterion"] == "RobustL1"
        raw_params["criterion"] = None

        raw_params["losscurve"] = None
        raw_params["learningcurve"] = None

        elem_prop = raw_params["elem_prop"]
        raw_params["elem_prop_magpie"] = 0
        raw_params["elem_prop_mat2vec"] = 0
        raw_params["elem_prop_onehot"] = 0
        raw_params[f"elem_prop_{elem_prop}"] = 1
        raw_params["elem_prop"] = None

        return raw_params

    def surrogate_evaluate(self, params: Parameterization):

        parameters = self.prepare_params_for_eval(params)
        parameters = pd.DataFrame([parameters])

        percentile = random.uniform(0, 1)  # generate random percentile

        mae = self.models["mae"].predict(parameters.assign(mae_rank=[percentile]))
        rmse = self.models["rmse"].predict(parameters.assign(rmse_rank=[percentile]))
        runtime = self.models["runtime"].predict(
            parameters.assign(runtime_rank=[percentile])
        )
        model_size = self.models["model_size"].predict(parameters)

        return mae, rmse, runtime, model_size