|
import os |
|
import subprocess |
|
from argparse import ArgumentParser |
|
import json |
|
from concurrent.futures import ThreadPoolExecutor |
|
from tqdm import tqdm |
|
import glob |
|
import tempfile |
|
import random |
|
from openllm_pass_rate_new_test import get_lean |
|
|
|
def wrapped_function(item): |
|
results = [] |
|
passed = 0 |
|
total = 0 |
|
|
|
temp_dir = tempfile.gettempdir() |
|
temp_file = os.path.join(temp_dir, f"test.lean") |
|
|
|
with open(temp_file, "w") as f: |
|
f.write(item['cmd']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
data = '{"path": "%s", "allTactics": true}' %(temp_file) |
|
command = 'echo \'%s\' | lake exe repl' % data |
|
|
|
try: |
|
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
stdout = result.stdout.decode('utf-8') |
|
stderr = result.stderr.decode('utf-8') |
|
|
|
json_stdout = json.loads(stdout) |
|
if "messages" not in json_stdout.keys(): |
|
passed += 1 |
|
|
|
results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'}) |
|
except subprocess.CalledProcessError as e: |
|
|
|
results.append({ 'error': str(e), 'status': 'nopass'}) |
|
total += 1 |
|
|
|
pass_rate = passed / (passed + total) * 100 |
|
|
|
|
|
return {'results': results, 'pass_rate': pass_rate} |
|
|
|
|
|
|
|
|
|
|
|
|
|
def single(command_list, args): |
|
results = [] |
|
passed = 0 |
|
total = 0 |
|
for item in tqdm(command_list): |
|
with open("test/test.lean", "w", encoding = 'utf-8') as f: |
|
f.write(item['cmd']) |
|
data = '{"path": "test/test.lean", "allTactics": true}' |
|
|
|
command = 'echo \'%s\' | lake exe repl' % data |
|
try: |
|
|
|
|
|
|
|
|
|
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
stdout = result.stdout.decode('utf-8') |
|
json_stdout = json.loads(stdout) |
|
if "messages" not in json_stdout.keys(): |
|
passed += 1 |
|
stderr = result.stderr.decode('utf-8') |
|
results.append({ |
|
|
|
'stdout': stdout, |
|
'stderr': stderr, |
|
'status': 'pass' |
|
}) |
|
except subprocess.CalledProcessError as e: |
|
results.append({ |
|
|
|
'error': str(e), |
|
'status': 'nopass' |
|
}) |
|
total += 1 |
|
|
|
|
|
pass_rate = passed / total * 100 |
|
print(pass_rate) |
|
|
|
|
|
with open('results.json', 'w') as f: |
|
json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
def multi(command_list, output_path, k ): |
|
results = [] |
|
passed = 0 |
|
total = 0 |
|
def execute_command(item, index): |
|
temp_dir = '/opt/jianqiao' |
|
def filter_json(json_data): |
|
filtered_data = {} |
|
for key in json_data.keys(): |
|
if key in ['question', 'answer', 'total output', 'results']: |
|
filtered_data[key] = json_data[key] |
|
return filtered_data |
|
|
|
result_dict = item |
|
result_dict['results'] = [] |
|
|
|
for i, cmd in enumerate(item['cmd']): |
|
temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") |
|
with open(temp_file, "w") as f: |
|
f.write(cmd) |
|
|
|
data = '{"path": "%s", "allTactics": true}' % temp_file |
|
command = f'echo \'{data}\' | lake exe repl' |
|
|
|
try: |
|
result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
stdout = json.loads(result.stdout.decode('utf-8')) |
|
stderr = result.stderr.decode('utf-8') |
|
|
|
except subprocess.TimeoutExpired as e: |
|
result_item = {'error': str(e), 'status': 'nopass_limit'} |
|
|
|
except subprocess.CalledProcessError as e: |
|
result_item = {'error': str(e), 'status': 'nopass_error'} |
|
|
|
else: |
|
if "messages" not in stdout and not len(stderr): |
|
result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' } |
|
elif not len(stderr) and "messages" in stdout: |
|
flag = 0 |
|
for me in stdout['messages']: |
|
if me['severity'] == 'error': |
|
flag = 1 |
|
start_line = me['pos']['line'] - 1 |
|
current_column =me['pos']['column'] -1 |
|
for line_n in range(start_line - 1, 0 , -1): |
|
line_len = len(cmd.split('\n')[line_n]) |
|
current_column += line_len + 1 |
|
if not line_len: |
|
break |
|
result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column} |
|
break |
|
if not flag : |
|
result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'} |
|
else: |
|
assert len(stderr) |
|
result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 } |
|
|
|
result_dict['results'].append(result_item) |
|
return result_dict |
|
|
|
|
|
total = len(command_list) |
|
|
|
with ThreadPoolExecutor(max_workers=128) as executor: |
|
futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)] |
|
for future in tqdm(futures, total=total, desc="Processing Commands"): |
|
result = future.result() |
|
results.append(result) |
|
|
|
|
|
|
|
def calculate_pass(result_list, k): |
|
pass_1_count = 0 |
|
pass_k_count = 0 |
|
|
|
for result in result_list: |
|
results = result.get('results', []) |
|
if results: |
|
for j in range(min(1, len(results))): |
|
if results[j].get('status') == 'pass': |
|
pass_1_count += 1 |
|
break |
|
|
|
for j in range(min(k, len(results))): |
|
if results[j].get('status') == 'pass': |
|
pass_k_count += 1 |
|
break |
|
|
|
pass_1 = pass_1_count / len(result_list) if result_list else 0 |
|
pass_k = pass_k_count / len(result_list) if result_list else 0 |
|
|
|
return pass_1, pass_k |
|
|
|
pass_1, pass_k = calculate_pass(results, k) |
|
print('total len:', len(results)) |
|
print("Pass@1:", pass_1) |
|
print(f"Pass@{k}:", pass_k) |
|
|
|
|
|
|
|
|
|
|
|
output_file = f"pass_rate_results/{output_path}" |
|
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True) |
|
|
|
with open(f"{output_file}", 'w') as f: |
|
json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False) |
|
|
|
import re |
|
def remove_simp_pattern_from_end(s): |
|
pattern = r'@\[simp\s*.*?\]$' |
|
return re.sub(pattern, '', s) |
|
|
|
|
|
|
|
|
|
def main(args): |
|
import pdb |
|
command_list = [] |
|
|
|
json_filename = 'data/basic_working.json' |
|
|
|
json_item = json.load(open(json_filename, encoding='utf-8')) |
|
working_env = json_item['working_file'] |
|
file_pattern = os.path.join(args.input_path, '[0-9]*.json') |
|
for file_path in glob.glob(file_pattern): |
|
with open(file_path, 'r', encoding='utf-8') as rf: |
|
for line in rf.readlines(): |
|
try: |
|
json_item = json.loads(line) |
|
json_item['cmd'] = [] |
|
for output in json_item['total output'][:min(args.k, len(json_item['total output']))]: |
|
if "llemma" in args.input_path: |
|
output = output.split('###')[0] |
|
statement = get_lean(output.strip(), args.input_path) |
|
json_item['cmd'].append('\n\n'.join([working_env, statement])) |
|
json_item['answer'] = json_item['content']['answer'] |
|
except: |
|
import pdb |
|
pdb.set_trace() |
|
command_list.append(json_item) |
|
command_list = command_list |
|
|
|
multi(command_list, args.output_path, args.k) |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
arg_parser = ArgumentParser() |
|
arg_parser.add_argument('--data_path', type=str, |
|
default='data/grade-school-math-master/grade_school_math/data/test.jsonl') |
|
arg_parser.add_argument('--input_path', type=str, default='') |
|
arg_parser.add_argument('--cuda_num', type=int, default=8) |
|
arg_parser.add_argument('--k', type=int, default=1) |
|
arg_parser.add_argument('--output_path', type=str, default='total.json') |
|
arg_parser.add_argument('--generate_method', type=str, |
|
choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency']) |
|
arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data']) |
|
args = arg_parser.parse_args() |
|
main(args) |
|
|
|
|
|
|
|
|