# Copyright 2024 The YourMT3 Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Please see the details in the LICENSE file.

MIDI <-> Note
• midi2note: convert a MIDI file to a list of Note instances.
• note2midi: convert a list of Note instances to a MIDI file.

import os
import copy
import warnings
import numpy as np
from typing import List, Dict, Optional, Tuple, Union
from mido import MetaMessage, Message, MidiFile, MidiTrack, second2tick
from utils.note_event_dataclasses import Note, NoteEvent
from utils.note2event import validate_notes, trim_overlapping_notes
from utils.note2event import note2note_event
""" midi2note:
Convert a MIDI file to a list of Note instances.

About new implementation:

 The widely used MIDI parsers (implementations from pretty_midi, 
onset_and_frames, reconvat, and mir_data) implementations used a method of 
applying the offset to the nearest previous note when note overlaps occurred.
 We often found issues with this lazy-processing approach, where the length of 
the overlapped notes later in the sequence would become extremely short. 

 This code has been re-implemented to address these issues by keeping note 
activations in channel-specific buffers, similar to actual DAWs,  
allowing for the application of the sustain pedal effect in multi-channel

Example from Slkah,'Track00805-S00' (bass stem):

(onset, offset)

<actual midi>
(8.83, 9.02*) * first note's offset is later than second note's onset, so overlap occurs.
(9.0, 9.55)

<pretty_midi & mir_data parser>
(8.83, 9.0)
(9.0, 9.02*) * second note is too short, because first note's offset is applied to second note.

<onset_and_frames & reconvat parser>
(8.83, 8.84*) * due to reverse search, first note's offset is missing, so minimum offset is applied.
(9.0, 9.55) 

<your_mt3 parser>
(8.83, 9.0) 
(9.0, 9.55)


def find_channel_of_track_name(midi_file: os.PathLike, track_name_keywords: List[str]) -> Optional[int]:
    mid = MidiFile(midi_file)
    found_channels = []

    for track in mid.tracks:
        track_name_found = False
        for msg in track:
            if msg.type == 'track_name':
                for k in track_name_keywords:
                    if k.lower() == msg.name.lower():  # exact match only
                        track_name_found = True

            if track_name_found and msg.type in ['note_on', 'note_off']:

    return list(set(found_channels))

def midi2note(file: Union[os.PathLike, str],
              binary_velocity: bool = True,
              ch_9_as_drum: bool = False,
              force_all_drum: bool = False,
              force_all_program_to: Optional[int] = None,
              track_name_to_program: Optional[Dict] = None,
              trim_overlap: bool = True,
              fix_offset: bool = True,
              quantize: bool = True,
              verbose: int = 0,
              minimum_offset_sec: float = 0.01,
              drum_offset_sec: float = 0.01,
              ignore_pedal: bool = False,
              return_programs: bool = False) -> Tuple[List[Note], float]:
    midi = MidiFile(file)
    max_time = midi.length  # in seconds

    finished_notes = []
    program_state = [None] * 16  # program_number = program_state[ch]
    sustain_state = [None] * 16  # sustain_state[ch] = True if sustain is on
    active_notes = [[] for i in range(16)]  # active notes by channel(0~15). active_notes[ch] = [Note1, Note_2,..]
    sustained_notes = [[] for i in range(16)
                      ]  # offset is passed, but sustain is applied. sustained_notes[ch] = [Note1, Note_2,..]

    # Mapping track name to program (for geerdes data)
    reserved_channels = []
    if track_name_to_program is not None:
        for key in track_name_to_program.keys():
            found_channels = find_channel_of_track_name(file, [key])
            if len(found_channels) > 0:
                for ch in found_channels:
                    program_state[ch] = track_name_to_program[key]
    if ch_9_as_drum is True:
        program_state[9] = DRUM_PROGRAM

    current_time = 0.
    for i, msg in enumerate(midi):
        current_time += msg.time
        if msg.type == 'program_change' and msg.channel not in reserved_channels:
            program_state[msg.channel] = msg.program
        elif msg.type == 'control_change' and msg.control == 64 and not ignore_pedal:
            if msg.value >= 64:
                sustain_state[msg.channel] = True
                sustain_state[msg.channel] = False
                for note in sustained_notes[msg.channel]:
                    note.offset = current_time
                sustained_notes[msg.channel] = []
        elif msg.type == 'note_on' and msg.velocity > 0:
            if program_state[msg.channel] == None:
                if force_all_program_to == None:
                    raise ValueError(
                        '📕 midi2note: program_change message is missing. Use `force_all_program_to` option')
                    program_state[msg.channel] = force_all_program_to
            # if (ch_9_as_drum and msg.channel == 9) or force_all_drum:
            if program_state[msg.channel] == DRUM_PROGRAM or force_all_drum:
                # drum's offset, active_notes, sustained_notes are not tracked.
                new_note = Note(is_drum=True,
                                offset=current_time + drum_offset_sec,
                new_note = Note(is_drum=False,
        elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
            temp_active_notes = active_notes.copy()
            offset_done_flag = False
            for note in active_notes[msg.channel]:
                if note.pitch == msg.note:
                    if sustain_state[msg.channel]:
                    elif offset_done_flag == False:
                        note.offset = current_time
                        offset_done_flag = True
                        # fix: note_off message is only for the oldest note_on message
            active_notes = temp_active_notes

    # Handle any still-active notes (e.g., if the file ends without note_off messages)
    for ch_notes in active_notes:
        for note in ch_notes:
            note.offset = min(current_time, note.onset + minimum_offset_sec)
    for ch_notes in sustained_notes:
        for note in ch_notes:
            note.offset = min(current_time, note.onset + minimum_offset_sec)

    notes = finished_notes

    if binary_velocity:
        for note in notes:
            note.velocity = 1 if note.velocity > 0 else 0

    notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch))

    # Quantize notes to 10 ms
    if quantize:
        for note in notes:
            note.onset = round(note.onset * 100) / 100.
            note.offset = round(note.offset * 100) / 100.

    # Trim overlapping notes
    if trim_overlap:
        notes = trim_overlapping_notes(notes, sort=True)

    # fix offset >= onset the Note instances
    if fix_offset:
        notes = validate_notes(notes, fix=True)

    # Print some statistics
    has_drum = False
    for note in notes:
        if note.is_drum:
            has_drum = True
    num_instr = sum([int(c is not None) for c in program_state])
    if verbose > 0:
            f'parsed {file}: midi_type={midi.type}, num_notes={len(notes)}, num_instr={num_instr}, has_drum={has_drum}')
    if return_programs:
        return notes, max_time, program_state
        return notes, max_time

def note_event2midi(note_events: List[NoteEvent],
                    output_file: Optional[os.PathLike] = None,
                    velocity: int = 100,
                    ticks_per_beat: int = 480,
                    tempo: int = 500000,
                    singing_program_mapping: int = 65,
                    singing_chorus_program_mapping: int = 53,
                    output_inverse_vocab: Optional[Dict] = None) -> None:
    """Converts a list of Note instances to a MIDI file.

        [NoteEvent(is_drum: bool, program: int, time: Optional[float], velocity: int,
         pitch: int, activity: Optional[Set[int]] = {<factory>})
    Example usage:

        note_event2midi(note_events, 'output.mid')

    midi = MidiFile(ticks_per_beat=ticks_per_beat, type=0)
    midi.type = 1
    track = MidiTrack()

    # Set tempo
    # track.append(mido.MetaMessage('set_tempo', tempo=tempo))

    # Assign channels to programs
    programs = set()
    for ne in note_events:
        if ne.program == 128 or ne.is_drum == True:
            programs.add(128)  # 128 represents drum here...
            ne.program = 128  # internally we use 128 for drum
    programs = sorted(programs)

    program_to_channel = {}
    available_channels = list(range(0, 9)) + list(range(10, 16))
    for prg in programs:
        if prg == 128:
            program_to_channel[prg] = 9
                program_to_channel[prg] = available_channels.pop(0)
            except IndexError:
                warnings.warn(f'not available channels for program {prg}, share channel 16')
                program_to_channel[prg] = 15

    # notes to note_events (this is simpler)
    drum_offset_events = []  # for drum notes, we need to add an offset event
    for ne in note_events:
        if ne.is_drum:
                NoteEvent(is_drum=True, program=ne.program, time=ne.time + 0.01, pitch=ne.pitch, velocity=0))
    note_events += drum_offset_events
    note_events.sort(key=lambda ne: (ne.time, ne.is_drum, ne.program, ne.velocity, ne.pitch))

    # Add note events to multitrack
    for program in programs:
        # Create a track for each program
        track = MidiTrack()

        # Add track name
        if program == 128:
            program_name = 'Drums'
        elif output_inverse_vocab is not None:
            program_name = output_inverse_vocab.get(program, (program, f'Prg. {str(program)}'))[1]
            program_name = f'Prg. {str(program)}'
        track.append(MetaMessage('track_name', name=program_name, time=0))

        # Channel is determined by the program
        channel = program_to_channel[program]

        # Some special treatment for singing voice and drums
        if program == 128:  # drum
            # set 0 but it is ignored in drum channel
            track.append(Message('program_change', program=0, time=0, channel=channel))
        elif program == 100:  # singing voice --> Alto Sax
            track.append(Message('program_change', program=singing_program_mapping, time=0, channel=channel))
        elif program == 101:  # singing voice (chrous) --> Voice Oohs
            track.append(Message('program_change', program=singing_chorus_program_mapping, time=0, channel=channel))
            track.append(Message('program_change', program=program, time=0, channel=channel))

        current_tick = int(0)
        for ne in note_events:
            if ne.program == program:
                absolute_tick = round(second2tick(ne.time, ticks_per_beat, tempo))
                if absolute_tick == current_tick:
                    delta_tick = int(0)
                elif absolute_tick < current_tick:
                    # this should not happen after sorting
                    raise ValueError(
                        f'at ne.time {ne.time}, absolute_tick {absolute_tick} < current_tick {current_tick}')
                    # Convert time shift value from seconds to ticks
                    delta_tick = absolute_tick - current_tick
                    current_tick += delta_tick

                # Create a note on or note off message
                msg_note = 'note_on' if ne.velocity > 0 else 'note_off'
                msg_velocity = velocity if ne.velocity > 0 else 0
                new_msg = Message(msg_note, note=ne.pitch, velocity=msg_velocity, time=delta_tick, channel=channel)


    # Save MIDI file
    if output_file != None:

def get_pitch_range_from_midi(midi_file: os.PathLike) -> Tuple[int, int]:
    """Returns the pitch range of a MIDI file.

        midi_file (os.PathLike): Path to a MIDI file.

        Tuple[int, int]: The lowest and highest notes in the MIDI file.
    notes = midi2note(midi_file, quantize=False, trim_overlap=False)
    pitches = [n.pitch for n in notes]
    return min(pitches), max(pitches)

def pitch_shift_midi(src_midi_file: os.PathLike,
                     min_pitch_shift: int = -5,
                     max_pitch_shift: int = 6,
                     write_midi_file: bool = True,
                     write_notes_file: bool = True,
                     write_note_events_file: bool = True) -> None:
    """Pitch shifts a MIDI file and write it as MIDI.

        src_midi_file (os.PathLike): Path to a MIDI file.
        min_pitch_shift (int): The number of semitones to shift.
        max_pitch_shift (int): The number of semitones to shift.

        dst_midi_file (os.PathLike): {src_midi_filename}_pshift_{i}.mid, where i can be [...,-1, 1, 2,...]
        dst_notes : List[Note]
        dst_note_events: List[NoteEvent]
    # source file
    src_midi_dir = os.path.dirname(src_midi_file)
    src_midi_filename = os.path.basename(src_midi_file).split('.')[0]
    src_notes_file = os.path.join(src_midi_dir, f'{src_midi_filename}_notes.npy')
    src_note_events_file = os.path.join(src_midi_dir, f'{src_midi_filename}_note_events.npy')
    src_notes, _ = midi2note(src_midi_file)
    # src_note_events = note2note_event(src_notes)

    for pitch_shift in range(min_pitch_shift, max_pitch_shift):
        if pitch_shift == 0:

        # destination file
        dst_midi_file = os.path.join(src_midi_dir, f'{src_midi_filename}_pshift{pitch_shift}.mid')
        dst_notes_file = os.path.join(src_midi_dir, f'{src_midi_filename}_pshift{pitch_shift}_notes.npy')
        dst_note_events_file = os.path.join(src_midi_dir, f'{src_midi_filename}_pshift{pitch_shift}_note_events.npy')

        dst_notes = []
        for note in src_notes:
            dst_note = copy.deepcopy(note)
            dst_note.pitch += pitch_shift

        dst_note_events = note2note_event(dst_notes)

        # write midi file
        if write_midi_file:
            note_event2midi(dst_note_events, dst_midi_file)
            print(f'Created {dst_midi_file}')

        # write notes file
        if write_notes_file:
            # get metadata for notes
            src_notes_metadata = np.load(src_notes_file, allow_pickle=True).tolist()
            dst_notes_metadata = src_notes_metadata
            dst_notes_metadata['pitch_shift'] = pitch_shift
            dst_notes_metadata['notes'] = dst_notes
            np.save(dst_notes_file, dst_notes_metadata, allow_pickle=True, fix_imports=False)
            print(f'Created {dst_notes_file}')

        # write note events file
        if write_note_events_file:
            # get metadata for note events
            src_note_events_metadata = np.load(src_note_events_file, allow_pickle=True).tolist()
            dst_note_events_metadata = src_note_events_metadata
            dst_note_events_metadata['pitch_shift'] = pitch_shift
            dst_note_events_metadata['note_events'] = dst_note_events
            np.save(dst_note_events_file, dst_note_events_metadata, allow_pickle=True, fix_imports=False)
            print(f'Created {dst_note_events_file}')