Spaces:
Sleeping
Sleeping
Duplicate from codeparrot/apps_metric
Browse filesCo-authored-by: Loubna Ben Allal <[email protected]>
- .gitattributes +27 -0
- README.md +53 -0
- app.py +6 -0
- apps_metric.py +82 -0
- example_script.py +133 -0
- requirements.txt +3 -0
- test_examples/solutions_problem_1.json +1 -0
- test_examples/solutions_problem_2.json +1 -0
- testing_util.py +525 -0
- tests.py +14 -0
- utils.py +212 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: APPS Metric
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
+
datasets:
|
7 |
+
- null
|
8 |
+
tags:
|
9 |
+
- evaluate
|
10 |
+
- metric
|
11 |
+
description: Evaluation metric for the APPS benchmark
|
12 |
+
sdk: gradio
|
13 |
+
sdk_version: 3.0.2
|
14 |
+
app_file: app.py
|
15 |
+
pinned: false
|
16 |
+
duplicated_from: codeparrot/apps_metric
|
17 |
+
---
|
18 |
+
|
19 |
+
# Metric Card for apps_metric
|
20 |
+
|
21 |
+
## Metric Description
|
22 |
+
This metric is used to evaluate code generation on the [APPS benchmark](https://huggingface.co/datasets/codeparrot/apps).
|
23 |
+
|
24 |
+
## How to Use
|
25 |
+
You can load the metric and use it with the following commands:
|
26 |
+
|
27 |
+
```python
|
28 |
+
from evaluate import load
|
29 |
+
apps_metric = load('codeparrot/apps_metric')
|
30 |
+
# to evaluate generations made for all levels for example
|
31 |
+
results = apps_metric.compute(predictions=generations, level="all")
|
32 |
+
```
|
33 |
+
|
34 |
+
### Inputs
|
35 |
+
**generations** list(list(str)): List of code generations, each sub-list corresponds to the generations for a problem in APPS dataset, **the order of the samples in the dataset must be kept (with respect to the difficulty level)**.
|
36 |
+
|
37 |
+
### Output Values
|
38 |
+
|
39 |
+
**average accuracy**: when a single solution is generated, average accuracy computes the average of test cases that are passed.
|
40 |
+
|
41 |
+
**strict accuracy**: when a single solution is generated, strict accuracy computes the average number of problems that pass all their test cases.
|
42 |
+
|
43 |
+
**pass@k**: when multiple solutions are generated per problem, pass@k is the metric originally used for the [HumanEval](https://huggingface.co/datasets/openai_humaneval) benchmark. For more details please refer to the [metric space](https://huggingface.co/spaces/evaluate-metric/code_eval) and [Codex paper](https://arxiv.org/pdf/2107.03374v2.pdf).
|
44 |
+
|
45 |
+
## Citation
|
46 |
+
```
|
47 |
+
@article{hendrycksapps2021,
|
48 |
+
title={Measuring Coding Challenge Competence With APPS},
|
49 |
+
author={Dan Hendrycks and Steven Basart and Saurav Kadavath and Mantas Mazeika and Akul Arora and Ethan Guo and Collin Burns and Samir Puranik and Horace He and Dawn Song and Jacob Steinhardt},
|
50 |
+
journal={NeurIPS},
|
51 |
+
year={2021}
|
52 |
+
}
|
53 |
+
```
|
app.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from evaluate.utils import launch_gradio_widget
|
3 |
+
|
4 |
+
|
5 |
+
module = evaluate.load("loubnabnl/apps_metric")
|
6 |
+
launch_gradio_widget(module)
|
apps_metric.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Evaluation of code generation on the APPS benchmark"""
|
15 |
+
|
16 |
+
import evaluate
|
17 |
+
import datasets
|
18 |
+
from .utils import compute_metrics
|
19 |
+
from .testing_util import run_test
|
20 |
+
|
21 |
+
|
22 |
+
_CITATION = """\
|
23 |
+
@article{hendrycksapps2021,
|
24 |
+
title={Measuring Coding Challenge Competence With APPS},
|
25 |
+
author={Dan Hendrycks and Steven Basart and Saurav Kadavath and Mantas Mazeika and Akul Arora and Ethan Guo and Collin Burns and Samir Puranik and Horace He and Dawn Song and Jacob Steinhardt},
|
26 |
+
journal={NeurIPS},
|
27 |
+
year={2021}
|
28 |
+
}
|
29 |
+
"""
|
30 |
+
|
31 |
+
|
32 |
+
_DESCRIPTION = """\
|
33 |
+
This is a metric to evaluate code generation using the APPS benchmark "Measuring Coding Challenge Competence With
|
34 |
+
APPS" (https://arxiv.org/pdf/2105.09938.pdf).
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
# TODO: Add description of the arguments of the module here
|
39 |
+
_KWARGS_DESCRIPTION = """
|
40 |
+
Computes Average accuracy and strict accuracy for single generations, and pass@k for multiple generations.
|
41 |
+
Args:
|
42 |
+
predictions: list of code generations to score. It's a list of list(s), each corresponding to a problem from APPS dataset.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
metrics: dict of three metrics: average accuracy, stric accuracy, and pass@k.
|
46 |
+
Examples:
|
47 |
+
>>> my_new_module = evaluate.load("loubnabnl/apps_metric")
|
48 |
+
>>> results = my_new_module.compute(predictions=[["s=input()\nprint(s)"]])
|
49 |
+
>>> print(results)
|
50 |
+
{'avg_accuracy': 0, 'strict_accuracy': 0, 'pass_at_k': None}
|
51 |
+
"""
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
57 |
+
class apps_metric(evaluate.EvaluationModule):
|
58 |
+
"""Evaluate code generation on APPS benchmark.
|
59 |
+
The generations are compiled and their corresponding unit tests are run"""
|
60 |
+
|
61 |
+
def _info(self):
|
62 |
+
|
63 |
+
return evaluate.EvaluationModuleInfo(
|
64 |
+
|
65 |
+
module_type="metric",
|
66 |
+
description=_DESCRIPTION,
|
67 |
+
citation=_CITATION,
|
68 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
69 |
+
|
70 |
+
features=datasets.Features({
|
71 |
+
'predictions': datasets.Sequence(datasets.Value("string")),
|
72 |
+
}),
|
73 |
+
homepage="https://github.com/hendrycks/apps",
|
74 |
+
reference_urls=["https://huggingface.co/datasets/codeparrot/apps"]
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
def _compute(self, predictions, k_list=[1, 10, 100], count_errors=True, level="all", debug=False):
|
80 |
+
"""Returns the scores"""
|
81 |
+
metrics = compute_metrics(predictions, k_list=k_list, count_errors=count_errors, level=level, debug=debug)
|
82 |
+
return metrics
|
example_script.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This is an example script to evaluate a code generation model on APPS, you can also use the APPS solutions as code generations
|
2 |
+
> python example_script.py --model_ckpt MODEL_NAME --num_tasks 10 --difficulty introductory --n_samples 1
|
3 |
+
> python example_script.py --use_solutions True --num_tasks 10 --difficulty introductory --n_samples 1"""
|
4 |
+
|
5 |
+
import json
|
6 |
+
import pprint
|
7 |
+
from tqdm import tqdm
|
8 |
+
from datasets import load_dataset
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, set_seed
|
10 |
+
from evaluate import load
|
11 |
+
|
12 |
+
def generate_prompt(sample):
|
13 |
+
starter_code = None if len(sample["starter_code"]) == 0 else sample["starter_code"]
|
14 |
+
try:
|
15 |
+
input_outpout = json.loads(sample["input_output"])
|
16 |
+
fn_name = None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
|
17 |
+
except ValueError:
|
18 |
+
fn_name = None
|
19 |
+
_input = "\nQUESTION:\n"
|
20 |
+
_input += sample["question"]
|
21 |
+
if starter_code:
|
22 |
+
_input += starter_code
|
23 |
+
if fn_name:
|
24 |
+
_input += "\nUse Standard Input format"
|
25 |
+
else:
|
26 |
+
_input += "\nUse Call-Based format"
|
27 |
+
|
28 |
+
_input += "\nANSWER:\n"
|
29 |
+
return _input
|
30 |
+
|
31 |
+
|
32 |
+
def complete_code(pipe, prompt, num_completions=1, max_length=256, **gen_kwargs):
|
33 |
+
"""Complete prompt with text generation pipeline and return num_completions."""
|
34 |
+
prompt = pipe.tokenizer.eos_token + prompt
|
35 |
+
try:
|
36 |
+
code_gens = pipe(prompt, num_return_sequences=num_completions, max_length=max_length, **gen_kwargs)
|
37 |
+
return [code_gen["generated_text"][len(prompt):] for code_gen in code_gens]
|
38 |
+
except IndexError:
|
39 |
+
print("prompt is longer than the context size of the model, generation skipped")
|
40 |
+
code_gens = ""
|
41 |
+
return [""]
|
42 |
+
|
43 |
+
|
44 |
+
def make_generations(dataset, args, model, tokenizer):
|
45 |
+
set_seed(args.seed)
|
46 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
|
47 |
+
|
48 |
+
# Generation settings
|
49 |
+
gen_kwargs = {
|
50 |
+
"do_sample": args.do_sample,
|
51 |
+
"temperature": args.temperature,
|
52 |
+
"top_p": args.top_p,
|
53 |
+
"top_k": args.top_k
|
54 |
+
}
|
55 |
+
|
56 |
+
# Generate completions for evaluation set
|
57 |
+
n_tasks = args.num_tasks if args.num_tasks is not None else len(dataset)
|
58 |
+
print(f"ntasks is {n_tasks}")
|
59 |
+
generations = []
|
60 |
+
for task in tqdm(range(n_tasks)):
|
61 |
+
task_generations = []
|
62 |
+
prompt = generate_prompt(dataset[task]).strip()
|
63 |
+
task_generations.extend(complete_code(pipe, prompt, num_completions=args.n_samples, max_length=args.max_length, **gen_kwargs))
|
64 |
+
generations.append([gen.replace(args.eos, "") for gen in task_generations])
|
65 |
+
return generations
|
66 |
+
|
67 |
+
|
68 |
+
def main(args):
|
69 |
+
DATA_PATH = "codeparrot/apps"
|
70 |
+
argsdict = vars(args)
|
71 |
+
print(pprint.pformat(argsdict))
|
72 |
+
|
73 |
+
# setup
|
74 |
+
print("Loading evaluation dataset...")
|
75 |
+
dataset = load_dataset(DATA_PATH, split="test", difficulties=[args.difficulty])
|
76 |
+
if args.use_solutions:
|
77 |
+
print("Using data solutions as code generations")
|
78 |
+
model = None
|
79 |
+
tokenizer = None
|
80 |
+
generations = []
|
81 |
+
for index in range(args.num_tasks+1):
|
82 |
+
try:
|
83 |
+
sol = json.loads(dataset[index]["solutions"])
|
84 |
+
generations.append(sol[:args.n_solutions])
|
85 |
+
except ValueError:
|
86 |
+
print(f"No solutions for task {index} or not enough to have {args.n_solutions} solutions")
|
87 |
+
break
|
88 |
+
|
89 |
+
else:
|
90 |
+
print("Loading tokenizer and model...")
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
92 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
93 |
+
generations = make_generations(dataset, args, model, tokenizer)
|
94 |
+
|
95 |
+
metric = load("loubnabnl/apps_metric")
|
96 |
+
results = metric.compute(predictions=generations, level=args.difficulty, k_list=args.k_list, count_errors=args.count_errors, debug=args.debug)
|
97 |
+
print(results)
|
98 |
+
with open(args.output_file, "w") as fp:
|
99 |
+
json.dump(results, fp)
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
import argparse
|
104 |
+
|
105 |
+
parser = argparse.ArgumentParser(description="Testing a Language Model on APPS Python Code dataset")
|
106 |
+
#model and tokenizer arguments
|
107 |
+
parser.add_argument("--model_ckpt", default="loubnabnl/apps-1.5B-model", type=str, help="path to model checkpoint.")
|
108 |
+
parser.add_argument("--tokenizer", default="gpt2", type=str, help="tokenizer to use.")
|
109 |
+
parser.add_argument("--eos", default="<|endoftext|>", type=str, help="end of sentence token.")
|
110 |
+
# generation arguments
|
111 |
+
parser.add_argument("--do_sample", default=True, type=bool, help="do sampling in generation")
|
112 |
+
parser.add_argument("--temperature", default=0.2, type=float, help="temperature for sampling")
|
113 |
+
parser.add_argument("--top_p", default=0.95, type=float, help="top p for sampling")
|
114 |
+
parser.add_argument("--top_k", default=0, type=float, help="top k for sampling")
|
115 |
+
parser.add_argument("--max_length", default=1024, type=int, help="max length of generated code")
|
116 |
+
# evaluation arguments
|
117 |
+
parser.add_argument("--difficulty", default="all", type=str, help="difficulty level to select in the dataset from:\
|
118 |
+
'all', 'introductory', 'interview' and 'competition' ")
|
119 |
+
parser.add_argument("--num_tasks", default=6, type=int, help="number of tasks to evaluate")
|
120 |
+
parser.add_argument("--use_solutions", default=False, type=bool, help="use solutions instead of generating new code")
|
121 |
+
parser.add_argument("--n_samples", default=1, type=int, help="number of samples to generate")
|
122 |
+
parser.add_argument("--n_solutions", default=1, type=int, help="number of solutions to use")
|
123 |
+
parser.add_argument("--k_list", default=[1, 2, 3], type=list, help="list of k values to evaluate pass@k")
|
124 |
+
parser.add_argument("--count_errors", default=False, type=bool, help="count compilation and runtime errors for single generations")
|
125 |
+
# configuration
|
126 |
+
parser.add_argument("--seed", default=0, type=int, help="generation seed")
|
127 |
+
parser.add_argument("--device_int", default=-1, type=int, help="device on which code generation is run, if positive use GPU")
|
128 |
+
parser.add_argument("--debug", default=False, type=bool, help="debug mode")
|
129 |
+
# save
|
130 |
+
parser.add_argument("--output_file", default="apps_metrics.json", type=str, help="output file to save the results")
|
131 |
+
|
132 |
+
args = parser.parse_args()
|
133 |
+
main(args)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
evaluate==0.1.0
|
2 |
+
datasets~=2.0
|
3 |
+
pyext==0.7
|
test_examples/solutions_problem_1.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["s = input()\nn = len(s)\nind = -1\nf = False\nfor i in range(n):\n if s[i] == '[':\n f = True\n elif s[i] == ':':\n if f:\n ind = i\n break\nbind = -1\nf = False\nfor i in range(n-1,-1,-1):\n if s[i] == ']':\n f = True\n elif s[i] == ':':\n if f:\n bind = i\n break\n# print(ind,bind)\nif ind == -1 or bind == -1:\n print(-1)\nelif ind >= bind:\n print(-1)\nelse:\n ans = 4\n for i in range(ind+1,bind):\n if s[i] == '|':\n ans += 1\n print(ans)\n", "def main():\n s = input()\n \n if s.count('[') == 0 or s.count(']') == 0:\n print(-1)\n return\n \n t = s[s.find('['):s.rfind(']')+1]\n \n if t.count(':') < 2:\n print(-1)\n return\n \n t = t[t.find(':'):t.rfind(':')+1]\n print(4 + t.count('|'))\n\nmain()", "s = input()\nif '[' in s:\n s = s[s.find('[') + 1:]\n if ']' in s:\n s = s[:s.rfind(']')]\n if s.count(':') >= 2:\n s = s[s.find(':') + 1 : s.rfind(':')]\n print(s.count('|') + 4)\n\n else:\n print(-1)\n else:\n print(-1)\nelse:\n print(-1)", "import sys\ns = input()\nst = s.find('[')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\nst = s.find(':')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\ns = s[::-1]\nst = s.find(']')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\nst = s.find(':')\nif st==-1: print((-1)); return\ns = s[st+1:]\n#print(s)\nx = s.count('|')\nprint(x+4 if x>=0 else -1)\n", "s = input()\n\nsb,eb,sc,ec = -1, -1, -1, -1\n\nfor i in range(len(s)):\n\tif s[i] == '[' and sb == -1:\n\t\tsb = i\n\telif s[i] == ']':\n\t\teb = i\n\telif s[i] == ':' and sc == -1 and sb!=-1:\n\t\tsc = i\n\nif eb <= sb or sc>eb:\n\tprint(-1)\nelif sb ==-1 or eb==-1 or sc==-1:\n\tprint(-1)\nelse:\n\tfor i in range(sc+1, eb):\n\t\tif s[i] == ':':\n\t\t\tec = i\n\tif ec == -1:\n\t\tprint(-1)\n\telse:\n\t\tcnt = 0\n\t\tfor i in range(sc,ec):\n\t\t\tif (s[i] == '|'):\n\t\t\t\tcnt += 1\n\t\tprint(cnt+4)", "s = input()\nt_d = 0\ntry:\n left = -1\n was_b = False\n for i in range(len(s)):\n if s[i] == '[' and not was_b:\n was_b = True\n continue\n if s[i] == ':' and was_b:\n left = i\n break\n t_d += 1\n if left == -1:\n raise ArithmeticError()\n right = -1\n was_b = False\n for i in range(len(s) - 1, -1, -1):\n if s[i] == ']' and not was_b:\n was_b = True\n continue\n if s[i] == ':' and was_b:\n right = i\n break\n t_d += 1\n if right == -1 or right <= left:\n raise ArithmeticError()\n for i in range(left + 1, right):\n if s[i] != '|':\n t_d += 1\n print(len(s) - t_d)\nexcept:\n print(-1)\n \n", "s = input()\n\nmode = 0\nl = len(s)\nr = -1\nfor i in range(len(s)):\n if mode == 0:\n if s[i] == \"[\":\n mode = 1\n if mode == 1:\n if s[i] == \":\":\n l = i\n break\n\nmode = 0\nfor i in range(len(s)-1, -1, -1):\n if mode == 0:\n if s[i] == \"]\":\n mode = 1\n if mode == 1:\n if s[i] == \":\":\n r = i\n break\n \nif l >= r:\n print(-1)\nelse:\n c = 0\n for i in range(l+1, r):\n if s[i] == \"|\":\n c += 1\n print(c+4)\n", "s = input()\n\nf1 = False\nf2 = False\nl1 = -1\nfor l in range(len(s)):\n if f1 == False and s[l] == '[':\n f1 = True\n elif f1 == True and s[l] == ':':\n f2 = True\n l1 = l\n break\ng1 = False\ng2 = False\nr1 = -1\nfor r in range(len(s) - 1, -1, -1):\n if g1 == False and s[r] == ']':\n g1 = True\n elif g1 == True and s[r] == ':':\n g2 = True\n r1 = r\n break\nif (l1 == -1 or r1 == -1) or (r1 <= l1):\n print(-1)\n \nelse:\n ans = 4\n for i in range(l1 + 1, r1):\n if s[i] == '|': ans += 1\n print(ans)", "s=input()\npos1=-1\npos2=-1\npos3=-1\npos4=-1\nfor i in range(0,len(s)):\n if(s[i]=='['):\n pos1=i\n break\nfor i in range(len(s)-1,pos1,-1):\n if(s[i]==']'):\n pos2=i\n break\nfor i in range(pos1,pos2+1):\n if(s[i]==':'):\n pos3=i\n break\nfor i in range(pos2,pos3,-1):\n if(s[i]==':'):\n pos4=i\n break\n \nif(pos1==-1 or pos2==-1 or pos3==-1 or pos4==-1 or len(s)<4):\n print('-1')\nelse:\n c=0\n for j in range(pos3,pos4):\n if(s[j]=='|'):\n c=c+1\n print(c+4)\n", "def ii():\n return int(input())\ndef mi():\n return list(map(int, input().split()))\ndef li():\n return list(mi())\n\ns = input().strip()\nn = len(s)\nans = -1\nfb = s.find('[')\nif fb >= 0:\n fc = s.find(':', fb)\n if fc >= 0:\n lb = s.rfind(']')\n if lb > fc:\n lc = s.rfind(':', 0, lb)\n if lc > fc:\n ans = 4 + s[fc:lc].count('|')\nprint(ans)\n", "s = input()\n\ndef sovle(s):\n\n i1 = s.find('[')\n if i1 == -1:\n return -1\n s = s[i1+1:]\n i2 = s.find(':')\n if i2 == -1:\n return -1\n\n s = s[i2+1 :]\n i1 = s.rfind(']')\n if i1 == -1:\n return -1\n s = s[:i1]\n i2 = s.rfind(':')\n if i2 == -1:\n return -1\n s = s[:i2]\n x = s.count('|')\n return x+4\n\nprint(sovle(s))", "def solve(s):\n if s.find('[') == -1:\n return -1\n s = s[s.find('['):]\n #print(s)\n if s.find(':') == -1:\n return -1\n s = s[s.find(':') + 1:]\n #print(s)\n if s.find(']') == -1:\n return -1\n s = s[:s.rfind(']')]\n #print(s)\n if s.find(':') == -1:\n return -1\n s = s[:s.rfind(':')]\n #print(s)\n return s.count('|') + 4\n\ns = input()\nprint(solve(s))", "s=input()\ni=s.find('[')\nif i==-1:\n print(-1)\n return\ns=s[i:]\ni=s.rfind(']')\n\nif i==-1:\n print(-1)\n return\ns=s[:i+1]\nl,h=0,0\nfor i,d in enumerate(s):\n if d==':':\n l=i\n break\nfor i,d in enumerate(s):\n if d==':':\n h=i\nif l==h:\n print(-1)\n return\nc=0\nfor i in range(l+1,h):\n if s[i]=='|':\n c+=1\nprint(c+4)\n", "from sys import stdin\ns=stdin.readline().strip()\nx=-1\nfor i in range(len(s)):\n if s[i]==\"[\":\n x=i\n break\ny=-1\nfor i in range(len(s)-1,-1,-1):\n if s[i]==\"]\":\n y=i\n break\nif x==-1 or y==-1 or y<x:\n print(-1)\n return\nx1=-1\nfor i in range(x,y):\n if s[i]==\":\":\n x1=i\n break\ny1=-1\nfor i in range(y-1,x,-1):\n if s[i]==\":\":\n y1=i\n break\nif x1==-1 or y1==-1 or y1<=x1:\n print(-1)\n return\nans=4\nfor i in range(x1,y1):\n if s[i]==\"|\":\n ans+=1\nprint(ans)\n", "s = str(input().strip())\ni = 0\nn = len(s)\nwhile i < n and s[i] != '[':\n i+=1\nif(i == n):\n print(-1)\n return\nj = n-1\nwhile j > i and s[j] != ']':\n j-=1\nif(j <= i):\n print(-1)\n return\nwhile i < j and s[i] != ':':\n i+=1\nif(i == j):\n print(-1)\n return\nwhile j > i and s[j] != ':':\n j-=1\nif(j == i):\n print(-1)\n return\nk = i+1\nc = 0\nwhile k < j:\n if(s[k] == '|'):\n c+=1\n k+=1\nprint(c+4)\n", "import sys\ns = input()\nl = len(s)\ns_list = [x for x in s]\n\ncounter = 0\ntry:\n\ta = s_list.index('[')\n\tcounter += a\n\ts_list = s_list[a + 1:]\nexcept:\n\tprint(-1)\n\treturn\n\ntry:\n\ta = s_list.index(':')\n\tcounter += a\n\ts_list = s_list[a + 1:]\nexcept:\n\tprint(-1)\n\treturn\n\ns_list_rev = s_list.copy()\ns_list_rev.reverse()\n\ntry:\n\tb = s_list_rev.index(']')\n\tcounter += b\n\ts_list_rev = s_list_rev[b+1:]\nexcept:\n\tprint(-1)\n\treturn\n\ntry:\n\tb = s_list_rev.index(':')\n\tcounter += b\n\ts_list_rev = s_list_rev[b+1:]\nexcept:\n\tprint(-1)\n\treturn\ns_list_rev = [x for x in s_list_rev if x != '|']\ncounter += len(s_list_rev)\nprint(l - counter)", "MOD = 10**9 + 7\nI = lambda:list(map(int,input().split()))\n\ns = input()\nres = 0\nn = len(s)\nst = -1\ne = -1\nfor i in range(n):\n if s[i] == '[':\n st = i\n break\nfor i in range(n-1, -1, -1):\n if s[i] == ']':\n e = i\n break\n# print(st , e)\nif st > e or st == -1 or e == -1:\n print(-1)\n return\na = -1\nb = -1\nfor i in range(st, e):\n if s[i] == ':':\n a = i\n break\nfor i in range(e, st, -1):\n if s[i] == ':':\n b = i\n break\nif a == b or a == -1 or b == -1:\n print(-1)\n return\ncount = 0\nfor i in range(a, b):\n if s[i] == '|':\n count += 1\nprint(4 + count)", "s=input()\nst=\"\"\nidx=-1\nfor i in range(len(s)):\n if s[i]=='[':\n idx=i\n break\nif idx==-1:\n print(-1)\n return\nidxl=-1\nfor i in range(len(s)-1,-1,-1):\n if s[i]==']' and i>idx:\n idxl=i\n break\nif idxl==-1:\n print(-1)\n return\ncol=col2=-1\nfor i in range(len(s)):\n if s[i]==':' and i>idx and i<idxl:\n col=i\n break\nif col==-1:\n print(-1)\n return\nfor i in range(len(s)-1,-1,-1):\n if s[i]==':' and i>col and i<idxl:\n col2=i\n break\nif col2==-1:\n print(-1)\n return\nans=0\nfor i in range(col+1,col2):\n if s[i]=='|':\n ans+=1\nprint(4+ans)\n \n\n\n", "s = input()\nrev = s[::-1]\n\nleft = s.find(\"[\")\nif left != -1:\n left = s.find(\":\", left)\n\nright = rev.find(\"]\")\nif right != -1:\n right = rev.find(\":\", right)\n\nif left == -1 or right == -1:\n print(-1)\n return\nright = len(s)-right-1\nif left >= right:\n print(-1)\n return\n\nprint(4 + s[left:right].count(\"|\"))\n", "def ba(s):\n c1 = s.find('[')\n c2 = s.find(':', c1+1)\n c3 = s.rfind(']', c2+1)\n c4 = s.rfind(':', c2+1, c3)\n if -1 in [c1, c2, c3, c4]:\n return -1\n return s.count('|', c2, c4)+4\n\n\nprint(ba(input()))\n\n", "s = input()\nif '[' in s and ']' in s:\n a = s.index('[') + 1\n b = len(s)-s[::-1].index(']') - 1\nelse:\n print(-1)\n return\ns = s[a:b]\nif s.count(':') >= 2:\n a = s.index(':')+1\n b = len(s)-s[::-1].index(':')-1\nelse:\n print(-1)\n return\nc = 0\nfor el in s[a:b]:\n if el =='|':\n c += 1\nprint(4 + c)", "s = input()\n\nb = [0]*len(s)\n\nob = 0\ncc = 0\np = -1\nq = -1\n\ncount = 0\n\nfor ind,c in enumerate(s):\n if c == '[':\n ob = 1\n elif c == ':' and p >= 0:\n q = ind\n elif c == ':' and ob == 1 and p < 0:\n p = ind\n elif c == ']' and q >= 0:\n cc = q\n elif c == '|':\n count += 1\n b[ind] = count\n\nif cc > 0:\n print( 4 + b[cc]-b[p])\nelse:\n print(-1)\n", "s = input()\nif '[' in s and ']' in s and ':' in s:\n e = s.count(':')\n if e<2:\n print(-1)\n else:\n a = s.index('[')\n b = len(s)-1-s[::-1].index(']')\n if b<a:\n print(-1)\n else:\n if s[a+1:b].count(':')<2:\n print(-1)\n else:\n st1 = True\n count = 0\n for i in range(a+1, b):\n if st1 and s[i]==':':\n pos1 = i\n st1 = False\n if s[i]==':':\n pos2 = i\n \n for i in range(pos1+1, pos2):\n if s[i]=='|':\n count+=1\n \n print(count+4)\nelse:\n print(-1) ", "s=input()\ni1=-1\ni2=-1\nk1=-1\nk2=-1\nc=0\nfor i in range(len(s)):\n if(s[i]=='['):\n i1=i\n break\nfor i in range(len(s)-1,-1,-1):\n if(s[i]==']'):\n i2=i\n break\nfor i in range(i1,i2+1):\n if(s[i]==':'):\n k1=i\n break\nfor i in range(i2,i1-1,-1):\n if(s[i]==':'):\n k2=i\n break\nfor i in range(k1,k2+1):\n if(s[i]=='|'):\n c+=1\n\nif(i1==-1 or i2==-1 or i1>=i2 or k1==-1 or k2==-1 or k1==k2):\n print(-1)\nelse:\n print(4+c)", "s = input()\nl = 0\nend = 0\ni = 1\n\nwhile i <= len(s):\n if l == 0 and s[-i] == ']':\n l += 1\n elif l == 1 and s[-i] == ':':\n l += 1\n end = len(s) - i\n break\n i += 1\n\nif l < 2:\n print(-1)\n return\n\nfor i in range(0, end):\n if l >= 4 and s[i] == '|':\n l += 1\n elif l == 2 and s[i] == '[':\n l += 1\n elif l == 3 and s[i] == ':':\n l += 1\n\nif l >= 4:\n print(l)\nelse:\n print(-1)"]
|
test_examples/solutions_problem_2.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["num = list(map(int, input()))\nbest = num[:]\nfor i in range(-1, -len(num) - 1, -1):\n if num[i] == 0:\n continue\n num[i] -= 1\n for j in range(i + 1, 0):\n num[j] = 9\n if sum(num) > sum(best):\n best = num[:]\ns = ''.join(map(str, best)).lstrip('0')\nprint(s)\n", "s_num = input()\nnum = int(s_num)\ndigs = [int(s_num[i]) for i in range(len(s_num))]\n\nmax_sum = sum(digs)\nres = num\nfor i in range(len(s_num)):\n if (digs[i] != 0):\n digs[i] -= 1\n n_sum = sum(digs[:i + 1]) + 9 * (len(s_num) - i - 1)\n if n_sum >= max_sum:\n n_res = int(''.join([str(digs[i]) for i in range(i + 1)]) + '9' * (len(s_num) - i - 1))\n if (n_sum == max_sum):\n res = max(n_res, res)\n else:\n res = n_res\n max_sum = n_sum\n\n digs[i] += 1\nprint(res)\n", "a=int(input())\nif(a//10==0):\n print(a)\n return\nk=9\nwhile(k<a):\n k=k*10+9\nif(k==a):\n print(k)\nelse:\n k//=10\n k=int(str(a)[0]+str(k))\n i=len(str(k))-1\n z=k\n while(z>a):\n z=int(str(k)[0:i]+str(int(str(k)[i])-1)+str(k)[i+1:len(str(k))])\n i-=1\n print(z) ", "x = int(input())\nif x < 10:\n print(x)\nelif x == int(str(x)[0] + '9'*(len(str(x))-1)):\n print(x)\nelse:\n a = str(x)[0] + '9' * (len(str(x)) - 1)\n a = list(a)\n for i in range(len(a) - 1, -1, -1):\n k = a[i]\n a[i] = str(int(a[i]) - 1)\n if x >= int(''.join(a)):\n print(int(''.join(a)))\n break\n a[i] = k\n", "def sum_str(y):\n return sum(map(int, str(y)))\n\n\nx = input()\nlength = len(x)\nbad_answer = str(int(x[0]) - 1) + '9' * (length - 1) \ntotal = sum_str(bad_answer)\n\n\nif length == 1 or sum_str(x) >= total:\n print(x)\nelse:\n for i in range(length - 1, 0, -1):\n new_total = 9 * (length - i)\n new_answer = str(int(x[:i]) - 1)\n new_total += sum_str(new_answer)\n\n if new_total >= total:\n new_answer = new_answer if new_answer != '0' else ''\n print(new_answer + '9' * (length - i))\n break\n else:\n print(bad_answer)\n", "import sys\n\ndef calc(s):\n res =0\n for c in s:\n res+= int(c)\n return res\n\n\ns = list(sys.stdin.readline().rstrip())\nbest = \"\".join(s) \ncount = calc(s)\n\ni = len(s)-1\nwhile i!=0:\n i-=1\n if s[i+1]!= '9':\n s[i+1] = '9'\n while s[i]=='0':\n s[i]='9'\n i-=1\n s[i] = chr(ord(s[i])-1)\n c = calc(s)\n if count < c:\n count = c\n best = \"\".join(s)\n\nif best[0] == '0':\n best = best[1:]\n\nprint(best)", "x = input()\nn = len(x)\nif n == 1:\n print(x)\n return\nans = \"\"\ns = 0\nps = 0\npn = \"\"\nfor i in range(n):\n ts = ps + int(x[i]) - 1 + 9 * (n - i - 1)\n if ts >= s:\n ans = pn + str(int(x[i]) - 1) + \"9\" * (n - i - 1)\n s = ts\n ps += int(x[i])\n pn += x[i]\nif ps >= s:\n ans = pn\nprint(int(ans))", "n = int(input())\n\ndef f(numb):\n lst = [numb]\n cap = 10\n\n while numb // cap > 0:\n lst.append((numb // cap - 1) * cap + cap - 1)\n cap *= 10\n\n return lst\n\ndef g(numb):\n lst = []\n while numb != 0:\n lst.append(numb % 10)\n numb //= 10\n\n return lst\n\n\nmaximum = max([sum(g(i)) for i in f(n)])\n\nmaximum = [i for i in f(n) if maximum == sum(g(i))]\n\nprint(max(maximum))", "\"\"\" Created by Shahen Kosyan on 3/11/17 \"\"\"\n\ndef __starting_point():\n x = input()\n\n if int(x) < 10:\n print(x)\n return\n\n arr = [int(a) for a in list(x)]\n x_sum = sum(arr)\n\n i = len(arr) - 1\n answer = ''\n while i > 0:\n if arr[i] != 9 and arr[i] != 8:\n arr[i - 1] -= 1\n answer = '9' + answer\n else:\n change = False\n for j in range(i - 1, 0, -1):\n if arr[j] < 9:\n change = True\n break\n\n if arr[i] == 8 and change:\n answer = '9' + answer\n arr[i - 1] -= 1\n else:\n if not change:\n answer = str(arr[i]) + answer\n else:\n answer = '9' + answer\n\n if i == 1 and arr[0] != 0:\n answer = str(arr[0]) + answer\n i -= 1\n\n answer = [int(a) for a in list(answer)]\n if x_sum == sum(answer):\n print(x)\n else:\n answer = [str(a) for a in answer]\n print(''.join(answer))\n\n__starting_point()", "x=input()\nl=len(x)\nx=int(x)\ns='9'*l\nsx=str(x)\nm=int(s)\nc=0\nwhile c!=1:\n if m>x:\n m=m-10**(l-1)\n else:\n c=1\nsm=str(m)\nmm=[] \nfor i in range(len(sm)):\n mm.append(int(sm[i]))\nxx=[] \nfor i in range(l):\n xx.append(int(sx[i]))\nif m==x:\n print(m)\nelif sum(xx)==sum(mm):\n print(x)\nelse:\n k=len(xx)-1\n while k>=0:\n if sum(xx)<sum(mm):\n if xx[k]==9:\n k-=1\n else:\n xx[k]=9\n xx[k-1]-=1\n k-=1\n else:\n if xx[0]==0:\n xx.remove(0)\n for b in range(len(xx)):\n xx[b]=str(xx[b])\n ww=''.join(xx)\n print(ww)\n break", "x = input()\nvariants = [x] + [str(int(x[:i]) - 1) +\n '9' * (len(x) - i) for i in range(1, len(x))]\nprint(int(max(variants, key=lambda x: (sum(map(int, x)), int(x)))))\n", "def sum_div(n):\n summa = 0\n while n > 0:\n summa = summa + n % 10\n n = n // 10\n return summa\n\n\ndef run(n):\n l_n = len(n)\n left = ''\n if l_n > 2 and '9' * l_n != n and n[1] == '9' and '9' * (l_n - 1) != n[1:]:\n left = n[0]\n n = n[1:]\n while l_n > 1 and n[1] == '9':\n left += n[1]\n n = n[1:]\n l_n = len(n)\n l_n = len(n)\n if len(n) == 1:\n return n\n elif '9' * (l_n - 1) == n[1:]:\n return left + n\n elif n[0] != '1':\n min_number = int(str(int(n[0]) - 1) + '9' * (l_n - 1))\n if sum_div(min_number) > sum_div(int(n)):\n return left + str(min_number)\n else:\n return left + n\n else:\n min_number = int('9' * (l_n - 1)) if l_n > 1 else 0\n if sum_div(min_number) > sum_div(int(n)):\n return left + str(min_number)\n else:\n return left + n\n\n\nn = input()\nprint(run(n))\n", "#This code is dedicated to Olya S.\n\ndef e(x):\n s=0\n while x>0:\n s+=x%10\n x//=10\n return s\n\ndef down(x):\n l=len(x)-1\n return str(int(x[0])-1)+'9'*l\n\nn=input()\nif len(n)>1 and n[1]=='9':\n print(n[0],end='')\n n=n[1:]\n while len(n)>1 and n[0]=='9' and n[1]=='9':\n print('9',end='')\n n=n[1:]\n\nif e(int(n))>=e(int(down(n))):\n print(n)\nelse:\n print(int(down(n)))\n\n \n \n\n\n\n \n\n", "def sum_n(n):\n l = len(n)\n\n summ = 0\n for i in range(l):\n summ += int(n[i])\n\n return summ\n\ndef transfer(x, i):\n x = list(x)\n \n x[i+1] = '9'\n if x[i] != '0':\n x[i] = str(int(x[i])-1)\n else:\n j = i\n while (j > 0) and (int(x[j]) == 0):\n x[j] = '9'\n j -= 1\n x[j] = str(int(x[j])-1)\n if (x[0] == '0'):\n del x[0]\n\n return x\n\nx = list(input())\nmax_cifr = sum_n(x)\nmaxnum = x\nres = ''\n\nfor i in range(len(x)-2, -1, -1):\n x = transfer(x, i)\n if(max_cifr < sum_n(x)):\n max_cifr = sum_n(x)\n maxnum = x\n\nfor i in range(len(maxnum)):\n res = res+maxnum[i]\n \nprint(res)\n", "x = input()\nsum = 0\nfor i in x:\n temp = int(i)\n sum += temp\n\nxlen = len(x)\none = int(x[0])\ntry:\n two = int(x[1])\nexcept:\n two = 0\n\nif (two == 9):\n count = 1\n for i in range(1, xlen):\n z = int(x[i])\n if (z == 9):\n count = i\n else:\n break\n answ = x[0:count] + \"8\" + (\"9\" * (xlen - count - 1))\nelif (one == 1):\n answ = '9' * (xlen - 1)\nelse:\n answ = str((one - 1)) + (\"9\" * (xlen-1))\n\nansw = str(answ)\nsumansw = 0\nfor i in answ:\n temp = int(i)\n sumansw += temp\n\nif (sum >= sumansw):\n print(x)\nelse:\n print(answ)", "def sum1(x): # \u043f\u043e\u0434\u0441\u0447\u0451\u0442 \u0441\u0443\u043c\u043c\u044b \u0446\u0438\u0444\u0440 \u0447\u0438\u0441\u043b\u0430 x\n summa = 0\n for i in x:\n summa += int(i)\n return summa\n\n\nx = input()\nc = sum1(x)\nresult = int(x)\nn = len(x) - 1\nj = n\nfor i in range(0, n):\n if x[i] != '0':\n ni = int(x[i]) - 1 # \u0443\u043c\u0435\u043d\u044c\u0448\u0430\u044e i-\u044b\u0439 \u0440\u0430\u0437\u0440\u044f\u0434 \u043d\u0430 1\n xi = x[0:i] + str(ni) + '9' * j # \u0441\u0442\u0440\u043e\u044e \u043d\u043e\u0432\u043e\u0435 \u0447\u0438\u0441\u043b\u043e\n j -= 1\n ci = sum1(xi)\n if c < ci:\n c = ci\n result = int(xi)\n elif c == ci and result < int(xi):\n result = int(xi)\n else:\n j -= 1\n continue\nprint(result)\n", "def f(n, k):\n n = str(n)\n if n[k] == \"0\":\n return f(n, k - 1)\n a = []\n for i in n:\n a.append(int(i))\n n = a\n n[k] = int(n[k]) - 1\n n[k + 1::] = [9] * (len(n) - k - 1)\n return n\na = input()\nn = len(a)\nans = [int(x) for x in a]\nms = sum(ans)\nfor i in range(0, n):\n ca = f(a, i)\n cs = sum(ca)\n if cs> ms:\n ans = ca\n ms = cs\n elif cs == ms:\n if int(''.join([str(_) for _ in ca])) > int(''.join([str(_) for _ in ans])):\n ans = ca\nprint(int(''.join([str(_) for _ in ans])))", "n = int(input().strip())\n\ns = []\nwhile n > 0:\n s.append(n % 10)\n n //= 10\ns = s[::-1]\n\nn = len(s)\nans = 0\nbest = -1\nfor i in range(n):\n res = sum(s[:i + 1]) - 1 + 9 * (n - i - 1)\n if res >= ans:\n ans = res\n best = i\n\ndef get(s, pos):\n ans = 0\n for i in range(len(s)):\n if i > pos:\n ans = ans * 10 + 9\n else:\n ans = ans * 10 + s[i]\n if i == pos:\n ans -= 1\n return ans\n\nif sum(s) >= ans:\n print(get(s, n))\nelse:\n print(get(s, best))\n\n", "def main():\n\n\tdef sum(x):\n\t\tres = 0\n\n\t\twhile x > 0:\n\t\t\tres += x % 10\n\t\t\tx //= 10\n\n\t\treturn res\n\n\tn = input()\n\tfirst = n[0]\n\tp = [1]\n\n\tfor i in range(1, 20):\n\t\tp.append(p[-1] * 10)\n\n\tdata = []\t\n\tfor i in range(len(n)):\n\t\tif i > 0 and n[i] == '0':\n\t\t\tcontinue\n\t\ttemp = n[:i] + str(max(0, int(n[i]) - 1)) + \"9\"* (len(n) - i - 1)\n\t\tdata.append((sum(int(temp)), int(temp)))\n\n\tdata.append((sum(int(n)), int(n)))\n\t\n\tdata.sort(reverse=True)\n\n\tprint(data[0][1])\n\n\treturn\n\ndef __starting_point():\n\tmain()\n__starting_point()", "def cnt_sum(str_num):\n\tsum = 0\n\tfor a in str_num:\n\t\tsum += ord(a) - ord('0')\n\treturn sum\n\nstr_a = input().strip()\nmax_sum = cnt_sum(str_a)\nans = str_a\ncnt_digit = len(str_a)\n\nfor i in range(cnt_digit - 1, -1, -1):\n\tif str_a[i] != '0':\n\t\tnew_str = str_a[:i] + chr(ord(str_a[i]) - 1) + '9'*(cnt_digit - i - 1)\n\t\tcur_sum = cnt_sum(new_str)\n\t\tif cur_sum > max_sum:\n\t\t\tmax_sum = cur_sum\n\t\t\tans = new_str\n\nprint(int(ans))\n", "def summaX(x):\n k=0\n for el in x:\n k+=int(el)\n return k\nn=input();N=[];Z=[]\nfor el in n:\n N.append(el)\nz=summaX(N)\nZ=N.copy()\nfor i in range(1,len(N)):\n if int(N[i])!=9:\n N[i-1]=int(N[i-1])-1\n for j in range(i,len(n)):\n N[j]=9\nif z>=summaX(N):\n for el in Z:\n print(el,end='')\nelse:\n if N[0]==0:\n N.pop(0)\n for el in N:\n print(el,end='')\n", "n = int(input())\n\ndef sumd(n):\n\tj = n\n\tsumn = 0\n\twhile j:\n\t\tsumn += j % 10\n\t\tj //= 10\n\treturn sumn\n\nj = n\nstrn = str(n)\nl = len(strn)\nsumn = sumd(n)\n\nstra = [i for i in str(n)]\ni = 1\nwhile i < l and stra[i] == '9':\n\ti += 1\nif (i != l):\n\tstra[i - 1] = str(int(stra[i - 1]) - 1)\n\twhile i < l:\n\t\tstra[i] = '9'\n\t\ti += 1\n\nss = ''\nfor i in range(l):\n\tss += stra[i]\nif ss[0] == '0':\n\tss = ss[1:]\nsn = int(ss)\n\nif sn < n and sumd(sn) <= sumn:\n\tss = strn\n\tsn = n\n\nprint(ss)\n", "from random import randint\n\ndef f(s):\n a = 0\n for i in s:\n a += int(i)\n return a\n\ndef solve(n):\n n1 = list(str(n))\n ans = 0\n maxx = 0\n for i in range(len(n1)):\n n2 = n1[:i] + [str(int(n1[i]) - 1)] + ['9' for j in range(len(n1) - i - 1)]\n if f(n2) >= maxx:\n maxx = f(n2)\n ans = n2\n if f(n1) >= maxx:\n maxx = f(n1)\n ans = n1\n return [int(''.join(ans)), maxx]\n\ndef tl(n):\n ans = 0\n maxx = 0\n for i in range(1, n + 1):\n if f(list(str(i))) >= maxx:\n maxx = f(list(str(i)))\n ans = i\n return [ans, maxx]\n\n'''for kkk in range(100):\n n = randint(1, 10 ** 5)\n c1 = solve(n)\n c2 = tl(n)\n if c1 != c2:\n print(n)\n print(c1)\n print(c2)\nprint('ok')'''\nn = int(input())\nprint(solve(n)[0])\n", "a = [1, 2, 3, 4, 5, 6, 7, 8, 9]\nfor length in range(2, 30):\n for first in range(1, 10):\n for pos in range(1, length):\n a.append(int(str(first) + '9' * (pos - 1) + '8' + '9' * (length - pos - 1)))\n a.append(int(str(first) + '9' * (length - 1)))\n \nn = int(input())\nl = 0\nr = len(a)\nwhile l < r - 1:\n middle = (l + r) // 2\n if (a[middle] <= n):\n l = middle\n else:\n r = middle\n \nprint(a[l])", "def get(s):\n ans = 0\n for i in s:\n ans += (ord(i) - ord('0'))\n return ans\n\n\ndef solve1():\n x = input()\n n = len(x)\n best_ans = x\n best_val = get(x)\n ans = str('' if int(x[0]) - 1 == 0 else int(x[0]) - 1) + '9' * (n - 1)\n if get(ans) > best_val or (get(ans) >= best_val and int(ans) > int(best_ans)):\n best_ans = ans\n best_val = get(ans)\n for i in range(1, n):\n #print(ans)\n ans = x[:i] + str(int(x[i]) - 1) + '9' * (n - i - 1)\n if get(ans) > best_val or (get(ans) >= best_val and int(ans) > int(best_ans)):\n best_ans = ans\n best_val = get(ans)\n return best_ans\n \nbest = [0] * 10000\ndef solve2():\n nonlocal best\n was = 0\n for i in range(1, 10000):\n if get(str(i)) >= was:\n best[i] = i\n was = get(str(i))\n else:\n best[i] = best[i - 1]\n \ndef stress():\n solve2()\n for i in range(1, 10000):\n if int(solve1(str(i))) != best[i]:\n print(i, best[i], solve1(str(i)))\n\n#stress()\nprint(solve1())"]
|
testing_util.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import sys
|
3 |
+
import faulthandler
|
4 |
+
import platform
|
5 |
+
|
6 |
+
# used for debugging to time steps
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
# to run the solution files we're using a timing based approach
|
10 |
+
import signal
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
# for capturing the stdout
|
14 |
+
from io import StringIO
|
15 |
+
# used for testing the code that reads from input
|
16 |
+
from unittest.mock import patch, mock_open
|
17 |
+
|
18 |
+
from pyext import RuntimeModule
|
19 |
+
|
20 |
+
from enum import Enum
|
21 |
+
class CODE_TYPE(Enum):
|
22 |
+
call_based = 0
|
23 |
+
standard_input = 1
|
24 |
+
|
25 |
+
# stuff for setting up signal timer
|
26 |
+
class TimeoutException(Exception):
|
27 |
+
pass
|
28 |
+
def timeout_handler(signum, frame):
|
29 |
+
print("alarm went off")
|
30 |
+
#return
|
31 |
+
raise TimeoutException
|
32 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
33 |
+
timeout = 4 # seconds
|
34 |
+
|
35 |
+
# used to capture stdout as a list
|
36 |
+
# from https://stackoverflow.com/a/16571630/6416660
|
37 |
+
# alternative use redirect_stdout() from contextlib
|
38 |
+
class Capturing(list):
|
39 |
+
def __enter__(self):
|
40 |
+
self._stdout = sys.stdout
|
41 |
+
sys.stdout = self._stringio = StringIO()
|
42 |
+
# Make closing the StringIO a no-op
|
43 |
+
self._stringio.close = lambda x: 1
|
44 |
+
return self
|
45 |
+
def __exit__(self, *args):
|
46 |
+
self.extend(self._stringio.getvalue().splitlines())
|
47 |
+
del self._stringio # free up some memory
|
48 |
+
sys.stdout = self._stdout
|
49 |
+
|
50 |
+
|
51 |
+
def run_test(sample, test=None, debug=False):
|
52 |
+
"""
|
53 |
+
if test(generated_code) is not None it'll try to run the code.
|
54 |
+
otherwise it'll just return an input and output pair.
|
55 |
+
"""
|
56 |
+
# Disable functionalities that can make destructive changes to the test.
|
57 |
+
reliability_guard()
|
58 |
+
|
59 |
+
if debug:
|
60 |
+
print(f"start = {datetime.now().time()}")
|
61 |
+
|
62 |
+
try:
|
63 |
+
in_outs = json.loads(sample["input_output"])
|
64 |
+
except ValueError:
|
65 |
+
in_outs = None
|
66 |
+
if in_outs:
|
67 |
+
if in_outs.get("fn_name") is None:
|
68 |
+
which_type = CODE_TYPE.standard_input # Standard input
|
69 |
+
method_name = None
|
70 |
+
else:
|
71 |
+
which_type = CODE_TYPE.call_based # Call-based
|
72 |
+
method_name = in_outs["fn_name"]
|
73 |
+
|
74 |
+
if debug:
|
75 |
+
print(f"loaded input_output = {datetime.now().time()}")
|
76 |
+
|
77 |
+
if test is None:
|
78 |
+
return in_outs
|
79 |
+
elif test is not None:
|
80 |
+
results = []
|
81 |
+
sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
|
82 |
+
if debug:
|
83 |
+
print(f"loading test code = {datetime.now().time()}")
|
84 |
+
|
85 |
+
if which_type == CODE_TYPE.call_based:
|
86 |
+
sol += test
|
87 |
+
if debug:
|
88 |
+
print(f"sol = {sol}")
|
89 |
+
signal.alarm(timeout)
|
90 |
+
try:
|
91 |
+
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
|
92 |
+
if "class Solution" not in test:
|
93 |
+
tmp = tmp_sol
|
94 |
+
else:
|
95 |
+
tmp = tmp_sol.Solution()
|
96 |
+
signal.alarm(0)
|
97 |
+
except Exception as e:
|
98 |
+
signal.alarm(0)
|
99 |
+
if debug:
|
100 |
+
print(f"type 0 compilation error = {e}")
|
101 |
+
results.append(-2)
|
102 |
+
return results
|
103 |
+
signal.alarm(0)
|
104 |
+
|
105 |
+
elif which_type == CODE_TYPE.standard_input:
|
106 |
+
# sol
|
107 |
+
tmp_test = test.split("\n")
|
108 |
+
|
109 |
+
new_test = []
|
110 |
+
for x in tmp_test:
|
111 |
+
if (not x.startswith("from ")) and (not x.startswith("import ")):
|
112 |
+
new_test.append("\t" + x + "\n")
|
113 |
+
else:
|
114 |
+
new_test.append(x + "\n")
|
115 |
+
tmp_test = new_test
|
116 |
+
|
117 |
+
new_test = ""
|
118 |
+
started = False
|
119 |
+
for i in tmp_test:
|
120 |
+
if i.startswith("\t") and not started:
|
121 |
+
new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
|
122 |
+
new_test += "def code():\n"
|
123 |
+
new_test += i
|
124 |
+
started = True
|
125 |
+
elif started and ((i.startswith("from ")) or (i.startswith("import "))):
|
126 |
+
new_test += "\t" + i
|
127 |
+
else:
|
128 |
+
new_test += i
|
129 |
+
tmp_test = new_test
|
130 |
+
|
131 |
+
sol += tmp_test
|
132 |
+
if debug:
|
133 |
+
print(f"sol = {sol}")
|
134 |
+
method_name = "code"
|
135 |
+
signal.alarm(timeout)
|
136 |
+
try:
|
137 |
+
tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
|
138 |
+
tmp = tmp_sol
|
139 |
+
signal.alarm(0)
|
140 |
+
except Exception as e:
|
141 |
+
signal.alarm(0)
|
142 |
+
if debug:
|
143 |
+
print(f"type 1 compilation error = {e}")
|
144 |
+
results.append(-2)
|
145 |
+
return results
|
146 |
+
signal.alarm(0)
|
147 |
+
if debug:
|
148 |
+
print(f"get method = {datetime.now().time()}")
|
149 |
+
|
150 |
+
try:
|
151 |
+
method = getattr(tmp, method_name) # get_attr second arg must be str
|
152 |
+
except:
|
153 |
+
signal.alarm(0)
|
154 |
+
e = sys.exc_info()
|
155 |
+
print(f"unable to get function error = {e}")
|
156 |
+
results.append(-2)
|
157 |
+
return results
|
158 |
+
|
159 |
+
for index, inputs in enumerate(in_outs["inputs"]):
|
160 |
+
# JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
|
161 |
+
try:
|
162 |
+
if isinstance(inputs[0], dict):
|
163 |
+
inputs = [{int(k): v for k,v in inputs[0].items()}]
|
164 |
+
except:
|
165 |
+
True
|
166 |
+
try:
|
167 |
+
if isinstance(in_outs["outputs"][index], dict):
|
168 |
+
in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
|
169 |
+
except:
|
170 |
+
True
|
171 |
+
try:
|
172 |
+
if isinstance(in_outs["outputs"][index][0], dict):
|
173 |
+
in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
|
174 |
+
except:
|
175 |
+
True
|
176 |
+
|
177 |
+
if debug:
|
178 |
+
print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
|
179 |
+
if which_type == CODE_TYPE.call_based: # Call-based
|
180 |
+
signal.alarm(timeout)
|
181 |
+
faulthandler.enable()
|
182 |
+
try:
|
183 |
+
output = method(*inputs)
|
184 |
+
|
185 |
+
# ground truth sequences are not tuples
|
186 |
+
if isinstance(output, tuple):
|
187 |
+
output = list(output)
|
188 |
+
|
189 |
+
tmp_result = output == in_outs["outputs"][index]
|
190 |
+
if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
|
191 |
+
tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
|
192 |
+
|
193 |
+
# ground truth sequences are not tuples
|
194 |
+
try:
|
195 |
+
if isinstance(output[0], tuple):
|
196 |
+
tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
|
197 |
+
except:
|
198 |
+
True
|
199 |
+
results.append(tmp_result)
|
200 |
+
|
201 |
+
# reset the alarm
|
202 |
+
signal.alarm(0)
|
203 |
+
except Exception as e:
|
204 |
+
signal.alarm(0)
|
205 |
+
faulthandler.disable()
|
206 |
+
if debug:
|
207 |
+
print(f"Standard input runtime error or time limit exceeded error = {e}")
|
208 |
+
results.append(-1)
|
209 |
+
continue
|
210 |
+
faulthandler.disable()
|
211 |
+
signal.alarm(0)
|
212 |
+
if debug:
|
213 |
+
print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
214 |
+
elif which_type == CODE_TYPE.standard_input: # Standard input
|
215 |
+
faulthandler.enable()
|
216 |
+
signal.alarm(timeout)
|
217 |
+
passed = False
|
218 |
+
|
219 |
+
if isinstance(inputs, list):
|
220 |
+
inputs = "\n".join(inputs)
|
221 |
+
if isinstance(in_outs['outputs'][index], list):
|
222 |
+
in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
|
223 |
+
|
224 |
+
with Capturing() as output:
|
225 |
+
try:
|
226 |
+
call_method(method, inputs)
|
227 |
+
# reset the alarm
|
228 |
+
signal.alarm(0)
|
229 |
+
passed = True
|
230 |
+
except Exception as e:
|
231 |
+
# runtime error or took too long
|
232 |
+
signal.alarm(0)
|
233 |
+
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
|
234 |
+
results.append(-1)
|
235 |
+
signal.alarm(0)
|
236 |
+
|
237 |
+
if not passed:
|
238 |
+
if debug:
|
239 |
+
nl = "\n"
|
240 |
+
if not isinstance(inputs, list):
|
241 |
+
print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
242 |
+
else:
|
243 |
+
print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
244 |
+
continue
|
245 |
+
|
246 |
+
if passed and debug:
|
247 |
+
print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
|
248 |
+
|
249 |
+
if custom_compare_(output, in_outs['outputs'][index]):
|
250 |
+
tmp_result = True
|
251 |
+
results.append(tmp_result)
|
252 |
+
continue
|
253 |
+
|
254 |
+
# ground truth sequences are expressed as lists not tuples
|
255 |
+
if isinstance(output, tuple):
|
256 |
+
output = list(output)
|
257 |
+
|
258 |
+
tmp_result = False
|
259 |
+
try:
|
260 |
+
tmp_result = (output == [in_outs["outputs"][index]])
|
261 |
+
if isinstance(in_outs["outputs"][index], list):
|
262 |
+
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
263 |
+
if isinstance(output[0], str):
|
264 |
+
tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
|
265 |
+
except Exception as e:
|
266 |
+
if debug:
|
267 |
+
print(f"Failed check1 exception = {e}")
|
268 |
+
pass
|
269 |
+
|
270 |
+
if tmp_result == True:
|
271 |
+
results.append(tmp_result)
|
272 |
+
continue
|
273 |
+
|
274 |
+
# try one more time without \n
|
275 |
+
if isinstance(in_outs["outputs"][index], list):
|
276 |
+
for tmp_index, i in enumerate(in_outs["outputs"][index]):
|
277 |
+
in_outs["outputs"][index][tmp_index] = i.split("\n")
|
278 |
+
in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
|
279 |
+
else:
|
280 |
+
in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
|
281 |
+
in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
|
282 |
+
in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
|
283 |
+
|
284 |
+
try:
|
285 |
+
tmp_result = (output == [in_outs["outputs"][index]])
|
286 |
+
if isinstance(in_outs["outputs"][index], list):
|
287 |
+
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
288 |
+
except Exception as e:
|
289 |
+
if debug:
|
290 |
+
print(f"Failed check2 exception = {e}")
|
291 |
+
pass
|
292 |
+
|
293 |
+
if tmp_result == True:
|
294 |
+
results.append(tmp_result)
|
295 |
+
continue
|
296 |
+
|
297 |
+
# try by converting the output into a split up list too
|
298 |
+
if isinstance(output, list):
|
299 |
+
output = list(filter(len, output))
|
300 |
+
|
301 |
+
if debug:
|
302 |
+
nl = "\n"
|
303 |
+
if not isinstance(inputs, list):
|
304 |
+
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
305 |
+
else:
|
306 |
+
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
307 |
+
|
308 |
+
if tmp_result == True:
|
309 |
+
results.append(tmp_result)
|
310 |
+
continue
|
311 |
+
|
312 |
+
try:
|
313 |
+
tmp_result = (output == [in_outs["outputs"][index]])
|
314 |
+
if isinstance(in_outs["outputs"][index], list):
|
315 |
+
tmp_result = tmp_result or (output == in_outs["outputs"][index])
|
316 |
+
except Exception as e:
|
317 |
+
if debug:
|
318 |
+
print(f"Failed check3 exception = {e}")
|
319 |
+
pass
|
320 |
+
|
321 |
+
try:
|
322 |
+
output_float = [float(e) for e in output]
|
323 |
+
gt_float = [float(e) for e in in_outs['outputs'][index]]
|
324 |
+
tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
|
325 |
+
except Exception as e:
|
326 |
+
pass
|
327 |
+
try:
|
328 |
+
if isinstance(output[0], list):
|
329 |
+
output_float = [float(e) for e in output[0]]
|
330 |
+
gt_float = [float(e) for e in in_outs['outputs'][index][0]]
|
331 |
+
tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
|
332 |
+
except Exception as e:
|
333 |
+
pass
|
334 |
+
|
335 |
+
if tmp_result == True:
|
336 |
+
results.append(tmp_result)
|
337 |
+
continue
|
338 |
+
|
339 |
+
# try by converting the stuff into split up list
|
340 |
+
if isinstance(in_outs["outputs"][index], list):
|
341 |
+
for tmp_index, i in enumerate(in_outs["outputs"][index]):
|
342 |
+
in_outs["outputs"][index][tmp_index] = set(i.split())
|
343 |
+
else:
|
344 |
+
in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
|
345 |
+
|
346 |
+
try:
|
347 |
+
tmp_result = (output == in_outs["outputs"][index])
|
348 |
+
except Exception as e:
|
349 |
+
if debug:
|
350 |
+
print(f"Failed check4 exception = {e}")
|
351 |
+
continue
|
352 |
+
|
353 |
+
if tmp_result == True:
|
354 |
+
results.append(tmp_result)
|
355 |
+
continue
|
356 |
+
|
357 |
+
# try by converting the output into a split up list too
|
358 |
+
if isinstance(output, list):
|
359 |
+
for tmp_index, i in enumerate(output):
|
360 |
+
output[tmp_index] = i.split()
|
361 |
+
output = list(filter(len, output))
|
362 |
+
for tmp_index, i in enumerate(output):
|
363 |
+
output[tmp_index] = set(i)
|
364 |
+
else:
|
365 |
+
output = output.split()
|
366 |
+
output = list(filter(len, output))
|
367 |
+
output = set(output)
|
368 |
+
|
369 |
+
try:
|
370 |
+
tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
|
371 |
+
except Exception as e:
|
372 |
+
if debug:
|
373 |
+
print(f"Failed check5 exception = {e}")
|
374 |
+
|
375 |
+
|
376 |
+
# if they are all numbers, round so that similar numbers are treated as identical
|
377 |
+
try:
|
378 |
+
tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
|
379 |
+
set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
|
380 |
+
except Exception as e:
|
381 |
+
if debug:
|
382 |
+
print(f"Failed check6 exception = {e}")
|
383 |
+
|
384 |
+
if tmp_result == True and debug:
|
385 |
+
print("PASSED")
|
386 |
+
|
387 |
+
results.append(tmp_result)
|
388 |
+
|
389 |
+
if debug:
|
390 |
+
nl = "\n"
|
391 |
+
if not isinstance(inputs, list):
|
392 |
+
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
393 |
+
else:
|
394 |
+
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
|
395 |
+
|
396 |
+
|
397 |
+
return results
|
398 |
+
|
399 |
+
|
400 |
+
def custom_compare_(output, ground_truth):
|
401 |
+
|
402 |
+
if isinstance(output, list):
|
403 |
+
output_1 = "\n".join(output)
|
404 |
+
if stripped_string_compare(output_1, ground_truth):
|
405 |
+
return True
|
406 |
+
|
407 |
+
if isinstance(output, list):
|
408 |
+
output_2 = [o.lstrip().rstrip() for o in output]
|
409 |
+
output_2 = "\n".join(output_2)
|
410 |
+
if stripped_string_compare(output_2, ground_truth):
|
411 |
+
return True
|
412 |
+
|
413 |
+
return False
|
414 |
+
|
415 |
+
def stripped_string_compare(s1, s2):
|
416 |
+
s1 = s1.lstrip().rstrip()
|
417 |
+
s2 = s2.lstrip().rstrip()
|
418 |
+
return s1 == s2
|
419 |
+
|
420 |
+
def call_method(method, inputs):
|
421 |
+
|
422 |
+
if isinstance(inputs, list):
|
423 |
+
inputs = "\n".join(inputs)
|
424 |
+
|
425 |
+
inputs_line_iterator = iter(inputs.split("\n"))
|
426 |
+
|
427 |
+
# sys.setrecursionlimit(10000)
|
428 |
+
|
429 |
+
# @patch('builtins.input', side_effect=inputs.split("\n"))
|
430 |
+
@patch('builtins.open', mock_open(read_data=inputs))
|
431 |
+
@patch('sys.stdin', StringIO(inputs))
|
432 |
+
@patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
|
433 |
+
@patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
|
434 |
+
@patch('sys.stdin.read', lambda *args: inputs)
|
435 |
+
# @patch('sys.stdout.write', print)
|
436 |
+
def _inner_call_method(_method):
|
437 |
+
try:
|
438 |
+
return _method()
|
439 |
+
except SystemExit as e:
|
440 |
+
pass
|
441 |
+
finally:
|
442 |
+
pass
|
443 |
+
return _inner_call_method(method)
|
444 |
+
|
445 |
+
|
446 |
+
|
447 |
+
|
448 |
+
def reliability_guard(maximum_memory_bytes=None):
|
449 |
+
"""
|
450 |
+
This disables various destructive functions and prevents the generated code
|
451 |
+
from interfering with the test (e.g. fork bomb, killing other processes,
|
452 |
+
removing filesystem files, etc.)
|
453 |
+
WARNING
|
454 |
+
This function is NOT a security sandbox. Untrusted code, including, model-
|
455 |
+
generated code, should not be blindly executed outside of one. See the
|
456 |
+
Codex paper for more information about OpenAI's code sandbox, and proceed
|
457 |
+
with caution.
|
458 |
+
"""
|
459 |
+
|
460 |
+
if maximum_memory_bytes is not None:
|
461 |
+
import resource
|
462 |
+
|
463 |
+
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
464 |
+
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
465 |
+
if not platform.uname().system == "Darwin":
|
466 |
+
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
467 |
+
|
468 |
+
faulthandler.disable()
|
469 |
+
|
470 |
+
import builtins
|
471 |
+
|
472 |
+
builtins.exit = None
|
473 |
+
builtins.quit = None
|
474 |
+
|
475 |
+
import os
|
476 |
+
|
477 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
478 |
+
|
479 |
+
os.kill = None
|
480 |
+
os.system = None
|
481 |
+
os.putenv = None
|
482 |
+
os.remove = None
|
483 |
+
os.removedirs = None
|
484 |
+
os.rmdir = None
|
485 |
+
os.fchdir = None
|
486 |
+
os.setuid = None
|
487 |
+
os.fork = None
|
488 |
+
os.forkpty = None
|
489 |
+
os.killpg = None
|
490 |
+
os.rename = None
|
491 |
+
os.renames = None
|
492 |
+
os.truncate = None
|
493 |
+
os.replace = None
|
494 |
+
os.unlink = None
|
495 |
+
os.fchmod = None
|
496 |
+
os.fchown = None
|
497 |
+
os.chmod = None
|
498 |
+
os.chown = None
|
499 |
+
os.chroot = None
|
500 |
+
os.fchdir = None
|
501 |
+
os.lchflags = None
|
502 |
+
os.lchmod = None
|
503 |
+
os.lchown = None
|
504 |
+
os.getcwd = None
|
505 |
+
os.chdir = None
|
506 |
+
|
507 |
+
import shutil
|
508 |
+
|
509 |
+
shutil.rmtree = None
|
510 |
+
shutil.move = None
|
511 |
+
shutil.chown = None
|
512 |
+
|
513 |
+
import subprocess
|
514 |
+
|
515 |
+
subprocess.Popen = None # type: ignore
|
516 |
+
|
517 |
+
__builtins__["help"] = None
|
518 |
+
|
519 |
+
import sys
|
520 |
+
|
521 |
+
sys.modules["ipdb"] = None
|
522 |
+
sys.modules["joblib"] = None
|
523 |
+
sys.modules["resource"] = None
|
524 |
+
sys.modules["psutil"] = None
|
525 |
+
sys.modules["tkinter"] = None
|
tests.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from evaluate import load
|
3 |
+
|
4 |
+
solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
|
5 |
+
solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
|
6 |
+
single_solutions = [solution_sample1[:1], solution_sample2[:1]]
|
7 |
+
multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
|
8 |
+
|
9 |
+
metric = load("codeparrot/apps_metric")
|
10 |
+
result_1 = metric.compute(predictions=single_solutions, level="all")
|
11 |
+
result_2 = metric.compute(predictions=multiple_solutions, level="all", k_list=[1, 2, 3])
|
12 |
+
|
13 |
+
assert result_1 == {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
|
14 |
+
assert result_2 == {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
utils.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import multiprocessing
|
4 |
+
import numpy as np
|
5 |
+
from typing import Dict
|
6 |
+
from datasets import load_dataset
|
7 |
+
from .testing_util import run_test
|
8 |
+
|
9 |
+
DATASET = "codeparrot/apps"
|
10 |
+
TIMEOUT = 10
|
11 |
+
|
12 |
+
def check_correctness(sample, generation, timeout, debug=True):
|
13 |
+
"""Check correctness of code generation with a global timeout.
|
14 |
+
The global timeout is to catch some extreme/rare cases not handled by the timeouts
|
15 |
+
inside `run_test`"""
|
16 |
+
def _temp_run(sample, generation, debug, result):
|
17 |
+
result.append(run_test(sample, test=generation, debug=debug))
|
18 |
+
|
19 |
+
manager = multiprocessing.Manager()
|
20 |
+
result = manager.list()
|
21 |
+
p = multiprocessing.Process(target=_temp_run, args=(sample, generation, debug, result))
|
22 |
+
p.start()
|
23 |
+
p.join(timeout=timeout + 1)
|
24 |
+
if p.is_alive():
|
25 |
+
p.kill()
|
26 |
+
if not result:
|
27 |
+
in_outs = json.loads(sample["input_output"])
|
28 |
+
# consider that all tests failed
|
29 |
+
result = [[-1 for i in range(len(in_outs["inputs"]))]]
|
30 |
+
if debug:
|
31 |
+
print(f"global timeout")
|
32 |
+
return result[0]
|
33 |
+
|
34 |
+
|
35 |
+
def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
|
36 |
+
"""We take the list of code generations and try to compile them
|
37 |
+
and the run their corresponding unit tests which are retrieved from the APPS dataset.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
generations: list of code generations (same order as samples in APPS dataset)
|
41 |
+
level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
results: dictionary of results, key is the problem index, value is a list of results for each generation
|
45 |
+
[-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
|
46 |
+
"""
|
47 |
+
|
48 |
+
# generations are code generations in the same order of the dataset
|
49 |
+
apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
|
50 |
+
results = {}
|
51 |
+
for index in range(len(generations)):
|
52 |
+
# code generations for problem (index)
|
53 |
+
problem_generations = generations[index]
|
54 |
+
# get corresponding samples from APPS dataset
|
55 |
+
sample = apps_eval[index]
|
56 |
+
res = []
|
57 |
+
# loop over the generations
|
58 |
+
for o_idx, o in enumerate(problem_generations):
|
59 |
+
curr_res = [-2]
|
60 |
+
try:
|
61 |
+
curr_res = check_correctness(sample, o, timeout=TIMEOUT, debug=debug)
|
62 |
+
if debug:
|
63 |
+
print(f"\nSuccessful compilation of task {index}!")
|
64 |
+
fixed = []
|
65 |
+
for e in curr_res:
|
66 |
+
if isinstance(e, np.ndarray):
|
67 |
+
e = e.item(0)
|
68 |
+
if isinstance(e, np.bool_):
|
69 |
+
e = bool(e)
|
70 |
+
fixed.append(e)
|
71 |
+
curr_res = fixed
|
72 |
+
if not np.all(curr_res):
|
73 |
+
if debug:
|
74 |
+
print(f"Results were not True for all test cases")
|
75 |
+
except Exception as e:
|
76 |
+
if debug:
|
77 |
+
print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
|
78 |
+
break
|
79 |
+
finally:
|
80 |
+
assert isinstance(curr_res, list)
|
81 |
+
res.append(curr_res)
|
82 |
+
results[index] = res
|
83 |
+
return results
|
84 |
+
|
85 |
+
|
86 |
+
def estimate_pass_at_k(num_samples, num_correct, k):
|
87 |
+
"""Estimates pass@k of each problem and returns them in an array."""
|
88 |
+
|
89 |
+
def estimator(n: int, c: int, k: int) -> float:
|
90 |
+
"""Calculates 1 - comb(n - c, k) / comb(n, k)."""
|
91 |
+
if n - c < k:
|
92 |
+
return 1.0
|
93 |
+
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
94 |
+
|
95 |
+
if isinstance(num_samples, int):
|
96 |
+
num_samples_it = itertools.repeat(num_samples, len(num_correct))
|
97 |
+
else:
|
98 |
+
assert len(num_samples) == len(num_correct)
|
99 |
+
num_samples_it = iter(num_samples)
|
100 |
+
|
101 |
+
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
|
102 |
+
|
103 |
+
|
104 |
+
def get_results(results: Dict[int, list], count_errors: bool = False, k_list: list = [1, 10, 100]):
|
105 |
+
"""
|
106 |
+
Given the results evaluated against the testcases we output some statistics.
|
107 |
+
For single generations:
|
108 |
+
>>> example_results = {0: [[-2]], 1: [[False,False]], 2: [[True,True]], 3: [[False,True,False,True]], 4: [[-1,-1]]}
|
109 |
+
>>> get_results(example_results, count_errors=True)
|
110 |
+
Computing accuracy metrics...
|
111 |
+
number of compile errors = 1 avg = 0.2
|
112 |
+
number of runtime errors = 1 avg = 0.2
|
113 |
+
number of problems evaluated = 5
|
114 |
+
Average Accuracy : 0.3
|
115 |
+
Strict Accuracy : 0.2
|
116 |
+
{'avg_accuracy': 0.3, 'strict_accuracy': 0.2, 'pass_at_k': None}
|
117 |
+
|
118 |
+
For multiple generations:
|
119 |
+
>>> example_results = {0: [[-2], [True, True, True]], 1: [[-1,-1, -1], [True, False, True]]}
|
120 |
+
>>> get_results(example_results, k_list=[1, 2])
|
121 |
+
Computing pass@k metric for multiple generations...
|
122 |
+
{'pass@1': 0.25, 'pass@2': 0.5}
|
123 |
+
{'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 0.25, 'pass@2': 0.5}}
|
124 |
+
"""
|
125 |
+
|
126 |
+
metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
|
127 |
+
|
128 |
+
if len(results[0]) == 1:
|
129 |
+
# for single generations we compute average accuracy and stric accuracy: original APPS metrics
|
130 |
+
print("Computing accuracy metrics...")
|
131 |
+
res = []
|
132 |
+
per_prob_res = []
|
133 |
+
all_correct = []
|
134 |
+
for index in results:
|
135 |
+
problem_results = np.asarray(results[index])
|
136 |
+
res.extend(problem_results)
|
137 |
+
per_prob_res.append(np.mean(problem_results > 0))
|
138 |
+
all_correct.append(np.all(problem_results > 0))
|
139 |
+
# we count campilation and runtime errors once per pronlem
|
140 |
+
compile_errors = len([e for e in res if -2 in e])
|
141 |
+
runtime_errors = len([e for e in res if -1 in e])
|
142 |
+
total_testcases = len(res)
|
143 |
+
if count_errors:
|
144 |
+
print(f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases}")
|
145 |
+
print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}")
|
146 |
+
print(f"number of problems evaluated = {total_testcases}")
|
147 |
+
|
148 |
+
print(f"Average Accuracy : {np.mean(per_prob_res)}")
|
149 |
+
print(f"Strict Accuracy : {np.mean(all_correct)}")
|
150 |
+
metrics["avg_accuracy"] = np.mean(per_prob_res)
|
151 |
+
metrics["strict_accuracy"] = np.mean(all_correct)
|
152 |
+
|
153 |
+
else:
|
154 |
+
# for multiple generations we use pass@k metric used in the HumanEval benchmark
|
155 |
+
# we use strict accuracy, a generation is valid if it has to pass all the tests
|
156 |
+
print("Computing pass@k metric for multiple generations...")
|
157 |
+
# total is list with nb generations per task (task=index)
|
158 |
+
# correct is number of generations that passed all tests per task
|
159 |
+
total = []
|
160 |
+
correct = []
|
161 |
+
for index in results:
|
162 |
+
all_correct = []
|
163 |
+
for generation in results[index]:
|
164 |
+
gen = np.array(generation)
|
165 |
+
all_correct.append(np.all(gen>0))
|
166 |
+
total.append(len(all_correct))
|
167 |
+
correct.append(sum(all_correct))
|
168 |
+
total = np.array(total)
|
169 |
+
correct = np.array(correct)
|
170 |
+
ks = k_list
|
171 |
+
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
|
172 |
+
print(pass_at_k)
|
173 |
+
metrics["pass_at_k"] = pass_at_k
|
174 |
+
return metrics
|
175 |
+
|
176 |
+
def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
|
177 |
+
"""Return metrics for the given generations.
|
178 |
+
Args:
|
179 |
+
generations: list of code generations for each problem (each generation is a list of generations)
|
180 |
+
k_list: list of k values to compute pass@k when using multiple generations
|
181 |
+
count_errors: whether to count compilation and runtime errors when using single generations
|
182 |
+
level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
|
183 |
+
Returns:
|
184 |
+
metrics: dict of metrics
|
185 |
+
|
186 |
+
Examples:
|
187 |
+
|
188 |
+
>>> import json
|
189 |
+
>>> # lists of solutions to the two first APPS problems (note not all solutions pass all tests)
|
190 |
+
>>> solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
|
191 |
+
>>> solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
|
192 |
+
>>> single_solutions = [solution_sample1[:1], solution_sample2[:1]]
|
193 |
+
>>> compute_metrics(single_solutions, level="all")
|
194 |
+
Computing accuracy metrics...
|
195 |
+
number of compile errors = 0 avg = 0.0
|
196 |
+
number of runtime errors = 0 avg = 0.0
|
197 |
+
number of problems evaluated = 2
|
198 |
+
Average Accuracy : 1.0
|
199 |
+
Strict Accuracy : 1.0
|
200 |
+
{'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
|
201 |
+
>>> multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
|
202 |
+
>>> compute_metrics(multiple_solutions, level="all", k_list=[1, 2, 3])
|
203 |
+
Computing pass@k metric for multiple generations...
|
204 |
+
{'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
|
205 |
+
{'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
|
206 |
+
"""
|
207 |
+
results = evaluate_generations(generations, level=level, debug=debug)
|
208 |
+
metrics = get_results(results, count_errors=count_errors, k_list=k_list)
|
209 |
+
return metrics
|
210 |
+
|
211 |
+
# import doctest
|
212 |
+
# doctest.testmod()
|