# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

import os.path as osp
import numpy as np
from PIL import Image
import torchvision.transforms as transforms


class NormalDataset():
    def __init__(self, cfg, split='train'):

        self.split = split
        self.root = cfg.root
        self.overfit = cfg.overfit

        self.opt = cfg.dataset
        self.datasets = self.opt.types
        self.input_size = self.opt.input_size
        self.set_splits = self.opt.set_splits
        self.scales = self.opt.scales
        self.pifu = self.opt.pifu

        # input data types and dimensions
        self.in_nml = [item[0] for item in cfg.net.in_nml]
        self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
        self.in_total = self.in_nml + ['normal_F', 'normal_B']
        self.in_total_dim = self.in_nml_dim + [3, 3]

        if self.split != 'train':
            self.rotations = range(0, 360, 120)
        else:
            self.rotations = np.arange(0, 360, 360 /
                                       self.opt.rotation_num).astype(np.int)

        self.datasets_dict = {}
        for dataset_id, dataset in enumerate(self.datasets):
            dataset_dir = osp.join(self.root, dataset, "smplx")
            self.datasets_dict[dataset] = {
                "subjects":
                np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str),
                "path":
                dataset_dir,
                "scale":
                self.scales[dataset_id]
            }

        self.subject_list = self.get_subject_list(split)

        # PIL to tensor
        self.image_to_tensor = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # PIL to tensor
        self.mask_to_tensor = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.ToTensor(),
            transforms.Normalize((0.0, ), (1.0, ))
        ])

    def get_subject_list(self, split):

        subject_list = []

        for dataset in self.datasets:

            if self.pifu:
                txt = osp.join(self.root, dataset, f'{split}_pifu.txt')
            else:
                txt = osp.join(self.root, dataset, f'{split}.txt')

            if osp.exists(txt):
                print(f"load from {txt}")
                subject_list += sorted(np.loadtxt(txt, dtype=str).tolist())

                if self.pifu:
                    miss_pifu = sorted(
                        np.loadtxt(osp.join(self.root, dataset,
                                            "miss_pifu.txt"),
                                   dtype=str).tolist())
                    subject_list = [
                        subject for subject in subject_list
                        if subject not in miss_pifu
                    ]
                    subject_list = [
                        "renderpeople/" + subject for subject in subject_list
                    ]

            else:
                train_txt = osp.join(self.root, dataset, 'train.txt')
                val_txt = osp.join(self.root, dataset, 'val.txt')
                test_txt = osp.join(self.root, dataset, 'test.txt')

                print(
                    f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n"
                )

                split_txt = osp.join(self.root, dataset, f'{split}.txt')

                subjects = self.datasets_dict[dataset]['subjects']
                train_split = int(len(subjects) * self.set_splits[0])
                val_split = int(
                    len(subjects) * self.set_splits[1]) + train_split

                with open(train_txt, "w") as f:
                    f.write("\n".join(dataset + "/" + item
                                      for item in subjects[:train_split]))
                with open(val_txt, "w") as f:
                    f.write("\n".join(
                        dataset + "/" + item
                        for item in subjects[train_split:val_split]))
                with open(test_txt, "w") as f:
                    f.write("\n".join(dataset + "/" + item
                                      for item in subjects[val_split:]))

                subject_list += sorted(
                    np.loadtxt(split_txt, dtype=str).tolist())

        bug_list = sorted(
            np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())

        subject_list = [
            subject for subject in subject_list if (subject not in bug_list)
        ]

        return subject_list

    def __len__(self):
        return len(self.subject_list) * len(self.rotations)

    def __getitem__(self, index):

        # only pick the first data if overfitting
        if self.overfit:
            index = 0

        rid = index % len(self.rotations)
        mid = index // len(self.rotations)

        rotation = self.rotations[rid]

        # choose specific test sets
        subject = self.subject_list[mid]

        subject_render = "/".join(
            [subject.split("/")[0] + "_12views",
             subject.split("/")[1]])

        # setup paths
        data_dict = {
            'dataset':
            subject.split("/")[0],
            'subject':
            subject,
            'rotation':
            rotation,
            'image_path':
            osp.join(self.root, subject_render, 'render',
                     f'{rotation:03d}.png')
        }

        # image/normal/depth loader
        for name, channel in zip(self.in_total, self.in_total_dim):

            if name != 'image':
                data_dict.update({
                    f'{name}_path':
                    osp.join(self.root, subject_render, name,
                             f'{rotation:03d}.png')
                })
            data_dict.update({
                name:
                self.imagepath2tensor(data_dict[f'{name}_path'],
                                      channel,
                                      inv='depth_B' in name)
            })

        path_keys = [
            key for key in data_dict.keys() if '_path' in key or '_dir' in key
        ]
        for key in path_keys:
            del data_dict[key]

        return data_dict

    def imagepath2tensor(self, path, channel=3, inv=False):

        rgba = Image.open(path).convert('RGBA')
        mask = rgba.split()[-1]
        image = rgba.convert('RGB')
        image = self.image_to_tensor(image)
        mask = self.mask_to_tensor(mask)
        image = (image * mask)[:channel]

        return (image * (0.5 - inv) * 2.0).float()