# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""event2note.py:

Event to NoteEvent:
• event2note_event

NoteEvent to Note:
• note_event2note
• merge_zipped_note_events_and_ties_to_notes

"""
import warnings
from collections import Counter
from typing import List, Tuple, Optional, Dict, Counter

from utils.note_event_dataclasses import Note, NoteEvent
from utils.note_event_dataclasses import Event
from utils.note2event import validate_notes, trim_overlapping_notes

MINIMUM_OFFSET_SEC = 0.01

DECODING_ERR_TYPES = [
    'decoding_time', 'Err/Missing prg in tie', 'Err/Missing tie', 'Err/Shift out of range', 'Err/Missing prg',
    'Err/Missing vel', 'Err/Multi-tie type 1', 'Err/Multi-tie type 2', 'Err/Unknown event', 'Err/onset not found',
    'Err/active ne incomplete', 'Err/merging segment tie', 'Err/long note > 10s'
]


def event2note_event(events: List[Event],
                     start_time: float = 0.0,
                     sort: bool = True,
                     tps: int = 100) -> Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], Counter[str]]:
    """Convert events to note events.

    Args:
        events: A list of events.
        start_time: The start time of the segment.
        sort: Whether to sort the note events.
        tps: Ticks per second.

    Returns:
        List[NoteEvent]: A list of note events.
        List[NoteEvent]: A list of tie note events.
        List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful
            for validating notes within a batch of segments extracted from a file.
        Counter[str]: A dictionary of error counters.
    """
    assert (start_time >= 0.)

    # Collect tie events
    tie_index = program_state = None
    tie_note_events = []
    last_activity = []  # For activity check and last activity of segment. [(program, pitch), ...]
    error_counter = {}  # Add a dictionary to count the errors by their types

    for i, e in enumerate(events):
        try:
            if e.type == 'tie':
                tie_index = i
                break
            if e.type == 'shift':
                break
            elif e.type == 'program':
                program_state = e.value
            elif e.type == 'pitch':
                if program_state is None:
                    raise ValueError('Err/Missing prg in tie')
                tie_note_events.append(
                    NoteEvent(is_drum=False, program=program_state, time=None, velocity=1, pitch=e.value))
                last_activity.append((program_state, e.value))  # (program, pitch)
        except ValueError as ve:
            error_type = str(ve)
            error_counter[error_type] = error_counter.get(error_type, 0.) + 1

    try:
        if tie_index is None:
            raise ValueError('Err/Missing tie')
        else:
            events = events[tie_index + 1:]
    except ValueError as ve:
        error_type = str(ve)
        error_counter[error_type] = error_counter.get(error_type, 0.) + 1
        return [], [], [], error_counter

    # Collect main events:
    note_events = []
    velocity_state = None
    start_tick = round(start_time * tps)
    tick_state = start_tick
    # keep the program_state of last tie event...

    for e in events:
        try:
            if e.type == 'shift':
                if e.value <= 0 or e.value > 1000:
                    raise ValueError('Err/Shift out of range')
                # tick_state += e.value
                tick_state = start_tick + e.value
            elif e.type == 'drum':
                note_events.append(
                    NoteEvent(is_drum=True, program=128, time=tick_state / tps, velocity=1, pitch=e.value))
            elif e.type == 'program':
                program_state = e.value
            elif e.type == 'velocity':
                velocity_state = e.value
            elif e.type == 'pitch':
                if program_state is None:
                    raise ValueError('Err/Missing prg')
                elif velocity_state is None:
                    raise ValueError('Err/Missing vel')
                # Check activity
                if velocity_state > 0:
                    last_activity.append((program_state, e.value))  # (program, pitch)
                elif velocity_state == 0 and (program_state, e.value) in last_activity:
                    last_activity.remove((program_state, e.value))
                else:
                    # print(f'tick_state: {tick_state}') # <-- This displays unresolved offset errors!!
                    raise ValueError('Err/Note off without note on')
                note_events.append(
                    NoteEvent(is_drum=False,
                              program=program_state,
                              time=tick_state / tps,
                              velocity=velocity_state,
                              pitch=e.value))
            elif e.type == 'EOS':
                break
            elif e.type == 'PAD':
                continue
            elif e.type == 'UNK':
                continue
            elif e.type == 'tie':
                if tick_state == start_tick:
                    raise ValueError('Err/Multi-tie type 1')
                else:
                    raise ValueError('Err/Multi-tie type 2')
            else:
                raise ValueError(f'Err/Unknown event')
        except ValueError as ve:
            error_type = str(ve)
            error_counter[error_type] = error_counter.get(error_type, 0.) + 1

    if sort:
        note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch))
        tie_note_events.sort(key=lambda n_ev: (n_ev.is_drum, n_ev.program, n_ev.pitch))

    return note_events, tie_note_events, last_activity, error_counter


def note_event2note(
    note_events: List[NoteEvent],
    tie_note_events: Optional[List[NoteEvent]] = None,
    sort: bool = True,
    fix_offset: bool = True,
    trim_overlap: bool = True,
) -> Tuple[List[Note], Counter[str]]:
    """Convert note events to notes.

    Returns:
        List[Note]: A list of merged note events.
        Counter[str]: A dictionary of error counters.
    """

    notes = []
    active_note_events = {}

    error_counter = {}  # Add a dictionary to count the errors by their types

    if tie_note_events is not None:
        for ne in tie_note_events:
            active_note_events[(ne.pitch, ne.program)] = ne

    if sort:
        note_events.sort(key=lambda ne: (ne.time, ne.is_drum, ne.pitch, ne.velocity, ne.program))

    for ne in note_events:
        try:
            if ne.time == None:
                continue
            elif ne.is_drum:
                if ne.velocity == 1:
                    notes.append(
                        Note(is_drum=True,
                             program=128,
                             onset=ne.time,
                             offset=ne.time + MINIMUM_OFFSET_SEC,
                             pitch=ne.pitch,
                             velocity=1))
                else:
                    continue
            elif ne.velocity == 1:
                active_ne = active_note_events.get((ne.pitch, ne.program))
                if active_ne is not None:
                    active_note_events.pop((ne.pitch, ne.program))
                    notes.append(
                        Note(False, active_ne.program, active_ne.time, ne.time, active_ne.pitch, active_ne.velocity))
                active_note_events[(ne.pitch, ne.program)] = ne

            elif ne.velocity == 0:
                active_ne = active_note_events.pop((ne.pitch, ne.program), None)
                if active_ne is not None:
                    notes.append(
                        Note(False, active_ne.program, active_ne.time, ne.time, active_ne.pitch, active_ne.velocity))
                else:
                    raise ValueError('Err/onset not found')
        except ValueError as ve:
            error_type = str(ve)
            error_counter[error_type] = error_counter.get(error_type, 0.) + 1

    for ne in active_note_events.values():
        try:
            if ne.velocity == 1:
                if ne.program == None or ne.pitch == None:
                    raise ValueError('Err/active ne incomplete')
                elif ne.time == None:
                    continue
                else:
                    notes.append(
                        Note(is_drum=False,
                             program=ne.program,
                             onset=ne.time,
                             offset=ne.time + MINIMUM_OFFSET_SEC,
                             pitch=ne.pitch,
                             velocity=1))
        except ValueError as ve:
            error_type = str(ve)
            error_counter[error_type] = error_counter.get(error_type, 0.) + 1

    if fix_offset:
        for n in list(notes):
            try:
                if n.offset - n.onset > 10:
                    n.offset = n.onset + MINIMUM_OFFSET_SEC
                    raise ValueError('Err/long note > 10s')
            except ValueError as ve:
                error_type = str(ve)
                error_counter[error_type] = error_counter.get(error_type, 0.) + 1

    if sort:
        notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch))

    if fix_offset:
        notes = validate_notes(notes, fix=True)

    if trim_overlap:
        notes = trim_overlapping_notes(notes, sort=True)

    return notes, error_counter


def merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_ties,
                                               force_note_off_missing_tie=True,
                                               fix_offset=True) -> Tuple[List[Note], Counter[str]]:
    """Merge zipped note events and ties.
    
    Args:
        zipped_note_events_and_ties: A list of tuples of (note events, tie note events, last_activity, start time).
        force_note_off_missing_tie: Whether to force note off for missing tie note events.
        fix_offset: Whether to fix the offset of notes.

    Returns:
        List[Note]: A list of merged note events.
        Counter[str]: A dictionary of error counters.
    """
    merged_note_events = []
    prev_last_activity = None
    seg_merge_err_cnt = Counter()
    for nes, tie_nes, last_activity, start_time in zipped_note_events_and_ties:
        if prev_last_activity is not None and force_note_off_missing_tie:
            # Check mismatch between prev_last_activity and current tie_note_events
            prog_pitch_tie = set([(ne.program, ne.pitch) for ne in tie_nes])
            for prog_pitch_pla in prev_last_activity:  # (program, pitch) of previous last active notes
                if prog_pitch_pla not in prog_pitch_tie:
                    # last acitve notes of previous segment is missing in tie information.
                    # We create a note off event for these notes at the beginning of current note events.
                    merged_note_events.append(
                        NoteEvent(is_drum=False,
                                  program=prog_pitch_pla[0],
                                  time=start_time,
                                  velocity=0,
                                  pitch=prog_pitch_pla[1]))
                    seg_merge_err_cnt['Err/merging segment tie'] += 1
            else:
                pass
        merged_note_events += nes
        prev_last_activity = last_activity

    # merged_note_events to notes
    notes, err_cnt = note_event2note(merged_note_events, tie_note_events=None, fix_offset=fix_offset)

    # gather error counts
    err_cnt.update(seg_merge_err_cnt)
    return notes, err_cnt