learnmlf commited on
Commit
2680a94
Β·
verified Β·
1 Parent(s): 304e099

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +8 -7
  2. main.py +640 -0
  3. requirements.txt +13 -0
  4. space_demo.py +180 -0
  5. testcase_utils.py +222 -0
  6. timeout_utils.py +67 -0
  7. utils.py +549 -0
README.md CHANGED
@@ -1,14 +1,15 @@
1
  ---
2
- title: MGDebugger
3
- emoji: 😻
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.44.1
8
- app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: 'MGDebugger Demo: Multi-Granularity LLM Debugger'
12
  ---
13
 
 
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test
3
+ emoji: πŸ“ˆ
4
+ colorFrom: pink
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.44.1
8
+ app_file: space_demo.py
9
  pinned: false
10
+ license: mit
 
11
  ---
12
 
13
+ We are highly inspired by [LDB](https://huggingface.co/spaces/shangdatalab-ucsd/LDB) πŸš€βœ¨
14
+
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
main.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from openai import OpenAI
3
+ from collections import Counter
4
+ from loguru import logger
5
+ import json
6
+ import os
7
+ from tqdm import tqdm
8
+ import sys
9
+ import io
10
+ import traceback
11
+ import ast
12
+ import time
13
+ import re
14
+ from groq import Groq
15
+ from utils import split_nested_functions, get_dependency_graph_str, evaluate, parse_json_response, extract_code_blocks, extract_functions, extract_function, create_dependency_graph, topological_sort, merge_changes_to_parents, evaluate_simple, parse_transcoder_problem_content
16
+ from testcase_utils import get_parameter_names, parse_tests
17
+
18
+ # CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote-code --dtype auto --api-key token-abc123s --port 18889
19
+ # CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --model TechxGenus/Codestral-22B-v0.1-GPTQ --dtype auto --api-key token-abc123s --port 18890 --trust-remote-code --chat-template helper/codestral_template.jinja
20
+ # CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --model Qwen/CodeQwen1.5-7B-Chat --dtype auto --api-key token-abc123s --port 18892 --trust-remote-code
21
+
22
+ MODEL = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
23
+ # MODEL = "TechxGenus/Codestral-22B-v0.1-GPTQ"
24
+ # MODEL = "Qwen/CodeQwen1.5-7B-Chat"
25
+
26
+
27
+ client = OpenAI(
28
+ base_url="http://localhost:18889/v1",
29
+ api_key="token-abc123s",
30
+ )
31
+
32
+ # dscoder 18889
33
+ # codestral 18890
34
+ # codeqwen 18892
35
+
36
+ # Hyperparameters
37
+ MAX_VLLM_RETRIES = 10 # maximum number of retries for the VLLM call
38
+ MAX_PARSE_RETRIES = 3 # we sometimes fail to parse the code/test cases from the response, so we retry
39
+ MAX_DEBUG_RETRIES = 1 # 3 seems to be slightly better than 1, but the difference is not significant
40
+ RETRY_DELAY = 0 # seconds
41
+ REPEAT_CONVERT_HIERARCHICAL_NUM = 1 # seems unimportant
42
+ REPEAT_TEST_CASE_GENERATION_NUM = 1 # 1 seems to be better than 3
43
+ MAX_OUTER_RETRY = 10 # maximum number of retries for the entire debugging process
44
+ CONTINUOUS_RETRY = False # whether to retry from the last fixed code, else retry from the original buggy code. it seems better to set this to False
45
+ TEMPERATURE = 0.8 # 0.8 better than 1.0 better than 0.2
46
+ MINP = 0.05
47
+
48
+ # for collecting stats
49
+ TOTAL_PROMPT_TOKENS = 0
50
+ TOTAL_COMPLETION_TOKENS = 0
51
+ TOTAL_MG_DEBUG_CALLS = 0
52
+ TOTAL_DEBUG_FUNCTION_CALLS = 0
53
+ TOTAL_GENERATE_TEST_CASES_CALLS = 0
54
+ TOTAL_CONVERT_HIERARCHICAL_CALLS = 0
55
+
56
+
57
+ def get_completion_with_retry(messages, model=MODEL, MAX_VLLM_RETRIES=MAX_VLLM_RETRIES):
58
+ global TOTAL_PROMPT_TOKENS, TOTAL_COMPLETION_TOKENS
59
+ for attempt in range(MAX_VLLM_RETRIES):
60
+ try:
61
+ logger.info(f"Attempting LLM call (attempt {attempt + 1}/{MAX_VLLM_RETRIES})")
62
+ logger.info(f"Input messages: {messages[-1]['content']}")
63
+ chat_completion = client.chat.completions.create(
64
+ messages=messages,
65
+ model=MODEL,
66
+ temperature=TEMPERATURE
67
+ )
68
+ response = chat_completion.choices[0].message.content
69
+ logger.info(f"LLM response: {response}")
70
+ logger.info("LLM call received")
71
+
72
+ # Update token counts
73
+ TOTAL_PROMPT_TOKENS += chat_completion.usage.prompt_tokens
74
+ TOTAL_COMPLETION_TOKENS += chat_completion.usage.completion_tokens
75
+
76
+ return response
77
+ except Exception as e:
78
+ logger.error(f"LLM call failed: {str(e)}")
79
+ if attempt < MAX_VLLM_RETRIES - 1:
80
+ logger.info(f"Retrying in {RETRY_DELAY} seconds...")
81
+ time.sleep(RETRY_DELAY)
82
+ else:
83
+ logger.error("Max retries reached. Giving up.")
84
+ raise
85
+
86
+
87
+ def generate_test_cases(full_code, gold_test_cases, function_name, MAX_PARSE_RETRIES=MAX_PARSE_RETRIES):
88
+ global TOTAL_GENERATE_TEST_CASES_CALLS
89
+ TOTAL_GENERATE_TEST_CASES_CALLS += 1
90
+
91
+ logger.info(f"Generating test cases for function: {function_name}")
92
+
93
+ prompt = f"""
94
+ Analyze the following Python code and focus on the function named `{function_name}`.
95
+ Generate the same number of test cases for this specific function based on the provided gold test cases for the main function.
96
+ Ensure that the generated test cases are consistent with the behavior expected in the gold test cases.
97
+
98
+ ### Full Code
99
+ {full_code}
100
+
101
+ ### Gold Test Cases for the main function
102
+ {json.dumps(gold_test_cases, indent=2)}
103
+
104
+ ### Function to generate test cases for: `{function_name}`
105
+
106
+ ### Output format:
107
+ **Test Case 1:**
108
+ Input: ...
109
+ Analysis: ...
110
+ Expected Output: ...
111
+
112
+ **Test Case 2:**
113
+ ...
114
+
115
+ **All Test Cases:**
116
+ ```json
117
+ {{
118
+ "test_cases": [
119
+ {{"input": {{'var1': 'value1', 'var2': 'value2'}}, "expected_output": "expected_output"}},
120
+ ...
121
+ ]
122
+ }}
123
+ ```
124
+
125
+ ### Hint
126
+ Analyze how the `{function_name}` function is used within the main function and how it contributes to the expected outputs in the gold test cases. Generate test cases that reflect this behavior. For each test case, you should analyze step-by-step based on both the input and the expected output of the main function, and then provide the corresponding input and expected output for the `{function_name}` function.
127
+ """
128
+
129
+
130
+ messages = [
131
+ {'role': 'system', 'content': "You are an AI assistant specialized in analyzing Python functions and generating test cases."},
132
+ {'role': 'user', 'content': prompt},
133
+ ]
134
+
135
+ if "starcoder2" in MODEL.lower():
136
+ # remove the system message
137
+ messages = messages[1:]
138
+
139
+ all_test_cases = []
140
+ for _ in range(REPEAT_TEST_CASE_GENERATION_NUM):
141
+ response = get_completion_with_retry(messages)
142
+ parsed_response = parse_json_response(response)
143
+ if parsed_response and 'test_cases' in parsed_response:
144
+ all_test_cases.extend(parsed_response['test_cases'])
145
+
146
+ if not all_test_cases:
147
+ logger.error("Failed to generate any valid test cases")
148
+ return []
149
+
150
+ # Perform majority voting on expected outputs
151
+ logger.info(f"Generated {len(all_test_cases)} test cases for {function_name}: {all_test_cases}")
152
+ test_case_counter = Counter([json.dumps(tc) for tc in all_test_cases])
153
+ # show the frequency of each test case
154
+ logger.info(f"Test case frequency:")
155
+ for tc, freq in test_case_counter.items():
156
+ logger.info(f"Test case: {tc}, Frequency: {freq}")
157
+ most_common_test_cases = test_case_counter.most_common(len(gold_test_cases))
158
+ final_test_cases = [json.loads(tc[0]) for tc in most_common_test_cases]
159
+
160
+ logger.info(f"Voted {len(final_test_cases)} test cases for {function_name}: {final_test_cases}")
161
+ return final_test_cases
162
+
163
+
164
+ def debug_function(function_code, function_name, test_cases, MAX_PARSE_RETRIES=MAX_PARSE_RETRIES):
165
+ global TOTAL_DEBUG_FUNCTION_CALLS
166
+ TOTAL_DEBUG_FUNCTION_CALLS += 1
167
+
168
+ logger.info(f"Debugging function:\n{function_code}\nWith test cases: {test_cases}")
169
+ prompt = f"""
170
+ Debug the following Python function. The function is not passing all test cases.
171
+ Analyze the code, identify the bug, and provide a fixed version of the function.
172
+
173
+ ### Function:
174
+ {function_code}
175
+
176
+ ### Test Cases:
177
+ {json.dumps(test_cases, indent=2)}
178
+
179
+ ### Task:
180
+
181
+ Provide your response in the following format:
182
+ 1. try to work as a python interpreter to execute the code step-by-step.
183
+ 2. identify the change of each variable as you "run" the code line-by-line.
184
+ 3. based on the execution trace, try to identify the bug
185
+ 4. provide the final fixed code in a Python code block (```python ... ```)
186
+
187
+ Make sure to include the entire function, including the function signature. And you will be rewarded to simulate the code execution in your mind and provide step-by-step trace of the code execution.
188
+
189
+ """
190
+
191
+ messages = [
192
+ {'role': 'system', 'content': 'You are an AI assistant helping to debug Python functions.'},
193
+ {'role': 'user', 'content': prompt},
194
+ ]
195
+
196
+ if "starcoder2" in MODEL.lower():
197
+ # remove the system message
198
+ messages = messages[1:]
199
+
200
+ for attempt in range(MAX_PARSE_RETRIES):
201
+ try:
202
+ response = get_completion_with_retry(messages)
203
+ code_blocks = extract_code_blocks(response)
204
+ if code_blocks:
205
+ # search in a reverse order to find the final fixed function
206
+ for block in code_blocks[::-1]:
207
+ fixed_function = extract_function(block, function_name)
208
+ if fixed_function:
209
+ analysis = response.split("```python")[0].strip()
210
+ logger.info("Generated debug analysis and fixed code")
211
+ return analysis, fixed_function
212
+ raise ValueError("No valid fixed function found in the response")
213
+ except ValueError as e:
214
+ logger.error(f"Failed to extract fixed function (attempt {attempt + 1}/{MAX_PARSE_RETRIES}): {str(e)}")
215
+ if attempt < MAX_PARSE_RETRIES - 1:
216
+ logger.info(f"Retrying in {RETRY_DELAY} seconds...")
217
+ time.sleep(RETRY_DELAY)
218
+ else:
219
+ logger.error("Max retries reached. Returning None for analysis and fixed code.")
220
+ return None, None
221
+
222
+
223
+ def convert_to_hierarchical(code, include_example=False):
224
+ global TOTAL_CONVERT_HIERARCHICAL_CALLS
225
+ TOTAL_CONVERT_HIERARCHICAL_CALLS += 1
226
+
227
+ logger.info("Converting code to hierarchical structure")
228
+
229
+ example = """
230
+ ### Example of tree-style hierarchical structure:
231
+
232
+ ```python
233
+ def main_function(input):
234
+ preprocessed_data = preprocess(input)
235
+ result = process(preprocessed_data)
236
+ return result
237
+
238
+ def preprocess(data):
239
+ cleaned_data = clean_data(data)
240
+ normalized_data = normalize_data(cleaned_data)
241
+ return normalized_data
242
+
243
+ def clean_data(data):
244
+ # Implementation of data cleaning
245
+ pass
246
+
247
+ def normalize_data(data):
248
+ # Implementation of data normalization
249
+ pass
250
+
251
+ def process(data):
252
+ feature_vector = extract_features(data)
253
+ result = classify(feature_vector)
254
+ return result
255
+
256
+ def extract_features(data):
257
+ # Implementation of feature extraction
258
+ pass
259
+
260
+ def classify(feature_vector):
261
+ # Implementation of classification
262
+ pass
263
+ ```
264
+ """ if include_example else ""
265
+
266
+ prompt = f"""
267
+ Convert the following Python code into a tree-style hierarchical structure with multiple levels of sub-functions.
268
+ Each significant step or logical block should be its own function, and functions can call other sub-functions.
269
+ Ensure that the main function calls these sub-functions in the correct order, creating a tree-like structure.
270
+
271
+ ### Original Code:
272
+ {code}
273
+
274
+ {example}
275
+
276
+ ### Instructions:
277
+ Please first analyze the codes step by step, and then provide the converted code in a Python code block (```python ... ```). When providing the final converted code, make sure to include all the functions in a flattened format, where each function is defined separately.
278
+ """
279
+
280
+ messages = [
281
+ {'role': 'system', 'content': 'You are an AI assistant specialized in refactoring Python code into a tree-style hierarchical structure.'},
282
+ {'role': 'user', 'content': prompt},
283
+ ]
284
+
285
+ if "starcoder2" in MODEL.lower():
286
+
287
+ # remove the system message
288
+ messages = messages[1:]
289
+
290
+ best_conversion = None
291
+ max_subfunctions = 0
292
+
293
+ for _ in range(REPEAT_CONVERT_HIERARCHICAL_NUM):
294
+ response = get_completion_with_retry(messages)
295
+ code_blocks = extract_code_blocks(response)
296
+
297
+ if code_blocks:
298
+ converted_code = code_blocks[0]
299
+ subfunctions = len(extract_functions(converted_code)) - 1 # Subtract 1 to exclude the main function
300
+
301
+ if subfunctions > max_subfunctions:
302
+ max_subfunctions = subfunctions
303
+ best_conversion = converted_code
304
+
305
+ if best_conversion:
306
+ logger.info(f"Converted code to tree-style hierarchical structure with {max_subfunctions} sub-functions")
307
+ # split nested functions
308
+ best_conversion = split_nested_functions(best_conversion)
309
+ return best_conversion
310
+ else:
311
+ logger.error("Failed to convert code to tree-style hierarchical structure")
312
+ code = split_nested_functions(code)
313
+ return code
314
+
315
+ def mg_debug_demo(full_code, gold_test_cases, max_debug_attempts=MAX_DEBUG_RETRIES, given_model=MODEL, given_client=client):
316
+ global MODEL, client
317
+ MODEL = given_model
318
+ client = given_client
319
+ return mg_debug(full_code, gold_test_cases, max_debug_attempts)
320
+
321
+
322
+ def mg_debug(full_code, gold_test_cases, max_debug_attempts=MAX_DEBUG_RETRIES):
323
+ global TOTAL_MG_DEBUG_CALLS
324
+ TOTAL_MG_DEBUG_CALLS += 1
325
+ logger.info("Starting main debugging process")
326
+
327
+ convert_hierarchical_attempts = 0
328
+ while convert_hierarchical_attempts < MAX_PARSE_RETRIES:
329
+ try:
330
+ # Convert to tree-style hierarchical structure
331
+ hierarchical_code = convert_to_hierarchical(full_code, include_example=False)
332
+ logger.info(f"Converted code to tree-style hierarchical structure:\n{hierarchical_code}")
333
+
334
+ functions = extract_functions(hierarchical_code)
335
+
336
+ # Create a dependency graph
337
+ dependency_graph = create_dependency_graph(functions)
338
+ logger.info(f"Dependency graph:\n{get_dependency_graph_str(dependency_graph)}")
339
+
340
+ break
341
+ except Exception as e:
342
+ logger.error(f"Failed to convert code to hierarchical structure (attempt {convert_hierarchical_attempts + 1}/{MAX_PARSE_RETRIES}): {str(e)}")
343
+ convert_hierarchical_attempts += 1
344
+ # retry
345
+
346
+ # Sort functions based on their dependencies (bottom-up)
347
+ sorted_functions = topological_sort(dependency_graph)
348
+ logger.info(f"Sorted functions: {sorted_functions}")
349
+
350
+ for func_name in sorted_functions:
351
+ logger.info(f"Processing function: {func_name}")
352
+
353
+ func_code = functions[func_name]
354
+ test_cases = generate_test_cases(hierarchical_code, gold_test_cases, func_name)
355
+ fixed_code = func_code
356
+
357
+ for debug_attempt in range(max_debug_attempts):
358
+ all_tests_pass = True
359
+
360
+ for test_case in test_cases:
361
+ passed, result = evaluate(hierarchical_code, func_name, test_case)
362
+ if not passed:
363
+ all_tests_pass = False
364
+ break
365
+
366
+ if all_tests_pass:
367
+ logger.info(f"All tests passed for function: {func_name}")
368
+ break
369
+
370
+ logger.info(f"Debugging function: {func_name} (Attempt {debug_attempt + 1}/{max_debug_attempts})")
371
+ analysis, new_fixed_code = debug_function(fixed_code, func_name, test_cases)
372
+
373
+ if new_fixed_code:
374
+ logger.info(f"New fixed code for {func_name}:\n{new_fixed_code}")
375
+ fixed_code = new_fixed_code
376
+ functions[func_name] = fixed_code
377
+ logger.info(f"Merging {func_name} changes")
378
+ hierarchical_code = merge_changes_to_parents(func_name, dependency_graph, functions)
379
+ logger.info(f"Code after merging updates in {func_name}:\n{hierarchical_code}")
380
+ else:
381
+ logger.warning(f"Failed to fix {func_name}. Keeping previous implementation.")
382
+ break
383
+
384
+ if not all_tests_pass:
385
+ logger.warning(f"Could not fix {func_name} after {max_debug_attempts} attempts. Keeping original implementation.")
386
+
387
+ # Reconstruct the full code with fixed functions
388
+ fixed_full_code = "\n\n".join(functions.values())
389
+ logger.info("Debugging process completed. Reconstructed full code.")
390
+
391
+ return fixed_full_code
392
+
393
+
394
+ def test():
395
+ buggy_code = '''
396
+ def make_palindrome(string: str) -> str:
397
+ """ Find the shortest palindrome that begins with the supplied string. """
398
+
399
+ def is_palindrome(s: str) -> bool:
400
+ return s == s[::-1]
401
+
402
+ suffix_start = 0
403
+ for i in range(len(string)):
404
+ if is_palindrome(string[i:]):
405
+ suffix_start = i
406
+
407
+ return string + string[:suffix_start][::-1]
408
+ '''.strip()
409
+
410
+ gold_test_cases = [
411
+ {'input': 'cat', 'expected_output': 'catac'},
412
+ {'input': 'cata', 'expected_output': 'catac'},
413
+ {'input': '', 'expected_output': ''}
414
+ ]
415
+
416
+ entry_point = 'make_palindrome'
417
+
418
+ fixed_code = mg_debug(buggy_code, gold_test_cases)
419
+ logger.info(f"Fixed code:\n{fixed_code}")
420
+
421
+ logger.info("============= Final evaluation with private test cases =============")
422
+ # evaluate the final codes with private testcases
423
+ all_tests = gold_test_cases
424
+ for testcase in all_tests:
425
+ result, testcase = evaluate(fixed_code, entry_point, testcase)
426
+ logger.info(f"Passed: {result}, Test case: {testcase}")
427
+
428
+
429
+ def debug_humaneval(input_seeds: str, max_examples: int = None, output_folder: str = None):
430
+ fixed_problems = 0
431
+ total_unsolved = 0
432
+
433
+ with open(input_seeds, "r") as f:
434
+ seeds = f.readlines()
435
+
436
+ unsolved_seeds = []
437
+ # filter those problems that are not solved
438
+ for i in range(len(seeds)):
439
+ problem = json.loads(seeds[i])
440
+ if not problem["is_solved"]:
441
+ unsolved_seeds.append(problem)
442
+
443
+ unsolved_seeds = [seed for seed in unsolved_seeds if seed['task_id'] == 'HumanEval/38']
444
+
445
+ # resume_path = "output_data/mbpp/seed/codeqwen/20240909-025624/CodeQwen1.5-7B-Chat_debugging_seeds_from_codeqwen.jsonl"
446
+ # # only load "debugged": false problems from the resume path
447
+ # with open(resume_path, "r") as f:
448
+ # resume_problems = f.readlines()
449
+ # resume_unsolved_task_ids = [json.loads(problem)["task_id"] for problem in resume_problems if not json.loads(problem)["debugged"]]
450
+ # unsolved_seeds = [problem for problem in unsolved_seeds if problem["task_id"] in resume_unsolved_task_ids]
451
+ # logger.info(f"Resuming from {resume_path}, {len(unsolved_seeds)} problems to debug")
452
+
453
+ # filter the unsolved problems that are not in the resume path
454
+
455
+ # parse transcoder problems
456
+ if "transcoder" in input_seeds.lower():
457
+ logger.info(f"Parsing the problem content for transcoder problems")
458
+ unsolved_seeds = [parse_transcoder_problem_content(problem) for problem in tqdm(unsolved_seeds)]
459
+
460
+ total_unsolved = len(unsolved_seeds)
461
+ logger.info(f"Debugging {total_unsolved} unsolved problems")
462
+ if max_examples is not None:
463
+ unsolved_seeds = unsolved_seeds[:max_examples]
464
+ logger.info(f"Filtering to {max_examples} examples")
465
+
466
+ for problem in tqdm(unsolved_seeds, ncols=100):
467
+
468
+ model_to_be_fixed = input_seeds.split("/")[-2]
469
+ model_name = MODEL.split("/")[-1]
470
+ with open(f"{output_folder}/{model_name}_debugging_seeds_from_{model_to_be_fixed}.jsonl", "w+") as f:
471
+ for seed in unsolved_seeds:
472
+ f.write(json.dumps(seed) + "\n")
473
+
474
+ logger.info(f"Processing unsolved problem: {problem['task_id']}")
475
+ logger.info(f"Problem: {problem}")
476
+ logger.info(f"Problem Raw Prompt: \n{problem['prompt']}")
477
+
478
+ try:
479
+ buggy_code = problem["solution"]
480
+ entry_point = problem["entry_point"]
481
+ try:
482
+ parameter_names = get_parameter_names(problem["prompt"], entry_point)
483
+ except:
484
+ parameter_names = get_parameter_names(problem["solution"], entry_point)
485
+ logger.info(f"Extracted parameter names: {parameter_names}")
486
+
487
+ # in order to save time, we extract the first 3 given tests for transcoder
488
+ if "transcoder" in problem["task_id"].lower():
489
+ logger.info(f"Extracted {len(problem['given_tests'])} given tests, only using the first 3 samples")
490
+ problem["given_tests"] = problem["given_tests"][:3]
491
+
492
+ gold_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
493
+ gold_tests = parse_tests(gold_tests_raw, parameter_names, entry_point)["test_cases"]
494
+ logger.info(f"Extracted gold test cases: {gold_tests}")
495
+
496
+ problem['fixed_codes'] = [] # Initialize list to store fixed codes for each retry
497
+ problem['mg_debug_retries'] = 0 # Initialize retry counter
498
+
499
+ for outer_retry in range(MAX_OUTER_RETRY):
500
+ fixed_code = mg_debug(buggy_code, gold_tests)
501
+ problem['fixed_codes'].append(fixed_code)
502
+ problem['mg_debug_retries'] += 1
503
+
504
+ all_passed = evaluate_simple(fixed_code, entry_point, problem["test"])
505
+ if all_passed:
506
+ fixed_problems += 1
507
+ problem['debugged'] = True
508
+ logger.info(f"Successfully fixed problem: {problem['task_id']} on retry {outer_retry + 1}")
509
+ break
510
+ else:
511
+ logger.info(f"Failed to fix problem: {problem['task_id']} on retry {outer_retry + 1}")
512
+ if outer_retry < MAX_OUTER_RETRY - 1:
513
+ if CONTINUOUS_RETRY:
514
+ buggy_code = fixed_code # Use the last fixed code as the new buggy code for the next retry
515
+ else:
516
+ # Use the original buggy code for the next retry
517
+ buggy_code = problem["solution"]
518
+
519
+ if not all_passed:
520
+ problem['debugged'] = False
521
+ logger.info(f"Failed to fix problem: {problem['task_id']} after {MAX_OUTER_RETRY} retries")
522
+
523
+ except Exception as e:
524
+ logger.error(f"Error occurred while processing problem: {problem['task_id']}")
525
+ logger.error(traceback.format_exc())
526
+ problem['fixed_codes'] = []
527
+ problem['debugged'] = False
528
+ problem['mg_debug_retries'] = 0
529
+
530
+ # Add statistics to the log
531
+ logger.info("=== Statistics ===")
532
+ logger.info(f"Total prompt tokens: {TOTAL_PROMPT_TOKENS}")
533
+ logger.info(f"Total completion tokens: {TOTAL_COMPLETION_TOKENS}")
534
+ logger.info(f"Total mg_debug calls: {TOTAL_MG_DEBUG_CALLS}")
535
+ logger.info(f"Total debug_function calls: {TOTAL_DEBUG_FUNCTION_CALLS}")
536
+ logger.info(f"Total generate_test_cases calls: {TOTAL_GENERATE_TEST_CASES_CALLS}")
537
+ logger.info(f"Total convert_hierarchical calls: {TOTAL_CONVERT_HIERARCHICAL_CALLS}")
538
+
539
+ # Compute and log average statistics
540
+ avg_prompt_tokens = TOTAL_PROMPT_TOKENS / TOTAL_MG_DEBUG_CALLS if TOTAL_MG_DEBUG_CALLS > 0 else 0
541
+ avg_completion_tokens = TOTAL_COMPLETION_TOKENS / TOTAL_MG_DEBUG_CALLS if TOTAL_MG_DEBUG_CALLS > 0 else 0
542
+ avg_debug_function_calls = TOTAL_DEBUG_FUNCTION_CALLS / TOTAL_MG_DEBUG_CALLS if TOTAL_MG_DEBUG_CALLS > 0 else 0
543
+ avg_generate_test_cases_calls = TOTAL_GENERATE_TEST_CASES_CALLS / TOTAL_MG_DEBUG_CALLS if TOTAL_MG_DEBUG_CALLS > 0 else 0
544
+
545
+ logger.info(f"Average prompt tokens per mg_debug call: {avg_prompt_tokens:.2f}")
546
+ logger.info(f"Average completion tokens per mg_debug call: {avg_completion_tokens:.2f}")
547
+ logger.info(f"Average debug_function calls per mg_debug call: {avg_debug_function_calls:.2f}")
548
+ logger.info(f"Average generate_test_cases calls per mg_debug call: {avg_generate_test_cases_calls:.2f}")
549
+
550
+ # distribution of debug retries that solved the problem
551
+ debug_retries = [problem['mg_debug_retries'] for problem in unsolved_seeds if problem['debugged']]
552
+ debug_retries_counter = Counter(debug_retries)
553
+
554
+ logger.info(f"=== Final Results ===")
555
+ logger.info(f"Total unsolved problems in seeds: {total_unsolved}")
556
+ logger.info(f"Problems fixed by our method: {fixed_problems}")
557
+ logger.info(f"Success rate: {fixed_problems / total_unsolved * 100:.2f}%")
558
+ logger.info(f"Debug retries distribution for solved problems: {debug_retries_counter}")
559
+
560
+ logger.info(f"Total solved problems before: {len(seeds) - total_unsolved}")
561
+ logger.info(f"Total solved problems after: {len(seeds) - total_unsolved + fixed_problems}")
562
+ logger.info(f"Previous accuracy: {(len(seeds) - total_unsolved) / len(seeds) * 100:.2f}%")
563
+ logger.info(f"Final accuracy: {(len(seeds) - total_unsolved + fixed_problems) / len(seeds) * 100:.2f}%")
564
+
565
+ # save the final results to a file
566
+ model_to_be_fixed = input_seeds.split("/")[-2]
567
+ model_name = MODEL.split("/")[-1]
568
+ with open(f"{output_folder}/{model_name}_debugging_seeds_from_{model_to_be_fixed}.jsonl", "w+") as f:
569
+ for seed in unsolved_seeds:
570
+ f.write(json.dumps(seed) + "\n")
571
+
572
+ with open(f"{output_folder}/statistics.json", "w") as f:
573
+ stats = {
574
+ "total_prompt_tokens": TOTAL_PROMPT_TOKENS,
575
+ "total_completion_tokens": TOTAL_COMPLETION_TOKENS,
576
+ "total_mg_debug_calls": TOTAL_MG_DEBUG_CALLS,
577
+ "total_debug_function_calls": TOTAL_DEBUG_FUNCTION_CALLS,
578
+ "total_generate_test_cases_calls": TOTAL_GENERATE_TEST_CASES_CALLS,
579
+ "total_convert_hierarchical_calls": TOTAL_CONVERT_HIERARCHICAL_CALLS,
580
+ "avg_prompt_tokens_per_mg_debug": avg_prompt_tokens,
581
+ "avg_completion_tokens_per_mg_debug": avg_completion_tokens,
582
+ "avg_debug_function_calls_per_mg_debug": avg_debug_function_calls,
583
+ "avg_generate_test_cases_calls_per_mg_debug": avg_generate_test_cases_calls
584
+ }
585
+ json.dump(stats, f, indent=2)
586
+ scores = {
587
+ "total_unsolved": total_unsolved,
588
+ "fixed_problems": fixed_problems,
589
+ "success_rate": fixed_problems / total_unsolved * 100,
590
+ "debug_retries_distribution": debug_retries_counter,
591
+ "total_solved_problems_before": len(seeds) - total_unsolved,
592
+ "total_solved_problems_after": len(seeds) - total_unsolved + fixed_problems,
593
+ "previous_accuracy": (len(seeds) - total_unsolved) / len(seeds) * 100,
594
+ "final_accuracy": (len(seeds) - total_unsolved + fixed_problems) / len(seeds) * 100
595
+ }
596
+ json.dump(scores, f, indent=2)
597
+
598
+
599
+ if __name__ == "__main__":
600
+
601
+ input_seeds = "input_data/humaneval/seed/reflexion/seed.jsonl"
602
+ # input_seeds = "input_data/humaneval/seed/codestral/seed.jsonl"
603
+ # input_seeds = "input_data/humaneval/seed/codeqwen/seed.jsonl"
604
+
605
+ # input_seeds = "input_data/mbpp/seed/codestral/seed.jsonl"
606
+ # input_seeds = "input_data/mbpp/seed/deepseekcoder/seed.jsonl"
607
+ # input_seeds = "input_data/mbpp/seed/codeqwen/seed.jsonl"
608
+
609
+ # input_seeds = "input_data/humanevalfix/seeds.jsonl"
610
+
611
+ seed_stamp = input_seeds.split("input_data/")[-1].replace("/seed.jsonl", "")
612
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
613
+ output_folder = f"./output_data/{seed_stamp}/{timestamp}"
614
+ os.makedirs(output_folder, exist_ok=True)
615
+
616
+ # Configure logger
617
+ logger.remove()
618
+ logger.add(sys.stderr, format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
619
+ logger.add(f"{output_folder}/all_info.log", format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
620
+
621
+ # save all important params in a file
622
+ with open(f"{output_folder}/params.json", "w+") as f:
623
+ info = {
624
+ "input_seeds": input_seeds,
625
+ "output_folder": output_folder,
626
+ "MAX_DEBUG_RETRIES": MAX_DEBUG_RETRIES,
627
+ "MAX_PARSE_RETRIES": MAX_PARSE_RETRIES,
628
+ "RETRY_DELAY": RETRY_DELAY,
629
+ "REPEAT_CONVERT_HIERARCHICAL_NUM": REPEAT_CONVERT_HIERARCHICAL_NUM,
630
+ "REPEAT_TEST_CASE_GENERATION_NUM": REPEAT_TEST_CASE_GENERATION_NUM,
631
+ "MODEL": MODEL,
632
+ "MAX_OUTER_RETRY": MAX_OUTER_RETRY,
633
+ "CONTINUOUS_RETRY": CONTINUOUS_RETRY,
634
+ "TEMPERATURE": TEMPERATURE,
635
+ "MAX_VLLM_RETRIES": MAX_VLLM_RETRIES,
636
+ "MINP": MINP
637
+ }
638
+ f.write(json.dumps(info, indent=2))
639
+
640
+ debug_humaneval(input_seeds, max_examples=None, output_folder=output_folder)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jsonlines==3.1.0
2
+ openai==1.12.0
3
+ datasets
4
+ tenacity==8.1.0
5
+ astunparse
6
+ transformers
7
+ accelerate
8
+ astor
9
+ graphviz
10
+ vllm
11
+ astroid
12
+ groq
13
+ loguru
space_demo.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+ import gradio as gr
4
+ import ast
5
+ import random
6
+ from main import mg_debug_demo
7
+ from groq import Groq
8
+ from openai import OpenAI
9
+
10
+
11
+ def debug_code(key, model, code, fixed_code, test_cases):
12
+ if model == 'deepseek-coder':
13
+ client = OpenAI(
14
+ api_key=key,
15
+ base_url="https://api.deepseek.com"
16
+ )
17
+ elif model in ['gemma2-9b-it']:
18
+ client = Groq(
19
+ api_key=key,
20
+ )
21
+ else:
22
+ client = OpenAI(
23
+ api_key=key
24
+ )
25
+ fixed_code = mg_debug_demo(code, test_cases, given_model=model, given_client=client)
26
+ return code, fixed_code, test_cases
27
+
28
+
29
+ app = gr.Blocks(
30
+ theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="gray")
31
+ )
32
+
33
+ with app:
34
+ with gr.Row():
35
+ gr.Markdown("# MGDebugger Demo")
36
+ with gr.Row():
37
+ with gr.Column():
38
+ with gr.Row():
39
+ openai_key_input = gr.Textbox(
40
+ label="API Key",
41
+ placeholder="Enter your API key here",
42
+ type="password",
43
+ )
44
+ model_selector = gr.Dropdown(
45
+ label="Choose Model",
46
+ choices=[("gemma2. Please get your api key from https://console.groq.com/keys", "gemma2-9b-it"),
47
+ ("deepseek-coder. Please get your api key from https://platform.deepseek.com/api-docs/", "deepseek-coder"),
48
+ ("gpt-4o-mini. Please get your api key from https://platform.openai.com/settings/profile?tab=api-keys", "gpt-4o-mini"),
49
+ ("gpt-4-1106-preview. Please get your api key from https://platform.openai.com/settings/profile?tab=api-keys", "gpt-4-1106-preview")],
50
+ value="deepseek-coder",
51
+ )
52
+ code_input = gr.TextArea(
53
+ label="Code Input",
54
+ placeholder="Enter your code here",
55
+ lines=10,
56
+ )
57
+ test_cases = gr.TextArea(
58
+ label="Public Test Cases",
59
+ placeholder="Enter your public test cases",
60
+ lines=10,
61
+ )
62
+ with gr.Row(): # This Row will contain the buttons
63
+ debug_button = gr.Button("Debug", variant="primary")
64
+ clear_button = gr.Button("Clear", variant="neutral")
65
+ with gr.Column():
66
+ fixed_code_output = gr.TextArea(
67
+ label="Fixed Code",
68
+ placeholder="Fixed code will be shown here",
69
+ lines=10,
70
+ interactive=False,
71
+ visible=True,
72
+ )
73
+ debug_button.click(
74
+ debug_code,
75
+ inputs=[
76
+ openai_key_input,
77
+ model_selector,
78
+ code_input,
79
+ fixed_code_output,
80
+ test_cases,
81
+ ],
82
+ outputs=[ code_input, fixed_code_output, test_cases],
83
+ )
84
+
85
+ def clear_inputs():
86
+ return (
87
+ "",
88
+ "",
89
+ pd.DataFrame(
90
+ {
91
+ "Pass?": [],
92
+ "Expression": [],
93
+ "Expected Value": [],
94
+ "Actual Value": [],
95
+ }
96
+ ),
97
+ "",
98
+ "",
99
+ )
100
+
101
+ clear_button.click(
102
+ clear_inputs,
103
+ inputs=[],
104
+ outputs=[code_input, test_cases, fixed_code_output],
105
+ )
106
+
107
+ gr.Markdown("## Text Examples")
108
+ gr.Examples(
109
+ [
110
+ [
111
+
112
+ {'input': {'music_string': 'o o| .| o| o| .| .| .| .| o o'}, 'expected_output': [4, 2, 1, 2, 2, 1, 1, 1, 1, 4, 4]}
113
+ ,
114
+ '''
115
+ def parse_music(music_string: str) -> List[int]:
116
+ """ Input to this function is a string representing musical notes in a special ASCII format.
117
+ Your task is to parse this string and return list of integers corresponding to how many beats does each
118
+ not last.
119
+
120
+ Here is a legend:
121
+ 'o' - whole note, lasts four beats
122
+ 'o|' - half note, lasts two beats
123
+ '.|' - quater note, lasts one beat
124
+
125
+ >>> parse_music('o o| .| o| o| .| .| .| .| o o')
126
+ [4, 2, 1, 2, 2, 1, 1, 1, 1, 4, 4]
127
+ """
128
+ note_map = {'o': 3, 'o|': 2, '.|': 1}
129
+ return [note_map[x] for x in music_string.split(' ') if x]
130
+ ''',
131
+ ],
132
+ [
133
+ {'input': {'numbers': [1.0, 2.0, 3.0], 'threshold': 0.5}, 'expected_output': False},
134
+ '''
135
+ def has_close_elements(numbers: List[float], threshold: float) -> bool:
136
+ """ Check if in given list of numbers, are any two numbers closer to each other than
137
+ given threshold.
138
+ >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
139
+ False
140
+ >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
141
+ True
142
+ """
143
+ for idx, elem in enumerate(numbers):
144
+ for idx2, elem2 in enumerate(numbers):
145
+ if idx != idx2:
146
+ distance = elem - elem2
147
+ if distance < threshold:
148
+ return True
149
+
150
+ return False
151
+ '''
152
+ ],
153
+ [
154
+ {'input': {'operations': [1, 2, -4, 5]}, 'expected_output': False},
155
+ '''
156
+ def below_zero(operations: List[int]) -> bool:
157
+ """ You're given a list of deposit and withdrawal operations on a bank account that starts with
158
+ zero balance. Your task is to detect if at any point the balance of account fallls below zero, and
159
+ at that point function should return True. Otherwise it should return False.
160
+ >>> below_zero([1, 2, 3])
161
+ False
162
+ >>> below_zero([1, 2, -4, 5])
163
+ True
164
+ """
165
+ balance = 0
166
+
167
+ for op in operations:
168
+ balance += op
169
+ if balance == 0:
170
+ return True
171
+
172
+ return False
173
+ '''
174
+ ]
175
+ ],
176
+ inputs=[test_cases, code_input],
177
+ )
178
+
179
+
180
+ app.launch()
testcase_utils.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+ from loguru import logger
3
+ import ast
4
+ import re
5
+ import json
6
+ from tqdm import tqdm
7
+
8
+
9
+ def get_parameter_names(prompt: str, entry_point: str) -> List[str]:
10
+ """
11
+ Extract parameter names from the function signature in the prompt.
12
+ """
13
+ # logger.debug(f"Prompt: {prompt}")
14
+ # logger.debug(f"Entry point: {entry_point}")
15
+ tree = ast.parse(prompt)
16
+ for node in ast.walk(tree):
17
+ # logger.debug(f"Node name: {node.name if hasattr(node, 'name') else None}")
18
+ if isinstance(node, ast.FunctionDef) and node.name == entry_point:
19
+ # Return the parameter names from the function definition that matches the entry point
20
+ return [param.arg for param in node.args.args]
21
+ return []
22
+
23
+
24
+ def parse_tests(test: str, parameter_names: List[str], entry_point: str) -> Dict[str, List[Dict[str, Any]]]:
25
+ """
26
+ Parse the test string into a structured format using AST.
27
+ """
28
+ # Remove the METADATA section
29
+ test = re.sub(r'METADATA = \{[^}]*\}', '', test)
30
+
31
+ # Parse the entire test string
32
+ tree = ast.parse(test)
33
+
34
+ test_cases = []
35
+ for node in ast.walk(tree):
36
+ if isinstance(node, ast.Assert):
37
+ # Process each assert statement
38
+ test_case = process_assert(node, entry_point, parameter_names)
39
+ if test_case:
40
+ test_cases.append(test_case)
41
+
42
+ return {"test_cases": test_cases}
43
+
44
+
45
+ def process_assert(node: ast.Assert, entry_point: str, parameter_names: List[str]) -> Dict[str, Any]:
46
+ """
47
+ Process a single assert statement and extract input and expected output.
48
+ """
49
+ if isinstance(node.test, ast.Compare) and isinstance(node.test.ops[0], ast.Eq):
50
+ left = node.test.left
51
+ right = node.test.comparators[0]
52
+
53
+ if isinstance(left, ast.Call) and isinstance(left.func, ast.Name) and left.func.id == "candidate":
54
+ input_dict = process_input(left.args, parameter_names)
55
+ # logger.debug(f"Input: {input_dict}")
56
+ # logger.debug(f"right: {right}")
57
+ # logger.debug(f"right type: {type(right)}")
58
+ # logger.debug(f"right value: {right.name if isinstance(right, ast.Name) else right.s if isinstance(right, ast.Str) else None}")
59
+
60
+ try:
61
+ # Attempt to evaluate using literal_eval
62
+ expected_output = ast.literal_eval(right)
63
+ except ValueError:
64
+ # Fallback to eval if literal_eval fails
65
+ # logger.warning("Falling back to eval due to failure in literal_eval")
66
+ expected_output = eval(compile(ast.Expression(right), filename="<ast>", mode="eval"))
67
+
68
+ return {"input": input_dict, "expected_output": expected_output}
69
+
70
+ return None
71
+
72
+
73
+ def process_input(args: List[ast.expr], parameter_names: List[str]) -> Dict[str, Any]:
74
+ """
75
+ Process the input arguments and match them with parameter names.
76
+ """
77
+ input_dict = {}
78
+
79
+ for i, arg in enumerate(args):
80
+ try:
81
+ # Attempt to evaluate using literal_eval for simpler cases
82
+ evaluated_arg = ast.literal_eval(arg)
83
+ except ValueError:
84
+ # Fallback to eval if literal_eval fails
85
+ # logger.warning("Falling back to eval due to failure in literal_eval")
86
+ evaluated_arg = eval(compile(ast.Expression(arg), filename="<ast>", mode="eval"))
87
+
88
+ if i < len(parameter_names):
89
+ input_dict[parameter_names[i]] = evaluated_arg
90
+ else:
91
+ # Handle extra arguments if any
92
+ input_dict[f"arg_{i}"] = evaluated_arg
93
+
94
+ return input_dict
95
+
96
+
97
+ def parse_all_problems(problems):
98
+ success_count = 0
99
+ unhandled_failures = 0
100
+ for problem in problems:
101
+ try:
102
+ problem = json.loads(problem)
103
+
104
+ # logger.info(f"Problem: {problem}")
105
+ # logger.debug(f"Test: {problem['test']}")
106
+
107
+ entry_point = problem["entry_point"]
108
+ parameter_names = get_parameter_names(problem["prompt"], entry_point)
109
+ # logger.info(f"Parameter names: {parameter_names}")
110
+
111
+ given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
112
+ given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
113
+
114
+ # Parse the test cases using the parameter names
115
+ parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
116
+ # logger.info(f"Parsed tests: {parsed_tests}")
117
+ success_count += 1
118
+ except:
119
+ logger.exception(f"Error processing problem {problem['task_id']}")
120
+ if problem['is_solved'] == False:
121
+ unhandled_failures += 1
122
+ continue
123
+
124
+ logger.info(f"Success count: {success_count}")
125
+ logger.info(f"Total problems: {len(problems)}")
126
+ logger.info(f"Unhandled failures: {unhandled_failures}")
127
+
128
+
129
+ def parse_specific_problem(problem):
130
+ try:
131
+ if isinstance(problem, str):
132
+ problem = json.loads(problem)
133
+
134
+ logger.info(f"Problem: {problem}")
135
+ logger.debug(f"Test: {problem['test']}")
136
+ logger.debug(f"Given Test: {problem['given_tests']}")
137
+
138
+ entry_point = problem["entry_point"]
139
+ parameter_names = get_parameter_names(problem["prompt"], entry_point)
140
+ logger.debug(f"Parameter names: {parameter_names}")
141
+
142
+ given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
143
+ given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
144
+ logger.debug(f"Given tests: {given_tests}")
145
+
146
+ # Parse the test cases using the parameter names
147
+ all_tests = parse_tests(problem["test"], parameter_names, entry_point)
148
+ logger.debug(f"Parsed tests: {all_tests}")
149
+ return all_tests
150
+ except:
151
+ logger.exception(f"Error processing problem {problem['task_id']}")
152
+ return None
153
+
154
+ #assert next_smallest([]) is None
155
+ #assert decode_cyclic(encode_cyclic("abc")) == "abc"
156
+ #assert round(find_zero([-6, 11, -6, 1]), 2) == 1.0
157
+ #assert abs(candidate(1.33) - 0.33) < 1e-6
158
+
159
+ def check_all_problems(problems):
160
+ problems_q = []
161
+ success_count = 0
162
+ fail_count = 0
163
+ for problem in tqdm(problems):
164
+ try:
165
+ problem = json.loads(problem)
166
+
167
+ logger.info(f"Problem: {problem}")
168
+ logger.debug(f"Test: {problem['test']}")
169
+ logger.debug(f"All Test: {problem['given_tests']}")
170
+
171
+ entry_point = problem["entry_point"]
172
+ parameter_names = get_parameter_names(problem["prompt"], entry_point)
173
+ logger.info(f"Parameter names: {parameter_names}")
174
+
175
+ # given_tests_len = len(problem["given_tests"])
176
+ # given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
177
+ # given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
178
+ # parsed_given_tests_len = len(given_tests['test_cases'])
179
+ # assert given_tests_len == parsed_given_tests_len
180
+ # success_count += 1
181
+
182
+ #Parse the test cases using the parameter names
183
+ tests_len_candidate = problem["test"].count('candidate')
184
+ parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
185
+ parsed_test_len = len(parsed_tests['test_cases'])
186
+ #assert parsed_test_len != 0
187
+ assert tests_len_candidate - 1 == parsed_test_len
188
+ logger.info(f"Parsed tests: {parsed_tests}")
189
+ success_count += 1
190
+ except:
191
+ logger.exception(f"Error processing problem {problem['task_id']}")
192
+ if problem['is_solved'] == False:
193
+ fail_count += 1
194
+ problems_q.append(problem['task_id'])
195
+ continue
196
+
197
+ with open('output_data/humaneval/seed/deepseek-coder-v2-lite-instruct/20240828-174550/dscoder_debugged_seeds_deepseek-coder-v2-lite-instruct_1_1_10.jsonl', "r") as f:
198
+ fixed = f.readlines()
199
+ for fix_problem in fixed:
200
+ fix_problem = json.loads(fix_problem)
201
+ if fix_problem['task_id'] in problems_q:
202
+ print(1)
203
+
204
+ logger.info(f"Success count: {success_count}")
205
+ logger.info(f"Total problems: {len(problems)}")
206
+ logger.info(f"Unhandled failures: {fail_count}")
207
+
208
+ if __name__ == "__main__":
209
+ input_seeds = "input_data/humaneval/seed/deepseek-coder-v2-lite-instruct/seed.jsonl"
210
+
211
+ with open(input_seeds, "r") as f:
212
+ problems = f.readlines()
213
+
214
+ check_all_problems(problems)
215
+ #parse_all_problems(problems)
216
+
217
+ # parse the one with 'task_id': 'HumanEval/32'
218
+ # for problem in problems:
219
+ # problem = json.loads(problem)
220
+ # if problem['task_id'] == 'HumanEval/33':
221
+ # parsed_tests = parse_specific_problem(problem)
222
+ # break
timeout_utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from threading import Thread
4
+ import resource
5
+ import signal
6
+ import json
7
+ import os
8
+ import traceback
9
+ from loguru import logger
10
+
11
+
12
+ def set_memory_limit(max_memory):
13
+ def limit_memory():
14
+ resource.setrlimit(resource.RLIMIT_AS, (max_memory, max_memory))
15
+ return limit_memory
16
+
17
+
18
+ def timeout_handler(_, __):
19
+ raise TimeoutError()
20
+
21
+
22
+ def to_jsonl(dict_data, file_path):
23
+ with open(file_path, 'a') as file:
24
+ json_line = json.dumps(dict_data)
25
+ file.write(json_line + os.linesep)
26
+
27
+
28
+ class PropagatingThread(Thread):
29
+ def run(self):
30
+ self.exc = None
31
+ try:
32
+ if hasattr(self, '_Thread__target'):
33
+ # Thread uses name mangling prior to Python 3.
34
+ self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs)
35
+ else:
36
+ self.ret = self._target(*self._args, **self._kwargs)
37
+ except Exception as e:
38
+ self.exc = e
39
+
40
+ def join(self, timeout=None):
41
+ super(PropagatingThread, self).join(timeout)
42
+ if self.exc:
43
+ raise self.exc
44
+ if self.is_alive():
45
+ return None
46
+ return self.ret
47
+
48
+ def terminate(self):
49
+ self._stop()
50
+
51
+
52
+ def function_with_timeout(func, args, timeout, max_memory=100 * 1024 * 1024):
53
+ result_container = []
54
+
55
+ def wrapper():
56
+ # set_memory_limit(max_memory)()
57
+ result_container.append(func(*args))
58
+
59
+ thread = PropagatingThread(target=wrapper)
60
+ thread.start()
61
+ thread.join(timeout)
62
+ if thread.is_alive():
63
+ logger.error(f"Timeout Error\n {args[0]} with timeout {timeout}")
64
+ thread.terminate()
65
+ raise TimeoutError()
66
+ else:
67
+ return result_container[0]
utils.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ from loguru import logger
3
+ import sys
4
+ import io
5
+ import json
6
+ import re
7
+ import traceback
8
+ import os
9
+ from timeout_utils import function_with_timeout
10
+
11
+ helpers = [
12
+ "import math",
13
+ "import re",
14
+ "import sys",
15
+ "import copy",
16
+ "import datetime",
17
+ "import itertools",
18
+ "import collections",
19
+ "import heapq",
20
+ "import statistics",
21
+ "import functools",
22
+ "import hashlib",
23
+ "import numpy",
24
+ "import numpy as np",
25
+ "import string",
26
+ "from typing import *",
27
+ "from collections import *",
28
+ "import heapq as hq",
29
+ "from itertools import *",
30
+ "from math import *",
31
+ "from statistics import *",
32
+ "from functools import *",
33
+ "from collections import *",
34
+ "from datetime import *",
35
+ "from copy import *",
36
+ ]
37
+
38
+ STARTING_CODE = "\n".join(helpers)
39
+
40
+
41
+ def create_dependency_graph(functions):
42
+ graph = {func_name: set() for func_name in functions}
43
+ for func_name, func_code in functions.items():
44
+ for other_func in functions:
45
+ if other_func in func_code and other_func != func_name:
46
+ graph[func_name].add(other_func)
47
+ return graph
48
+
49
+
50
+ def topological_sort(graph):
51
+ visited = set()
52
+ stack = []
53
+
54
+ def dfs(node):
55
+ visited.add(node)
56
+ for neighbor in graph[node]:
57
+ if neighbor not in visited:
58
+ dfs(neighbor)
59
+ stack.append(node)
60
+
61
+ for node in graph:
62
+ if node not in visited:
63
+ dfs(node)
64
+
65
+ return stack
66
+
67
+
68
+ def merge_changes_to_parents(func_name, dependency_graph, functions):
69
+ # Update the function in the functions dictionary
70
+ logger.info(f"Updating function {func_name} in the functions dictionary")
71
+
72
+ # For any function that calls the modified function, update its code
73
+ for parent, children in dependency_graph.items():
74
+ if func_name in children:
75
+ parent_code = functions[parent]
76
+ updated_parent_code = parent_code.replace(func_name, f"{func_name}")
77
+ functions[parent] = updated_parent_code
78
+ logger.info(f"Updated references to {func_name} in parent function {parent}")
79
+
80
+ # Regenerate the full code
81
+ full_code = "\n\n".join(functions.values())
82
+
83
+ logger.info(f"Merged changes from {func_name} to all relevant functions")
84
+ return full_code
85
+
86
+
87
+ def extract_functions(code):
88
+ logger.info("Extracting functions from code")
89
+ tree = ast.parse(code)
90
+ functions = {}
91
+ for node in ast.walk(tree):
92
+ if isinstance(node, ast.FunctionDef):
93
+ func_code = ast.get_source_segment(code, node)
94
+ functions[node.name] = func_code
95
+ logger.info(f"Extracted {len(functions)} functions: {', '.join(functions.keys())}")
96
+ return functions
97
+
98
+
99
+ def extract_code_blocks(response):
100
+ """Extract all code blocks from the response."""
101
+ return re.findall(r'```python\s*(.*?)\s*```', response, re.DOTALL)
102
+
103
+
104
+ def extract_function(code_block, function_name):
105
+ """Extract a specific function from a code block."""
106
+ try:
107
+ tree = ast.parse(code_block)
108
+ except:
109
+ logger.error(f"Failed to parse code block for function: {function_name} from\n{code_block}")
110
+ return None
111
+ for node in ast.walk(tree):
112
+ if isinstance(node, ast.FunctionDef) and node.name == function_name:
113
+ return ast.get_source_segment(code_block, node)
114
+ return None
115
+
116
+
117
+ def evaluate_given_tests(code, given_tests, max_memory=100 * 1024 * 1024):
118
+ test_code = f"{STARTING_CODE}\n\n{code}\n\n{given_tests}"
119
+ try:
120
+ function_with_timeout(exec, (test_code, globals()), timeout=10, max_memory=max_memory)
121
+ return True
122
+ except TimeoutError as e:
123
+ logger.error(f"Timeout Error: {str(e)}")
124
+ except MemoryError as e:
125
+ logger.error(f"Memory Error: {str(e)}")
126
+ except AssertionError as e:
127
+ logger.error(f"Assertion Error: {str(e)}")
128
+ except Exception as e:
129
+ logger.error(f'Error: {str(e)}')
130
+ logger.error(f'Traceback: {traceback.format_exc()}')
131
+ return False
132
+
133
+ def evaluate_simple(code, entry_point, all_test, max_memory=100 * 1024 * 1024):
134
+ '''
135
+ directly concatenate the code and test code to evaluate on the private test cases
136
+ '''
137
+
138
+ test_code = f"{STARTING_CODE}\n\n{code}\n\n{all_test}\n\ncheck({entry_point})"
139
+ try:
140
+ function_with_timeout(exec, (test_code, globals()), timeout=10, max_memory=max_memory)
141
+ return True
142
+ except TimeoutError as e:
143
+ logger.error(f"Timeout Error: {str(e)}")
144
+ except MemoryError as e:
145
+ logger.error(f"Memory Error: {str(e)}")
146
+ except AssertionError as e:
147
+ logger.error(f"Assertion Error: {str(e)}")
148
+ except Exception as e:
149
+ logger.error(f'Error: {str(e)}')
150
+ logger.error(f'Traceback: {traceback.format_exc()}')
151
+ return False
152
+
153
+
154
+ def evaluate(code, entry_point, testcase, return_trace=False):
155
+ logger.info(f"Evaluating {entry_point} with testcase: {testcase['input']}")
156
+
157
+ # Extract all functions from the code
158
+ try:
159
+ functions = extract_functions(code)
160
+ except:
161
+ logger.error(f"Failed to extract functions from code {code}")
162
+ # import pdb
163
+ # pdb.set_trace()
164
+ logger.info(f"Extracted functions: {', '.join(functions.keys())}")
165
+
166
+ # filter the functions that are called in the entry_point function
167
+ entry_point_function = functions[entry_point]
168
+ # entry_point_tree = ast.parse(entry_point_function)
169
+ # entry_point_calls = [node.func.id for node in ast.walk(entry_point_tree) if isinstance(node, ast.Call)]
170
+ # functions = {name: func for name, func in functions.items() if name in entry_point_calls}
171
+ # directly search for the string
172
+ functions = {name: func for name, func in functions.items() if name in entry_point_function}
173
+ logger.info(f"Filtered functions: {', '.join(functions.keys())}")
174
+
175
+ # Combine all functions into a single code block
176
+ full_code = "\n\n".join(functions.values())
177
+ # logger.info(f"Code being evaluated:\n{full_code}")
178
+
179
+ # Convert the input to a string representation that can be safely evaluated
180
+ input_repr = repr(testcase['input'])
181
+
182
+ if isinstance(testcase['input'], dict):
183
+ # Sometimes the input is a dictionary, which needs to be unpacked as keyword arguments
184
+ test_code = f'''{full_code}\n\nprint(repr({entry_point}(**{input_repr})))'''
185
+ else:
186
+ test_code = f'''{full_code}\n\nprint(repr({entry_point}({input_repr})))'''
187
+
188
+ # add the starting code to the test code
189
+ test_code = f"{STARTING_CODE}\n\n{test_code}"
190
+
191
+ old_stdout = sys.stdout
192
+ new_stdout = io.StringIO()
193
+ sys.stdout = new_stdout
194
+
195
+ try:
196
+ function_with_timeout(exec, (test_code, globals()), timeout=10)
197
+ output = new_stdout.getvalue().strip()
198
+ sys.stdout = old_stdout
199
+
200
+ # Convert both expected and actual output to the same type for comparison
201
+ expected_output = repr(testcase["expected_output"])
202
+
203
+ # Update actual_output before assertion
204
+ testcase['actual_output'] = ast.literal_eval(output)
205
+
206
+ assert output == expected_output, f"Expected {expected_output}, but got {output}"
207
+ logger.info(f'Test case passed: {testcase}')
208
+ logger.info(f'Expected: {expected_output}, Got: {output}')
209
+ return True, testcase
210
+ except TimeoutError:
211
+ logger.error(f'Test case failed: {testcase}')
212
+ logger.error(f"Timeout Error: {str(e)}")
213
+ except AssertionError as e:
214
+ logger.error(f'Test case failed: {testcase}')
215
+ logger.error(str(e))
216
+ except Exception as e:
217
+ logger.error(f'Test case failed: {testcase}')
218
+ logger.error(f'Error: {str(e)}')
219
+ logger.error(f'Traceback: {traceback.format_exc()}')
220
+ testcase['actual_output'] = str(e)
221
+ if return_trace:
222
+ testcase['traceback'] = traceback.format_exc()
223
+ finally:
224
+ sys.stdout = old_stdout
225
+
226
+ return False, testcase
227
+
228
+
229
+ def extract_json_from_string(s):
230
+
231
+ # search for all the ```json blocks
232
+ matches = re.findall(r'```json\s*(.*?)\s*```', s, re.DOTALL)
233
+ if matches:
234
+ return matches[-1]
235
+ return None
236
+
237
+
238
+ def parse_json_response(response):
239
+ json_str = extract_json_from_string(response)
240
+ if json_str:
241
+ try:
242
+ # Standard JSON corrections
243
+ json_str = json_str.strip().replace("True", "true")
244
+ json_str = json_str.replace("False", "false")
245
+ json_str = json_str.replace("'", '"')
246
+ json_str = json_str.replace("None", "null")
247
+
248
+ # Convert tuple notation to list notation
249
+ json_str = re.sub(r'\((-?\d+),\s*(-?\d+)\)', r'[\1, \2]', json_str)
250
+
251
+ logger.info(f"Extracted JSON string: {json_str}")
252
+ try:
253
+ return json.loads(json_str)
254
+ except:
255
+ # remove comments (for mistral model)
256
+ json_str = re.sub(r'#.*', '', json_str)
257
+ return json.loads(json_str)
258
+ except json.JSONDecodeError as e:
259
+ logger.error(f"Failed to parse extracted JSON: {json_str}")
260
+ logger.error(f"JSONDecodeError: {str(e)}")
261
+ # import pdb
262
+ # pdb.set_trace()
263
+ else:
264
+ logger.error("No JSON object found in the response")
265
+ return None
266
+
267
+
268
+ def get_dependency_graph_str(graph, root=None, prefix="", is_last=True):
269
+ result = []
270
+
271
+ if root is None:
272
+ # Collect all roots if no specific root is given
273
+ roots = [node for node in graph if not any(node in children for children in graph.values())]
274
+ for i, root in enumerate(roots):
275
+ result.append(get_dependency_graph_str(graph, root, "", i == len(roots) - 1))
276
+ return "\n".join(result)
277
+
278
+ connector = "└── " if is_last else "β”œβ”€β”€ "
279
+ result.append(prefix + connector + root)
280
+
281
+ if root in graph:
282
+ children = sorted(graph[root])
283
+ new_prefix = prefix + (" " if is_last else "β”‚ ")
284
+ for i, child in enumerate(children):
285
+ is_last_child = (i == len(children) - 1)
286
+ result.append(get_dependency_graph_str(graph, child, new_prefix, is_last_child))
287
+
288
+ return "\n".join(result)
289
+
290
+
291
+ def extract_functions_from_code(node, parent=None):
292
+ """ Recursively extract functions and set parents. """
293
+ if isinstance(node, ast.Module):
294
+ for n in node.body:
295
+ extract_functions_from_code(n, parent=node)
296
+ elif isinstance(node, ast.FunctionDef):
297
+ node.parent = parent
298
+ if parent is not None and isinstance(parent, (ast.FunctionDef, ast.Module)):
299
+ parent.children.append(node)
300
+ for n in node.body:
301
+ extract_functions_from_code(n, parent=node)
302
+
303
+
304
+ def split_nested_functions(code):
305
+ tree = ast.parse(code)
306
+ for node in ast.walk(tree):
307
+ node.children = []
308
+ extract_functions_from_code(tree)
309
+
310
+ flat_functions = []
311
+
312
+ def flatten_functions(node):
313
+ if isinstance(node, ast.FunctionDef):
314
+ flat_functions.append(node)
315
+ # Remove nested function definitions from the body
316
+ node.body = [n for n in node.body if not isinstance(n, ast.FunctionDef)]
317
+ for child in node.children:
318
+ flatten_functions(child)
319
+
320
+ flatten_functions(tree)
321
+
322
+ # Function to correct indentation for function docstrings
323
+ def correct_indentation(functions):
324
+ for func in functions:
325
+ # Get existing docstring if present
326
+ docstring = ast.get_docstring(func)
327
+ if docstring:
328
+ # Replace existing docstring node with corrected indentation
329
+ corrected_docstring = "\n".join([line if line.strip() != "" else "" for line in docstring.split("\n")])
330
+ func.body[0].value.s = corrected_docstring
331
+
332
+ correct_indentation(flat_functions)
333
+
334
+ return '\n\n'.join(ast.unparse(f).strip() for f in flat_functions)
335
+
336
+
337
+ def remove_unused_functions(code, entry_point):
338
+
339
+ tree = ast.parse(code)
340
+
341
+ function_names = {node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)}
342
+ function_calls = set()
343
+
344
+ class FunctionCallVisitor(ast.NodeVisitor):
345
+ def visit_Call(self, node):
346
+ if isinstance(node.func, ast.Name) and node.func.id in function_names:
347
+ function_calls.add(node.func.id)
348
+ self.generic_visit(node)
349
+
350
+ FunctionCallVisitor().visit(tree)
351
+
352
+ used_functions = set()
353
+
354
+ def mark_used(func_name):
355
+ if func_name not in used_functions:
356
+ used_functions.add(func_name)
357
+ for node in ast.walk(tree):
358
+ if isinstance(node, ast.FunctionDef) and node.name == func_name:
359
+ FunctionCallVisitor().visit(node)
360
+ for call in function_calls:
361
+ mark_used(call)
362
+
363
+ mark_used(entry_point)
364
+
365
+ # only keep the functions that are used
366
+ tree.body = [node for node in tree.body if not isinstance(node, ast.FunctionDef) or node.name in used_functions]
367
+
368
+ all_unused_functions = function_names - used_functions
369
+
370
+ # convert back to code
371
+ return ast.unparse(tree), all_unused_functions
372
+
373
+
374
+ def test_remove_unused_functions():
375
+ code = '''
376
+ def rolling_max(numbers: List[int]) -> List[int]:
377
+ """From a given list of integers, generate a list of rolling maximum element found until given moment
378
+ in the sequence.
379
+ >>> rolling_max([1, 2, 3, 2, 3, 4, 2])
380
+ [1, 2, 3, 3, 3, 4, 4]"""
381
+ (max_so_far, rolling_max_list) = initialize_max_and_list(numbers)
382
+ for num in numbers[1:]:
383
+ (max_so_far, rolling_max_list) = update_max_and_list(max_so_far, num, rolling_max_list)
384
+ return rolling_max_list
385
+
386
+ def initialize_max_and_list(numbers: List[int]) -> Tuple[int, List[int]]:
387
+ max_so_far = numbers[0]
388
+ rolling_max_list = [max_so_far]
389
+ return (max_so_far, rolling_max_list)
390
+
391
+ def update_max_and_list(max_so_far: int, num: int, rolling_max_list: List[int]) -> Tuple[int, List[int]]:
392
+ max_so_far = max(max_so_far, num)
393
+ rolling_max_list.append(max_so_far)
394
+ return (max_so_far, rolling_max_list)
395
+
396
+ def clean_data(data: List[str]) -> List[str]:
397
+ return [d.strip() for d in data]
398
+ '''.strip()
399
+
400
+ entry_point = "rolling_max"
401
+ logger.info(f"Original code:\n{code}")
402
+ output, unused_functions = remove_unused_functions(code, entry_point)
403
+ logger.info(f"Unused functions: {unused_functions}")
404
+ logger.info(f"Cleaned code:\n{output}")
405
+
406
+
407
+ def test_split_nested_functions():
408
+ # The initial code provided by the user
409
+ code = '''
410
+ def find_suffix_start(s: str) -> int:
411
+ for i in range(len(s)):
412
+ if is_palindrome(s[i:]):
413
+ return i
414
+ return 0
415
+
416
+ def make_palindrome(string: str) -> str:
417
+ """This function takes a string and returns a palindrome by appending the reverse of the prefix of the string that makes it a palindrome."""
418
+
419
+
420
+ def is_palindrome(s: str) -> bool:
421
+ """
422
+ This function takes a string and returns True if it is a palindrome, False otherwise.
423
+ """
424
+
425
+
426
+ def compare(s: str) -> bool:
427
+ """
428
+ This function takes a string and returns True if it is a palindrome, False otherwise.
429
+ inner function
430
+ """
431
+
432
+ return s == s[::-1]
433
+
434
+ return compare(s)
435
+
436
+ suffix_start = find_suffix_start(string)
437
+ return string + string[:suffix_start][::-1]
438
+ '''.strip()
439
+
440
+ # Splitting the nested functions and correcting the indentation
441
+ output = split_nested_functions(code)
442
+ print(output)
443
+
444
+
445
+ def test_parse_json_response():
446
+
447
+ response = """
448
+ **All Test Cases:**
449
+ ```json
450
+ {
451
+ "test_cases": [
452
+ {"input": {"date": "03-11-2000"}, "expected_output": [11, 3, 2000]},
453
+ {"input": {"date": "15-01-2012"}, "expected_output": [15, 1, 2012]},
454
+ {"input": {"date": "04-0-2040"}, "expected_output": None},
455
+ {"input": {"date": "06-04-2020"}, "expected_output": [4, 6, 2020]},
456
+ {"input": {"date": "06/04/2020"}, "expected_output": None}
457
+ ]
458
+ }
459
+ ```
460
+ """.strip()
461
+
462
+ parsed_json = parse_json_response(response)
463
+ print(parsed_json)
464
+
465
+
466
+ def insert_docstring(code, docstring):
467
+
468
+ # surround the docstring with triple quotes
469
+ docstring = f'"""{docstring}"""'
470
+
471
+ lines = code.split('\n')
472
+ # Find the first non-empty line
473
+ first_line = next((i for i, line in enumerate(lines) if line.strip()), 0)
474
+
475
+ # Determine the indentation of the first line
476
+ indentation = len(lines[first_line]) - len(lines[first_line].lstrip())
477
+
478
+ # Find the 'def' line
479
+ def_line = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), first_line)
480
+
481
+ # Insert the docstring after the 'def' line, maintaining indentation
482
+ docstring_lines = [' ' * (indentation + 4) + line for line in docstring.split('\n')]
483
+ lines = lines[:def_line+1] + docstring_lines + lines[def_line+1:]
484
+
485
+ return '\n'.join(lines)
486
+
487
+
488
+ def parse_transcoder_problem_content(problem):
489
+ # Extract the last group of content between [c++] and [python]
490
+ cpp_code = problem["prompt"].split("[c++]")[-1].split("[python]")[0].strip()
491
+ full_question = f'This function is translated into Python from the following C++ code: \n{cpp_code}\n'
492
+
493
+ try:
494
+ # Try to parse the existing solution
495
+ tree = ast.parse(problem["solution"])
496
+
497
+ # Create a new docstring node
498
+ docstring = ast.Expr(ast.Str(full_question))
499
+
500
+ # Find the first function definition in the AST
501
+ for node in tree.body:
502
+ if isinstance(node, ast.FunctionDef):
503
+ # Insert the docstring at the beginning of the function body
504
+ node.body.insert(0, docstring)
505
+ break
506
+ else:
507
+ # If no function definition is found, add the docstring at the end of the module
508
+ tree.body.append(docstring)
509
+
510
+ # Convert the modified AST back to source code
511
+ modified_solution = ast.unparse(tree)
512
+
513
+ except SyntaxError:
514
+ # If there's a syntax error, use the string-based method
515
+ logger.debug(f"Failed to parse solution for problem: {problem['task_id']}")
516
+ modified_solution = insert_docstring(problem["solution"], full_question)
517
+ logger.debug(f"Modified solution: {modified_solution}")
518
+
519
+ # Update the problem dictionary with the modified solution
520
+ problem["solution"] = modified_solution
521
+
522
+ return problem
523
+
524
+
525
+ def test_parse_transcoder_problem_content():
526
+
527
+ input_seeds = "input_data/transcoder/seed/starcoder/seed.jsonl"
528
+ with open(input_seeds, "r") as f:
529
+ problems = [json.loads(line) for line in f]
530
+
531
+ for problem in problems:
532
+ try:
533
+ result = parse_transcoder_problem_content(problem)
534
+ except Exception as e:
535
+ logger.error(f"Failed to parse solution for problem: {problem['task_id']}")
536
+ logger.error(f"The solution is: \n{problem['solution']}")
537
+ logger.error(f"Error: {str(e)}")
538
+
539
+ logger.info("Successfully parsed all solutions")
540
+ # show an example
541
+ logger.info(f"Example result: {result['solution']}")
542
+
543
+
544
+ if __name__ == "__main__":
545
+
546
+ # test_split_nested_functions()
547
+ # test_parse_json_response()
548
+ # test_remove_unused_functions()
549
+ test_parse_transcoder_problem_content()