File size: 6,593 Bytes
2319518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import logging
import os
import re
import base64
import torch
from config import get_model, get_react_parser
from utils.data_utils import load_jsonl, save_jsonl

torch.manual_seed(1234)

EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。
[问题]:{query}
"""

EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong".
[Question]: {query}
"""

visualization_code_correctness = {
    'visualization-hard': None,
    'visualization-easy': None,
}


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        a = base64.b64encode(image_file.read()).decode('utf-8')
    return a


def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''):
    output = ""
    if judger_model_name == 'gpt-4-vision-preview':
        logging.warning("This is an example of `gpt-4-vision-preview`. "
                        "Please set the API key and use according to your actual situation.")
        from openai import OpenAI
        client = OpenAI()
        content_list = []
        content_list.append({"type": "text", "text": prompt})
        input_images = []
        for img in imgs:
            if 'http' not in img:
                base64_image = encode_image(img)
                img = f"data:image/jpeg;base64,{base64_image}"
            input_images.append({"type": "image_url", 'image_url': img})
        content_list.extend(input_images)
        response = client.chat.completions.create(
            model="gpt-4-vision-preview",
            messages=[
                {
                    "role": "user",
                    "content": content_list,
                }
            ],
            max_tokens=300,
        )
        output = response.choices[0]
    elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']:
        inputs = []
        for img in imgs:
            if 'http' not in img and judger_model_name == 'qwen-vl-plus':
                img = "file://" + img
            inputs.append({'image': img})
        inputs.append({'text': prompt})

        logging.info('Eval'.center(60, '-'))
        logging.info(inputs)
        output = judger_model.generate(inputs)
    logging.info(output)
    logging.info('=' * 60)
    return output


def extract_images(text):
    regex = re.compile(r'!\[fig-(.+)\]\((.+)\)')
    results = re.findall(regex, text)
    images = []
    for res in results:
        assert len(res) == 2
        if os.path.exists(res[1]):
            images.append(res[1])
    return images


def check_images_observation(text, images, model_name):
    start_flag = get_react_parser(model_name).observation
    for image in images:
        logging.info('Image'.center(60, '-'))
        logging.info(image)

        end_idx = text.find(image)
        tmp_text = text[:end_idx + len(image)]
        start_idx = tmp_text.rfind(start_flag)
        check_text = tmp_text[start_idx + len(start_flag):]

        logging.info('Observation'.center(60, '-'))
        logging.info(check_text)

        # As long as there exists correctly executed observation, we consider `True`
        if 'error:' not in check_text and 'Traceback' not in check_text:
            return True
    return False


eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN}


def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'):
    if judger_model_name == 'gpt-4-vision-preview':
        judger_model = None
    elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']:
        if judger_model_name == 'qwen-vl-chat':
            logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
                            'evaluation model for `Visualization` task.. If you insist on using it, '
                            'the evaluation results might differ from the official results.')
        judger_model = get_model(judger_model_name)
    else:
        raise Exception("Not supported judger model.")

    one_action, one_action_right = 0, 0
    zero_action, zero_action_right = 0, 0

    data_list = load_jsonl(output_fname)
    for item in data_list:
        if 'visualization' not in item['tags']:
            continue

        item['vis_acc'] = False
        if '<|im_end|>' in item['query']:
            one_action += 1
            prompt = item['query'].split('<|im_end|>')[0]
        else:
            zero_action += 1
            prompt = item['query']

        images = extract_images(item['gen'])

        if images and check_images_observation(item['gen'], images,
                                               model_name):
            input_prompt = eval_visual_prompt[item.get('lang', 'en')]
            format_prompt = input_prompt.format(query=prompt)
            output = judger_model_inference(judger_model_name, judger_model, images, format_prompt)
            if 'right' in output.lower():
                item['vis_acc'] = True
                if '<|im_end|>' in item['query']:
                    one_action_right += 1
                else:
                    zero_action_right += 1

    logging.info('*' * 60)
    logging.info('{:^60}'.format('Visualization Acc.'))
    logging.info('*' * 60)
    logging.info(
        'Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}'
        .format(zero_action, zero_action_right,
                zero_action_right / zero_action * 100))
    logging.info(
        'Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}'
        .format(one_action, one_action_right,
                one_action_right / one_action * 100))
    logging.info('all count={}, all right={}, all acc={:.2f}'.format(
        zero_action + one_action, zero_action_right + one_action_right,
        (zero_action_right + one_action_right) / (zero_action + one_action) *
        100))

    visualization_code_correctness[
        'visualization-hard'] = zero_action_right / zero_action * 100
    visualization_code_correctness[
        'visualization-easy'] = one_action_right / one_action * 100

    error_data_list = [
        item for item in data_list
        if 'visualization' in item['tags'] and not item['vis_acc']
    ]
    error_data_output_fname = os.path.splitext(
        output_fname)[0] + '_vis_error.jsonl'
    save_jsonl(error_data_list, error_data_output_fname)

    return visualization_code_correctness