import logging
import os
import re
import base64
import torch
from config import get_model, get_react_parser
from utils.data_utils import load_jsonl, save_jsonl

torch.manual_seed(1234)

EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。
[问题]:{query}
"""

EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong".
[Question]: {query}
"""

visualization_code_correctness = {
    'visualization-hard': None,
    'visualization-easy': None,
}


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        a = base64.b64encode(image_file.read()).decode('utf-8')
    return a


def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''):
    output = ""
    if judger_model_name == 'gpt-4-vision-preview':
        logging.warning("This is an example of `gpt-4-vision-preview`. "
                        "Please set the API key and use according to your actual situation.")
        from openai import OpenAI
        client = OpenAI()
        content_list = []
        content_list.append({"type": "text", "text": prompt})
        input_images = []
        for img in imgs:
            if 'http' not in img:
                base64_image = encode_image(img)
                img = f"data:image/jpeg;base64,{base64_image}"
            input_images.append({"type": "image_url", 'image_url': img})
        content_list.extend(input_images)
        response = client.chat.completions.create(
            model="gpt-4-vision-preview",
            messages=[
                {
                    "role": "user",
                    "content": content_list,
                }
            ],
            max_tokens=300,
        )
        output = response.choices[0]
    elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']:
        inputs = []
        for img in imgs:
            if 'http' not in img and judger_model_name == 'qwen-vl-plus':
                img = "file://" + img
            inputs.append({'image': img})
        inputs.append({'text': prompt})

        logging.info('Eval'.center(60, '-'))
        logging.info(inputs)
        output = judger_model.generate(inputs)
    logging.info(output)
    logging.info('=' * 60)
    return output


def extract_images(text):
    regex = re.compile(r'!\[fig-(.+)\]\((.+)\)')
    results = re.findall(regex, text)
    images = []
    for res in results:
        assert len(res) == 2
        if os.path.exists(res[1]):
            images.append(res[1])
    return images


def check_images_observation(text, images, model_name):
    start_flag = get_react_parser(model_name).observation
    for image in images:
        logging.info('Image'.center(60, '-'))
        logging.info(image)

        end_idx = text.find(image)
        tmp_text = text[:end_idx + len(image)]
        start_idx = tmp_text.rfind(start_flag)
        check_text = tmp_text[start_idx + len(start_flag):]

        logging.info('Observation'.center(60, '-'))
        logging.info(check_text)

        # As long as there exists correctly executed observation, we consider `True`
        if 'error:' not in check_text and 'Traceback' not in check_text:
            return True
    return False


eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN}


def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'):
    if judger_model_name == 'gpt-4-vision-preview':
        judger_model = None
    elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']:
        if judger_model_name == 'qwen-vl-chat':
            logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
                            'evaluation model for `Visualization` task.. If you insist on using it, '
                            'the evaluation results might differ from the official results.')
        judger_model = get_model(judger_model_name)
    else:
        raise Exception("Not supported judger model.")

    one_action, one_action_right = 0, 0
    zero_action, zero_action_right = 0, 0

    data_list = load_jsonl(output_fname)
    for item in data_list:
        if 'visualization' not in item['tags']:
            continue

        item['vis_acc'] = False
        if '<|im_end|>' in item['query']:
            one_action += 1
            prompt = item['query'].split('<|im_end|>')[0]
        else:
            zero_action += 1
            prompt = item['query']

        images = extract_images(item['gen'])

        if images and check_images_observation(item['gen'], images,
                                               model_name):
            input_prompt = eval_visual_prompt[item.get('lang', 'en')]
            format_prompt = input_prompt.format(query=prompt)
            output = judger_model_inference(judger_model_name, judger_model, images, format_prompt)
            if 'right' in output.lower():
                item['vis_acc'] = True
                if '<|im_end|>' in item['query']:
                    one_action_right += 1
                else:
                    zero_action_right += 1

    logging.info('*' * 60)
    logging.info('{:^60}'.format('Visualization Acc.'))
    logging.info('*' * 60)
    logging.info(
        'Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}'
        .format(zero_action, zero_action_right,
                zero_action_right / zero_action * 100))
    logging.info(
        'Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}'
        .format(one_action, one_action_right,
                one_action_right / one_action * 100))
    logging.info('all count={}, all right={}, all acc={:.2f}'.format(
        zero_action + one_action, zero_action_right + one_action_right,
        (zero_action_right + one_action_right) / (zero_action + one_action) *
        100))

    visualization_code_correctness[
        'visualization-hard'] = zero_action_right / zero_action * 100
    visualization_code_correctness[
        'visualization-easy'] = one_action_right / one_action * 100

    error_data_list = [
        item for item in data_list
        if 'visualization' in item['tags'] and not item['vis_acc']
    ]
    error_data_output_fname = os.path.splitext(
        output_fname)[0] + '_vis_error.jsonl'
    save_jsonl(error_data_list, error_data_output_fname)

    return visualization_code_correctness