Spaces:
Running
Running
File size: 6,593 Bytes
2319518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import logging
import os
import re
import base64
import torch
from config import get_model, get_react_parser
from utils.data_utils import load_jsonl, save_jsonl
torch.manual_seed(1234)
EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。
[问题]:{query}
"""
EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong".
[Question]: {query}
"""
visualization_code_correctness = {
'visualization-hard': None,
'visualization-easy': None,
}
def encode_image(image_path):
with open(image_path, "rb") as image_file:
a = base64.b64encode(image_file.read()).decode('utf-8')
return a
def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''):
output = ""
if judger_model_name == 'gpt-4-vision-preview':
logging.warning("This is an example of `gpt-4-vision-preview`. "
"Please set the API key and use according to your actual situation.")
from openai import OpenAI
client = OpenAI()
content_list = []
content_list.append({"type": "text", "text": prompt})
input_images = []
for img in imgs:
if 'http' not in img:
base64_image = encode_image(img)
img = f"data:image/jpeg;base64,{base64_image}"
input_images.append({"type": "image_url", 'image_url': img})
content_list.extend(input_images)
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": content_list,
}
],
max_tokens=300,
)
output = response.choices[0]
elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']:
inputs = []
for img in imgs:
if 'http' not in img and judger_model_name == 'qwen-vl-plus':
img = "file://" + img
inputs.append({'image': img})
inputs.append({'text': prompt})
logging.info('Eval'.center(60, '-'))
logging.info(inputs)
output = judger_model.generate(inputs)
logging.info(output)
logging.info('=' * 60)
return output
def extract_images(text):
regex = re.compile(r'!\[fig-(.+)\]\((.+)\)')
results = re.findall(regex, text)
images = []
for res in results:
assert len(res) == 2
if os.path.exists(res[1]):
images.append(res[1])
return images
def check_images_observation(text, images, model_name):
start_flag = get_react_parser(model_name).observation
for image in images:
logging.info('Image'.center(60, '-'))
logging.info(image)
end_idx = text.find(image)
tmp_text = text[:end_idx + len(image)]
start_idx = tmp_text.rfind(start_flag)
check_text = tmp_text[start_idx + len(start_flag):]
logging.info('Observation'.center(60, '-'))
logging.info(check_text)
# As long as there exists correctly executed observation, we consider `True`
if 'error:' not in check_text and 'Traceback' not in check_text:
return True
return False
eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN}
def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'):
if judger_model_name == 'gpt-4-vision-preview':
judger_model = None
elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']:
if judger_model_name == 'qwen-vl-chat':
logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the '
'evaluation model for `Visualization` task.. If you insist on using it, '
'the evaluation results might differ from the official results.')
judger_model = get_model(judger_model_name)
else:
raise Exception("Not supported judger model.")
one_action, one_action_right = 0, 0
zero_action, zero_action_right = 0, 0
data_list = load_jsonl(output_fname)
for item in data_list:
if 'visualization' not in item['tags']:
continue
item['vis_acc'] = False
if '<|im_end|>' in item['query']:
one_action += 1
prompt = item['query'].split('<|im_end|>')[0]
else:
zero_action += 1
prompt = item['query']
images = extract_images(item['gen'])
if images and check_images_observation(item['gen'], images,
model_name):
input_prompt = eval_visual_prompt[item.get('lang', 'en')]
format_prompt = input_prompt.format(query=prompt)
output = judger_model_inference(judger_model_name, judger_model, images, format_prompt)
if 'right' in output.lower():
item['vis_acc'] = True
if '<|im_end|>' in item['query']:
one_action_right += 1
else:
zero_action_right += 1
logging.info('*' * 60)
logging.info('{:^60}'.format('Visualization Acc.'))
logging.info('*' * 60)
logging.info(
'Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}'
.format(zero_action, zero_action_right,
zero_action_right / zero_action * 100))
logging.info(
'Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}'
.format(one_action, one_action_right,
one_action_right / one_action * 100))
logging.info('all count={}, all right={}, all acc={:.2f}'.format(
zero_action + one_action, zero_action_right + one_action_right,
(zero_action_right + one_action_right) / (zero_action + one_action) *
100))
visualization_code_correctness[
'visualization-hard'] = zero_action_right / zero_action * 100
visualization_code_correctness[
'visualization-easy'] = one_action_right / one_action * 100
error_data_list = [
item for item in data_list
if 'visualization' in item['tags'] and not item['vis_acc']
]
error_data_output_fname = os.path.splitext(
output_fname)[0] + '_vis_error.jsonl'
save_jsonl(error_data_list, error_data_output_fname)
return visualization_code_correctness
|