lean4-compile / pass_rate_new.py
rookiemango's picture
Upload folder using huggingface_hub
dddc1ae verified
history blame
7.79 kB
import os
import subprocess
from argparse import ArgumentParser
import json
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import glob
import tempfile
def wrapped_function(item):
results = []
passed = 0
total = 0
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, f"test.lean")
with open(temp_file, "w") as f:
# Rest of the function code...
# Process the item using the temporary file
# ...
# Clean up the temporary file
data = '{"path": "%s", "allTactics": true}' %(temp_file)
command = 'echo \'%s\' | lake exe repl' % data
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout = result.stdout.decode('utf-8')
stderr = result.stderr.decode('utf-8')
# stdout = result.stdout.decode('utf-8')
json_stdout = json.loads(stdout)
if "messages" not in json_stdout.keys():
passed += 1
# results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
except subprocess.CalledProcessError as e:
# results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
results.append({ 'error': str(e), 'status': 'nopass'})
total += 1
pass_rate = passed / (passed + total) * 100
return {'results': results, 'pass_rate': pass_rate}
# Set the directory where your .lean files are located
# Get a list of all .lean files in the directory
# lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
# lean_files = ["test/file.lean"]
def single(command_list, args):
results = []
passed = 0
total = 0
for item in tqdm(command_list):
with open("test/test.lean", "w", encoding = 'utf-8') as f:
data = '{"path": "test/test.lean", "allTactics": true}'
# data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
command = 'echo \'%s\' | lake exe repl' % data
# process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
# stderr=subprocess.PIPE)
# stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
# stdout = stdout.decode('utf-8')
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout = result.stdout.decode('utf-8')
json_stdout = json.loads(stdout)
if "messages" not in json_stdout.keys():
passed += 1
stderr = result.stderr.decode('utf-8')
# 'item': item['content'],
'stdout': stdout,
'stderr': stderr,
'status': 'pass'
except subprocess.CalledProcessError as e:
# 'item': item['content'],
'error': str(e),
'status': 'nopass'
total += 1
# Calculate pass rate
pass_rate = passed / total * 100
# Save results to a JSON file
with open('results.json', 'w') as f:
json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
def multi(command_list, output_path):
results = []
passed = 0
total = 0
def execute_command(item):
temp_dir = '/opt/jianqiao'
temp_file = os.path.join(temp_dir, f"test_{item['index']}.lean") # Ensure unique filenames
with open(temp_file, "w") as f:
data = '{"path": "%s", "allTactics": true}' % temp_file
command = f'echo \'{data}\' | lake exe repl'
result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout = result.stdout.decode('utf-8')
stderr = result.stderr.decode('utf-8')
if "messages" not in json.loads(stdout) and not len(stderr):
return {'stdout': stdout, 'stderr': stderr, 'status': 'pass' , 'statement':item['statement'], 'content': item['content']}
return {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'statement':item['statement'] , 'content': item['content']}
except subprocess.TimeoutExpired as e:
return {'error': str(e), 'status': 'nopass_limit', 'statement':item['statement'], 'content': item['content']}
except subprocess.CalledProcessError as e:
return {'error': str(e), 'status': 'nopass_error', 'statement':item['statement'], 'content': item['content']}
total = len(command_list)
with ThreadPoolExecutor(max_workers=32) as executor:
futures = [executor.submit(execute_command, {'index': i, 'cmd': cmd['cmd'], 'statement':cmd['statement'], 'content':cmd['content']}) for i, cmd in enumerate(command_list)]
for future in tqdm(futures, total=total, desc="Processing Commands"):
result = future.result()
if result['status'] == 'pass':
passed += 1
pass_rate = (passed / total) * 100
print(f"total test: {total}")
print(f"Pass rate: {pass_rate}%")
output_file = f"pass_rate_results/{output_path}"
# Create the directory if it doesn't exist
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(f"{output_file}", 'w') as f:
json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
import re
def remove_simp_pattern_from_end(s):
pattern = r'@\[simp\s*.*?\]$'
return re.sub(pattern, '', s)
def main(args):
command_list = []
file_pattern = os.path.join(args.input_path, '[0-9]*.json')
for file_path in glob.glob(file_pattern):
with open(file_path, 'r', encoding='utf-8') as rf:
for line in rf.readlines():
json_item = json.loads(line)
working_env = json_item['content']['working_file']
# pdb.set_trace()
# statement = json_item['total output'][0]
statement = json_item['total output'][0].split("#align")[0]
json_item['statement'] = statement
json_item['cmd'] = '\n\n'.join([working_env, statement])
assert len(statement) > 0
# json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
import pdb
command_list = command_list
results = []
passed = 0
total = 0
multi(command_list, args.output_path)
if __name__ == '__main__':
arg_parser = ArgumentParser()
arg_parser.add_argument('--data_path', type=str,
arg_parser.add_argument('--input_path', type=str, default='')
arg_parser.add_argument('--cuda_num', type=int, default=8)
arg_parser.add_argument('--output_path', type=str, default='total.json')
arg_parser.add_argument('--generate_method', type=str,
choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
args = arg_parser.parse_args()