LLMBB-Agent / benchmark /metrics /code_execution.py
vlff李飞飞
update md
2319518
raw
history blame
8.7 kB
import logging
import os
import func_timeout
from config import get_react_parser
from func_timeout import func_set_timeout
from utils.code_utils import extract_code, replace_upload_fname
from utils.data_utils import load_jsonl, save_jsonl
pre_load = """
import os
if 'upload_file' not in os.getcwd():
os.chdir("./upload_file/")
import seaborn as sns
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.ion()
import numpy as np
import pandas as pd
from sympy import Eq, symbols, solve
import re
import json
import math
"""
tags_config = {
'visualization': {
'timelimit': True,
'extract_first_code': True,
},
'math': {
'timelimit': True,
'extract_first_code': False,
},
'general': {
'timelimit': False,
'extract_first_code': True,
}
}
code_executability = {'math': None, 'visualization': None, 'general': None}
@func_set_timeout(10)
def exec_limit_time(text):
exec(text, locals())
def exec_code(text, timelimit=False):
if timelimit:
exec_limit_time(text)
else:
exec(text, locals())
def postprocess_code(gen_code, line):
if '<|im_start|>' in line['query']:
first_action_code = get_action_input_code(line['query'])
gen_code = first_action_code + gen_code
upload_fname_list = line[
'input_file_path'] if line and 'input_file_path' in line else []
gen_code = replace_upload_fname(gen_code, upload_fname_list)
if 'def solution()' in gen_code:
gen_code += '\nsolution()\n'
if 'plt.show()' in gen_code:
gen_code += "\nplt.pause(1)\nplt.close('all')\n"
if 'sns.' in gen_code and 'plot' in gen_code:
gen_code += "\nplt.close('all')\n"
gen_code = pre_load + gen_code
return gen_code
def get_action_input_code(text,
model_name='qwen-14b-chat',
extract_first_code=False):
action_input_list = []
tmp = text
react_parser = get_react_parser(model_name)
while True:
action_input = react_parser.get_first_action_input(tmp)
if not action_input:
break
action_input_list.append(action_input)
tmp = tmp.split(action_input)[1]
if not tmp or extract_first_code:
break
code = ''
for action_input in action_input_list:
code = code + '# concat\n' + extract_code(action_input) + '\n'
return code
def eval_code_execution_rate(output_fname,
tag='all_ci',
model_name='qwen-14b-chat',
timelimit=False,
extract_first_code=False):
data_list = load_jsonl(output_fname)
pip_package = []
for line_id, line in enumerate(data_list):
line['idx'] = line_id
tags_list = line['tags'].split(',')
if tag not in tags_list:
continue
# update args
for cur_tag in tags_list:
if cur_tag != 'all_ci':
timelimit = tags_config[cur_tag]['timelimit']
extract_first_code = tags_config[cur_tag]['extract_first_code']
line['executable_code'] = False
line['missing_code'] = False
line['code_error_info'] = ''
# get Action Input code from response
gen_code = get_action_input_code(line['gen'],
model_name=model_name,
extract_first_code=extract_first_code)
if not gen_code:
line['missing_code'] = True
line['code'] = ''
line['code_error_info'] = 'missing code'
continue
line['code'] = gen_code
gen_code = postprocess_code(gen_code, line)
while True:
try:
exec_code(gen_code, timelimit=timelimit)
line['executable_code'] = True
break
except func_timeout.exceptions.FunctionTimedOut as ex:
line['code_error_info'] = str(ex)
break
except (ImportError, ModuleNotFoundError) as ex:
try:
packege = str(ex).split("'")[1].strip()
except Exception:
packege = ''
if packege and packege not in pip_package: # install package
pip_package.append(packege)
os.system('pip install ' + packege)
logging.info(f'Automatic installation: {packege}')
else:
line['code_error_info'] = str(ex)
break
except Exception as ex:
line['code_error_info'] = str(ex)
break
# double check
observation = get_react_parser(model_name).get_first_observation(
line['gen'])
if line['executable_code'] and ('error:' in observation):
logging.warning(
'The code executes correctly, but it has an error in IPython!')
logging.warning(f'Code:\n{gen_code}')
logging.warning(f'IPython error info:\n{observation}')
logging.info('=' * 60)
elif not line['executable_code'] and not ('error:' in observation):
logging.warning(
'The code has an execution error, but it runs correctly in IPython!'
)
logging.warning(f'Code:\n{gen_code}')
logging.warning(f"Exec error info:\n{line['code_error_info']}")
logging.warning(f'IPython observation:\n{observation}')
logging.info('=' * 60)
# save error data
error_data_list = [
item for item in data_list
if not item['executable_code'] or item['missing_code']
]
error_data_output_fname = os.path.splitext(
output_fname)[0] + '_exec_error.jsonl'
save_jsonl(error_data_list, error_data_output_fname)
log_result(data_list)
return code_executability
def log_result(data_list, verbose=True):
if verbose:
logging.info('*' * 60)
logging.info('{:^60}'.format('Detail'))
logging.info('*' * 60)
for line_id, line in enumerate(data_list):
logging.info(f'Question {line_id}'.center(60, '='))
logging.info(line['query'])
logging.info(f'Generated {line_id}'.center(60, '-'))
logging.info('\n' + line['gen'])
logging.info(f'Code {line_id}'.center(60, '-'))
logging.info('\n' + line['code'])
logging.info(f'Exec Result {line_id}'.center(60, '-'))
prefix_info = 'Exec Success' if line[
'executable_code'] else 'Exec Error: '
exec_info = prefix_info + line['code_error_info']
logging.info(exec_info)
logging.info('=' * 60)
logging.info('{:^60}'.format('Code Execuation Rate'))
logging.info('=' * 60)
involved_tags = []
for line in data_list:
involved_tags += line['tags'].split(',')
involved_tags = list(set(involved_tags))
for key in involved_tags:
logging.info(f'task: {key}'.center(60, '='))
key_item_list = [item for item in data_list if key in item['tags']]
all_count = len(key_item_list)
missing_code_count = len(
[item for item in key_item_list if item['missing_code']])
executable_code_count = len(
[item for item in key_item_list if item['executable_code']])
logging.info(f'All Test: {all_count}')
logging.info(f'Missing Code: {missing_code_count}')
logging.info(f'Predict Exec Success: {executable_code_count}')
logging.info('Codes available && Execution Rate: {:.2f}'.format(
executable_code_count / (all_count - missing_code_count) * 100))
logging.info('Execution Rate: {:.2f}'.format(executable_code_count /
all_count * 100))
logging.info('Non-executable rate: {:.2f}'.format(
(all_count - missing_code_count - executable_code_count) /
all_count * 100))
logging.info('Missing code rate: {:.2f}'.format(missing_code_count /
all_count * 100))
if key != 'all_ci':
code_executability[key] = executable_code_count / all_count * 100
if verbose:
logging.info('Error List: ')
error_list = [(item['idx'], item['code_error_info'])
for item in key_item_list if item['code_error_info']]
error_list.sort(key=lambda x: x[1])
for x in error_list:
logging.info(x)