# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
""" tokenizer.py: Encodes and decodes events to/from tokens. """
import numpy as np
import warnings
from abc import ABC, abstractmethod
from utils.note_event_dataclasses import Event, EventRange, Note  #, Codec
from utils.event_codec import FastCodec as Codec
from utils.note_event_dataclasses import NoteEvent
from utils.note2event import note_event2event
from utils.event2note import event2note_event, note_event2note
from typing import List, Optional, Union, Tuple, Dict, Counter


#TODO: Too complex to be an abstract class.
class EventTokenizerBase(ABC):
    """
    A base class for encoding and decoding events to and from tokens.
    """

    def __init__(
        self,
        base_codec: Union[Codec, str] = 'mt3',
        special_tokens: List[str] = ['PAD', 'EOS', 'UNK'],
        extra_tokens: List[str] = [],
        max_shift_steps: int = 206,  # 1001 in Gardner et al.
        program_vocabulary: Optional[Dict] = None,
        drum_vocabulary: Optional[Dict] = None,
    ) -> None:
        """
        Initializes the EventTokenizerBase object.

        :param base_codec: The codec to use for encoding and decoding.
        :param special_tokens: None or list of special tokens to include in the vocabulary.
        :param extra_tokens: None or list of tokens to be treated as additional special tokens.
        :param program_vocabulary: None or a dictionary mapping program names to program indices.
        :param drum_vocabulary: None or a dictionary mapping drum names to drum indices.
        :param max_shift_steps: The maximum number of shift steps to use for the codec.
        """
        # Initialize the codec attribute based on the input codec parameter.
        if isinstance(base_codec, str):
            # If codec is a string, initialize codec with the appropriate Codec object.
            if base_codec.lower() == 'mt3':
                event_ranges = [
                    EventRange('pitch', min_value=0, max_value=127),
                    EventRange('velocity', min_value=0, max_value=1),
                    EventRange('tie', min_value=0, max_value=0),
                    EventRange('program', min_value=0, max_value=127),
                    EventRange('drum', min_value=0, max_value=127),
                ]
            else:
                raise ValueError(f'Unknown codec name: {base_codec}')

            # Initialize codec
            self.codec = Codec(special_tokens=special_tokens + extra_tokens,
                               max_shift_steps=max_shift_steps,
                               event_ranges=event_ranges,
                               program_vocabulary=program_vocabulary,
                               drum_vocabulary=drum_vocabulary,
                               name='mt3')

        elif isinstance(base_codec, Codec):
            # If codec is a Codec object, store it directly.
            self.codec = base_codec
            if program_vocabulary is not None or drum_vocabulary is not None:
                print('')
                warnings.warn("Vocabulary cannot be applied when using a custom codec.")
        else:
            # If codec is neither a string nor a Codec object, raise a NotImplementedError.
            raise TypeError(f'Unknown codec type: {type(base_codec)}')
        self.num_tokens = self.codec._num_classes

    def _encode(self, events: List[Event]) -> List[int]:
        return [self.codec.encode_event(e) for e in events]

    def _decode(self, tokens: List[int]) -> List[Event]:
        return [self.codec.decode_event_index(idx) for idx in tokens]

    @abstractmethod
    def encode(self):
        """ Encode your custom events to tokens. """
        pass

    @abstractmethod
    def decode(self):
        """ Decode your custom tokens to events."""
        pass


class EventTokenizer(EventTokenizerBase):
    """
    Eencoding and decoding events to and from tokens.
    """

    def __init__(self,
                 base_codec: Union[Codec, str] = 'mt3',
                 special_tokens: List[str] = ['PAD', 'EOS', 'UNK'],
                 extra_tokens: List[str] = [],
                 max_shift_steps: int = 206,
                 program_vocabulary: Optional[Dict] = None,
                 drum_vocabulary: Optional[Dict] = None) -> None:
        """
        Initializes the EventTokenizerBase object.

        :param codec: The codec to use for encoding and decoding.
        :param special_tokens: None or list of special tokens to include in the vocabulary.
        :param extra_tokens: None or list of tokens to be treated as additional special tokens.
        :param program_vocabulary: None or a dictionary mapping program names to program indices.
        :param drum_vocabulary: None or a dictionary mapping drum names to drum indices.
        :param max_shift_steps: The maximum number of shift steps to use for the codec.
        """
        # Initialize the codec attribute based on the input codec parameter.
        super().__init__(
            base_codec=base_codec,
            special_tokens=special_tokens,
            extra_tokens=extra_tokens,
            max_shift_steps=max_shift_steps,
            program_vocabulary=program_vocabulary,
            drum_vocabulary=drum_vocabulary,
        )

    def encode(self, events):
        """ Encode your custom events to tokens. """
        return super()._encode(events)

    def decode(self, tokens):
        """ Decode your custom tokens to events."""
        return super()._decode(tokens)


class NoteEventTokenizer(EventTokenizerBase):
    """ Encodes and decodes note events to/from tokens. """

    def __init__(
            self,
            base_codec: Union[Codec, str] = 'mt3',
            max_length: int = 1024,  # max length of tokens 
            tps: int = 100,
            sort_note_event: bool = True,
            special_tokens: List[str] = ['PAD', 'EOS', 'UNK'],
            extra_tokens: List[str] = [],
            max_shift_steps: int = 206,
            program_vocabulary: Optional[Dict] = None,
            drum_vocabulary: Optional[Dict] = None,
            ignore_decoding_tokens: List[str] = [],
            ignore_decoding_tokens_from_and_to: Optional[List[str]] = None,
            debug_mode: bool = False) -> None:
        """
        Initializes the TaskEventNoteTokenizer object.

        List[NoteEvent] -> encdoe_note_events -> np.ndarray[int]

        np.ndarray[int] -> decode_note_events -> Tuple[List[NoteEvent], List[NoteEvent]]
                             
        :param codec: The codec to use for encoding and decoding.
        :param special_tokens: None or list of special tokens to include in the vocabulary.
        :param extra_tokens: None or list of tokens to be treated as additional special tokens.
        :param program_vocabulary: None or a dictionary mapping program names to program indices.
        :param drum_vocabulary: None or a dictionary mapping drum names to drum indices.
        :param max_shift_steps: The maximum number of shift steps to use for the codec.

        :param ignore_decoding_tokens: List of tokens to ignore during decoding.
        :param ignore_decoding_tokens_from_and_to: List of tokens to ignore during decoding. [from, to]
        """
        super().__init__(base_codec=base_codec,
                         special_tokens=special_tokens,
                         extra_tokens=extra_tokens,
                         max_shift_steps=max_shift_steps,
                         program_vocabulary=program_vocabulary,
                         drum_vocabulary=drum_vocabulary)
        self.max_length = max_length
        self.tps = tps
        self.sort = sort_note_event

        # Prepare prefix, suffix and pad tokens.
        self._prefix = []
        self._suffix = []
        for stk in self.codec.special_tokens:
            if stk == 'EOS':
                self._suffix.append(self.codec.special_tokens.index('EOS'))
            elif stk == 'PAD':
                self._zero_pad = [0] * 1024
            elif stk == 'UNK':
                pass
            else:
                pass
                # raise NotImplementedError(f'Unknown special token: {stk}')
        self.eos_id = self.codec.special_tokens.index('EOS')
        self.pad_id = self.codec.special_tokens.index('PAD')
        self.ids_to_ignore_decoding = [self.codec.special_tokens.index(t) for t in ignore_decoding_tokens]
        self.ignore_tokens_from_and_to = ignore_decoding_tokens_from_and_to
        self.debug_mode = debug_mode

    def _decode(self, tokens):
        # This is event detokenizer, not note_event. It is required for displaying events in validation dashboard
        return super()._decode(tokens)

    def encode(
        self,
        note_events: List[NoteEvent],
        tie_note_events: Optional[List[NoteEvent]] = None,
        start_time: float = 0.,
    ) -> List[int]:
        """ Encodes note events and tie note events to tokens. """
        events = note_event2event(
            note_events=note_events,
            tie_note_events=tie_note_events,
            start_time=start_time,  # required for calcuating relative time
            tps=self.tps,
            sort=self.sort)
        return super()._encode(events)

    def encode_plus(
            self,
            note_events: List[NoteEvent],
            tie_note_events: Optional[List[NoteEvent]] = None,
            start_times: float = 0.,  # Fixing bug: start_time --> start_times 
            add_special_tokens: Optional[bool] = True,
            max_length: Optional[int] = None,  #  if None, use self.max_length
            pad_to_max_length: Optional[bool] = True,
            return_attention_mask: bool = False) -> Union[List[int], Tuple[List[int], List[int]]]:
        """ Encodes note events and tie note info to padded tokens. """
        encoded = self.encode(note_events, tie_note_events, start_times)

        # if task_events:
        #     encoded = super()._encode(task_events) + encoded
        if add_special_tokens:
            if self._prefix:
                encoded = self._prefix + encoded
            if self._suffix:
                encoded = encoded + self._suffix

        if max_length is None:
            max_length = self.max_length

        length = len(encoded)
        if length >= max_length:
            encoded = encoded[:max_length]
            length = max_length

        if return_attention_mask:
            attention_mask = [1] * length

        # <PAD>
        if pad_to_max_length is True:
            if len(self._zero_pad) != max_length:
                self._zero_pad = [self.pad_id] * max_length
            if return_attention_mask:
                attention_mask += self._zero_pad[length:]
            encoded = encoded + self._zero_pad[length:]

        if return_attention_mask:
            return encoded, attention_mask

        return encoded

    def encode_task(self, task_events: List[Event], max_length: Optional[int] = None) -> List[int]:
        # NOTE: This is an event tokenizer that generates task ids, not the list of note_event objects.
        encoded = super()._encode(task_events)

        # <PAD>
        if max_length is not None:
            if len(self._zero_pad_task) != max_length:
                self._zero_pad_task = [self.pad_id] * max_length
            length = len(encoded)
            encoded = encoded + self._zero_pad[length:]

        return encoded

    def decode(
        self,
        tokens: List[int],
        start_time: float = 0.,
        return_events: bool = False,
    ) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]],
                                                              List[Event], int]]:
        """Decodes a sequence of tokens into note events.

        Args:
            tokens (List[int]): The list of tokens to be decoded.
            start_time (float, optional): The starting time for the note events. Defaults to 0.
            return_events (bool, optional): Indicates whether to include the raw events in the return value.
                                            Defaults to False.

        Returns:
            Union[Tuple[List[NoteEvent], List[NoteEvent]],
                Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events.
            If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`,
            `last_activity`, and `err_cnt`.
            If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`,
            `last_activity`, `events`, and `err_cnt`.
        """
        if self.debug_mode:
            ignored_tokens_from_input = [t for t in tokens if t in self.ids_to_ignore_decoding]
            print(ignored_tokens_from_input)

        if self.ids_to_ignore_decoding:
            tokens = [t for t in tokens if t not in self.ids_to_ignore_decoding]

        events = super()._decode(tokens)
        note_events, tie_note_events, last_activity, err_cnt = event2note_event(events, start_time, True, self.tps)
        if return_events:
            return note_events, tie_note_events, last_activity, events, err_cnt
        else:
            return note_events, tie_note_events, last_activity, err_cnt

    def decode_batch(
        self,
        batch_tokens: Union[List[List[int]], np.ndarray],
        start_times: List[float],
        return_events: bool = False
    ) -> Union[Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], int],
               Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], List[List[Event]],
                     Counter[str]]]:
        """ 
        Decodes a batch of tokens to note_events and tie_note_events.

        Args:
            batch_tokens (List[List[int]] or np.ndarray): Tokens to be decoded.
            start_times (List[float]): List of start times for each token set.
            return_events (bool, optional): Flag to determine if events should be returned. Defaults to False.

        """
        if isinstance(batch_tokens, np.ndarray):
            batch_tokens = batch_tokens.tolist()

        if len(batch_tokens) != len(start_times):
            raise ValueError('The length of batch_tokens and start_times must be same.')

        zipped_note_events_and_tie = []
        list_events = []
        total_err_cnt = 0

        for tokens, start_time in zip(batch_tokens, start_times):
            if return_events:
                note_events, tie_note_events, last_activity, events, err_cnt = self.decode(
                    tokens, start_time, return_events)
                list_events.append(events)
            else:
                note_events, tie_note_events, last_activity, err_cnt = self.decode(tokens, start_time, return_events)

            zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time))
            total_err_cnt += err_cnt

        if return_events:
            return zipped_note_events_and_tie, list_events, total_err_cnt
        else:
            return zipped_note_events_and_tie, total_err_cnt

    def decode_list_batches(
        self,
        list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]],
        list_start_times: Union[List[List[float]], List[float]],
        return_events: bool = False
    ) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], Counter[str]],
               Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]],
                     List[List[Event]], Counter[str]]]:
        """ 
        Decodes a list of variable-size batches of token array to a list of
        zipped note_events and tie_note_events.

        Args:
            list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length)
            list_start_times: List[float], where the length is sum of all batch_sizes.
            return_events: bool, Defaults to False.

        Returns:
            list_list_zipped_note_events_and_tie:
                List[
                    Tuple[
                        List[NoteEvent]: A list of note events.
                        List[NoteEvent]: A list of tie note events.
                        List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful
                            for validating notes within a batch of segments extracted from a file.
                        List[float]: A list of segment start times.
                    ]
                ]
            (Optional) list_events:
                List[List[Event]]
            total_err_cnt:
                Counter[str]: error counter.
        """
        list_tokens = []
        for arr in list_batch_tokens:
            for tokens in arr:
                list_tokens.append(tokens)
        assert (len(list_tokens) == len(list_start_times))

        zipped_note_events_and_tie = []
        list_events = []
        total_err_cnt = Counter()
        for tokens, start_time in zip(list_tokens, list_start_times):
            note_events, tie_note_events, last_activity, events, err_cnt = self.decode(
                tokens, start_time, return_events)
            zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time))
            if return_events:
                list_events.append(events)
            total_err_cnt += err_cnt

        if return_events:
            return zipped_note_events_and_tie, list_events, total_err_cnt
        else:
            return zipped_note_events_and_tie, total_err_cnt