# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for MERaLiON.
"""

from typing import List, Optional, Union

import numpy as np

from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput


# copied from transformers.models.qwen2_audio.processing_qwen2_audio.Qwen2AudioProcessor
class MERaLiONProcessor(ProcessorMixin):
    r"""
    Constructs a MERaLiON processor which wraps a whisper feature extractor and a gemma tokenizer into a single processor.

    [`MERaLiONProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`GemmaTokenizer`]. See the
    [`~MERaLiONProcessor.__call__`] and [`~MERaLiONProcessor.decode`] for more information.

    Args:
        feature_extractor ([`WhisperFeatureExtractor`], *optional*):
            The feature extractor is a required input.
        tokenizer ([`GemmaTokenizer`], *optional*):
            The tokenizer is a required input.
        chat_template (`Optional[str]`, *optional*):
                The Jinja template to use for formatting the conversation. If not provided, the default chat template
                is used.
    """

    attributes = ["feature_extractor", "tokenizer"]
    feature_extractor_class = "WhisperFeatureExtractor"
    tokenizer_class = "GemmaTokenizer"
    valid_kwargs = ["fixed_speech_embeds_length", "speech_signature", "speech_token_index", "time_duration_limit", "do_normalize"]

    def __init__(
        self, 
        feature_extractor=None, 
        tokenizer=None, 
        fixed_speech_embeds_length=100,
        speech_signature="<SpeechHere>",
        speech_token_index=255999,
        time_duration_limit=-1,
        do_normalize=True
    ):
        self.fixed_speech_embeds_length = fixed_speech_embeds_length
        self.speech_signature = speech_signature
        self.speech_token_index = speech_token_index
        self.time_duration_limit = time_duration_limit
        self.do_normalize = do_normalize

        super().__init__(feature_extractor, tokenizer)

        self.speech_token = self.tokenizer.added_tokens_decoder[self.speech_token_index].content

    def _process_text(self, text, speech_signature):
        target_string = self.speech_token * self.fixed_speech_embeds_length
        if isinstance(text, list) or isinstance(text, tuple):
            pieces = [item.replace(speech_signature, target_string) for item in text]
            return pieces
        return text.replace(speech_signature, target_string)
    
    def _slice_audios(self, audios, time_duration_limit, sampling_rate):
        if time_duration_limit <= 0:
            return audios
        
        slice_length = time_duration_limit * sampling_rate
        if isinstance(audios, np.ndarray) and audios.ndim == 2:
            return audios[:, :slice_length]
        
        if isinstance(audios, np.ndarray) and audios.ndim == 1:
            return audios[:slice_length]
        
        if isinstance(audios, list):
            return [audio[:slice_length] for audio in audios]    

    def __call__(
        self,
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        audios: Union[np.ndarray, List[np.ndarray]] = None,
        padding: Union[bool, str, PaddingStrategy] = True,
        sampling_rate: Optional[int] = None,
        speech_signature = None,
        time_duration_limit = None,
        do_normalize = None,
        **kwargs,
    ) -> BatchFeature:
        """
        Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
        and `kwargs` arguments to GemmaTokenizer's [`~GemmaTokenizer.__call__`] if `text` is not `None` to encode
        the text. To prepare the audio(s), this method forwards the `audios` and `kwrags` arguments to
        WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audios` is not `None`. Please refer to the doctsring
        of the above two methods for more information.

        Args:
            text (`str`, `List[str]`, `List[List[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            audios (`np.ndarray`, `List[np.ndarray]`):
                The audio or batch of audios to be prepared. Each audio can be a NumPy array.
            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
                Select a strategy to pad the returned sequences (according to the model's padding side and padding
                index) among:
                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
                  sequence if provided).
                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
                  acceptable input length for the model if that argument is not provided.
                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
                  lengths).
            sampling_rate (`int`, defaults to 16000):
                The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
        """

        if text is None:
            raise ValueError("You need to specify either a `text` input to process.")
        if sampling_rate is None:
            sampling_rate = self.feature_extractor.sampling_rate
        if speech_signature is None:
            speech_signature = self.speech_signature
        if time_duration_limit is None:
            time_duration_limit = self.time_duration_limit
        if do_normalize is None:
            do_normalize = self.do_normalize

        inputs_dict = {}

        text = self._process_text(text, speech_signature)

        text_input = self.tokenizer(
            text=text,
            return_tensors="pt",
            add_special_tokens=False,
            return_attention_mask=True,
            padding=padding,
            **kwargs
        )   

        inputs_dict["input_ids"] = text_input.input_ids
        inputs_dict["attention_mask"] = text_input.attention_mask
        
        if audios is not None:
            audios = self._slice_audios(audios, time_duration_limit, sampling_rate)
            
            audio_inputs = self.feature_extractor(
                audios, 
                sampling_rate=sampling_rate, 
                return_tensors="pt",
                return_attention_mask=True, 
                padding="max_length", 
                do_normalize=self.do_normalize,
                **kwargs
            )
            audio_inputs["feature_attention_mask"] = audio_inputs.pop(
                "attention_mask"
            )  # rename attention_mask to prevent conflicts later on
            inputs_dict.update(audio_inputs)

        return BatchFeature(data={**inputs_dict})

    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to GemmaTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        """
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to GemmaTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        """
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        feature_extractor_input_names = self.feature_extractor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names + ["feature_attention_mask"]))