import numpy as np
from typing import List
import pandas as pd
import torch
import xgboost as xgb
from transformers import AutoTokenizer, BertForSequenceClassification
from tqdm import tqdm


class BertEmbedder:
    def __init__(self, tokenizer_path:str, model_path:str, cut_head:bool=False):
        """
            cut_head = True if the model have classifier head
        """
        self.embedder = BertForSequenceClassification.from_pretrained(model_path)
        self.max_length = self.embedder.config.max_position_embeddings
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, max_length=self.max_length)

        if cut_head:
            self.embedder = self.embedder.bert

        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        print(f"Used device for BERT: {self.device }", flush=True)
        self.embedder.to(self.device)

    def __call__(self, text: str):
        encoded_input = self.tokenizer(text,
                                       return_tensors='pt',
                                       max_length=self.max_length,
                                       padding=True,
                                       truncation=True).to(self.device)
        model_output = self.embedder(**encoded_input)
        text_embed = model_output.pooler_output[0].cpu()
        return text_embed

    def batch_predict(self, texts: List[str]):
        encoded_input = self.tokenizer(texts,
                                       return_tensors='pt',
                                       max_length=self.max_length,
                                       padding=True,
                                       truncation=True).to(self.device)
        model_output = self.embedder(**encoded_input)
        texts_embeds = model_output.pooler_output.cpu()
        return texts_embeds


class PredictModel:
    def __init__(self, embedder, classifier_code, classifier_group, batch_size=8):
        self.batch_size = batch_size
        self.embedder = embedder
        self.classifier_code = classifier_code
        self.classifier_group = classifier_group

    def _texts2vecs(self, texts, logging=False):
        embeds = []
        batches_texts = np.array_split(texts, len(texts) // self.batch_size)
        if logging:
            iterator = tqdm(batches_texts)
        else:
            iterator = batches_texts
        for batch_texts in iterator:
            batch_texts = batch_texts.tolist()
            embeds += self.embedder.batch_predict(batch_texts).tolist()
        embeds = np.array(embeds)
        return embeds

    def fit(self, texts: List[str], labels: List[str], logging: bool=False):
        if logging:
            print('Start text2vec transform')
        embeds = self._texts2vecs(texts, logging)
        if logging:
            print('Start codes-classifier fitting')
        self.classifier_code.fit(embeds, labels)
        labels = [l.split('.')[0] for l in labels]
        if logging:
            print('Start groups-classifier fitting')
        self.classifier_group.fit(embeds, labels)

    def predict_code(self, texts: List[str], log: bool=False):
        if log:
            print('Start text2vec transform')
        embeds = self._texts2vecs(texts, log)
        if log:
            print('Start classifier prediction')
        prediction = self.classifier_code.predict(embeds)
        return prediction

    def predict_group(self, texts: List[str], logging: bool=False):
        if logging:
            print('Start text2vec transform')
        embeds = self._texts2vecs(texts, logging)
        if logging:
            print('Start classifier prediction')
        prediction = self.classifier_group.predict(embeds)
        return prediction

class CustomXGBoost:
    def __init__(self, use_gpu):
        if use_gpu:
            self.model = xgb.XGBClassifier(tree_method="gpu_hist")
        else:
            self.model = xgb.XGBClassifier()
        self.classes_ = None

    def fit(self, X, y, **kwargs):
        self.classes_ = np.unique(y).tolist()
        y = [self.classes_.index(l) for l in y]
        self.model.fit(X, y, **kwargs)

    def predict_proba(self, X):
        pred = self.model.predict_proba(X)
        return pred

    def predict(self, X):
        preds = self.model.predict_proba(X)
        return np.array([self.classes_[p] for p in np.argmax(preds, axis=1)])

class SimpleModel:
    def __init__(self):
        self.classes_ = None

    def fit(self, X, y):
        print(y[0])
        self.classes_ = [y[0]]

    def predict_proba(self, X):
        return np.array([[1.0]] * len(X))

def balance_dataset(labels_train_for_group, vecs_train_for_group, balance=None, logging=True):
    if balance == 'remove':
        min_len = -1
        for code_l in np.unique(labels_train_for_group):
            cur_len = sum(labels_train_for_group==code_l)
            if logging:
                print(code_l, cur_len)
            if min_len > cur_len or min_len==-1:
                min_len = cur_len
        if logging:
            print('min_len is', min_len)
        df_train_group = pd.DataFrame()
        df_train_group['labels'] = labels_train_for_group
        df_train_group['vecs'] = vecs_train_for_group.tolist()
        df_train_group = df_train_group.groupby('labels', as_index=False).apply(lambda array: array.loc[np.random.choice(array.index, min_len, False),:])
        labels_train_for_group = df_train_group['labels'].values
        vecs_train_for_group = [np.array(v) for v in df_train_group['vecs'].values]

    elif balance == 'duplicate':
        df_train_group = pd.DataFrame()
        df_train_group['labels'] = labels_train_for_group
        df_train_group['vecs'] = vecs_train_for_group.tolist()
        max_len = 0
        for code_data in df_train_group.groupby('labels'):
            cur_len = len(code_data[1])
            if logging:
                print(code_data[0], cur_len)
            if max_len < cur_len:
                max_len = cur_len
        if logging:
            print('max_len is ', max_len)
        labels_train_for_group = []
        vecs_train_for_group = []
        for code_data in df_train_group.groupby('labels'):
            cur_len = len(code_data[1])
            cur_labels = code_data[1]['labels'].values.tolist()
            cur_vecs = code_data[1]['vecs'].values.tolist()
            while cur_len < max_len:
                cur_len *= 2
                cur_labels += cur_labels
                cur_vecs += cur_vecs
            cur_labels = cur_labels[:max_len]
            cur_vecs = cur_vecs[:max_len]
            labels_train_for_group += cur_labels
            vecs_train_for_group += cur_vecs

        labels_train_for_group = np.array(labels_train_for_group)
        vecs_train_for_group = np.array(vecs_train_for_group)
    return labels_train_for_group, vecs_train_for_group