from copy import copy
from typing import Dict, List, Optional, Tuple, Union


class LMTemplateParser:
    """Intermidate prompt template parser, specifically for language models.

    Args:
        meta_template (list of dict, optional): The meta template for the
            model.
    """

    def __init__(self, meta_template: Optional[List[Dict]] = None):
        self.meta_template = meta_template
        if meta_template:
            assert isinstance(meta_template, list)
            self.roles: Dict[str, dict] = dict()  # maps role name to config
            for item in meta_template:
                assert isinstance(item, dict)
                assert item['role'] not in self.roles, \
                    'role in meta prompt must be unique!'
                self.roles[item['role']] = item.copy()

    def __call__(self, dialog) -> str:
        """Parse a prompt template, and wrap it with meta template if
        applicable.

        Args:
            dialog (List[str or PromptList]): A prompt
                template (potentially before being wrapped by meta template).

        Returns:
            str: The final string.
        """
        assert isinstance(dialog, (str, list))
        if isinstance(dialog, str):
            return dialog
        if self.meta_template:

            prompt = ''
            for index, item in enumerate(dialog):
                if isinstance(item, str):
                    prompt += item
                else:
                    new_str = self._prompt2str(item, index == len(dialog) - 1)
                    prompt += new_str
        else:
            # in case the model does not have any meta template
            prompt = ''
            last_sep = ''
            for item in dialog:
                if isinstance(item, str):
                    if item:
                        prompt += last_sep + item
                elif item.get('content', ''):
                    prompt += last_sep + item.get('prompt', '')
                last_sep = '\n'
        return prompt

    def _format_begin(self, role_cfg, message):
        name = message.get('name', None)
        if name is not None:
            begin = role_cfg['begin'].get('with_name', '')
            if name in role_cfg['begin'].get('name', {}):
                begin = begin.format(name=role_cfg['begin']['name'][name])
            else:
                begin = begin.format(name=name)
        else:
            if isinstance(role_cfg.get('begin', ''), str):
                begin = role_cfg.get('begin', '')
            elif isinstance(role_cfg['begin'], dict):
                begin = role_cfg['begin'].get('without_name', '')
        return begin

    def _prompt2str(self,
                    prompt: Union[str, Dict],
                    last: bool = False) -> Tuple[str, bool]:
        if isinstance(prompt, str):
            return prompt
        merged_prompt = self.roles.get(prompt['role'])

        if merged_prompt.get('fallback_role'):
            merged_prompt = self.roles.get(merged_prompt['fallback_role'])
        begin = self._format_begin(merged_prompt, prompt)
        res = begin
        if last and merged_prompt.get('generate', False):
            res += prompt.get('content', '')
            return res
        res += prompt.get('content', '') + merged_prompt.get('end', '')
        if last and merged_prompt['role'] != 'assistant':
            res += self._format_begin(self.roles['assistant'], {})
            return res
        return res


class BaseLLM:
    """Base class for model wrapper.

    Args:
        path (str): The path to the model.
        max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
            to 512.
        tokenizer_only (bool): If True, only the tokenizer will be initialized.
            Defaults to False.
        meta_template (list of dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
    """

    def __init__(self,
                 path: str,
                 tokenizer_only: bool = False,
                 template_parser: 'LMTemplateParser' = LMTemplateParser,
                 meta_template: Optional[List[Dict]] = None,
                 *,
                 max_new_tokens: int = 512,
                 top_p: float = 0.8,
                 top_k: float = 40,
                 temperature: float = 0.8,
                 repetition_penalty: float = 1.0,
                 stop_words: Union[List[str], str] = None):
        self.path = path
        self.tokenizer_only = tokenizer_only
        # meta template
        self.template_parser = template_parser(meta_template)
        self.eos_token_id = None
        if meta_template and 'eos_token_id' in meta_template:
            self.eos_token_id = meta_template['eos_token_id']

        if isinstance(stop_words, str):
            stop_words = [stop_words]
        self.gen_params = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            stop_words=stop_words)

    def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
        """Generate results given a str (or list of) inputs.

        Args:
            inputs (Union[str, List[str]]):
            gen_params (dict): The input params for generation.

        Returns:
            Union[str, List[str]]: A (list of) generated strings.

        eg.
            batched = True
            if isinstance(inputs, str):
                inputs = [inputs]
                batched = False
            response = ['']
            if batched:
                return response
            return response[0]
        """
        raise NotImplementedError

    def stream_generate(self, inputs: str, **gen_params) -> List[str]:
        """Generate results as streaming given a str inputs.

        Args:
            inputs (str):
            gen_params (dict): The input params for generation.

        Returns:
            str: A generated string.
        """
        raise NotImplementedError

    def chat(self,
             inputs: Union[List[dict], List[List[dict]]],
             session_ids: Union[int, List[int]] = None,
             **gen_params):
        """Generate completion from a list of templates.

        Args:
            inputs (Union[List[dict], List[List[dict]]]):
            gen_params (dict): The input params for generation.
        Returns:
        """
        if isinstance(inputs[0], list):
            _inputs = list()
            for msg in inputs:
                _inputs.append(self.template_parser(msg))
        else:
            _inputs = self.template_parser(inputs)
        return self.generate(_inputs, **gen_params)

    def stream_chat(self, inputs: List[dict], **gen_params):
        """Generate results as streaming given a list of templates.

        Args:
            inputs (Union[List[dict]):
            gen_params (dict): The input params for generation.
        Returns:
        """
        raise NotImplementedError

    def tokenize(self, prompts: Union[str, List[str], List[dict],
                                      List[List[dict]]]):
        """Tokenize the input prompts.

        Args:
            prompts(str | List[str]): user's prompt, or a batch prompts

        Returns:
            Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
            ids, ids' length and requested output length
        """
        raise NotImplementedError

    def update_gen_params(self, **kwargs):
        gen_params = copy(self.gen_params)
        gen_params.update(kwargs)
        return gen_params


class AsyncLLMMixin:

    async def generate(self,
                       inputs: Union[str, List[str]],
                       session_ids: Union[int, List[int]] = None,
                       **gen_params) -> str:
        """Generate results given a str (or list of) inputs.

        Args:
            inputs (Union[str, List[str]]):
            gen_params (dict): The input params for generation.

        Returns:
            Union[str, List[str]]: A (list of) generated strings.

        eg.
            batched = True
            if isinstance(inputs, str):
                inputs = [inputs]
                batched = False
            response = ['']
            if batched:
                return response
            return response[0]
        """
        raise NotImplementedError

    async def stream_generate(self, inputs: str, **gen_params) -> List[str]:
        """Generate results as streaming given a str inputs.

        Args:
            inputs (str):
            gen_params (dict): The input params for generation.

        Returns:
            str: A generated string.
        """
        raise NotImplementedError

    async def chat(self,
                   inputs: Union[List[dict], List[List[dict]]],
                   session_ids: Union[int, List[int]] = None,
                   **gen_params):
        """Generate completion from a list of templates.

        Args:
            inputs (Union[List[dict], List[List[dict]]]):
            gen_params (dict): The input params for generation.
        Returns:
        """
        if isinstance(inputs[0], list):
            _inputs = list()
            for msg in inputs:
                _inputs.append(self.template_parser(msg))
        else:
            _inputs = self.template_parser(inputs)
        return await self.generate(_inputs, session_ids, **gen_params)

    async def stream_chat(self, inputs: List[dict], **gen_params):
        """Generate results as streaming given a list of templates.

        Args:
            inputs (Union[List[dict]):
            gen_params (dict): The input params for generation.
        Returns:
        """
        raise NotImplementedError

    async def tokenize(self, prompts: Union[str, List[str], List[dict],
                                            List[List[dict]]]):
        """Tokenize the input prompts.

        Args:
            prompts(str | List[str]): user's prompt, or a batch prompts

        Returns:
            Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
            ids, ids' length and requested output length
        """
        raise NotImplementedError


class AsyncBaseLLM(AsyncLLMMixin, BaseLLM):
    pass