import logging import os import re import base64 import torch from config import get_model, get_react_parser from utils.data_utils import load_jsonl, save_jsonl torch.manual_seed(1234) EVAL_VISUAL_PROMPT_ZH = """请判断图片是否与下面的[问题]一致,如果一致则回复“right”,不一致则回复“wrong”。 [问题]:{query} """ EVAL_VISUAL_PROMPT_EN = """Please judge whether the image is consistent with the [Question] below, if it is consistent then reply "right", if not then reply "wrong". [Question]: {query} """ visualization_code_correctness = { 'visualization-hard': None, 'visualization-easy': None, } def encode_image(image_path): with open(image_path, "rb") as image_file: a = base64.b64encode(image_file.read()).decode('utf-8') return a def judger_model_inference(judger_model_name, judger_model, imgs=[], prompt=''): output = "" if judger_model_name == 'gpt-4-vision-preview': logging.warning("This is an example of `gpt-4-vision-preview`. " "Please set the API key and use according to your actual situation.") from openai import OpenAI client = OpenAI() content_list = [] content_list.append({"type": "text", "text": prompt}) input_images = [] for img in imgs: if 'http' not in img: base64_image = encode_image(img) img = f"data:image/jpeg;base64,{base64_image}" input_images.append({"type": "image_url", 'image_url': img}) content_list.extend(input_images) response = client.chat.completions.create( model="gpt-4-vision-preview", messages=[ { "role": "user", "content": content_list, } ], max_tokens=300, ) output = response.choices[0] elif judger_model_name in ['qwen-vl-plus', 'qwen-vl-chat']: inputs = [] for img in imgs: if 'http' not in img and judger_model_name == 'qwen-vl-plus': img = "file://" + img inputs.append({'image': img}) inputs.append({'text': prompt}) logging.info('Eval'.center(60, '-')) logging.info(inputs) output = judger_model.generate(inputs) logging.info(output) logging.info('=' * 60) return output def extract_images(text): regex = re.compile(r'!\[fig-(.+)\]\((.+)\)') results = re.findall(regex, text) images = [] for res in results: assert len(res) == 2 if os.path.exists(res[1]): images.append(res[1]) return images def check_images_observation(text, images, model_name): start_flag = get_react_parser(model_name).observation for image in images: logging.info('Image'.center(60, '-')) logging.info(image) end_idx = text.find(image) tmp_text = text[:end_idx + len(image)] start_idx = tmp_text.rfind(start_flag) check_text = tmp_text[start_idx + len(start_flag):] logging.info('Observation'.center(60, '-')) logging.info(check_text) # As long as there exists correctly executed observation, we consider `True` if 'error:' not in check_text and 'Traceback' not in check_text: return True return False eval_visual_prompt = {'zh': EVAL_VISUAL_PROMPT_ZH, 'en': EVAL_VISUAL_PROMPT_EN} def eval_visualization_acc(output_fname, model_name, judger_model_name='gpt-4-vision-preview'): if judger_model_name == 'gpt-4-vision-preview': judger_model = None elif judger_model_name in ['qwen-vl-chat', 'qwen-vl-plus']: if judger_model_name == 'qwen-vl-chat': logging.warning('In this benchmark of version 20231206, `Qwen-vl-chat` is no longer used as the ' 'evaluation model for `Visualization` task.. If you insist on using it, ' 'the evaluation results might differ from the official results.') judger_model = get_model(judger_model_name) else: raise Exception("Not supported judger model.") one_action, one_action_right = 0, 0 zero_action, zero_action_right = 0, 0 data_list = load_jsonl(output_fname) for item in data_list: if 'visualization' not in item['tags']: continue item['vis_acc'] = False if '<|im_end|>' in item['query']: one_action += 1 prompt = item['query'].split('<|im_end|>')[0] else: zero_action += 1 prompt = item['query'] images = extract_images(item['gen']) if images and check_images_observation(item['gen'], images, model_name): input_prompt = eval_visual_prompt[item.get('lang', 'en')] format_prompt = input_prompt.format(query=prompt) output = judger_model_inference(judger_model_name, judger_model, images, format_prompt) if 'right' in output.lower(): item['vis_acc'] = True if '<|im_end|>' in item['query']: one_action_right += 1 else: zero_action_right += 1 logging.info('*' * 60) logging.info('{:^60}'.format('Visualization Acc.')) logging.info('*' * 60) logging.info( 'Visualization-Hard count={}, Visualization-Hard right count={}, Visualization-Hard acc={:.2f}' .format(zero_action, zero_action_right, zero_action_right / zero_action * 100)) logging.info( 'Visualization-Easy count={}, Visualization-Easy right count={}, Visualization-Easy acc={:.2f}' .format(one_action, one_action_right, one_action_right / one_action * 100)) logging.info('all count={}, all right={}, all acc={:.2f}'.format( zero_action + one_action, zero_action_right + one_action_right, (zero_action_right + one_action_right) / (zero_action + one_action) * 100)) visualization_code_correctness[ 'visualization-hard'] = zero_action_right / zero_action * 100 visualization_code_correctness[ 'visualization-easy'] = one_action_right / one_action * 100 error_data_list = [ item for item in data_list if 'visualization' in item['tags'] and not item['vis_acc'] ] error_data_output_fname = os.path.splitext( output_fname)[0] + '_vis_error.jsonl' save_jsonl(error_data_list, error_data_output_fname) return visualization_code_correctness