import os import subprocess from argparse import ArgumentParser import json from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import glob import tempfile import random def wrapped_function(item): results = [] passed = 0 total = 0 temp_dir = tempfile.gettempdir() temp_file = os.path.join(temp_dir, f"test.lean") with open(temp_file, "w") as f: f.write(item['cmd']) # Rest of the function code... # Process the item using the temporary file # ... # Clean up the temporary file data = '{"path": "%s", "allTactics": true}' %(temp_file) command = 'echo \'%s\' | lake exe repl' % data try: result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout = result.stdout.decode('utf-8') stderr = result.stderr.decode('utf-8') # stdout = result.stdout.decode('utf-8') json_stdout = json.loads(stdout) if "messages" not in json_stdout.keys(): passed += 1 # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'}) results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'}) except subprocess.CalledProcessError as e: # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'}) results.append({ 'error': str(e), 'status': 'nopass'}) total += 1 pass_rate = passed / (passed + total) * 100 return {'results': results, 'pass_rate': pass_rate} # Set the directory where your .lean files are located # Get a list of all .lean files in the directory # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")] # lean_files = ["test/file.lean"] def single(command_list, args): results = [] passed = 0 total = 0 for item in tqdm(command_list): with open("test/test.lean", "w", encoding = 'utf-8') as f: f.write(item['cmd']) data = '{"path": "test/test.lean", "allTactics": true}' # data = '{"cmd": "%s", "allTactics": true}' % item['cmd'] command = 'echo \'%s\' | lake exe repl' % data try: # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, # stderr=subprocess.PIPE) # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8')) # stdout = stdout.decode('utf-8') result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout = result.stdout.decode('utf-8') json_stdout = json.loads(stdout) if "messages" not in json_stdout.keys(): passed += 1 stderr = result.stderr.decode('utf-8') results.append({ # 'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass' }) except subprocess.CalledProcessError as e: results.append({ # 'item': item['content'], 'error': str(e), 'status': 'nopass' }) total += 1 # Calculate pass rate pass_rate = passed / total * 100 print(pass_rate) # Save results to a JSON file with open('results.json', 'w') as f: json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False) def multi(command_list, output_path, k ): results = [] passed = 0 total = 0 def execute_command(item, index): temp_dir = '/opt/jianqiao' def filter_json(json_data): filtered_data = {} for key in json_data.keys(): if key in ['question', 'answer', 'total output', 'results']: filtered_data[key] = json_data[key] return filtered_data result_dict = filter_json(item) result_dict['results'] = [] for i, cmd in enumerate(item['cmd']): temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames with open(temp_file, "w") as f: f.write(cmd) data = '{"path": "%s", "allTactics": true}' % temp_file command = f'echo \'{data}\' | lake exe repl' try: result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout = json.loads(result.stdout.decode('utf-8')) stderr = result.stderr.decode('utf-8') except subprocess.TimeoutExpired as e: result_item = {'error': str(e), 'status': 'nopass_limit'} except subprocess.CalledProcessError as e: result_item = {'error': str(e), 'status': 'nopass_error'} else: if "messages" not in stdout and not len(stderr): result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' } elif not len(stderr) and "messages" in stdout: flag = 0 for me in stdout['messages']: if me['severity'] == 'error': flag = 1 start_line = me['pos']['line'] - 1 current_column =me['pos']['column'] -1 for line_n in range(start_line - 1, 0 , -1): line_len = len(cmd.split('\n')[line_n]) current_column += line_len + 1 if not line_len: break result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column} break if not flag : result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'} else: assert len(stderr) result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 } result_dict['results'].append(result_item) return result_dict total = len(command_list) with ThreadPoolExecutor(max_workers=128) as executor: futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)] for future in tqdm(futures, total=total, desc="Processing Commands"): result = future.result() results.append(result) # if result['status'] == 'pass': # passed += 1 def calculate_pass(result_list, k): pass_1_count = 0 pass_k_count = 0 for result in result_list: results = result.get('results', []) if results: for j in range(min(1, len(results))): if results[j].get('status') == 'pass': pass_1_count += 1 break for j in range(min(k, len(results))): if results[j].get('status') == 'pass': pass_k_count += 1 break pass_1 = pass_1_count / len(result_list) if result_list else 0 pass_k = pass_k_count / len(result_list) if result_list else 0 return pass_1, pass_k pass_1, pass_k = calculate_pass(results, k) print("Pass@1:", pass_1) print(f"Pass@{k}:", pass_k) # pass_rate = (passed / total) * 100 # print(f"total test: {total}") # print(f"Pass rate: {pass_rate}%") output_file = f"pass_rate_results/{output_path}" # Create the directory if it doesn't exist os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(f"{output_file}", 'w') as f: json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False) import re def remove_simp_pattern_from_end(s): pattern = r'@\[simp\s*.*?\]$' return re.sub(pattern, '', s) def get_lean(text): content = "" try: code_block_pattern = r"```lean\s*\n(.*?)\n```" code_blocks = re.findall(code_block_pattern, text, re.DOTALL) content = "\n\n".join(code_blocks) except: matches = re.findall(r'```(.*?)```', text, re.DOTALL) if len(matches): content = "\n\n".join(matches) finally: if not len(content.strip()): try: code_block_pattern = r"```lean4\s*\n(.*?)\n```" code_blocks = re.findall(code_block_pattern, text, re.DOTALL) content = "\n\n".join(code_blocks) except: content = '' if not len(content.strip()): content = "theorem h : f + g = 39 := by exact rfl" return content def main(args): command_list = [] all_dicts = {} with open(f"{args.input_path}/1.jsonl", 'r', encoding='utf-8') as rf: for line in rf.readlines(): try: json_item = json.loads(line) working_env = json_item['working_file'] text = get_lean(json_item['model_response']).split("#align")[0] json_item['cmd'] = ['\n\n'.join([working_env, text])] json_item['answer'] = json_item['statement_poof'] all_dicts[json_item['query_id']] = json_item assert len(text) > 0 except: import pdb pdb.set_trace() file_pattern = os.path.join(args.input_path, '[2-9]*.jsonl') for file_path in glob.glob(file_pattern): with open(file_path, 'r', encoding='utf-8') as rf: for line in rf.readlines(): try: json_item = json.loads(line) working_env = json_item['working_file'] text = get_lean(json_item['model_response']).split("#align")[0] all_dicts[json_item['query_id']]['cmd'].append('\n\n'.join([working_env, text])) assert len(text) > 0 except: import pdb pdb.set_trace() for k, v in all_dicts.items(): command_list.append(v) multi(command_list, args.output_path, args.k) if __name__ == '__main__': arg_parser = ArgumentParser() arg_parser.add_argument('--data_path', type=str, default='data/grade-school-math-master/grade_school_math/data/test.jsonl') arg_parser.add_argument('--input_path', type=str, default='') arg_parser.add_argument('--cuda_num', type=int, default=8) arg_parser.add_argument('--k', type=int, default=1) arg_parser.add_argument('--output_path', type=str, default='total.json') arg_parser.add_argument('--generate_method', type=str, choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency']) arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data']) args = arg_parser.parse_args() main(args)