from typing import Dict, List, Any
import torch
from transformers import pipeline
from videollama2.conversation import conv_templates, SeparatorStyle
from videollama2.constants import DEFAULT_MMODAL_TOKEN, MMODAL_TOKEN_INDEX
from videollama2.mm_utils import get_model_name_from_path, tokenizer_MMODAL_token, KeywordsStoppingCriteria, process_video, process_image
from videollama2.model.builder import load_pretrained_model

class EndpointHandler():
    def __init__(self, path="DAMO-NLP-SG/VideoLLaMA2-8x7B"):
        model_name = get_model_name_from_path(path)
        self.tokenizer, self.model, self.processor, self.context_len = load_pretrained_model(path, None, model_name)
        self.model = self.model.to('cuda:0')

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # get inputs
        paths = data.get("paths", [])
        questions = data.get("questions", [])
        modal_list = data.get("modal_list", [])

        # check if modal_list and paths are provided
        if not paths or not modal_list:
            return [{"error": "Missing paths or modal_list"}]

        # Visual preprocess (load & transform image or video)
        if modal_list[0] == 'video':
            tensor = process_video(paths[0], self.processor, self.model.config.image_aspect_ratio).to(dtype=torch.float16, device='cuda', non_blocking=True)
            default_mm_token = DEFAULT_MMODAL_TOKEN["VIDEO"]
            modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"]
        else:
            tensor = process_image(paths[0], self.processor, self.model.config.image_aspect_ratio)[0].to(dtype=torch.float16, device='cuda', non_blocking=True)
            default_mm_token = DEFAULT_MMODAL_TOKEN["IMAGE"]
            modal_token_index = MMODAL_TOKEN_INDEX["IMAGE"]
        tensor = [tensor]

        # Text preprocess (tag process & generate prompt)
        question = default_mm_token + "\n" + questions[0]
        conv_mode = 'llama_2'
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], question)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        input_ids = tokenizer_MMODAL_token(prompt, self.tokenizer, modal_token_index, return_tensors='pt').unsqueeze(0).to('cuda:0')

        # Generate a response according to visual signals and prompts
        stop_str = conv.sep if conv.sep_style in [SeparatorStyle.SINGLE] else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images_or_videos=tensor,
                modal_list=modal_list,
                do_sample=True,
                temperature=0.2,
                max_new_tokens=1024,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
            )
        outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        return [{"output": outputs[0]}]