import torch
import logging
import os

logger = logging.getLogger(__name__)
from torchvision import transforms
from PIL import Image


class SBInputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.img_id = img_id
        # Please note that the auxlabel is not used in SB
        # it is just kept in order not to modify the original code


class SBInputFeatures(object):
    """A single set of features of data"""

    def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.added_input_mask = added_input_mask
        self.segment_ids = segment_ids
        self.img_feat = img_feat


def sbreadfile(filename):
    '''
    Đọc dữ liệu từ tệp và trả về dưới dạng danh sách các từ và danh sách hình ảnh.
    '''
    print("Chuẩn bị dữ liệu từ", filename)
    with open(filename, encoding='utf8') as f:
        data = []
        imgs = []
        sentence = []
        imgid = ''
        for line in f:
            line = line.strip()  # Loại bỏ các dấu cách thừa ở đầu và cuối dòng
            if line.startswith('IMGID:'):
                imgid = line.split('IMGID:')[1] + '.jpg'
                continue
            if line == '':
                if len(sentence) > 0:
                    data.append(sentence)
                    imgs.append(imgid)
                    sentence = []
                    imgid = ''
                continue
            word = line.split('\t')[0]  # Chỉ lấy từ (không lấy nhãn)
            sentence.append(word)

        if len(sentence) > 0:  # Xử lý dữ liệu cuối cùng trong tệp
            data.append(sentence)
            imgs.append(imgid)

    print("Số lượng mẫu: " + str(len(data)))
    print("Số lượng hình ảnh: " + str(len(imgs)))
    return data, imgs


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_sbtsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        return sbreadfile(input_file)


class MNERProcessor(DataProcessor):
    """Processor for the CoNLL-2003 data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        data, imgs = self._read_sbtsv(os.path.join(data_dir, "train.txt"))
        return self._create_examples(data, imgs, "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        data, imgs = self._read_sbtsv(os.path.join(data_dir, "dev.txt"))
        return self._create_examples(data, imgs, "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        data, imgs = self._read_sbtsv(os.path.join(data_dir, "test.txt"))
        return self._create_examples(data, imgs, "test")

    def get_labels(self):
        #         return [
        # "O","I-PRODUCT-AWARD",
        # "B-MISCELLANEOUS",
        # "B-QUANTITY-NUM",
        # "B-ORGANIZATION-SPORTS",
        # "B-DATETIME",
        # "I-ADDRESS",
        # "I-PERSON",
        # "I-EVENT-SPORT",
        # "B-ADDRESS",
        # "B-EVENT-NATURAL",
        # "I-LOCATION-GPE",
        # "B-EVENT-GAMESHOW",
        # "B-DATETIME-TIMERANGE",
        # "I-QUANTITY-NUM",
        # "I-QUANTITY-AGE",
        # "B-EVENT-CUL",
        # "I-QUANTITY-TEM",
        # "I-PRODUCT-LEGAL",
        # "I-LOCATION-STRUC",
        # "I-ORGANIZATION",
        # "B-PHONENUMBER",
        # "B-IP",
        # "B-QUANTITY-AGE",
        # "I-DATETIME-TIME",
        # "I-DATETIME",
        # "B-ORGANIZATION-MED",
        # "B-DATETIME-SET",
        # "I-EVENT-CUL",
        # "B-QUANTITY-DIM",
        # "I-QUANTITY-DIM",
        # "B-EVENT",
        # "B-DATETIME-DATERANGE",
        # "I-EVENT-GAMESHOW",
        # "B-PRODUCT-AWARD",
        # "B-LOCATION-STRUC",
        # "B-LOCATION",
        # "B-PRODUCT",
        # "I-MISCELLANEOUS",
        # "B-SKILL",
        # "I-QUANTITY-ORD",
        # "I-ORGANIZATION-STOCK",
        # "I-LOCATION-GEO",
        # "B-PERSON",
        # "B-PRODUCT-COM",
        # "B-PRODUCT-LEGAL",
        # "I-LOCATION",
        # "B-QUANTITY-TEM",
        # "I-PRODUCT",
        # "B-QUANTITY-CUR",
        # "I-QUANTITY-CUR",
        # "B-LOCATION-GPE",
        # "I-PHONENUMBER",
        # "I-ORGANIZATION-MED",
        # "I-EVENT-NATURAL",
        # "I-EMAIL",
        # "B-ORGANIZATION",
        # "B-URL",
        # "I-DATETIME-TIMERANGE",
        # "I-QUANTITY",
        # "I-IP",
        # "B-EVENT-SPORT",
        # "B-PERSONTYPE",
        # "B-QUANTITY-PER",
        # "I-QUANTITY-PER",
        # "I-PRODUCT-COM",
        # "I-DATETIME-DURATION",
        # "B-LOCATION-GPE-GEO",
        # "B-QUANTITY-ORD",
        # "I-EVENT",
        # "B-DATETIME-TIME",
        # "B-QUANTITY",
        # "I-DATETIME-SET",
        # "I-LOCATION-GPE-GEO",
        # "B-ORGANIZATION-STOCK",
        # "I-ORGANIZATION-SPORTS",
        # "I-SKILL",
        # "I-URL",
        # "B-DATETIME-DURATION",
        # "I-DATETIME-DATE",
        # "I-PERSONTYPE",
        # "B-DATETIME-DATE",
        # "I-DATETIME-DATERANGE",
        # "B-LOCATION-GEO",
        # "B-EMAIL","X","<s>", "</s>"]

        # vlsp2016
        return [
            "I-LOC", "B-MISC",
            "I-PER",
            "I-ORG",
            "B-LOC",
            "I-MISC",
            "B-ORG",
            "O",
            "B-PER",
            "X",
            "<s>",
            "</s>"]

        # vlsp2018
        # return [
        #         "O","I-ORGANIZATION",
        #         "B-ORGANIZATION",
        #         "I-LOCATION",
        #         "B-MISCELLANEOUS",
        #         "I-PERSON",
        #         "B-PERSON",
        #         "I-MISCELLANEOUS",
        #         "B-LOCATION",
        #         "X",
        #         "<s>",
        #         "</s>"]

    def get_auxlabels(self):
        return ["O", "B", "I", "X", "<s>", "</s>"]

    def get_start_label_id(self):
        label_list = self.get_labels()
        label_map = {label: i for i, label in enumerate(label_list, 1)}
        return label_map['<s>']

    def get_stop_label_id(self):
        label_list = self.get_labels()
        label_map = {label: i for i, label in enumerate(label_list, 1)}
        return label_map['</s>']

    def _create_examples(self, lines, imgs, set_type):
        examples = []
        for i, (sentence) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = ' '.join(sentence)
            text_b = None
            img_id = imgs[i]
            examples.append(
                SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id))
        return examples


def create_examples(lines, imgs, set_type):
    examples = []
    for i, (sentence) in enumerate(lines):
        guid = "%s-%s" % (set_type, i)
        text_a = ' '.join(sentence)
        text_b = None
        img_id = imgs[i]
        examples.append(
            SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id))
    return examples


def get_test_examples_predict(data_dir):
    """See base class."""
    data, imgs = sbreadfile(os.path.join(data_dir, "test.txt"))
    return create_examples(data, imgs, "test")


def image_process(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    return image


def convert_mm_examples_to_features_predict(examples,
                                            max_seq_length, tokenizer, crop_size, path_img):
    features = []
    count = 0

    transform = transforms.Compose([
        transforms.Resize([256, 256]),
        transforms.RandomCrop(crop_size),  # args.crop_size, by default it is set to be 224
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    for (ex_index, example) in enumerate(examples):
        textlist = example.text_a.split(' ')
        tokens = []
        for i, word in enumerate(textlist):
            token = tokenizer.tokenize(word)
            tokens.extend(token)
        if len(tokens) >= max_seq_length - 1:
            tokens = tokens[0:(max_seq_length - 2)]
        ntokens = []
        segment_ids = []
        ntokens.append("<s>")
        segment_ids.append(0)
        for i, token in enumerate(tokens):
            ntokens.append(token)
            segment_ids.append(0)
        ntokens.append("</s>")
        segment_ids.append(0)
        input_ids = tokenizer.convert_tokens_to_ids(ntokens)
        input_mask = [1] * len(input_ids)
        added_input_mask = [1] * (len(input_ids) + 49)  # 1 or 49 is for encoding regional image representations

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            added_input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        image_name = example.img_id
        image_path = os.path.join(path_img, image_name)

        if not os.path.exists(image_path):
            if 'NaN' not in image_path:
                print(image_path)
        try:
            image = image_process(image_path, transform)
        except:
            count += 1
            image_path_fail = os.path.join(path_img, 'background.jpg')
            image = image_process(image_path_fail, transform)

        else:
            if ex_index < 1:
                logger.info("*** Example ***")
                logger.info("guid: %s" % (example.guid))
                logger.info("tokens: %s" % " ".join(
                    [str(x) for x in tokens]))
                logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
                logger.info(
                    "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))

            features.append(
                SBInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask,
                                segment_ids=segment_ids, img_feat=image))

    print('the number of problematic samples: ' + str(count))
    return features