from __future__ import division
import os
import shutil
import json
import cv2
from PIL import Image

import numpy as np
from torch.utils.data import Dataset

from utils.image import _palette


class VOSTest(Dataset):
    def __init__(self,
                 image_root,
                 label_root,
                 seq_name,
                 images,
                 labels,
                 rgb=True,
                 transform=None,
                 single_obj=False,
                 resolution=None):
        self.image_root = image_root
        self.label_root = label_root
        self.seq_name = seq_name
        self.images = images
        self.labels = labels
        self.obj_num = 1
        self.num_frame = len(self.images)
        self.transform = transform
        self.rgb = rgb
        self.single_obj = single_obj
        self.resolution = resolution

        self.obj_nums = []
        self.obj_indices = []

        curr_objs = [0]
        for img_name in self.images:
            self.obj_nums.append(len(curr_objs) - 1)
            current_label_name = img_name.split('.')[0] + '.png'
            if current_label_name in self.labels:
                current_label = self.read_label(current_label_name)
                curr_obj = list(np.unique(current_label))
                for obj_idx in curr_obj:
                    if obj_idx not in curr_objs:
                        curr_objs.append(obj_idx)
            self.obj_indices.append(curr_objs.copy())

        self.obj_nums[0] = self.obj_nums[1]

    def __len__(self):
        return len(self.images)

    def read_image(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_root, self.seq_name, img_name)
        img = cv2.imread(img_path)
        img = np.array(img, dtype=np.float32)
        if self.rgb:
            img = img[:, :, [2, 1, 0]]
        return img

    def read_label(self, label_name, squeeze_idx=None):
        label_path = os.path.join(self.label_root, self.seq_name, label_name)
        label = Image.open(label_path)
        label = np.array(label, dtype=np.uint8)
        if self.single_obj:
            label = (label > 0).astype(np.uint8)
        elif squeeze_idx is not None:
            squeezed_label = label * 0
            for idx in range(len(squeeze_idx)):
                obj_id = squeeze_idx[idx]
                if obj_id == 0:
                    continue
                mask = label == obj_id
                squeezed_label += (mask * idx).astype(np.uint8)
            label = squeezed_label
        return label

    def __getitem__(self, idx):
        img_name = self.images[idx]
        current_img = self.read_image(idx)
        height, width, channels = current_img.shape
        if self.resolution is not None:
            width = int(np.ceil(
                float(width) * self.resolution / float(height)))
            height = int(self.resolution)

        current_label_name = img_name.split('.')[0] + '.png'
        obj_num = self.obj_nums[idx]
        obj_idx = self.obj_indices[idx]

        if current_label_name in self.labels:
            current_label = self.read_label(current_label_name, obj_idx)
            sample = {
                'current_img': current_img,
                'current_label': current_label
            }
        else:
            sample = {'current_img': current_img}

        sample['meta'] = {
            'seq_name': self.seq_name,
            'frame_num': self.num_frame,
            'obj_num': obj_num,
            'current_name': img_name,
            'height': height,
            'width': width,
            'flip': False,
            'obj_idx': obj_idx
        }

        if self.transform is not None:
            sample = self.transform(sample)
        return sample


class YOUTUBEVOS_Test(object):
    def __init__(self,
                 root='./datasets/YTB',
                 year=2018,
                 split='val',
                 transform=None,
                 rgb=True,
                 result_root=None):
        if split == 'val':
            split = 'valid'
        root = os.path.join(root, str(year), split)
        self.db_root_dir = root
        self.result_root = result_root
        self.rgb = rgb
        self.transform = transform
        self.seq_list_file = os.path.join(self.db_root_dir, 'meta.json')
        self._check_preprocess()
        self.seqs = list(self.ann_f.keys())
        self.image_root = os.path.join(root, 'JPEGImages')
        self.label_root = os.path.join(root, 'Annotations')

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq_name = self.seqs[idx]
        data = self.ann_f[seq_name]['objects']
        obj_names = list(data.keys())
        images = []
        labels = []
        for obj_n in obj_names:
            images += map(lambda x: x + '.jpg', list(data[obj_n]["frames"]))
            labels.append(data[obj_n]["frames"][0] + '.png')
        images = np.sort(np.unique(images))
        labels = np.sort(np.unique(labels))

        try:
            if not os.path.isfile(
                    os.path.join(self.result_root, seq_name, labels[0])):
                if not os.path.exists(os.path.join(self.result_root,
                                                   seq_name)):
                    os.makedirs(os.path.join(self.result_root, seq_name))
                shutil.copy(
                    os.path.join(self.label_root, seq_name, labels[0]),
                    os.path.join(self.result_root, seq_name, labels[0]))
        except Exception as inst:
            print(inst)
            print('Failed to create a result folder for sequence {}.'.format(
                seq_name))

        seq_dataset = VOSTest(self.image_root,
                              self.label_root,
                              seq_name,
                              images,
                              labels,
                              transform=self.transform,
                              rgb=self.rgb)
        return seq_dataset

    def _check_preprocess(self):
        _seq_list_file = self.seq_list_file
        if not os.path.isfile(_seq_list_file):
            print(_seq_list_file)
            return False
        else:
            self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos']
            return True


class YOUTUBEVOS_DenseTest(object):
    def __init__(self,
                 root='./datasets/YTB',
                 year=2018,
                 split='val',
                 transform=None,
                 rgb=True,
                 result_root=None):
        if split == 'val':
            split = 'valid'
        root_sparse = os.path.join(root, str(year), split)
        root_dense = root_sparse + '_all_frames'
        self.db_root_dir = root_dense
        self.result_root = result_root
        self.rgb = rgb
        self.transform = transform
        self.seq_list_file = os.path.join(root_sparse, 'meta.json')
        self._check_preprocess()
        self.seqs = list(self.ann_f.keys())
        self.image_root = os.path.join(root_dense, 'JPEGImages')
        self.label_root = os.path.join(root_sparse, 'Annotations')

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq_name = self.seqs[idx]

        data = self.ann_f[seq_name]['objects']
        obj_names = list(data.keys())
        images_sparse = []
        for obj_n in obj_names:
            images_sparse += map(lambda x: x + '.jpg',
                                 list(data[obj_n]["frames"]))
        images_sparse = np.sort(np.unique(images_sparse))

        images = np.sort(
            list(os.listdir(os.path.join(self.image_root, seq_name))))
        start_img = images_sparse[0]
        end_img = images_sparse[-1]
        for start_idx in range(len(images)):
            if start_img in images[start_idx]:
                break
        for end_idx in range(len(images))[::-1]:
            if end_img in images[end_idx]:
                break
        images = images[start_idx:(end_idx + 1)]
        labels = np.sort(
            list(os.listdir(os.path.join(self.label_root, seq_name))))

        try:
            if not os.path.isfile(
                    os.path.join(self.result_root, seq_name, labels[0])):
                if not os.path.exists(os.path.join(self.result_root,
                                                   seq_name)):
                    os.makedirs(os.path.join(self.result_root, seq_name))
                shutil.copy(
                    os.path.join(self.label_root, seq_name, labels[0]),
                    os.path.join(self.result_root, seq_name, labels[0]))
        except Exception as inst:
            print(inst)
            print('Failed to create a result folder for sequence {}.'.format(
                seq_name))

        seq_dataset = VOSTest(self.image_root,
                              self.label_root,
                              seq_name,
                              images,
                              labels,
                              transform=self.transform,
                              rgb=self.rgb)
        seq_dataset.images_sparse = images_sparse

        return seq_dataset

    def _check_preprocess(self):
        _seq_list_file = self.seq_list_file
        if not os.path.isfile(_seq_list_file):
            print(_seq_list_file)
            return False
        else:
            self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos']
            return True


class DAVIS_Test(object):
    def __init__(self,
                 split=['val'],
                 root='./DAVIS',
                 year=2017,
                 transform=None,
                 rgb=True,
                 full_resolution=False,
                 result_root=None):
        self.transform = transform
        self.rgb = rgb
        self.result_root = result_root
        if year == 2016:
            self.single_obj = True
        else:
            self.single_obj = False
        if full_resolution:
            resolution = 'Full-Resolution'
        else:
            resolution = '480p'
        self.image_root = os.path.join(root, 'JPEGImages', resolution)
        self.label_root = os.path.join(root, 'Annotations', resolution)
        seq_names = []
        for spt in split:
            if spt == 'test':
                spt = 'test-dev'
            with open(os.path.join(root, 'ImageSets', str(year),
                                   spt + '.txt')) as f:
                seqs_tmp = f.readlines()
            seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
            seq_names.extend(seqs_tmp)
        self.seqs = list(np.unique(seq_names))

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq_name = self.seqs[idx]
        images = list(
            np.sort(os.listdir(os.path.join(self.image_root, seq_name))))
        labels = [images[0].replace('jpg', 'png')]

        if not os.path.isfile(
                os.path.join(self.result_root, seq_name, labels[0])):
            seq_result_folder = os.path.join(self.result_root, seq_name)
            try:
                if not os.path.exists(seq_result_folder):
                    os.makedirs(seq_result_folder)
            except Exception as inst:
                print(inst)
                print(
                    'Failed to create a result folder for sequence {}.'.format(
                        seq_name))
            source_label_path = os.path.join(self.label_root, seq_name,
                                             labels[0])
            result_label_path = os.path.join(self.result_root, seq_name,
                                             labels[0])
            if self.single_obj:
                label = Image.open(source_label_path)
                label = np.array(label, dtype=np.uint8)
                label = (label > 0).astype(np.uint8)
                label = Image.fromarray(label).convert('P')
                label.putpalette(_palette)
                label.save(result_label_path)
            else:
                shutil.copy(source_label_path, result_label_path)

        seq_dataset = VOSTest(self.image_root,
                              self.label_root,
                              seq_name,
                              images,
                              labels,
                              transform=self.transform,
                              rgb=self.rgb,
                              single_obj=self.single_obj,
                              resolution=480)
        return seq_dataset


class _EVAL_TEST(Dataset):
    def __init__(self, transform, seq_name):
        self.seq_name = seq_name
        self.num_frame = 10
        self.transform = transform

    def __len__(self):
        return self.num_frame

    def __getitem__(self, idx):
        current_frame_obj_num = 2
        height = 400
        width = 400
        img_name = 'test{}.jpg'.format(idx)
        current_img = np.zeros((height, width, 3)).astype(np.float32)
        if idx == 0:
            current_label = (current_frame_obj_num * np.ones(
                (height, width))).astype(np.uint8)
            sample = {
                'current_img': current_img,
                'current_label': current_label
            }
        else:
            sample = {'current_img': current_img}

        sample['meta'] = {
            'seq_name': self.seq_name,
            'frame_num': self.num_frame,
            'obj_num': current_frame_obj_num,
            'current_name': img_name,
            'height': height,
            'width': width,
            'flip': False
        }

        if self.transform is not None:
            sample = self.transform(sample)
        return sample


class EVAL_TEST(object):
    def __init__(self, transform=None, result_root=None):
        self.transform = transform
        self.result_root = result_root

        self.seqs = ['test1', 'test2', 'test3']

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq_name = self.seqs[idx]

        if not os.path.exists(os.path.join(self.result_root, seq_name)):
            os.makedirs(os.path.join(self.result_root, seq_name))

        seq_dataset = _EVAL_TEST(self.transform, seq_name)
        return seq_dataset