rookiemango commited on
Commit
dddc1ae
·
verified ·
1 Parent(s): 67945ec

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .github/workflows/ci.yml +24 -0
  3. .gitignore +7 -0
  4. .vscode/copyright.code-snippets +13 -0
  5. .vscode/extensions.json +13 -0
  6. .vscode/module-docstring.code-snippets +35 -0
  7. .vscode/settings.json +11 -0
  8. README.md +174 -0
  9. REPL.lean +4 -0
  10. REPL/Frontend.lean +47 -0
  11. REPL/JSON.lean +186 -0
  12. REPL/Lean/ContextInfo.lean +9 -0
  13. REPL/Lean/Environment.lean +31 -0
  14. REPL/Lean/InfoTree.lean +272 -0
  15. REPL/Lean/InfoTree/ToJson.lean +114 -0
  16. REPL/Main.lean +323 -0
  17. REPL/Snapshots.lean +306 -0
  18. REPL/Util/Path.lean +36 -0
  19. REPL/Util/Pickle.lean +44 -0
  20. __pycache__/code.cpython-310.pyc +0 -0
  21. __pycache__/code.cpython-39.pyc +0 -0
  22. __pycache__/openllm_pass_rate_new_test.cpython-39.pyc +0 -0
  23. all_code.py +159 -0
  24. basic_working.json +0 -0
  25. code.py +69 -0
  26. data/basic_working.json +0 -0
  27. data/notlean_dependency.json +3 -0
  28. gpt_pass_rate_multi_pass.py +54 -0
  29. gpt_pass_rate_new_notlean_test.py +289 -0
  30. gpt_pass_rate_new_test.py +287 -0
  31. lake-manifest.json +68 -0
  32. lakefile.lean +17 -0
  33. lean-toolchain +1 -0
  34. nohup.out +4 -0
  35. openllm_pass_rate_multi_pass.py +106 -0
  36. openllm_pass_rate_new_notlean_test.py +265 -0
  37. openllm_pass_rate_new_test.py +306 -0
  38. pass_rate.py +194 -0
  39. pass_rate_atp_pass.py +112 -0
  40. pass_rate_atp_test.py +264 -0
  41. pass_rate_found_item.py +175 -0
  42. pass_rate_multi.py +48 -0
  43. pass_rate_multi_notlean.py +40 -0
  44. pass_rate_multi_notlean_pass.py +43 -0
  45. pass_rate_multi_pass.py +112 -0
  46. pass_rate_new.py +196 -0
  47. pass_rate_new_test.py +255 -0
  48. pass_rate_new_test_allcontent.py +255 -0
  49. pass_rate_notlean.py +202 -0
  50. pass_rate_notlean_test.py +261 -0
.gitattributes CHANGED
@@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ results.json filter=lfs diff=lfs merge=lfs -text
37
+ test/15k_state_problem_translation.json filter=lfs diff=lfs merge=lfs -text
38
+ test/lean4_random/1k_test.json filter=lfs diff=lfs merge=lfs -text
39
+ test/lean4_random/5k_first.json filter=lfs diff=lfs merge=lfs -text
40
+ test/lean4_random/5k_second.json filter=lfs diff=lfs merge=lfs -text
41
+ test/lean4_random/5k_third.json filter=lfs diff=lfs merge=lfs -text
42
+ test/result.json filter=lfs diff=lfs merge=lfs -text
43
+ test/zero_shot/lean4_basic_test/generation/lean4_random_5k_first_1epoch/1/result.json filter=lfs diff=lfs merge=lfs -text
44
+ test/zero_shot/lean4_basic_test/generation/lean4_random_5k_first_2epoch/1/result.json filter=lfs diff=lfs merge=lfs -text
45
+ test/zero_shot/lean4_basic_test/generation/lean4_random_5k_first_3epoch/1/result.json filter=lfs diff=lfs merge=lfs -text
46
+ test/zero_shot/lean4_random_test/generation/lean4_random_5k_first_1epoch/1/result.json filter=lfs diff=lfs merge=lfs -text
47
+ test/zero_shot/lean4_random_test/generation/lean4_random_5k_first_2epoch/1/result.json filter=lfs diff=lfs merge=lfs -text
48
+ test/zero_shot/lean4_random_test/generation/lean4_random_5k_first_3epoch/1/result.json filter=lfs diff=lfs merge=lfs -text
49
+ test/zero_shot/math_train/generation/lean4_random_15k_all/2/1/0.json filter=lfs diff=lfs merge=lfs -text
50
+ test/zero_shot/math_train/generation/lean4_random_15k_all/2/1/1.json filter=lfs diff=lfs merge=lfs -text
51
+ gpt_result/lean_basic/gpt3/1.jsonl filter=lfs diff=lfs merge=lfs -text
52
+ gpt_result/lean_basic/gpt3/2.jsonl filter=lfs diff=lfs merge=lfs -text
53
+ gpt_result/lean_basic/gpt3/3.jsonl filter=lfs diff=lfs merge=lfs -text
54
+ gpt_result/lean_basic/gpt3/4.jsonl filter=lfs diff=lfs merge=lfs -text
55
+ gpt_result/lean_basic/gpt3/5.jsonl filter=lfs diff=lfs merge=lfs -text
56
+ gpt_result/lean_basic/gpt4/1.jsonl filter=lfs diff=lfs merge=lfs -text
57
+ gpt_result/lean_basic/gpt4/2.jsonl filter=lfs diff=lfs merge=lfs -text
58
+ gpt_result/lean_basic/gpt4/3.jsonl filter=lfs diff=lfs merge=lfs -text
59
+ gpt_result/lean_basic/gpt4/4.jsonl filter=lfs diff=lfs merge=lfs -text
60
+ gpt_result/lean_basic/gpt4/5.jsonl filter=lfs diff=lfs merge=lfs -text
61
+ gpt_result/lean_random/gpt3/1.jsonl filter=lfs diff=lfs merge=lfs -text
62
+ gpt_result/lean_random/gpt3/2.jsonl filter=lfs diff=lfs merge=lfs -text
63
+ gpt_result/lean_random/gpt3/3.jsonl filter=lfs diff=lfs merge=lfs -text
64
+ gpt_result/lean_random/gpt3/4.jsonl filter=lfs diff=lfs merge=lfs -text
65
+ gpt_result/lean_random/gpt3/5.jsonl filter=lfs diff=lfs merge=lfs -text
66
+ gpt_result/lean_random/gpt4/1.jsonl filter=lfs diff=lfs merge=lfs -text
67
+ gpt_result/lean_random/gpt4/2.jsonl filter=lfs diff=lfs merge=lfs -text
68
+ gpt_result/lean_random/gpt4/3.jsonl filter=lfs diff=lfs merge=lfs -text
69
+ gpt_result/lean_random/gpt4/4.jsonl filter=lfs diff=lfs merge=lfs -text
70
+ gpt_result/lean_random/gpt4/5.jsonl filter=lfs diff=lfs merge=lfs -text
.github/workflows/ci.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Tests
2
+
3
+ on: [push, pull_request]
4
+
5
+ jobs:
6
+ test:
7
+ runs-on: ubuntu-latest
8
+
9
+ steps:
10
+ - name: Checkout code
11
+ uses: actions/checkout@v2
12
+
13
+ - name: install elan
14
+ run: |
15
+ set -o pipefail
16
+ curl -sSfL https://github.com/leanprover/elan/releases/download/v3.0.0/elan-x86_64-unknown-linux-gnu.tar.gz | tar xz
17
+ ./elan-init -y --default-toolchain none
18
+ echo "$HOME/.elan/bin" >> $GITHUB_PATH
19
+
20
+ - name: build
21
+ run: lake build
22
+
23
+ - name: Run tests
24
+ run: ./test.sh
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ /build
2
+ /lake-packages/*
3
+ /lakefile.olean
4
+ /.lake
5
+ /test/Mathlib/.lake
6
+ /test/*.olean
7
+ /test/*.olean.tmp
.vscode/copyright.code-snippets ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Copyright header for mathlib": {
3
+ "scope": "lean4",
4
+ "prefix": "copyright",
5
+ "body": [
6
+ "/-",
7
+ "Copyright (c) ${CURRENT_YEAR} $1. All rights reserved.",
8
+ "Released under Apache 2.0 license as described in the file LICENSE.",
9
+ "Authors: $1",
10
+ "-/"
11
+ ]
12
+ }
13
+ }
.vscode/extensions.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations.
3
+ // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp
4
+
5
+ // List of extensions which should be recommended for users of this workspace.
6
+ "recommendations": [
7
+ "leanprover.lean4"
8
+ ],
9
+ // List of extensions recommended by VS Code that should not be recommended for users of this workspace.
10
+ "unwantedRecommendations": [
11
+ "ms-vscode-remote.remote-containers"
12
+ ]
13
+ }
.vscode/module-docstring.code-snippets ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Module docstring for mathlib": {
3
+ "scope": "lean4",
4
+ "prefix": "module docstring",
5
+ "body": [
6
+ "/-!",
7
+ "# ${TM_FILENAME_BASE/([^_]*)(_?)/${1:/capitalize}${2:+ }/g}",
8
+ "",
9
+ "## Main definitions",
10
+ "",
11
+ "* `FooBar`",
12
+ "",
13
+ "## Main statements",
14
+ "",
15
+ "* `fooBar_unique`",
16
+ "",
17
+ "## Notation",
18
+ "",
19
+ "",
20
+ "",
21
+ "## Implementation details",
22
+ "",
23
+ "",
24
+ "",
25
+ "## References",
26
+ "",
27
+ "* [F. Bar, *Quuxes*][bibkey]",
28
+ "",
29
+ "## Tags",
30
+ "",
31
+ "Foobars, barfoos",
32
+ "-/",
33
+ "",
34
+ ]},
35
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.insertSpaces": true,
3
+ "editor.tabSize": 2,
4
+ "editor.rulers" : [100],
5
+ "files.encoding": "utf8",
6
+ "files.eol": "\n",
7
+ "files.insertFinalNewline": true,
8
+ // We don't use this: it messes up our test files!
9
+ // "files.trimFinalNewlines": true,
10
+ "files.trimTrailingWhitespace": true,
11
+ }
README.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A read-eval-print-loop for Lean 4
2
+
3
+ Run using `lake exe repl`.
4
+ Communicates via JSON on stdin and stdout.
5
+ Commands should be separated by blank lines.
6
+
7
+ The REPL works both in "command" mode and "tactic" mode.
8
+
9
+ ## Command mode
10
+
11
+ In command mode, you send complete commands (e.g. declarations) to the REPL.
12
+
13
+ Commands may be of the form
14
+
15
+ ```json
16
+ { "cmd" : "def f := 2" }
17
+ ```
18
+
19
+ ```json
20
+ { "cmd" : "example : f = 2 := rfl", "env" : 1 }
21
+ ```
22
+
23
+ The `env` field, if present,
24
+ must contain a number received in the `env` field of a previous response,
25
+ and causes the command to be run in the existing environment.
26
+
27
+ If there is no `env` field, a new environment is created.
28
+
29
+ You can only use `import` commands when you do not specify the `env` field.
30
+
31
+ You can backtrack simply by using earlier values for `env`.
32
+
33
+ The response includes:
34
+ * A numeric label for the `Environment` after your command,
35
+ which you can use as the starting point for subsequent commands.
36
+ * Any messages generated while processing your command.
37
+ * A list of the `sorry`s in your command, including
38
+ * their expected type, and
39
+ * a numeric label for the proof state at the `sorry`, which you can then use in tactic mode.
40
+
41
+ Example output:
42
+
43
+ ```json
44
+ {"sorries":
45
+ [{"pos": {"line": 1, "column": 18},
46
+ "endPos": {"line": 1, "column": 23},
47
+ "goal": "⊢ Nat",
48
+ "proofState": 0}],
49
+ "messages":
50
+ [{"severity": "error",
51
+ "pos": {"line": 1, "column": 23},
52
+ "endPos": {"line": 1, "column": 26},
53
+ "data":
54
+ "type mismatch\n rfl\nhas type\n f = f : Prop\nbut is expected to have type\n f = 2 : Prop"}],
55
+ "env": 6}
56
+ ```
57
+
58
+ showing any messages generated, and sorries with their goal states.
59
+
60
+ ## File mode
61
+
62
+ There is a simple wrapper around command mode that allows reading in an entire file.
63
+
64
+ If `test/file.lean` contains
65
+ ```lean
66
+ def f : Nat := 37
67
+
68
+ def g := 2
69
+
70
+ theorem h : f + g = 39 := by exact rfl
71
+ ```
72
+
73
+ then
74
+ ```
75
+ echo '{"path": "test/file.lean", "allTactics": true}' | lake exe repl
76
+ ```
77
+ results in output
78
+ ```json
79
+ {"tactics":
80
+ [{"tactic": "exact rfl",
81
+ "proofState": 0,
82
+ "pos": {"line": 5, "column": 29},
83
+ "goals": "⊢ f + g = 39",
84
+ "endPos": {"line": 5, "column": 38}}],
85
+ "env": 0}
86
+ ```
87
+
88
+ ## Tactic mode (experimental)
89
+
90
+ To enter tactic mode issue a command containing a `sorry`,
91
+ and then use the `proofState` index returned for each `sorry`.
92
+
93
+ Example usage:
94
+ ```json
95
+ {"cmd" : "def f (x : Unit) : Nat := by sorry"}
96
+
97
+ {"sorries":
98
+ [{"proofState": 0,
99
+ "pos": {"line": 1, "column": 29},
100
+ "goal": "x : Unit\n⊢ Nat",
101
+ "endPos": {"line": 1, "column": 34}}],
102
+ "messages":
103
+ [{"severity": "warning",
104
+ "pos": {"line": 1, "column": 4},
105
+ "endPos": {"line": 1, "column": 5},
106
+ "data": "declaration uses 'sorry'"}],
107
+ "env": 0}
108
+
109
+ {"tactic": "apply Int.natAbs", "proofState": 0}
110
+
111
+ {"proofState": 1, "goals": ["x : Unit\n⊢ Int"]}
112
+
113
+ {"tactic": "exact -37", "proofState": 1}
114
+
115
+ {"proofState": 2, "goals": []}
116
+ ```
117
+
118
+ You can use `sorry` in tactic mode.
119
+ The result will contain additional `proofState` identifiers for the goal at each sorry.
120
+
121
+ At present there is nothing you can do with a completed proof state:
122
+ we would like to extend this so that you can replace the original `sorry` with your tactic script,
123
+ and obtain the resulting `Environment`
124
+
125
+ ## Pickling
126
+
127
+ The REPL supports pickling environments and proof states to disk as `.olean` files.
128
+ As long as the same imports are available, it should be possible to move such an `.olean` file
129
+ to another machine and unpickle into a new REPL session.
130
+
131
+ The commands are
132
+
133
+ ```json
134
+ {"pickleTo": "path/to/file.olean", "env": 7}
135
+
136
+ {"pickleTo": "path/to/file.olean", "proofState": 17}
137
+
138
+ {"unpickleEnvFrom": "path/to/file.olean"}
139
+
140
+ {"unpickleProofStateFrom": "path/to/file.olean"}
141
+ ```
142
+
143
+ The unpickling commands will report the new "env" or "proofState" identifier that
144
+ you can use in subsequent commands.
145
+
146
+ Pickling is quite efficient:
147
+ * we don't record full `Environment`s, only the changes relative to imports
148
+ * unpickling uses memory mapping
149
+ * file sizes are generally small, but see https://github.com/digama0/leangz if compression is
150
+ desirable
151
+
152
+ ## Using the REPL from another project
153
+
154
+ Set up your project as usual using `lake new` or `lake init`
155
+ (or the interactive setup GUI available via the VSCode extension under the `∀` menu).
156
+
157
+ In that project, add `require` statements in the `lakefile.lean` for any dependencies you need
158
+ (e.g. Mathlib). (You probably should verify that `lake build` works as expected in that project.)
159
+
160
+ Now you can run the REPL as:
161
+ ```shell
162
+ lake env ../path/to/repl/.lake/build/bin/repl < commands.in
163
+ ```
164
+ (Here `../path/to/repl/` represents the path to your checkout of this repository,
165
+ in which you've already run `lake build`.)
166
+
167
+ The `lake env` prefix sets up the environment associated to your local project, so that the REPL
168
+ can find needed imports.
169
+
170
+ ## Future work
171
+
172
+ * Replay tactic scripts from tactic mode back into the original `sorry`.
173
+ * Currently if you create scoped environment extensions (e.g. scoped notations) in a session
174
+ these are not correctly pickled and unpickled in later sessions.
REPL.lean ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import REPL.Frontend
2
+ import REPL.Lean.InfoTree
3
+ import REPL.JSON
4
+ import REPL.Main
REPL/Frontend.lean ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2023 Scott Morrison. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Scott Morrison
5
+ -/
6
+ import Lean.Elab.Frontend
7
+
8
+ open Lean Elab
9
+
10
+ namespace Lean.Elab.IO
11
+
12
+ /--
13
+ Wrapper for `IO.processCommands` that enables info states, and returns
14
+ * the new command state
15
+ * messages
16
+ * info trees
17
+ -/
18
+ def processCommandsWithInfoTrees
19
+ (inputCtx : Parser.InputContext) (parserState : Parser.ModuleParserState)
20
+ (commandState : Command.State) : IO (Command.State × List Message × List InfoTree) := do
21
+ let commandState := { commandState with infoState.enabled := true }
22
+ let s ← IO.processCommands inputCtx parserState commandState <&> Frontend.State.commandState
23
+ pure (s, s.messages.msgs.toList, s.infoState.trees.toList)
24
+
25
+ /--
26
+ Process some text input, with or without an existing command state.
27
+ If there is no existing environment, we parse the input for headers (e.g. import statements),
28
+ and create a new environment.
29
+ Otherwise, we add to the existing environment.
30
+
31
+ Returns the resulting command state, along with a list of messages and info trees.
32
+ -/
33
+ def processInput (input : String) (cmdState? : Option Command.State)
34
+ (opts : Options := {}) (fileName : Option String := none) :
35
+ IO (Command.State × List Message × List InfoTree) := unsafe do
36
+ Lean.initSearchPath (← Lean.findSysroot)
37
+ enableInitializersExecution
38
+ let fileName := fileName.getD "<input>"
39
+ let inputCtx := Parser.mkInputContext input fileName
40
+ let (parserState, commandState) ← match cmdState? with
41
+ | none => do
42
+ let (header, parserState, messages) ← Parser.parseHeader inputCtx
43
+ let (env, messages) ← processHeader header opts messages inputCtx
44
+ pure (parserState, (Command.mkState env messages opts))
45
+ | some cmdState => do
46
+ pure ({ : Parser.ModuleParserState }, cmdState)
47
+ processCommandsWithInfoTrees inputCtx parserState commandState
REPL/JSON.lean ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2023 Scott Morrison. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Scott Morrison
5
+ -/
6
+ import Lean.Data.Json
7
+ import Lean.Message
8
+ import Lean.Elab.InfoTree.Main
9
+
10
+ open Lean Elab InfoTree
11
+
12
+ namespace REPL
13
+
14
+ structure CommandOptions where
15
+ allTactics : Option Bool := none
16
+ /--
17
+ Should be "full", "tactics", "original", or "substantive".
18
+ Anything else is ignored.
19
+ -/
20
+ infotree : Option String
21
+
22
+ /-- Run Lean commands.
23
+ If `env = none`, starts a new session (in which you can use `import`).
24
+ If `env = some n`, builds on the existing environment `n`.
25
+ -/
26
+ structure Command extends CommandOptions where
27
+ env : Option Nat
28
+ cmd : String
29
+ deriving ToJson, FromJson
30
+
31
+ /-- Process a Lean file in a fresh environment. -/
32
+ structure File extends CommandOptions where
33
+ path : System.FilePath
34
+ deriving FromJson
35
+
36
+ /--
37
+ Run a tactic in a proof state.
38
+ -/
39
+ structure ProofStep where
40
+ proofState : Nat
41
+ tactic : String
42
+ deriving ToJson, FromJson
43
+
44
+ /-- Line and column information for error messages and sorries. -/
45
+ structure Pos where
46
+ line : Nat
47
+ column : Nat
48
+ deriving ToJson, FromJson
49
+
50
+ /-- Severity of a message. -/
51
+ inductive Severity
52
+ | trace | info | warning | error
53
+ deriving ToJson, FromJson
54
+
55
+ /-- A Lean message. -/
56
+ structure Message where
57
+ pos : Pos
58
+ endPos : Option Pos
59
+ severity : Severity
60
+ data : String
61
+ deriving ToJson, FromJson
62
+
63
+ /-- Construct the JSON representation of a Lean message. -/
64
+ def Message.of (m : Lean.Message) : IO Message := do pure <|
65
+ { pos := ⟨m.pos.line, m.pos.column⟩,
66
+ endPos := m.endPos.map fun p => ⟨p.line, p.column⟩,
67
+ severity := match m.severity with
68
+ | .information => .info
69
+ | .warning => .warning
70
+ | .error => .error,
71
+ data := (← m.data.toString).trim }
72
+
73
+ /-- A Lean `sorry`. -/
74
+ structure Sorry where
75
+ pos : Pos
76
+ endPos : Pos
77
+ goal : String
78
+ /--
79
+ The index of the proof state at the sorry.
80
+ You can use the `ProofStep` instruction to run a tactic at this state.
81
+ -/
82
+ proofState : Option Nat
83
+ deriving FromJson
84
+
85
+ instance : ToJson Sorry where
86
+ toJson r := Json.mkObj <| .join [
87
+ [("goal", r.goal)],
88
+ [("proofState", toJson r.proofState)],
89
+ if r.pos.line ≠ 0 then [("pos", toJson r.pos)] else [],
90
+ if r.endPos.line ≠ 0 then [("endPos", toJson r.endPos)] else [],
91
+ ]
92
+
93
+ /-- Construct the JSON representation of a Lean sorry. -/
94
+ def Sorry.of (goal : String) (pos endPos : Lean.Position) (proofState : Option Nat) : Sorry :=
95
+ { pos := ⟨pos.line, pos.column⟩,
96
+ endPos := ⟨endPos.line, endPos.column⟩,
97
+ goal,
98
+ proofState }
99
+
100
+ structure Tactic where
101
+ pos : Pos
102
+ endPos : Pos
103
+ goals : String
104
+ tactic : String
105
+ proofState : Option Nat
106
+ deriving ToJson, FromJson
107
+
108
+ /-- Construct the JSON representation of a Lean tactic. -/
109
+ def Tactic.of (goals tactic : String) (pos endPos : Lean.Position) (proofState : Option Nat) : Tactic :=
110
+ { pos := ⟨pos.line, pos.column⟩,
111
+ endPos := ⟨endPos.line, endPos.column⟩,
112
+ goals,
113
+ tactic,
114
+ proofState }
115
+
116
+ /--
117
+ A response to a Lean command.
118
+ `env` can be used in later calls, to build on the stored environment.
119
+ -/
120
+ structure CommandResponse where
121
+ env : Nat
122
+ messages : List Message := []
123
+ sorries : List Sorry := []
124
+ tactics : List Tactic := []
125
+ infotree : Option Json := none
126
+ deriving FromJson
127
+
128
+ def Json.nonemptyList [ToJson α] (k : String) : List α → List (String × Json)
129
+ | [] => []
130
+ | l => [⟨k, toJson l⟩]
131
+
132
+ instance : ToJson CommandResponse where
133
+ toJson r := Json.mkObj <| .join [
134
+ [("env", r.env)],
135
+ Json.nonemptyList "messages" r.messages,
136
+ Json.nonemptyList "sorries" r.sorries,
137
+ Json.nonemptyList "tactics" r.tactics,
138
+ match r.infotree with | some j => [("infotree", j)] | none => []
139
+ ]
140
+
141
+ /--
142
+ A response to a Lean tactic.
143
+ `proofState` can be used in later calls, to run further tactics.
144
+ -/
145
+ structure ProofStepResponse where
146
+ proofState : Nat
147
+ goals : List String
148
+ messages : List Message := []
149
+ sorries : List Sorry := []
150
+ traces : List String
151
+ deriving ToJson, FromJson
152
+
153
+ instance : ToJson ProofStepResponse where
154
+ toJson r := Json.mkObj <| .join [
155
+ [("proofState", r.proofState)],
156
+ [("goals", toJson r.goals)],
157
+ Json.nonemptyList "messages" r.messages,
158
+ Json.nonemptyList "sorries" r.sorries,
159
+ Json.nonemptyList "traces" r.traces
160
+ ]
161
+
162
+ /-- Json wrapper for an error. -/
163
+ structure Error where
164
+ message : String
165
+ deriving ToJson, FromJson
166
+
167
+ structure PickleEnvironment where
168
+ env : Nat
169
+ pickleTo : System.FilePath
170
+ deriving ToJson, FromJson
171
+
172
+ structure UnpickleEnvironment where
173
+ unpickleEnvFrom : System.FilePath
174
+ deriving ToJson, FromJson
175
+
176
+ structure PickleProofState where
177
+ proofState : Nat
178
+ pickleTo : System.FilePath
179
+ deriving ToJson, FromJson
180
+
181
+ structure UnpickleProofState where
182
+ unpickleProofStateFrom : System.FilePath
183
+ env : Option Nat
184
+ deriving ToJson, FromJson
185
+
186
+ end REPL
REPL/Lean/ContextInfo.lean ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import Lean
2
+
3
+ namespace Lean.Elab.ContextInfo
4
+
5
+ /-- Pretty print an expression in the given `ContextInfo` with the given `LocalContext`. -/
6
+ def ppExpr (ctx : ContextInfo) (lctx : LocalContext) (e : Expr) : IO Format :=
7
+ ctx.runMetaM lctx (do Meta.ppExpr (← instantiateMVars e))
8
+
9
+ end Lean.Elab.ContextInfo
REPL/Lean/Environment.lean ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import REPL.Util.Pickle
2
+ import Lean.Replay
3
+
4
+ open System (FilePath)
5
+
6
+ namespace Lean.Environment
7
+
8
+ /--
9
+ Pickle an `Environment` to disk.
10
+
11
+ We only store:
12
+ * the list of imports
13
+ * the new constants from `Environment.constants`
14
+ and when unpickling, we build a fresh `Environment` from the imports,
15
+ and then add the new constants.
16
+ -/
17
+ def pickle (env : Environment) (path : FilePath) : IO Unit :=
18
+ _root_.pickle path (env.header.imports, env.constants.map₂)
19
+
20
+ /--
21
+ Unpickle an `Environment` from disk.
22
+
23
+ We construct a fresh `Environment` with the relevant imports,
24
+ and then replace the new constants.
25
+ -/
26
+ def unpickle (path : FilePath) : IO (Environment × CompactedRegion) := unsafe do
27
+ let ((imports, map₂), region) ← _root_.unpickle (Array Import × PHashMap Name ConstantInfo) path
28
+ let env ← importModules imports {} 0
29
+ return (← env.replay (HashMap.ofList map₂.toList), region)
30
+
31
+ end Lean.Environment
REPL/Lean/InfoTree.lean ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2023 Scott Morrison. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Scott Morrison
5
+ -/
6
+ import Lean
7
+
8
+ /-!
9
+ Additional functions to deal with `InfoTree`.
10
+ -/
11
+
12
+ open Lean Elab Meta
13
+
14
+ namespace Lean.FileMap
15
+
16
+ /-- Extract the range of a `Syntax` expressed as lines and columns. -/
17
+ -- Extracted from the private declaration `Lean.Elab.formatStxRange`,
18
+ -- in `Lean.Elab.InfoTree.Main`.
19
+ def stxRange (fileMap : FileMap) (stx : Syntax) : Position × Position :=
20
+ let pos := stx.getPos?.getD 0
21
+ let endPos := stx.getTailPos?.getD pos
22
+ (fileMap.toPosition pos, fileMap.toPosition endPos)
23
+
24
+ end Lean.FileMap
25
+
26
+ namespace Lean.Syntax
27
+
28
+ /-- Check if a `Syntax` is an explicit invocation of the `sorry` tactic. -/
29
+ def isSorryTactic (stx : Syntax) : Bool :=
30
+ s!"{stx}" = "(Tactic.tacticSorry \"sorry\")"
31
+
32
+ /-- Check if a `Syntax` is an explicit `sorry` term. -/
33
+ def isSorryTerm (stx : Syntax) : Bool :=
34
+ s!"{stx}" = "(Term.sorry \"sorry\")"
35
+
36
+ end Lean.Syntax
37
+
38
+ namespace Lean.Elab
39
+
40
+ /-- Extract the range of a `Syntax` expressed as lines and columns. -/
41
+ -- Extracted from the private declaration `Lean.Elab.formatStxRange`,
42
+ -- in `Lean.Elab.InfoTree.Main`.
43
+ def stxRange (fileMap : FileMap) (stx : Syntax) : Position × Position :=
44
+ let pos := stx.getPos?.getD 0
45
+ let endPos := stx.getTailPos?.getD pos
46
+ (fileMap.toPosition pos, fileMap.toPosition endPos)
47
+
48
+ end Lean.Elab
49
+
50
+ namespace Lean.Elab.Info
51
+
52
+ /-- The type of a `Lean.Elab.Info`, as a string. -/
53
+ def kind : Info → String
54
+ | .ofTacticInfo _ => "TacticInfo"
55
+ | .ofTermInfo _ => "TermInfo"
56
+ | .ofCommandInfo _ => "CommmandInfo"
57
+ | .ofMacroExpansionInfo _ => "MacroExpansionInfo"
58
+ | .ofOptionInfo _ => "OptionInfo"
59
+ | .ofFieldInfo _ => "FieldInfo"
60
+ | .ofCompletionInfo _ => "CompletionInfo"
61
+ | .ofUserWidgetInfo _ => "UserWidgetInfo"
62
+ | .ofCustomInfo _ => "CustomInfo"
63
+ | .ofFVarAliasInfo _ => "FVarAliasInfo"
64
+ | .ofFieldRedeclInfo _ => "FieldRedeclInfo"
65
+ | .ofOmissionInfo _ => "OmissionInfo"
66
+
67
+ /-- The `Syntax` for a `Lean.Elab.Info`, if there is one. -/
68
+ def stx? : Info → Option Syntax
69
+ | .ofTacticInfo info => info.stx
70
+ | .ofTermInfo info => info.stx
71
+ | .ofCommandInfo info => info.stx
72
+ | .ofMacroExpansionInfo info => info.stx
73
+ | .ofOptionInfo info => info.stx
74
+ | .ofFieldInfo info => info.stx
75
+ | .ofCompletionInfo info => info.stx
76
+ | .ofUserWidgetInfo info => info.stx
77
+ | .ofCustomInfo info => info.stx
78
+ | .ofFVarAliasInfo _ => none
79
+ | .ofFieldRedeclInfo info => info.stx
80
+ | .ofOmissionInfo info => info.stx
81
+
82
+ /-- Is the `Syntax` for this `Lean.Elab.Info` original, or synthetic? -/
83
+ def isOriginal (i : Info) : Bool :=
84
+ match i.stx? with
85
+ | none => true -- Somewhat unclear what to do with `FVarAliasInfo`, so be conservative.
86
+ | some stx => match stx.getHeadInfo with
87
+ | .original .. => true
88
+ | _ => false
89
+
90
+ end Lean.Elab.Info
91
+ namespace Lean.Elab.TacticInfo
92
+
93
+ /-- Find the name for the outermost `Syntax` in this `TacticInfo`. -/
94
+ def name? (t : TacticInfo) : Option Name :=
95
+ match t.stx with
96
+ | Syntax.node _ n _ => some n
97
+ | _ => none
98
+
99
+ /-- Decide whether a tactic is "substantive",
100
+ or is merely a tactic combinator (e.g. `by`, `;`, multiline tactics, parenthesized tactics). -/
101
+ def isSubstantive (t : TacticInfo) : Bool :=
102
+ match t.name? with
103
+ | none => false
104
+ | some `null => false
105
+ | some ``cdot => false
106
+ | some ``cdotTk => false
107
+ | some ``Lean.Parser.Term.byTactic => false
108
+ | some ``Lean.Parser.Tactic.tacticSeq => false
109
+ | some ``Lean.Parser.Tactic.tacticSeq1Indented => false
110
+ | some ``Lean.Parser.Tactic.«tactic_<;>_» => false
111
+ | some ``Lean.Parser.Tactic.paren => false
112
+ | _ => true
113
+
114
+ end Lean.Elab.TacticInfo
115
+
116
+ namespace Lean.Elab.InfoTree
117
+
118
+ /--
119
+ Keep `.node` nodes and `.hole` nodes satisfying predicates.
120
+
121
+ Returns a `List InfoTree`, although in most situations this will be a singleton.
122
+ -/
123
+ partial def filter (p : Info → Bool) (m : MVarId → Bool := fun _ => false) :
124
+ InfoTree → List InfoTree
125
+ | .context ctx tree => tree.filter p m |>.map (.context ctx)
126
+ | .node info children =>
127
+ if p info then
128
+ [.node info (children.toList.map (filter p m)).join.toPArray']
129
+ else
130
+ (children.toList.map (filter p m)).join
131
+ | .hole mvar => if m mvar then [.hole mvar] else []
132
+
133
+ /-- Discard all nodes besides `.context` nodes and `TacticInfo` nodes. -/
134
+ partial def retainTacticInfo (tree : InfoTree) : List InfoTree :=
135
+ tree.filter fun | .ofTacticInfo _ => true | _ => false
136
+
137
+ /-- Retain only nodes with "original" syntax. -/
138
+ partial def retainOriginal (tree : InfoTree) : List InfoTree :=
139
+ tree.filter Info.isOriginal
140
+
141
+ /-- Discard all TacticInfo nodes that are tactic combinators or structuring tactics. -/
142
+ -- There is considerable grey area here: what to do with `classical`?
143
+ partial def retainSubstantive (tree : InfoTree) : List InfoTree :=
144
+ tree.filter fun | .ofTacticInfo i => i.isSubstantive | _ => true
145
+
146
+ /-- Analogue of `Lean.Elab.InfoTree.findInfo?`, but that returns all results. -/
147
+ partial def findAllInfo (t : InfoTree) (ctx? : Option ContextInfo) (p : Info → Bool) :
148
+ List (Info × Option ContextInfo) :=
149
+ match t with
150
+ | context ctx t => t.findAllInfo (ctx.mergeIntoOuter? ctx?) p
151
+ | node i ts =>
152
+ let info := if p i then [(i, ctx?)] else []
153
+ let rest := ts.toList.bind (fun t => t.findAllInfo ctx? p)
154
+ info ++ rest
155
+ | _ => []
156
+
157
+ /-- Return all `TacticInfo` nodes in an `InfoTree` with "original" syntax,
158
+ each equipped with its relevant `ContextInfo`. -/
159
+ def findTacticNodes (t : InfoTree) : List (TacticInfo × ContextInfo) :=
160
+ let infos := t.findAllInfo none fun i => match i with
161
+ | .ofTacticInfo i' => i.isOriginal && i'.isSubstantive
162
+ | _ => false
163
+ infos.filterMap fun p => match p with
164
+ | (.ofTacticInfo i, some ctx) => (i, ctx)
165
+ | _ => none
166
+
167
+ /-- Return all `TacticInfo` nodes in an `InfoTree`
168
+ corresponding to explicit invocations of the `sorry` tactic,
169
+ each equipped with its relevant `ContextInfo`. -/
170
+ def findSorryTacticNodes (t : InfoTree) : List (TacticInfo × ContextInfo) :=
171
+ let infos := t.findAllInfo none fun i => match i with
172
+ | .ofTacticInfo i => i.stx.isSorryTactic && !i.goalsBefore.isEmpty
173
+ | _ => false
174
+ infos.filterMap fun p => match p with
175
+ | (.ofTacticInfo i, some ctx) => (i, ctx)
176
+ | _ => none
177
+
178
+ /-- Return all `TermInfo` nodes in an `InfoTree`
179
+ corresponding to explicit `sorry` terms,
180
+ each equipped with its relevant `ContextInfo`. -/
181
+ def findSorryTermNodes (t : InfoTree) : List (TermInfo × ContextInfo) :=
182
+ let infos := t.findAllInfo none fun i => match i with
183
+ | .ofTermInfo i => i.stx.isSorryTerm
184
+ | _ => false
185
+ infos.filterMap fun p => match p with
186
+ | (.ofTermInfo i, some ctx) => (i, ctx)
187
+ | _ => none
188
+
189
+ inductive SorryType
190
+ | tactic : MVarId → SorryType
191
+ | term : LocalContext → Option Expr → SorryType
192
+ deriving Inhabited
193
+
194
+ /--
195
+ Finds all appearances of `sorry` in an `InfoTree`, reporting
196
+ * the `ContextInfo` at that point,
197
+ * the `MVarId` for a goal that was closed by `sorry`,
198
+ or the `Option Expr` expected type for a term supplied by `sorry`
199
+ * and the start and end positions of the `sorry` in the file.
200
+ -/
201
+ def sorries (t : InfoTree) : List (ContextInfo × SorryType × Position × Position) :=
202
+ (t.findSorryTacticNodes.map fun ⟨i, ctx⟩ =>
203
+ -- HACK: creating a child ngen
204
+ ({ ctx with mctx := i.mctxBefore, ngen := ctx.ngen.mkChild.1 }, .tactic i.goalsBefore.head!,
205
+ stxRange ctx.fileMap i.stx)) ++
206
+ (t.findSorryTermNodes.map fun ⟨i, ctx⟩ =>
207
+ (ctx, .term i.lctx i.expectedType?, stxRange ctx.fileMap i.stx))
208
+
209
+ def tactics (t : InfoTree) : List (ContextInfo × Syntax × List MVarId × Position × Position) :=
210
+ (t.findTacticNodes.map fun ⟨i, ctx⟩ =>
211
+ -- HACK: creating a child ngen
212
+ ({ ctx with mctx := i.mctxBefore, ngen := ctx.ngen.mkChild.1 }, i.stx, i.goalsBefore,
213
+ stxRange ctx.fileMap i.stx))
214
+
215
+
216
+ end Lean.Elab.InfoTree
217
+
218
+ namespace Lean.Elab.TacticInfo
219
+
220
+ /-- Return the range of the tactic, as a pair of file positions. -/
221
+ def range (info : TacticInfo) (ctx : ContextInfo) : Position × Position := ctx.fileMap.stxRange info.stx
222
+
223
+ /-- Pretty print a tactic. -/
224
+ def pp (info : TacticInfo) (ctx : ContextInfo) : IO Format :=
225
+ ctx.runMetaM {} try
226
+ Lean.PrettyPrinter.ppTactic ⟨info.stx⟩
227
+ catch _ =>
228
+ pure "<failed to pretty print>"
229
+
230
+ open Meta
231
+
232
+ /-- Run a tactic on the goals stored in a `TacticInfo`. -/
233
+ def runMetaMGoalsBefore (info : TacticInfo) (ctx : ContextInfo) (x : List MVarId → MetaM α) : IO α := do
234
+ ctx.runMetaM {} <| Meta.withMCtx info.mctxBefore <| x info.goalsBefore
235
+
236
+ /-- Run a tactic on the after goals stored in a `TacticInfo`. -/
237
+ def runMetaMGoalsAfter (info : TacticInfo) (ctx : ContextInfo) (x : List MVarId → MetaM α) : IO α := do
238
+ ctx.runMetaM {} <| Meta.withMCtx info.mctxAfter <| x info.goalsAfter
239
+
240
+ /-- Run a tactic on the main goal stored in a `TacticInfo`. -/
241
+ def runMetaM (info : TacticInfo) (ctx : ContextInfo) (x : MVarId → MetaM α) : IO α := do
242
+ match info.goalsBefore.head? with
243
+ | none => throw <| IO.userError s!"No goals at {← info.pp ctx}"
244
+ | some g => info.runMetaMGoalsBefore ctx fun _ => do g.withContext <| x g
245
+
246
+ def mainGoal (info : TacticInfo) (ctx : ContextInfo) : IO Expr :=
247
+ info.runMetaM ctx (fun g => do instantiateMVars (← g.getType))
248
+
249
+ def formatMainGoal (info : TacticInfo) (ctx : ContextInfo) : IO Format :=
250
+ info.runMetaM ctx (fun g => do ppExpr (← instantiateMVars (← g.getType)))
251
+
252
+ def goalState (info : TacticInfo) (ctx : ContextInfo) : IO (List Format) := do
253
+ info.runMetaMGoalsBefore ctx (fun gs => gs.mapM fun g => do Meta.ppGoal g)
254
+
255
+ def goalStateAfter (info : TacticInfo) (ctx : ContextInfo) : IO (List Format) := do
256
+ info.runMetaMGoalsAfter ctx (fun gs => gs.mapM fun g => do Meta.ppGoal g)
257
+
258
+ def ppExpr (info : TacticInfo) (ctx : ContextInfo) (e : Expr) : IO Format :=
259
+ info.runMetaM ctx (fun _ => do Meta.ppExpr (← instantiateMVars e))
260
+
261
+ end Lean.Elab.TacticInfo
262
+
263
+ namespace Lean.Elab.InfoTree
264
+
265
+ /--
266
+ Finds all tactic invocations in an `InfoTree`,
267
+ ignoring structuring tactics (e.g. `by`, `;`, multiline tactics, parenthesized tactics).
268
+ -/
269
+ def substantiveTactics (t : InfoTree) : List (TacticInfo × ContextInfo) :=
270
+ t.findTacticNodes.filter fun i => i.1.isSubstantive
271
+
272
+ end Lean.Elab.InfoTree
REPL/Lean/InfoTree/ToJson.lean ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import REPL.Lean.InfoTree
2
+ import REPL.Lean.ContextInfo
3
+
4
+ /-!
5
+ # Exporting an `InfoTree` as Json
6
+
7
+ -/
8
+
9
+ namespace Lean.Elab
10
+
11
+ structure InfoTreeNode (α : Type) where
12
+ kind : String
13
+ node : Option α
14
+ children : List Json
15
+ deriving ToJson
16
+
17
+ deriving instance ToJson for Lean.Position
18
+
19
+ structure Syntax.Range where
20
+ synthetic : Bool
21
+ start : Lean.Position
22
+ finish : Lean.Position
23
+ deriving ToJson
24
+
25
+ structure Syntax.Json where
26
+ pp : Option String
27
+ -- raw : String
28
+ range : Range
29
+ deriving ToJson
30
+
31
+ def _root_.Lean.Syntax.toRange (stx : Syntax) (ctx : ContextInfo) : Syntax.Range :=
32
+ let pos := stx.getPos?.getD 0
33
+ let endPos := stx.getTailPos?.getD pos
34
+ { start := ctx.fileMap.toPosition pos
35
+ finish := ctx.fileMap.toPosition endPos
36
+ synthetic := match stx.getHeadInfo with
37
+ | .original .. => false
38
+ | _ => true }
39
+
40
+ def _root_.Lean.Syntax.toJson (stx : Syntax) (ctx : ContextInfo) (lctx : LocalContext) : IO Syntax.Json := do
41
+ return {
42
+ pp := match (← ctx.ppSyntax lctx stx).pretty with
43
+ | "failed to pretty print term (use 'set_option pp.rawOnError true' for raw representation)" => none
44
+ | pp => some pp
45
+ -- raw := toString stx
46
+ range := stx.toRange ctx }
47
+
48
+ structure TacticInfo.Json where
49
+ name : Option Name
50
+ stx : Syntax.Json
51
+ goalsBefore : List String
52
+ goalsAfter : List String
53
+ deriving ToJson
54
+
55
+ -- Note: this is not responsible for converting the children to Json.
56
+ def TacticInfo.toJson (i : TacticInfo) (ctx : ContextInfo) : IO TacticInfo.Json := do
57
+ return {
58
+ name := i.name?
59
+ stx :=
60
+ { pp := Format.pretty (← i.pp ctx),
61
+ -- raw := toString i.info.stx,
62
+ range := i.stx.toRange ctx },
63
+ goalsBefore := (← i.goalState ctx).map Format.pretty,
64
+ goalsAfter := (← i.goalStateAfter ctx).map Format.pretty }
65
+
66
+ structure CommandInfo.Json where
67
+ elaborator : Option Name
68
+ stx : Syntax.Json
69
+ deriving ToJson
70
+
71
+ def CommandInfo.toJson (info : CommandInfo) (ctx : ContextInfo) : IO CommandInfo.Json := do
72
+ return {
73
+ elaborator := match info.elaborator with | .anonymous => none | n => some n,
74
+ stx := ← info.stx.toJson ctx {} }
75
+
76
+ structure TermInfo.Json where
77
+ elaborator : Option Name
78
+ stx : Syntax.Json
79
+ expectedType? : Option String
80
+ expr : String
81
+ isBinder : Bool
82
+ deriving ToJson
83
+
84
+ def TermInfo.toJson (info : TermInfo) (ctx : ContextInfo) : IO TermInfo.Json := do
85
+ return {
86
+ elaborator := match info.elaborator with | .anonymous => none | n => some n,
87
+ stx := ← info.stx.toJson ctx info.lctx,
88
+ expectedType? := ← info.expectedType?.mapM fun ty => do
89
+ pure (← ctx.ppExpr info.lctx ty).pretty
90
+ expr := (← ctx.ppExpr info.lctx info.expr).pretty
91
+ isBinder := info.isBinder }
92
+
93
+ structure InfoTree.HoleJson where
94
+ goalState : String
95
+ deriving ToJson
96
+
97
+ partial def InfoTree.toJson (t : InfoTree) (ctx? : Option ContextInfo) : IO Json := do
98
+ match t with
99
+ | .context ctx t => t.toJson (ctx.mergeIntoOuter? ctx?)
100
+ | .node info children =>
101
+ if let some ctx := ctx? then
102
+ let node : Option Json ← match info with
103
+ | .ofTermInfo info => some <$> (do pure <| Lean.toJson (← info.toJson ctx))
104
+ | .ofCommandInfo info => some <$> (do pure <| Lean.toJson (← info.toJson ctx))
105
+ | .ofTacticInfo info => some <$> (do pure <| Lean.toJson (← info.toJson ctx))
106
+ | _ => pure none
107
+ return Lean.toJson (InfoTreeNode.mk info.kind node (← children.toList.mapM fun t' => t'.toJson ctx))
108
+ else throw <| IO.userError "No `ContextInfo` available."
109
+ | .hole mvarId =>
110
+ if let some ctx := ctx? then
111
+ return Lean.toJson (InfoTree.HoleJson.mk (← ctx.runMetaM {} (do Meta.ppGoal mvarId)).pretty)
112
+ else throw <| IO.userError "No `ContextInfo` available."
113
+
114
+ end Lean.Elab
REPL/Main.lean ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2023 Scott Morrison. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Scott Morrison
5
+ -/
6
+ import REPL.JSON
7
+ import REPL.Frontend
8
+ import REPL.Util.Path
9
+ import REPL.Lean.ContextInfo
10
+ import REPL.Lean.Environment
11
+ import REPL.Lean.InfoTree
12
+ import REPL.Lean.InfoTree.ToJson
13
+ import REPL.Snapshots
14
+
15
+ /-!
16
+ # A REPL for Lean.
17
+
18
+ Communicates via JSON on stdin and stdout. Commands should be separated by blank lines.
19
+
20
+ Commands may be of the form
21
+ ```
22
+ { "cmd" : "import Mathlib.Data.List.Basic\ndef f := 2" }
23
+ ```
24
+ or
25
+ ```
26
+ { "cmd" : "example : f = 2 := rfl", "env" : 3 }
27
+ ```
28
+
29
+ The `env` field, if present,
30
+ must contain a number received in the `env` field of a previous response,
31
+ and causes the command to be run in the existing environment.
32
+
33
+ If there is no `env` field, a new environment is created.
34
+
35
+ You can only use `import` commands when you do not specify the `env` field.
36
+
37
+ You can backtrack simply by using earlier values for `env`.
38
+
39
+ The results are of the form
40
+ ```
41
+ {"sorries":
42
+ [{"pos": {"line": 1, "column": 18},
43
+ "endPos": {"line": 1, "column": 23},
44
+ "goal": "\n⊢ Nat"}],
45
+ "messages":
46
+ [{"severity": "error",
47
+ "pos": {"line": 1, "column": 23},
48
+ "endPos": {"line": 1, "column": 26},
49
+ "data":
50
+ "type mismatch\n rfl\nhas type\n f = f : Prop\nbut is expected to have type\n f = 2 : Prop"}],
51
+ "env": 6}
52
+ ```
53
+ showing any messages generated, or sorries with their goal states.
54
+ Information is generated for tactic mode sorries, but not for term mode sorries.
55
+ -/
56
+
57
+ open Lean Elab
58
+
59
+ namespace REPL
60
+
61
+ /-- The monadic state for the Lean REPL. -/
62
+ structure State where
63
+ /--
64
+ Environment snapshots after complete declarations.
65
+ The user can run a declaration in a given environment using `{"cmd": "def f := 37", "env": 17}`.
66
+ -/
67
+ cmdStates : Array CommandSnapshot := #[]
68
+ /--
69
+ Proof states after individual tactics.
70
+ The user can run a tactic in a given proof state using `{"tactic": "exact 42", "proofState": 5}`.
71
+ Declarations with containing `sorry` record a proof state at each sorry,
72
+ and report the numerical index for the recorded state at each sorry.
73
+ -/
74
+ proofStates : Array ProofSnapshot := #[]
75
+
76
+ /--
77
+ The Lean REPL monad.
78
+
79
+ We only use this with `m := IO`, but it is set up as a monad transformer for flexibility.
80
+ -/
81
+ abbrev M (m : Type → Type) := StateT State m
82
+
83
+ variable [Monad m] [MonadLiftT IO m]
84
+
85
+ /-- Record an `CommandSnapshot` into the REPL state, returning its index for future use. -/
86
+ def recordCommandSnapshot (state : CommandSnapshot) : M m Nat := do
87
+ let id := (← get).cmdStates.size
88
+ modify fun s => { s with cmdStates := s.cmdStates.push state }
89
+ return id
90
+
91
+ /-- Record a `ProofSnapshot` into the REPL state, returning its index for future use. -/
92
+ def recordProofSnapshot (proofState : ProofSnapshot) : M m Nat := do
93
+ let id := (← get).proofStates.size
94
+ modify fun s => { s with proofStates := s.proofStates.push proofState }
95
+ return id
96
+
97
+ def sorries (trees : List InfoTree) (env? : Option Environment) : M m (List Sorry) :=
98
+ trees.bind InfoTree.sorries |>.mapM
99
+ fun ⟨ctx, g, pos, endPos⟩ => do
100
+ let (goal, proofState) ← match g with
101
+ | .tactic g => do
102
+ let s ← ProofSnapshot.create ctx none env? [g]
103
+ pure ("\n".intercalate <| (← s.ppGoals).map fun s => s!"{s}", some s)
104
+ | .term lctx (some t) => do
105
+ let s ← ProofSnapshot.create ctx lctx env? [] [t]
106
+ pure ("\n".intercalate <| (← s.ppGoals).map fun s => s!"{s}", some s)
107
+ | .term _ none => unreachable!
108
+ let proofStateId ← proofState.mapM recordProofSnapshot
109
+ return Sorry.of goal pos endPos proofStateId
110
+
111
+ def ppTactic (ctx : ContextInfo) (stx : Syntax) : IO Format :=
112
+ ctx.runMetaM {} try
113
+ Lean.PrettyPrinter.ppTactic ⟨stx⟩
114
+ catch _ =>
115
+ pure "<failed to pretty print>"
116
+
117
+ def tactics (trees : List InfoTree) : M m (List Tactic) :=
118
+ trees.bind InfoTree.tactics |>.mapM
119
+ fun ⟨ctx, stx, goals, pos, endPos⟩ => do
120
+ let proofState := some (← ProofSnapshot.create ctx none none goals)
121
+ let goals := s!"{(← ctx.ppGoals goals)}".trim
122
+ let tactic := Format.pretty (← ppTactic ctx stx)
123
+ let proofStateId ← proofState.mapM recordProofSnapshot
124
+ return Tactic.of goals tactic pos endPos proofStateId
125
+
126
+ /-- Record a `ProofSnapshot` and generate a JSON response for it. -/
127
+ def createProofStepReponse (proofState : ProofSnapshot) (old? : Option ProofSnapshot := none) :
128
+ M m ProofStepResponse := do
129
+ let messages := proofState.newMessages old?
130
+ let messages ← messages.mapM fun m => Message.of m
131
+ let traces ← proofState.newTraces old?
132
+ let trees := proofState.newInfoTrees old?
133
+ let trees ← match old? with
134
+ | some old => do
135
+ let (ctx, _) ← old.runMetaM do return { ← CommandContextInfo.save with }
136
+ let ctx := PartialContextInfo.commandCtx ctx
137
+ pure <| trees.map fun t => InfoTree.context ctx t
138
+ | none => pure trees
139
+ -- For debugging purposes, sometimes we print out the trees here:
140
+ -- trees.forM fun t => do IO.println (← t.format)
141
+ let sorries ← sorries trees none
142
+ let id ← recordProofSnapshot proofState
143
+ return {
144
+ proofState := id
145
+ goals := (← proofState.ppGoals).map fun s => s!"{s}"
146
+ messages
147
+ sorries
148
+ traces }
149
+
150
+ /-- Pickle a `CommandSnapshot`, generating a JSON response. -/
151
+ def pickleCommandSnapshot (n : PickleEnvironment) : M m (CommandResponse ⊕ Error) := do
152
+ match (← get).cmdStates[n.env]? with
153
+ | none => return .inr ⟨"Unknown environment."⟩
154
+ | some env =>
155
+ discard <| env.pickle n.pickleTo
156
+ return .inl { env := n.env }
157
+
158
+ /-- Unpickle a `CommandSnapshot`, generating a JSON response. -/
159
+ def unpickleCommandSnapshot (n : UnpickleEnvironment) : M IO CommandResponse := do
160
+ let (env, _) ← CommandSnapshot.unpickle n.unpickleEnvFrom
161
+ let env ← recordCommandSnapshot env
162
+ return { env }
163
+
164
+ /-- Pickle a `ProofSnapshot`, generating a JSON response. -/
165
+ -- This generates a new identifier, which perhaps is not what we want?
166
+ def pickleProofSnapshot (n : PickleProofState) : M m (ProofStepResponse ⊕ Error) := do
167
+ match (← get).proofStates[n.proofState]? with
168
+ | none => return .inr ⟨"Unknown proof State."⟩
169
+ | some proofState =>
170
+ discard <| proofState.pickle n.pickleTo
171
+ return .inl (← createProofStepReponse proofState)
172
+
173
+ /-- Unpickle a `ProofSnapshot`, generating a JSON response. -/
174
+ def unpickleProofSnapshot (n : UnpickleProofState) : M IO (ProofStepResponse ⊕ Error) := do
175
+ let (cmdSnapshot?, notFound) ← do match n.env with
176
+ | none => pure (none, false)
177
+ | some i => do match (← get).cmdStates[i]? with
178
+ | some env => pure (some env, false)
179
+ | none => pure (none, true)
180
+ if notFound then
181
+ return .inr ⟨"Unknown environment."⟩
182
+ let (proofState, _) ← ProofSnapshot.unpickle n.unpickleProofStateFrom cmdSnapshot?
183
+ Sum.inl <$> createProofStepReponse proofState
184
+
185
+ /--
186
+ Run a command, returning the id of the new environment, and any messages and sorries.
187
+ -/
188
+ def runCommand (s : Command) : M IO (CommandResponse ⊕ Error) := do
189
+ let (cmdSnapshot?, notFound) ← do match s.env with
190
+ | none => pure (none, false)
191
+ | some i => do match (← get).cmdStates[i]? with
192
+ | some env => pure (some env, false)
193
+ | none => pure (none, true)
194
+ if notFound then
195
+ return .inr ⟨"Unknown environment."⟩
196
+ let initialCmdState? := cmdSnapshot?.map fun c => c.cmdState
197
+ let (cmdState, messages, trees) ← try
198
+ IO.processInput s.cmd initialCmdState?
199
+ catch ex =>
200
+ return .inr ⟨ex.toString⟩
201
+ let messages ← messages.mapM fun m => Message.of m
202
+ -- For debugging purposes, sometimes we print out the trees here:
203
+ -- trees.forM fun t => do IO.println (← t.format)
204
+ let sorries ← sorries trees (initialCmdState?.map (·.env))
205
+ let tactics ← match s.allTactics with
206
+ | some true => tactics trees
207
+ | _ => pure []
208
+ let cmdSnapshot :=
209
+ { cmdState
210
+ cmdContext := (cmdSnapshot?.map fun c => c.cmdContext).getD
211
+ { fileName := "", fileMap := default, tacticCache? := none } }
212
+ let env ← recordCommandSnapshot cmdSnapshot
213
+ let jsonTrees := match s.infotree with
214
+ | some "full" => trees
215
+ | some "tactics" => trees.bind InfoTree.retainTacticInfo
216
+ | some "original" => trees.bind InfoTree.retainTacticInfo |>.bind InfoTree.retainOriginal
217
+ | some "substantive" => trees.bind InfoTree.retainTacticInfo |>.bind InfoTree.retainSubstantive
218
+ | _ => []
219
+ let infotree := if jsonTrees.isEmpty then
220
+ none
221
+ else
222
+ some <| Json.arr (← jsonTrees.toArray.mapM fun t => t.toJson none)
223
+ return .inl
224
+ { env,
225
+ messages,
226
+ sorries,
227
+ tactics
228
+ infotree }
229
+
230
+ def processFile (s : File) : M IO (CommandResponse ⊕ Error) := do
231
+ try
232
+ let cmd ← IO.FS.readFile s.path
233
+ runCommand { s with env := none, cmd }
234
+ catch e =>
235
+ pure <| .inr ⟨e.toString⟩
236
+
237
+ /--
238
+ Run a single tactic, returning the id of the new proof statement, and the new goals.
239
+ -/
240
+ -- TODO detect sorries?
241
+ def runProofStep (s : ProofStep) : M IO (ProofStepResponse ⊕ Error) := do
242
+ match (← get).proofStates[s.proofState]? with
243
+ | none => return .inr ⟨"Unknown proof state."⟩
244
+ | some proofState =>
245
+ try
246
+ let proofState' ← proofState.runString s.tactic
247
+ return .inl (← createProofStepReponse proofState' proofState)
248
+ catch ex =>
249
+ return .inr ⟨"Lean error:\n" ++ ex.toString⟩
250
+
251
+ end REPL
252
+
253
+ open REPL
254
+
255
+ /-- Get lines from stdin until a blank line is entered. -/
256
+ partial def getLines : IO String := do
257
+ let line ← (← IO.getStdin).getLine
258
+ if line.trim.isEmpty then
259
+ return line
260
+ else
261
+ return line ++ (← getLines)
262
+
263
+ instance [ToJson α] [ToJson β] : ToJson (α ⊕ β) where
264
+ toJson x := match x with
265
+ | .inl a => toJson a
266
+ | .inr b => toJson b
267
+
268
+ /-- Commands accepted by the REPL. -/
269
+ inductive Input
270
+ | command : REPL.Command → Input
271
+ | file : REPL.File → Input
272
+ | proofStep : REPL.ProofStep → Input
273
+ | pickleEnvironment : REPL.PickleEnvironment → Input
274
+ | unpickleEnvironment : REPL.UnpickleEnvironment → Input
275
+ | pickleProofSnapshot : REPL.PickleProofState → Input
276
+ | unpickleProofSnapshot : REPL.UnpickleProofState → Input
277
+
278
+ /-- Parse a user input string to an input command. -/
279
+ def parse (query : String) : IO Input := do
280
+ let json := Json.parse query
281
+ match json with
282
+ | .error e => throw <| IO.userError <| toString <| toJson <|
283
+ (⟨"Could not parse JSON:\n" ++ e⟩ : Error)
284
+ | .ok j => match fromJson? j with
285
+ | .ok (r : REPL.ProofStep) => return .proofStep r
286
+ | .error _ => match fromJson? j with
287
+ | .ok (r : REPL.PickleEnvironment) => return .pickleEnvironment r
288
+ | .error _ => match fromJson? j with
289
+ | .ok (r : REPL.UnpickleEnvironment) => return .unpickleEnvironment r
290
+ | .error _ => match fromJson? j with
291
+ | .ok (r : REPL.PickleProofState) => return .pickleProofSnapshot r
292
+ | .error _ => match fromJson? j with
293
+ | .ok (r : REPL.UnpickleProofState) => return .unpickleProofSnapshot r
294
+ | .error _ => match fromJson? j with
295
+ | .ok (r : REPL.Command) => return .command r
296
+ | .error _ => match fromJson? j with
297
+ | .ok (r : REPL.File) => return .file r
298
+ | .error e => throw <| IO.userError <| toString <| toJson <|
299
+ (⟨"Could not parse as a valid JSON command:\n" ++ e⟩ : Error)
300
+
301
+ /-- Read-eval-print loop for Lean. -/
302
+ unsafe def repl : IO Unit :=
303
+ StateT.run' loop {}
304
+ where loop : M IO Unit := do
305
+ let query ← getLines
306
+ if query = "" then
307
+ return ()
308
+ if query.startsWith "#" || query.startsWith "--" then loop else
309
+ IO.println <| toString <| ← match ← parse query with
310
+ | .command r => return toJson (← runCommand r)
311
+ | .file r => return toJson (← processFile r)
312
+ | .proofStep r => return toJson (← runProofStep r)
313
+ | .pickleEnvironment r => return toJson (← pickleCommandSnapshot r)
314
+ | .unpickleEnvironment r => return toJson (← unpickleCommandSnapshot r)
315
+ | .pickleProofSnapshot r => return toJson (← pickleProofSnapshot r)
316
+ | .unpickleProofSnapshot r => return toJson (← unpickleProofSnapshot r)
317
+ IO.println "" -- easier to parse the output if there are blank lines
318
+ loop
319
+
320
+ /-- Main executable function, run as `lake exe repl`. -/
321
+ unsafe def main (_ : List String) : IO Unit := do
322
+ initSearchPath (← Lean.findSysroot)
323
+ repl
REPL/Snapshots.lean ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2023 Lean FRO, LLC. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Scott Morrison
5
+ -/
6
+ import Lean.Replay
7
+ import Lean.Elab.Command
8
+ import REPL.Util.Pickle
9
+
10
+ open Lean Elab
11
+
12
+ namespace Lean.Elab.Command
13
+
14
+ @[inline] def CommandElabM.run (x : CommandElabM α) (ctx : Context) (s : State) : EIO Exception (α × State) :=
15
+ (x ctx).run s
16
+
17
+ @[inline] def CommandElabM.run' (x : CommandElabM α) (ctx : Context) (s : State) : EIO Exception α :=
18
+ Prod.fst <$> x.run ctx s
19
+
20
+ @[inline] def CommandElabM.toIO (x : CommandElabM α) (ctx : Context) (s : State) : IO (α × State) := do
21
+ match (← (x.run ctx s).toIO') with
22
+ | Except.error (Exception.error _ msg) => throw <| IO.userError (← msg.toString)
23
+ | Except.error (Exception.internal id _) => throw <| IO.userError <| "internal exception #" ++ toString id.idx
24
+ | Except.ok a => return a
25
+
26
+ end Lean.Elab.Command
27
+
28
+ namespace REPL
29
+
30
+ /--
31
+ Bundled structure for the `State` and `Context` objects
32
+ for the `CommandElabM` monad.
33
+ -/
34
+ structure CommandSnapshot where
35
+ cmdState : Command.State
36
+ cmdContext : Command.Context
37
+
38
+ namespace CommandSnapshot
39
+
40
+ open Lean.Elab.Command
41
+
42
+ /-- A copy of `Command.State` with the `Environment`, caches, and logging omitted. -/
43
+ structure CompactableCommandSnapshot where
44
+ -- env : Environment
45
+ scopes : List Scope := [{ header := "" }]
46
+ nextMacroScope : Nat := firstFrontendMacroScope + 1
47
+ maxRecDepth : Nat
48
+ nextInstIdx : Nat := 1 -- for generating anonymous instance names
49
+ ngen : NameGenerator := {}
50
+ -- infoState : InfoState := {}
51
+ -- traceState : TraceState := {}
52
+ -- messages : MessageLog := {}
53
+
54
+ open System (FilePath)
55
+
56
+ /--
57
+ Run a `CommandElabM` monadic function in the current `ProofSnapshot`,
58
+ updating the `Command.State`.
59
+ -/
60
+ def runCommandElabM (p : CommandSnapshot) (t : CommandElabM α) : IO (α × CommandSnapshot) := do
61
+ let (a, cmdState) ← (CommandElabM.toIO · p.cmdContext p.cmdState) do t
62
+ return (a, { p with cmdState })
63
+
64
+
65
+ /--
66
+ Pickle a `CommandSnapshot`, discarding closures and non-essential caches.
67
+
68
+ When pickling the `Environment`, we do so relative to its imports.
69
+ -/
70
+ def pickle (p : CommandSnapshot) (path : FilePath) : IO Unit := do
71
+ let env := p.cmdState.env
72
+ let p' := { p with cmdState := { p.cmdState with env := ← mkEmptyEnvironment }}
73
+ _root_.pickle path
74
+ (env.header.imports,
75
+ env.constants.map₂,
76
+ ({ p'.cmdState with } : CompactableCommandSnapshot),
77
+ p'.cmdContext)
78
+
79
+ /--
80
+ Unpickle a `CommandSnapshot`.
81
+ -/
82
+ def unpickle (path : FilePath) : IO (CommandSnapshot × CompactedRegion) := unsafe do
83
+ let ((imports, map₂, cmdState, cmdContext), region) ←
84
+ _root_.unpickle (Array Import × PHashMap Name ConstantInfo × CompactableCommandSnapshot ×
85
+ Command.Context) path
86
+ let env ← (← importModules imports {} 0).replay (HashMap.ofList map₂.toList)
87
+ let p' : CommandSnapshot :=
88
+ { cmdState := { cmdState with env }
89
+ cmdContext }
90
+ let (_, p'') ← p'.runCommandElabM do
91
+ for o in ← getOpenDecls do
92
+ if let .simple ns _ := o then do
93
+ activateScoped ns
94
+ return (p'', region)
95
+
96
+ end CommandSnapshot
97
+
98
+ /--
99
+ Bundled structure for the `State` and `Context` objects
100
+ for the `CoreM`, `MetaM`, `TermElabM`, and `TacticM` monads.
101
+ -/
102
+ structure ProofSnapshot where
103
+ coreState : Core.State
104
+ coreContext : Core.Context
105
+ metaState : Meta.State
106
+ metaContext : Meta.Context
107
+ termState : Term.State
108
+ termContext : Term.Context
109
+ tacticState : Tactic.State
110
+ tacticContext : Tactic.Context
111
+
112
+ namespace ProofSnapshot
113
+
114
+ open Lean Elab Tactic
115
+
116
+ /-- New messages in a `ProofSnapshot`, relative to an optional previous `ProofSnapshot`. -/
117
+ def newMessages (new : ProofSnapshot) (old? : Option ProofSnapshot := none) : List Lean.Message :=
118
+ match old? with
119
+ | none => new.coreState.messages.msgs.toList
120
+ | some old => new.coreState.messages.msgs.toList.drop (old.coreState.messages.msgs.size)
121
+
122
+ /-- New info trees in a `ProofSnapshot`, relative to an optional previous `ProofSnapshot`. -/
123
+ def newInfoTrees (new : ProofSnapshot) (old? : Option ProofSnapshot := none) : List InfoTree :=
124
+ let infoState := new.coreState.infoState
125
+ let trees := match old? with
126
+ | none => infoState.trees.toList
127
+ | some old => infoState.trees.toList.drop (old.coreState.infoState.trees.size)
128
+ trees.map fun t => t.substitute infoState.assignment
129
+
130
+ /-- Run a `CoreM` monadic function in the current `ProofSnapshot`, updating the `Core.State`. -/
131
+ def runCoreM (p : ProofSnapshot) (t : CoreM α) : IO (α × ProofSnapshot) := do
132
+ let (a, coreState) ← (Lean.Core.CoreM.toIO · p.coreContext p.coreState) do t
133
+ return (a, { p with coreState })
134
+
135
+ /-- Run a `MetaM` monadic function in the current `ProofSnapshot`, updating the `Meta.State`. -/
136
+ def runMetaM (p : ProofSnapshot) (t : MetaM α) : IO (α × ProofSnapshot) := do
137
+ let ((a, metaState), p') ←
138
+ p.runCoreM (Lean.Meta.MetaM.run (ctx := p.metaContext) (s := p.metaState) do t)
139
+ return (a, { p' with metaState })
140
+
141
+ /-- Run a `TermElabM` monadic function in the current `ProofSnapshot`, updating the `Term.State`. -/
142
+ def runTermElabM (p : ProofSnapshot) (t : TermElabM α) : IO (α × ProofSnapshot) := do
143
+ let ((a, termState), p') ← p.runMetaM (Lean.Elab.Term.TermElabM.run (s := p.termState)
144
+ (do let r ← t; Term.synthesizeSyntheticMVarsNoPostponing; pure r))
145
+ return (a, { p' with termState })
146
+
147
+ /-- Run a `TacticM` monadic function in the current `ProofSnapshot`, updating the `Tactic.State`. -/
148
+ def runTacticM (p : ProofSnapshot) (t : TacticM α) : IO (α × ProofSnapshot) := do
149
+ let ((a, tacticState), p') ← p.runTermElabM (t p.tacticContext |>.run p.tacticState)
150
+ return (a, { p' with tacticState })
151
+
152
+ /--
153
+ Run a `TacticM` monadic function in the current `ProofSnapshot`, updating the `Tactic.State`,
154
+ and discarding the return value.
155
+ -/
156
+ def runTacticM' (p : ProofSnapshot) (t : TacticM α) : IO ProofSnapshot :=
157
+ Prod.snd <$> p.runTacticM t
158
+
159
+ /-- New traces in a `ProofSnapshot`, relative to an optional previous `ProofSnapshot`. -/
160
+ def newTraces (new : ProofSnapshot) (old? : Option ProofSnapshot := none) : IO (List String) :=
161
+ match old? with
162
+ | none => (·.1) <$> new.runCoreM (do
163
+ (← getTraces).toList.mapM fun t => do pure (← t.msg.toString).trim)
164
+ | some old => do
165
+ let oldCount ← (·.1) <$> old.runCoreM (return (← getTraces).size)
166
+ (·.1) <$> new.runCoreM (do
167
+ ((← getTraces).toList.drop oldCount).mapM fun t => do pure (← t.msg.toString).trim)
168
+
169
+ /--
170
+ Evaluate a `Syntax` into a `TacticM` tactic, and run it in the current `ProofSnapshot`.
171
+ -/
172
+ def runSyntax (p : ProofSnapshot) (t : Syntax) : IO ProofSnapshot :=
173
+ Prod.snd <$> p.runTacticM (evalTactic t)
174
+
175
+ /--
176
+ Parse a string into a `Syntax`, evaluate it as a `TacticM` tactic,
177
+ and run it in the current `ProofSnapshot`.
178
+ -/
179
+ def runString (p : ProofSnapshot) (t : String) : IO ProofSnapshot :=
180
+ match Parser.runParserCategory p.coreState.env `tactic t with
181
+ | .error e => throw (IO.userError e)
182
+ | .ok stx => p.runSyntax stx
183
+
184
+ /-- Pretty print the current goals in the `ProofSnapshot`. -/
185
+ def ppGoals (p : ProofSnapshot) : IO (List Format) :=
186
+ Prod.fst <$> p.runMetaM do p.tacticState.goals.mapM (Meta.ppGoal ·)
187
+ /--
188
+ Construct a `ProofSnapshot` from a `ContextInfo` and optional `LocalContext`, and a list of goals.
189
+
190
+ For convenience, we also allow a list of `Expr`s, and these are appended to the goals
191
+ as fresh metavariables with the given types.
192
+ -/
193
+ def create (ctx : ContextInfo) (lctx? : Option LocalContext) (env? : Option Environment)
194
+ (goals : List MVarId) (types : List Expr := []) : IO ProofSnapshot := do
195
+ ctx.runMetaM (lctx?.getD {}) do
196
+ let goals := goals ++ (← types.mapM fun t => Expr.mvarId! <$> Meta.mkFreshExprMVar (some t))
197
+ goals.head!.withContext do
198
+ let s ← getThe Core.State
199
+ let s := match env? with
200
+ | none => s
201
+ | some env => { s with env }
202
+ pure <|
203
+ { coreState := s
204
+ coreContext := ← readThe Core.Context
205
+ metaState := ← getThe Meta.State
206
+ metaContext := ← readThe Meta.Context
207
+ termState := {}
208
+ termContext := {}
209
+ tacticState := { goals }
210
+ tacticContext := { elaborator := .anonymous } }
211
+
212
+ open Lean.Core in
213
+ /-- A copy of `Core.State` with the `Environment`, caches, and logging omitted. -/
214
+ structure CompactableCoreState where
215
+ -- env : Environment
216
+ nextMacroScope : MacroScope := firstFrontendMacroScope + 1
217
+ ngen : NameGenerator := {}
218
+ -- traceState : TraceState := {}
219
+ -- cache : Core.Cache := {}
220
+ -- messages : MessageLog := {}
221
+ -- infoState : Elab.InfoState := {}
222
+
223
+ open Lean.Meta in
224
+ /-- A copy of `Meta.Context` with closures omitted. -/
225
+ structure CompactableMetaContext where
226
+ config : Config := {}
227
+ lctx : LocalContext := {}
228
+ localInstances : LocalInstances := #[]
229
+ defEqCtx? : Option DefEqContext := none
230
+ synthPendingDepth : Nat := 0
231
+ -- canUnfold? : Option (Config → ConstantInfo → CoreM Bool) := none
232
+
233
+ /-- A copy of `Term.Context` with closures and a cache omitted. -/
234
+ structure CompactableTermContext where
235
+ declName? : Option Name := none
236
+ auxDeclToFullName : FVarIdMap Name := {}
237
+ macroStack : MacroStack := []
238
+ mayPostpone : Bool := true
239
+ errToSorry : Bool := true
240
+ autoBoundImplicit : Bool := false
241
+ autoBoundImplicits : PArray Expr := {}
242
+ -- autoBoundImplicitForbidden : Name → Bool := fun _ => false
243
+ sectionVars : NameMap Name := {}
244
+ sectionFVars : NameMap Expr := {}
245
+ implicitLambda : Bool := true
246
+ isNoncomputableSection : Bool := false
247
+ ignoreTCFailures : Bool := false
248
+ inPattern : Bool := false
249
+ -- tacticCache? : Option (IO.Ref Tactic.Cache) := none
250
+ saveRecAppSyntax : Bool := true
251
+ holesAsSyntheticOpaque : Bool := false
252
+
253
+ open System (FilePath)
254
+
255
+ /--
256
+ Pickle a `ProofSnapshot`, discarding closures and non-essential caches.
257
+
258
+ When pickling the `Environment`, we do so relative to its imports.
259
+ -/
260
+ def pickle (p : ProofSnapshot) (path : FilePath) : IO Unit := do
261
+ let env := p.coreState.env
262
+ let p' := { p with coreState := { p.coreState with env := ← mkEmptyEnvironment }}
263
+ _root_.pickle path
264
+ (env.header.imports,
265
+ env.constants.map₂,
266
+ ({ p'.coreState with } : CompactableCoreState),
267
+ p'.coreContext,
268
+ p'.metaState,
269
+ ({ p'.metaContext with } : CompactableMetaContext),
270
+ p'.termState,
271
+ ({ p'.termContext with } : CompactableTermContext),
272
+ p'.tacticState,
273
+ p'.tacticContext)
274
+
275
+ /--
276
+ Unpickle a `ProofSnapshot`.
277
+ -/
278
+ def unpickle (path : FilePath) (cmd? : Option CommandSnapshot) :
279
+ IO (ProofSnapshot × CompactedRegion) := unsafe do
280
+ let ((imports, map₂, coreState, coreContext, metaState, metaContext, termState, termContext,
281
+ tacticState, tacticContext), region) ←
282
+ _root_.unpickle (Array Import × PHashMap Name ConstantInfo × CompactableCoreState ×
283
+ Core.Context × Meta.State × CompactableMetaContext × Term.State × CompactableTermContext ×
284
+ Tactic.State × Tactic.Context) path
285
+ let env ← match cmd? with
286
+ | none =>
287
+ enableInitializersExecution
288
+ (← importModules imports {} 0).replay (HashMap.ofList map₂.toList)
289
+ | some cmd =>
290
+ cmd.cmdState.env.replay (HashMap.ofList map₂.toList)
291
+ let p' : ProofSnapshot :=
292
+ { coreState := { coreState with env }
293
+ coreContext
294
+ metaState
295
+ metaContext := { metaContext with }
296
+ termState
297
+ termContext := { termContext with }
298
+ tacticState
299
+ tacticContext }
300
+ let (_, p'') ← p'.runCoreM do
301
+ for o in ← getOpenDecls do
302
+ if let .simple ns _ := o then
303
+ activateScoped ns
304
+ return (p'', region)
305
+
306
+ end ProofSnapshot
REPL/Util/Path.lean ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2022 Gabriel Ebner. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Gabriel Ebner
5
+ -/
6
+ import Lean
7
+
8
+ -- This has been duplicated from Std4 to avoid a dependency.
9
+
10
+ /-!
11
+ # `compile_time_search_path%` term elaborator.
12
+
13
+ Use this as `searchPathRef.set compile_time_search_path%`.
14
+ -/
15
+
16
+ open Lean System
17
+
18
+ -- Ideally this instance would be constructed simply by `deriving instance ToExpr for FilePath`
19
+ -- but for now we have decided not to upstream the `ToExpr` derive handler from `Mathlib`.
20
+ -- https://leanprover.zulipchat.com/#narrow/stream/348111-std4/topic/ToExpr.20derive.20handler/near/386476438
21
+ instance : ToExpr FilePath where
22
+ toTypeExpr := mkConst ``FilePath
23
+ toExpr path := mkApp (mkConst ``FilePath.mk) (toExpr path.1)
24
+
25
+ /--
26
+ Term elaborator that retrieves the current `SearchPath`.
27
+
28
+ Typical usage is `searchPathRef.set compile_time_search_path%`.
29
+
30
+ This must not be used in files that are potentially compiled on another machine and then
31
+ imported.
32
+ (That is, if used in an imported file it will embed the search path from whichever machine
33
+ compiled the `.olean`.)
34
+ -/
35
+ elab "compile_time_search_path%" : term =>
36
+ return toExpr (← searchPathRef.get)
REPL/Util/Pickle.lean ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2023 Mario Carneiro. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Mario Carneiro
5
+ -/
6
+ import Lean.Environment
7
+
8
+ /-!
9
+ # Pickling and unpickling objects
10
+
11
+ By abusing `saveModuleData` and `readModuleData` we can pickle and unpickle objects to disk.
12
+ -/
13
+
14
+ open Lean System
15
+
16
+ /--
17
+ Save an object to disk.
18
+ If you need to write multiple objects from within a single declaration,
19
+ you will need to provide a unique `key` for each.
20
+ -/
21
+ def pickle {α : Type} (path : FilePath) (x : α) (key : Name := by exact decl_name%) : IO Unit :=
22
+ saveModuleData path key (unsafe unsafeCast x)
23
+
24
+ /--
25
+ Load an object from disk.
26
+ Note: The returned `CompactedRegion` can be used to free the memory behind the value
27
+ of type `α`, using `CompactedRegion.free` (which is only safe once all references to the `α` are
28
+ released). Ignoring the `CompactedRegion` results in the data being leaked.
29
+ Use `withUnpickle` to call `CompactedRegion.free` automatically.
30
+
31
+ This function is unsafe because the data being loaded may not actually have type `α`, and this
32
+ may cause crashes or other bad behavior.
33
+ -/
34
+ unsafe def unpickle (α : Type) (path : FilePath) : IO (α × CompactedRegion) := do
35
+ let (x, region) ← readModuleData path
36
+ pure (unsafeCast x, region)
37
+
38
+ /-- Load an object from disk and run some continuation on it, freeing memory afterwards. -/
39
+ unsafe def withUnpickle [Monad m] [MonadLiftT IO m] {α β : Type}
40
+ (path : FilePath) (f : α → m β) : m β := do
41
+ let (x, region) ← unpickle α path
42
+ let r ← f x
43
+ region.free
44
+ pure r
__pycache__/code.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
__pycache__/code.cpython-39.pyc ADDED
Binary file (1.84 kB). View file
 
__pycache__/openllm_pass_rate_new_test.cpython-39.pyc ADDED
Binary file (7.79 kB). View file
 
all_code.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ # from mariana.repl.pass_rate_new import main
4
+
5
+ def handle():
6
+
7
+ data = json.load(open('pass_rate_results/lean4_basic_test/lean4_random_15k_all.jsonl'))
8
+ # data['results'] += json.load(open('pass_rate_results/gsm8k_train/lean4_random_15k_all.jsonl'))['results']
9
+
10
+
11
+ PROMPT_DICT = {
12
+ "wild": (
13
+ "# Problem:\n{question}\n\n"
14
+ "# Proof:\n{answer}."
15
+ ),
16
+ "lean4": (
17
+ "Statement and proof in natural language:\n\n"
18
+ "{statement_text}\n\n"
19
+ "Translate the statement and proof in natural language to lean4:"
20
+ ),
21
+ "prompt_no_input": (
22
+ "Below is an instruction that describes a task. "
23
+ "Write a response that appropriately completes the request.\n\n"
24
+ "### Instruction:\n{instruction}\n\n### Response:"
25
+ ),
26
+ }
27
+ training_data = []
28
+ ratio = []
29
+ for item in data['results']:
30
+ if item['status'] == 'pass':
31
+ if not len(item['stderr']):
32
+ ratio.append(1)
33
+ # training_data.append(
34
+ # {
35
+ # "statement_poof":item['statement'],
36
+ # "model_response":PROMPT_DICT["wild"].format(question= item['content']['question'], answer = item['content']['answer']),
37
+ # "task": "statementproof_inform",
38
+ # }
39
+ # )
40
+ else:
41
+ ratio.append(0)
42
+ print ( item['stderr'])
43
+
44
+
45
+ # with open("pass_rate_results/combined_lean4_random_15k_all_passed.jsonl", "w") as f:
46
+ # json.dump(training_data, f, ensure_ascii=False, indent=2)
47
+
48
+
49
+ print("false positives: ", 1 - sum(ratio)/len(ratio))
50
+
51
+
52
+ def savetojson():
53
+
54
+ import_statements = '''import algebra.algebra.basic
55
+ import algebra.order.floor
56
+ import algebra.associated
57
+ import algebra.big_operators.basic
58
+ import algebra.big_operators.enat
59
+ import algebra.big_operators.order
60
+ import algebra.big_operators.pi
61
+ import algebra.geom_sum
62
+ import algebra.group.pi
63
+ import algebra.group.commute
64
+ import algebra.group_power.basic
65
+ import algebra.group_power.identities
66
+ import algebra.order.floor
67
+ import algebra.quadratic_discriminant
68
+ import algebra.ring.basic
69
+ import analysis.asymptotics.asymptotic_equivalent
70
+ import analysis.mean_inequalities
71
+ import analysis.normed_space.basic
72
+ import analysis.inner_product_space.basic
73
+ import analysis.inner_product_space.euclidean_dist
74
+ import analysis.normed_space.pi_Lp
75
+ import analysis.special_functions.exp
76
+ import analysis.special_functions.exp_deriv
77
+ import analysis.special_functions.log
78
+ import analysis.special_functions.logb
79
+ import analysis.special_functions.log_deriv
80
+ import analysis.special_functions.pow
81
+ import analysis.special_functions.sqrt
82
+ import analysis.special_functions.trigonometric.basic
83
+ import analysis.special_functions.trigonometric.complex
84
+ import combinatorics.simple_graph.basic
85
+ import data.complex.basic
86
+ import data.complex.exponential
87
+ import data.finset.basic
88
+ import data.fintype.card
89
+ import data.int.basic
90
+ import data.int.gcd
91
+ import data.int.modeq
92
+ import data.int.parity
93
+ import data.list.intervals
94
+ import data.list.palindrome
95
+ import data.multiset.basic
96
+ import data.nat.basic
97
+ import data.nat.choose.basic
98
+ import data.nat.digits
99
+ import data.nat.factorial.basic
100
+ import data.nat.fib
101
+ import data.nat.modeq
102
+ import data.nat.multiplicity
103
+ import data.nat.parity
104
+ import data.nat.prime
105
+ import data.pnat.basic
106
+ import data.pnat.prime
107
+ import data.polynomial
108
+ import data.polynomial.basic
109
+ import data.polynomial.eval
110
+ import data.rat.basic
111
+ import data.real.basic
112
+ import data.real.ennreal
113
+ import data.real.irrational
114
+ import data.real.nnreal
115
+ import data.real.sqrt
116
+ import data.real.golden_ratio
117
+ import data.set.finite
118
+ import data.sym.sym2
119
+ import data.zmod.basic
120
+ import dynamics.fixed_points.basic
121
+ import field_theory.finite.basic
122
+ import geometry.euclidean.basic
123
+ import geometry.euclidean.circumcenter
124
+ import geometry.euclidean.monge_point
125
+ import geometry.euclidean.sphere
126
+ import init.data.nat.gcd
127
+ import linear_algebra.affine_space.affine_map
128
+ import linear_algebra.affine_space.independent
129
+ import linear_algebra.affine_space.ordered
130
+ import linear_algebra.finite_dimensional
131
+ import logic.equiv.basic
132
+ import measure_theory.integral.interval_integral
133
+ import number_theory.arithmetic_function
134
+ import number_theory.legendre_symbol.quadratic_reciprocity
135
+ import number_theory.primes_congruent_one
136
+ import order.bounds
137
+ import order.filter.basic
138
+ import order.well_founded
139
+ import topology.basic
140
+ import topology.instances.nnreal
141
+ '''
142
+
143
+ data = {
144
+ "working_file": import_statements
145
+ }
146
+
147
+ with open('data/notlean_dependency.json', 'w', encoding='utf-8') as f:
148
+ json.dump(data, f, indent=4)
149
+
150
+
151
+ def load_to_atp():
152
+ data_path = 'pass_rate_results/math_train/1/10pass10.jsonl'
153
+ data = json.load(open(data_path, "r", encoding='utf-8'))
154
+ import pdb
155
+ pdb.set_trace()
156
+ if __name__ == '__main__':
157
+ load_to_atp()
158
+ # get_novel_premises()
159
+ # savetojson()
basic_working.json ADDED
The diff for this file is too large to render. See raw diff
 
code.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import subprocess
4
+ import json
5
+
6
+ # Set the directory where your .lean files are located
7
+
8
+ # Get a list of all .lean files in the directory
9
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
10
+ # lean_files = ["test/file.lean"]
11
+ def main(args):
12
+ command_list = []
13
+ for i in range(8):
14
+ with open(f"{args.input_path}/{i}.json", 'r', encoding='utf-8') as rf:
15
+ for line in rf.readlines():
16
+ try:
17
+ json_item = json.loads(line)
18
+ json_item['cmd'] = '\n'.join()
19
+ except:
20
+ import pdb
21
+ pdb.set_trace()
22
+ command_list.append(json_item)
23
+ results = []
24
+ passed = 0
25
+ total = 0
26
+
27
+ for item in command_list:
28
+ data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
29
+ command = 'echo \'%s\' | lake exe repl' % data
30
+
31
+ try:
32
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
33
+ stdout = result.stdout.decode('utf-8')
34
+ stderr = result.stderr.decode('utf-8')
35
+ results.append({
36
+ 'file_path': item['file_path'],
37
+ 'stdout': stdout,
38
+ 'stderr': stderr,
39
+ 'status': 'pass'
40
+ })
41
+ passed += 1
42
+ except subprocess.CalledProcessError as e:
43
+ results.append({
44
+ 'file_path': item['file_path'],
45
+ 'error': str(e),
46
+ 'status': 'nopass'
47
+ })
48
+ total += 1
49
+
50
+ # Calculate pass rate
51
+ pass_rate = passed / total * 100
52
+
53
+ # Save results to a JSON file
54
+ with open('results.json', 'w') as f:
55
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2)
56
+
57
+ if __name__ == '__main__':
58
+ arg_parser = ArgumentParser()
59
+ arg_parser.add_argument('--data_path', type=str,
60
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
61
+ arg_parser.add_argument('--input_path', type=str, default='')
62
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
63
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
64
+ arg_parser.add_argument('--generate_method', type=str,
65
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
66
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
67
+ args = arg_parser.parse_args()
68
+ main(args)
69
+
data/basic_working.json ADDED
The diff for this file is too large to render. See raw diff
 
data/notlean_dependency.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "working_file": "import algebra.algebra.basic\nimport algebra.order.floor\nimport algebra.associated\nimport algebra.big_operators.basic\nimport algebra.big_operators.enat\nimport algebra.big_operators.order\nimport algebra.big_operators.pi\nimport algebra.geom_sum\nimport algebra.group.pi\nimport algebra.group.commute\nimport algebra.group_power.basic\nimport algebra.group_power.identities\nimport algebra.order.floor\nimport algebra.quadratic_discriminant\nimport algebra.ring.basic\nimport analysis.asymptotics.asymptotic_equivalent\nimport analysis.mean_inequalities\nimport analysis.normed_space.basic\nimport analysis.inner_product_space.basic\nimport analysis.inner_product_space.euclidean_dist\nimport analysis.normed_space.pi_Lp\nimport analysis.special_functions.exp\nimport analysis.special_functions.exp_deriv\nimport analysis.special_functions.log\nimport analysis.special_functions.logb\nimport analysis.special_functions.log_deriv\nimport analysis.special_functions.pow\nimport analysis.special_functions.sqrt\nimport analysis.special_functions.trigonometric.basic\nimport analysis.special_functions.trigonometric.complex\nimport combinatorics.simple_graph.basic\nimport data.complex.basic\nimport data.complex.exponential\nimport data.finset.basic\nimport data.fintype.card\nimport data.int.basic\nimport data.int.gcd\nimport data.int.modeq\nimport data.int.parity\nimport data.list.intervals\nimport data.list.palindrome\nimport data.multiset.basic\nimport data.nat.basic\nimport data.nat.choose.basic\nimport data.nat.digits\nimport data.nat.factorial.basic\nimport data.nat.fib\nimport data.nat.modeq\nimport data.nat.multiplicity\nimport data.nat.parity\nimport data.nat.prime\nimport data.pnat.basic\nimport data.pnat.prime\nimport data.polynomial\nimport data.polynomial.basic\nimport data.polynomial.eval\nimport data.rat.basic\nimport data.real.basic\nimport data.real.ennreal\nimport data.real.irrational\nimport data.real.nnreal\nimport data.real.sqrt\nimport data.real.golden_ratio\nimport data.set.finite\nimport data.sym.sym2\nimport data.zmod.basic\nimport dynamics.fixed_points.basic\nimport field_theory.finite.basic\nimport geometry.euclidean.basic\nimport geometry.euclidean.circumcenter\nimport geometry.euclidean.monge_point\nimport geometry.euclidean.sphere\nimport init.data.nat.gcd\nimport linear_algebra.affine_space.affine_map\nimport linear_algebra.affine_space.independent\nimport linear_algebra.affine_space.ordered\nimport linear_algebra.finite_dimensional\nimport logic.equiv.basic\nimport measure_theory.integral.interval_integral\nimport number_theory.arithmetic_function\nimport number_theory.legendre_symbol.quadratic_reciprocity\nimport number_theory.primes_congruent_one\nimport order.bounds\nimport order.filter.basic\nimport order.well_founded\nimport topology.basic\nimport topology.instances.nnreal\n"
3
+ }
gpt_pass_rate_multi_pass.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import subprocess
3
+ import re
4
+
5
+ # Output file
6
+ output_file = "pass_rate_output.txt"
7
+
8
+ # Clearing the output file before appending new content
9
+ with open(output_file, "w") as file:
10
+ file.write("")
11
+
12
+ # List of input paths
13
+ input_path_lists = [
14
+ "test/zero_shot/wild_test/generation/lean4_random_15k_all/2/1/",
15
+ ]
16
+
17
+ def get_output(input_string, k):
18
+ pattern = r"gpt_result/(\w+)/(\w+)"
19
+ match = re.search(pattern, input_string)
20
+ if match:
21
+ part1 = match.group(1)
22
+ part2 = match.group(2)
23
+ result = f"gpt_result/{part2}/{part1}_pass{k}.json"
24
+ print(result)
25
+ return result
26
+ else:
27
+ print("No match found.")
28
+ return None
29
+
30
+ # List of input paths
31
+ input_path_lists = [
32
+ # "gpt_result/lean_basic/gpt4/",
33
+ # "gpt_result/lean_random/gpt4/",
34
+ "gpt_result/wild/gpt4/",
35
+ # "gpt_result/lean_basic/gpt3/",
36
+ # "gpt_result/lean_random/gpt3/",
37
+ "gpt_result/wild/gpt3/",
38
+ ]
39
+
40
+ # Iterate through the input paths and run the command
41
+ for input_path in input_path_lists:
42
+ k = 5
43
+ if "wild" in input_path or "gsm8k_train" in input_path or "math_train" in input_path:
44
+ print(f"wild")
45
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
46
+ command = f"python3 gpt_pass_rate_new_notlean_test.py --input_path {input_path} --output_path {get_output(input_path,k)} --k {k}"
47
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
48
+ print("\n\n",file=open(output_file, "a"))
49
+ else:
50
+ print(f"lean")
51
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
52
+ command = f"python3 gpt_pass_rate_new_test.py --input_path {input_path} --output_path {get_output(input_path, k)} --k {k}"
53
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
54
+ print("\n\n",file=open(output_file, "a"))
gpt_pass_rate_new_notlean_test.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+ import random
10
+
11
+ def wrapped_function(item):
12
+ results = []
13
+ passed = 0
14
+ total = 0
15
+
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_file = os.path.join(temp_dir, f"test.lean")
18
+
19
+ with open(temp_file, "w") as f:
20
+ f.write(item['cmd'])
21
+
22
+ # Rest of the function code...
23
+ # Process the item using the temporary file
24
+ # ...
25
+
26
+ # Clean up the temporary file
27
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
28
+ command = 'echo \'%s\' | lake exe repl' % data
29
+
30
+ try:
31
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32
+ stdout = result.stdout.decode('utf-8')
33
+ stderr = result.stderr.decode('utf-8')
34
+ # stdout = result.stdout.decode('utf-8')
35
+ json_stdout = json.loads(stdout)
36
+ if "messages" not in json_stdout.keys():
37
+ passed += 1
38
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ except subprocess.CalledProcessError as e:
41
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
42
+ results.append({ 'error': str(e), 'status': 'nopass'})
43
+ total += 1
44
+
45
+ pass_rate = passed / (passed + total) * 100
46
+
47
+
48
+ return {'results': results, 'pass_rate': pass_rate}
49
+
50
+ # Set the directory where your .lean files are located
51
+
52
+ # Get a list of all .lean files in the directory
53
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
54
+ # lean_files = ["test/file.lean"]
55
+ def single(command_list, args):
56
+ results = []
57
+ passed = 0
58
+ total = 0
59
+ for item in tqdm(command_list):
60
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
61
+ f.write(item['cmd'])
62
+ data = '{"path": "test/test.lean", "allTactics": true}'
63
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
64
+ command = 'echo \'%s\' | lake exe repl' % data
65
+ try:
66
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
67
+ # stderr=subprocess.PIPE)
68
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
69
+ # stdout = stdout.decode('utf-8')
70
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
71
+ stdout = result.stdout.decode('utf-8')
72
+ json_stdout = json.loads(stdout)
73
+ if "messages" not in json_stdout.keys():
74
+ passed += 1
75
+ stderr = result.stderr.decode('utf-8')
76
+ results.append({
77
+ # 'item': item['content'],
78
+ 'stdout': stdout,
79
+ 'stderr': stderr,
80
+ 'status': 'pass'
81
+ })
82
+ except subprocess.CalledProcessError as e:
83
+ results.append({
84
+ # 'item': item['content'],
85
+ 'error': str(e),
86
+ 'status': 'nopass'
87
+ })
88
+ total += 1
89
+
90
+ # Calculate pass rate
91
+ pass_rate = passed / total * 100
92
+ print(pass_rate)
93
+
94
+ # Save results to a JSON file
95
+ with open('results.json', 'w') as f:
96
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
97
+
98
+
99
+ def multi(command_list, output_path, k ):
100
+ results = []
101
+ passed = 0
102
+ total = 0
103
+ def execute_command(item, index):
104
+ temp_dir = '/opt/jianqiao'
105
+ def filter_json(json_data):
106
+ filtered_data = {}
107
+ for key in json_data.keys():
108
+ if key in ['question', 'answer', 'total output', 'results']:
109
+ filtered_data[key] = json_data[key]
110
+ return filtered_data
111
+ # result_dict = filter_json(item)
112
+ result_dict = item
113
+ result_dict['results'] = []
114
+
115
+ for i, cmd in enumerate(item['cmd']):
116
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
117
+ with open(temp_file, "w") as f:
118
+ f.write(cmd)
119
+
120
+ data = '{"path": "%s", "allTactics": true}' % temp_file
121
+ command = f'echo \'{data}\' | lake exe repl'
122
+
123
+ try:
124
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
125
+ stdout = json.loads(result.stdout.decode('utf-8'))
126
+ stderr = result.stderr.decode('utf-8')
127
+
128
+ except subprocess.TimeoutExpired as e:
129
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
130
+
131
+ except subprocess.CalledProcessError as e:
132
+ result_item = {'error': str(e), 'status': 'nopass_error'}
133
+
134
+ else:
135
+ if "messages" not in stdout and not len(stderr):
136
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
137
+ elif not len(stderr) and "messages" in stdout:
138
+ flag = 0
139
+ for me in stdout['messages']:
140
+ if me['severity'] == 'error':
141
+ flag = 1
142
+ start_line = me['pos']['line'] - 1
143
+ current_column =me['pos']['column'] -1
144
+ for line_n in range(start_line - 1, 0 , -1):
145
+ line_len = len(cmd.split('\n')[line_n])
146
+ current_column += line_len + 1
147
+ if not line_len:
148
+ break
149
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
150
+ break
151
+ if not flag :
152
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
153
+ else:
154
+ assert len(stderr)
155
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
156
+
157
+ result_dict['results'].append(result_item)
158
+ return result_dict
159
+
160
+
161
+ total = len(command_list)
162
+
163
+ with ThreadPoolExecutor(max_workers=128) as executor:
164
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
165
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
166
+ result = future.result()
167
+ results.append(result)
168
+ # if result['status'] == 'pass':
169
+ # passed += 1
170
+
171
+ def calculate_pass(result_list, k):
172
+ pass_1_count = 0
173
+ pass_k_count = 0
174
+
175
+ for result in result_list:
176
+ results = result.get('results', [])
177
+ if results:
178
+ for j in range(min(1, len(results))):
179
+ if results[j].get('status') == 'pass':
180
+ pass_1_count += 1
181
+ break
182
+
183
+ for j in range(min(k, len(results))):
184
+ if results[j].get('status') == 'pass':
185
+ pass_k_count += 1
186
+ break
187
+
188
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
189
+ pass_k = pass_k_count / len(result_list) if result_list else 0
190
+
191
+ return pass_1, pass_k
192
+
193
+ pass_1, pass_k = calculate_pass(results, k)
194
+ print("Pass@1:", pass_1)
195
+ print(f"Pass@{k}:", pass_k)
196
+
197
+ # pass_rate = (passed / total) * 100
198
+ # print(f"total test: {total}")
199
+ # print(f"Pass rate: {pass_rate}%")
200
+
201
+ output_file = f"pass_rate_results/{output_path}"
202
+ # Create the directory if it doesn't exist
203
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
204
+
205
+ with open(f"{output_file}", 'w') as f:
206
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
207
+
208
+ import re
209
+ def remove_simp_pattern_from_end(s):
210
+ pattern = r'@\[simp\s*.*?\]$'
211
+ return re.sub(pattern, '', s)
212
+
213
+
214
+ def get_lean(text):
215
+ content = ""
216
+ try:
217
+ code_block_pattern = r"```lean\s*\n(.*?)\n```"
218
+ code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
219
+ content = "\n\n".join(code_blocks)
220
+ except:
221
+ matches = re.findall(r'```(.*?)```', text, re.DOTALL)
222
+ if len(matches):
223
+ content = "\n\n".join(matches)
224
+ finally:
225
+ if not len(content.strip()):
226
+ try:
227
+ code_block_pattern = r"```lean4\s*\n(.*?)\n```"
228
+ code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
229
+ content = "\n\n".join(code_blocks)
230
+ except:
231
+ content = ''
232
+
233
+ if not len(content.strip()):
234
+ content = "theorem h : f + g = 39 := by exact rfl"
235
+ return content
236
+
237
+ def main(args):
238
+ command_list = []
239
+ # json_filename = 'data/notlean_dependency.json'
240
+ json_filename = 'data/basic_working.json'
241
+
242
+ json_item = json.load(open(json_filename, encoding='utf-8'))
243
+ working_env = json_item['working_file']
244
+
245
+ all_dicts = {}
246
+ with open(f"{args.input_path}/1.jsonl", 'r', encoding='utf-8') as rf:
247
+ for line in rf.readlines():
248
+ try:
249
+ json_item = json.loads(line)
250
+ text = get_lean(json_item['model_response']).split("#align")[0]
251
+ json_item['cmd'] = ['\n\n'.join([working_env, text])]
252
+ all_dicts[json_item['query_id']] = json_item
253
+ assert len(text) > 0
254
+ except:
255
+ import pdb
256
+ pdb.set_trace()
257
+
258
+ file_pattern = os.path.join(args.input_path, '[2-9]*.jsonl')
259
+ for file_path in glob.glob(file_pattern):
260
+ with open(file_path, 'r', encoding='utf-8') as rf:
261
+ for line in rf.readlines():
262
+ try:
263
+ json_item = json.loads(line)
264
+ text = get_lean(json_item['model_response']).split("#align")[0]
265
+ all_dicts[json_item['query_id']]['cmd'].append('\n\n'.join([working_env, text]))
266
+ assert len(text) > 0
267
+ except:
268
+ import pdb
269
+ pdb.set_trace()
270
+ for k, v in all_dicts.items():
271
+ command_list.append(v)
272
+ multi(command_list, args.output_path, args.k)
273
+
274
+ if __name__ == '__main__':
275
+ arg_parser = ArgumentParser()
276
+ arg_parser.add_argument('--data_path', type=str,
277
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
278
+ arg_parser.add_argument('--input_path', type=str, default='')
279
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
280
+ arg_parser.add_argument('--k', type=int, default=1)
281
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
282
+ arg_parser.add_argument('--generate_method', type=str,
283
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
284
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
285
+ args = arg_parser.parse_args()
286
+ main(args)
287
+
288
+
289
+
gpt_pass_rate_new_test.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+ import random
10
+
11
+ def wrapped_function(item):
12
+ results = []
13
+ passed = 0
14
+ total = 0
15
+
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_file = os.path.join(temp_dir, f"test.lean")
18
+
19
+ with open(temp_file, "w") as f:
20
+ f.write(item['cmd'])
21
+
22
+ # Rest of the function code...
23
+ # Process the item using the temporary file
24
+ # ...
25
+
26
+ # Clean up the temporary file
27
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
28
+ command = 'echo \'%s\' | lake exe repl' % data
29
+
30
+ try:
31
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32
+ stdout = result.stdout.decode('utf-8')
33
+ stderr = result.stderr.decode('utf-8')
34
+ # stdout = result.stdout.decode('utf-8')
35
+ json_stdout = json.loads(stdout)
36
+ if "messages" not in json_stdout.keys():
37
+ passed += 1
38
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ except subprocess.CalledProcessError as e:
41
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
42
+ results.append({ 'error': str(e), 'status': 'nopass'})
43
+ total += 1
44
+
45
+ pass_rate = passed / (passed + total) * 100
46
+
47
+
48
+ return {'results': results, 'pass_rate': pass_rate}
49
+
50
+ # Set the directory where your .lean files are located
51
+
52
+ # Get a list of all .lean files in the directory
53
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
54
+ # lean_files = ["test/file.lean"]
55
+ def single(command_list, args):
56
+ results = []
57
+ passed = 0
58
+ total = 0
59
+ for item in tqdm(command_list):
60
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
61
+ f.write(item['cmd'])
62
+ data = '{"path": "test/test.lean", "allTactics": true}'
63
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
64
+ command = 'echo \'%s\' | lake exe repl' % data
65
+ try:
66
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
67
+ # stderr=subprocess.PIPE)
68
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
69
+ # stdout = stdout.decode('utf-8')
70
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
71
+ stdout = result.stdout.decode('utf-8')
72
+ json_stdout = json.loads(stdout)
73
+ if "messages" not in json_stdout.keys():
74
+ passed += 1
75
+ stderr = result.stderr.decode('utf-8')
76
+ results.append({
77
+ # 'item': item['content'],
78
+ 'stdout': stdout,
79
+ 'stderr': stderr,
80
+ 'status': 'pass'
81
+ })
82
+ except subprocess.CalledProcessError as e:
83
+ results.append({
84
+ # 'item': item['content'],
85
+ 'error': str(e),
86
+ 'status': 'nopass'
87
+ })
88
+ total += 1
89
+
90
+ # Calculate pass rate
91
+ pass_rate = passed / total * 100
92
+ print(pass_rate)
93
+
94
+ # Save results to a JSON file
95
+ with open('results.json', 'w') as f:
96
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
97
+
98
+
99
+ def multi(command_list, output_path, k ):
100
+ results = []
101
+ passed = 0
102
+ total = 0
103
+ def execute_command(item, index):
104
+ temp_dir = '/opt/jianqiao'
105
+ def filter_json(json_data):
106
+ filtered_data = {}
107
+ for key in json_data.keys():
108
+ if key in ['question', 'answer', 'total output', 'results']:
109
+ filtered_data[key] = json_data[key]
110
+ return filtered_data
111
+ result_dict = filter_json(item)
112
+ result_dict['results'] = []
113
+
114
+ for i, cmd in enumerate(item['cmd']):
115
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
116
+ with open(temp_file, "w") as f:
117
+ f.write(cmd)
118
+
119
+ data = '{"path": "%s", "allTactics": true}' % temp_file
120
+ command = f'echo \'{data}\' | lake exe repl'
121
+
122
+ try:
123
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
124
+ stdout = json.loads(result.stdout.decode('utf-8'))
125
+ stderr = result.stderr.decode('utf-8')
126
+
127
+ except subprocess.TimeoutExpired as e:
128
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
129
+
130
+ except subprocess.CalledProcessError as e:
131
+ result_item = {'error': str(e), 'status': 'nopass_error'}
132
+
133
+ else:
134
+ if "messages" not in stdout and not len(stderr):
135
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
136
+ elif not len(stderr) and "messages" in stdout:
137
+ flag = 0
138
+ for me in stdout['messages']:
139
+ if me['severity'] == 'error':
140
+ flag = 1
141
+ start_line = me['pos']['line'] - 1
142
+ current_column =me['pos']['column'] -1
143
+ for line_n in range(start_line - 1, 0 , -1):
144
+ line_len = len(cmd.split('\n')[line_n])
145
+ current_column += line_len + 1
146
+ if not line_len:
147
+ break
148
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
149
+ break
150
+ if not flag :
151
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
152
+ else:
153
+ assert len(stderr)
154
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
155
+
156
+ result_dict['results'].append(result_item)
157
+ return result_dict
158
+
159
+
160
+ total = len(command_list)
161
+
162
+ with ThreadPoolExecutor(max_workers=128) as executor:
163
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
164
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
165
+ result = future.result()
166
+ results.append(result)
167
+ # if result['status'] == 'pass':
168
+ # passed += 1
169
+
170
+ def calculate_pass(result_list, k):
171
+ pass_1_count = 0
172
+ pass_k_count = 0
173
+
174
+ for result in result_list:
175
+ results = result.get('results', [])
176
+ if results:
177
+ for j in range(min(1, len(results))):
178
+ if results[j].get('status') == 'pass':
179
+ pass_1_count += 1
180
+ break
181
+
182
+ for j in range(min(k, len(results))):
183
+ if results[j].get('status') == 'pass':
184
+ pass_k_count += 1
185
+ break
186
+
187
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
188
+ pass_k = pass_k_count / len(result_list) if result_list else 0
189
+
190
+ return pass_1, pass_k
191
+
192
+ pass_1, pass_k = calculate_pass(results, k)
193
+ print("Pass@1:", pass_1)
194
+ print(f"Pass@{k}:", pass_k)
195
+
196
+ # pass_rate = (passed / total) * 100
197
+ # print(f"total test: {total}")
198
+ # print(f"Pass rate: {pass_rate}%")
199
+
200
+ output_file = f"pass_rate_results/{output_path}"
201
+ # Create the directory if it doesn't exist
202
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
203
+
204
+ with open(f"{output_file}", 'w') as f:
205
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
206
+
207
+ import re
208
+ def remove_simp_pattern_from_end(s):
209
+ pattern = r'@\[simp\s*.*?\]$'
210
+ return re.sub(pattern, '', s)
211
+
212
+
213
+ def get_lean(text):
214
+ content = ""
215
+ try:
216
+ code_block_pattern = r"```lean\s*\n(.*?)\n```"
217
+ code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
218
+ content = "\n\n".join(code_blocks)
219
+ except:
220
+ matches = re.findall(r'```(.*?)```', text, re.DOTALL)
221
+ if len(matches):
222
+ content = "\n\n".join(matches)
223
+ finally:
224
+ if not len(content.strip()):
225
+ try:
226
+ code_block_pattern = r"```lean4\s*\n(.*?)\n```"
227
+ code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
228
+ content = "\n\n".join(code_blocks)
229
+ except:
230
+ content = ''
231
+
232
+ if not len(content.strip()):
233
+ content = "theorem h : f + g = 39 := by exact rfl"
234
+ return content
235
+
236
+ def main(args):
237
+ command_list = []
238
+
239
+ all_dicts = {}
240
+ with open(f"{args.input_path}/1.jsonl", 'r', encoding='utf-8') as rf:
241
+ for line in rf.readlines():
242
+ try:
243
+ json_item = json.loads(line)
244
+ working_env = json_item['working_file']
245
+ text = get_lean(json_item['model_response']).split("#align")[0]
246
+
247
+ json_item['cmd'] = ['\n\n'.join([working_env, text])]
248
+ json_item['answer'] = json_item['statement_poof']
249
+ all_dicts[json_item['query_id']] = json_item
250
+ assert len(text) > 0
251
+ except:
252
+ import pdb
253
+ pdb.set_trace()
254
+
255
+ file_pattern = os.path.join(args.input_path, '[2-9]*.jsonl')
256
+ for file_path in glob.glob(file_pattern):
257
+ with open(file_path, 'r', encoding='utf-8') as rf:
258
+ for line in rf.readlines():
259
+ try:
260
+ json_item = json.loads(line)
261
+ working_env = json_item['working_file']
262
+ text = get_lean(json_item['model_response']).split("#align")[0]
263
+ all_dicts[json_item['query_id']]['cmd'].append('\n\n'.join([working_env, text]))
264
+ assert len(text) > 0
265
+ except:
266
+ import pdb
267
+ pdb.set_trace()
268
+ for k, v in all_dicts.items():
269
+ command_list.append(v)
270
+ multi(command_list, args.output_path, args.k)
271
+
272
+ if __name__ == '__main__':
273
+ arg_parser = ArgumentParser()
274
+ arg_parser.add_argument('--data_path', type=str,
275
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
276
+ arg_parser.add_argument('--input_path', type=str, default='')
277
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
278
+ arg_parser.add_argument('--k', type=int, default=1)
279
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
280
+ arg_parser.add_argument('--generate_method', type=str,
281
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
282
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
283
+ args = arg_parser.parse_args()
284
+ main(args)
285
+
286
+
287
+
lake-manifest.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"version": 7,
2
+ "packagesDir": ".lake/packages",
3
+ "packages":
4
+ [{"url": "https://github.com/leanprover/std4",
5
+ "type": "git",
6
+ "subDir": null,
7
+ "rev": "e5306c3b0edefe722370b7387ee9bcd4631d6c17",
8
+ "name": "std",
9
+ "manifestFile": "lake-manifest.json",
10
+ "inputRev": "main",
11
+ "inherited": true,
12
+ "configFile": "lakefile.lean"},
13
+ {"url": "https://github.com/leanprover-community/quote4",
14
+ "type": "git",
15
+ "subDir": null,
16
+ "rev": "fd760831487e6835944e7eeed505522c9dd47563",
17
+ "name": "Qq",
18
+ "manifestFile": "lake-manifest.json",
19
+ "inputRev": "master",
20
+ "inherited": true,
21
+ "configFile": "lakefile.lean"},
22
+ {"url": "https://github.com/leanprover-community/aesop",
23
+ "type": "git",
24
+ "subDir": null,
25
+ "rev": "8be30c25e3caa06937feeb62f7ca898370f80ee9",
26
+ "name": "aesop",
27
+ "manifestFile": "lake-manifest.json",
28
+ "inputRev": "master",
29
+ "inherited": true,
30
+ "configFile": "lakefile.lean"},
31
+ {"url": "https://github.com/leanprover-community/ProofWidgets4",
32
+ "type": "git",
33
+ "subDir": null,
34
+ "rev": "fb65c476595a453a9b8ffc4a1cea2db3a89b9cd8",
35
+ "name": "proofwidgets",
36
+ "manifestFile": "lake-manifest.json",
37
+ "inputRev": "v0.0.30",
38
+ "inherited": true,
39
+ "configFile": "lakefile.lean"},
40
+ {"url": "https://github.com/leanprover/lean4-cli",
41
+ "type": "git",
42
+ "subDir": null,
43
+ "rev": "be8fa79a28b8b6897dce0713ef50e89c4a0f6ef5",
44
+ "name": "Cli",
45
+ "manifestFile": "lake-manifest.json",
46
+ "inputRev": "main",
47
+ "inherited": true,
48
+ "configFile": "lakefile.lean"},
49
+ {"url": "https://github.com/leanprover-community/import-graph.git",
50
+ "type": "git",
51
+ "subDir": null,
52
+ "rev": "61a79185b6582573d23bf7e17f2137cd49e7e662",
53
+ "name": "importGraph",
54
+ "manifestFile": "lake-manifest.json",
55
+ "inputRev": "main",
56
+ "inherited": true,
57
+ "configFile": "lakefile.lean"},
58
+ {"url": "https://github.com/leanprover-community/mathlib4",
59
+ "type": "git",
60
+ "subDir": null,
61
+ "rev": "3cecb823a74ed737c6ebc115e515eba649ec7715",
62
+ "name": "mathlib",
63
+ "manifestFile": "lake-manifest.json",
64
+ "inputRev": "3cecb823a74ed737c6ebc115e515eba649ec7715",
65
+ "inherited": false,
66
+ "configFile": "lakefile.lean"}],
67
+ "name": "REPL",
68
+ "lakeDir": ".lake"}
lakefile.lean ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Lake
2
+ open Lake DSL
3
+
4
+ package REPL {
5
+ -- add package configuration options here
6
+ }
7
+
8
+ lean_lib REPL {
9
+ -- add library configuration options here
10
+ }
11
+
12
+ @[default_target]
13
+ lean_exe repl where
14
+ root := `REPL.Main
15
+ supportInterpreter := true
16
+
17
+ require mathlib from git "https://github.com/leanprover-community/mathlib4"@"3cecb823a74ed737c6ebc115e515eba649ec7715"
lean-toolchain ADDED
@@ -0,0 +1 @@
 
 
1
+ leanprover/lean4:v4.7.0-rc2
nohup.out ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+
3
+ Pass rate: 20.0%
4
+
openllm_pass_rate_multi_pass.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import subprocess
3
+ import re
4
+
5
+ # Output file
6
+ output_file = "pass_rate_output.txt"
7
+
8
+ # Clearing the output file before appending new content
9
+ with open(output_file, "w") as file:
10
+ file.write("")
11
+
12
+ def get_output(input_path, k):
13
+ try:
14
+ # Split the input path based on '/'
15
+ parts = input_path.split('/')
16
+
17
+ # Find the index of 'zero_shot' and get the 4th part after it
18
+ zero_shot_index = parts.index('zero_shot')
19
+
20
+ part1 = parts[zero_shot_index + 1]
21
+ part3 = parts[zero_shot_index + 3]
22
+ part4 = parts[zero_shot_index + 4]
23
+ result = f"openllm_result/{part3}/{part1}_pass{part4}.json"
24
+ print(result)
25
+ return result
26
+
27
+ except:
28
+
29
+ print("No match found.")
30
+ return None
31
+
32
+
33
+ # List of input paths
34
+ input_path_lists = [
35
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/deepseek-math-7b-base/1/",
36
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/deepseek-math-7b-instruct/1/",
37
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/deepseek-math-7b-instruct/1/",
38
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/llemma_7b/1/",
39
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/llemma_7b/1/",
40
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/llemma_7b/1/",
41
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/llemma_34b/1/",
42
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/llemma_34b/1/",
43
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/llemma_34b/1/",
44
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/internlm2-math-7b/1/",
45
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/internlm2-math-7b/1/",
46
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/internlm2-math-7b/1/",
47
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/Mistral-7B-Instruct-v0.2/1/",
48
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/Mistral-7B-Instruct-v0.2/1/",
49
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/Mistral-7B-Instruct-v0.2/1/",
50
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/internlm2-math-20b/1/",
51
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/internlm2-math-20b/1/",
52
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/internlm2-math-20b/1/",
53
+
54
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/deepseek-math-7b-base/5/",
55
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/deepseek-math-7b-base/5/",
56
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/deepseek-math-7b-base/5/",
57
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/deepseek-math-7b-instruct/5/",
58
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/deepseek-math-7b-instruct/5/",
59
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/deepseek-math-7b-instruct/5/",
60
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/llemma_7b/5/",
61
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/llemma_7b/5/",
62
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/llemma_7b/5/",
63
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/llemma_34b/5/",
64
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/llemma_34b/5/",
65
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/llemma_34b/5/",
66
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/internlm2-math-7b/5/",
67
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/internlm2-math-7b/5/",
68
+ "../auto-info/generate_result/zero_shot/wild_test/generation/internlm2-math-7b/5/",
69
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/Mistral-7B-Instruct-v0.2/5/",
70
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/Mistral-7B-Instruct-v0.2/5/",
71
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/Mistral-7B-Instruct-v0.2/5/",
72
+ # "../auto-info/generate_result/zero_shot/lean4_basic_test/generation/internlm2-math-20b/5/",
73
+ # "../auto-info/generate_result/zero_shot/lean4_random_test/generation/internlm2-math-20b/5/",
74
+ "../auto-info/generate_result/zero_shot/wild_test/generation/internlm2-math-20b/5/",
75
+ ]
76
+
77
+
78
+ def extract_group(input_path):
79
+ try:
80
+ # Split the input path based on '/'
81
+ parts = input_path.split('/')
82
+
83
+ # Find the index of 'zero_shot' and get the 4th part after it
84
+ zero_shot_index = parts.index('zero_shot')
85
+ k = parts[zero_shot_index + 4]
86
+ return k
87
+ except (IndexError, ValueError):
88
+ # Handle cases where 'zero_shot' is not found or there are not enough parts
89
+ return None
90
+
91
+ # Iterate through the input paths and run the command
92
+ for input_path in input_path_lists:
93
+ k = extract_group(input_path)
94
+
95
+ if "wild" in input_path or "gsm8k_train" in input_path or "math_train" in input_path:
96
+ print(f"wild")
97
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
98
+ command = f"python3 openllm_pass_rate_new_notlean_test.py --input_path {input_path} --output_path {get_output(input_path,k)} --k {k}"
99
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
100
+ print("\n\n",file=open(output_file, "a"))
101
+ else:
102
+ print(f"lean")
103
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
104
+ command = f"python3 openllm_pass_rate_new_test.py --input_path {input_path} --output_path {get_output(input_path, k)} --k {k}"
105
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
106
+ print("\n\n",file=open(output_file, "a"))
openllm_pass_rate_new_notlean_test.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+ import random
10
+ from openllm_pass_rate_new_test import get_lean
11
+
12
+ def wrapped_function(item):
13
+ results = []
14
+ passed = 0
15
+ total = 0
16
+
17
+ temp_dir = tempfile.gettempdir()
18
+ temp_file = os.path.join(temp_dir, f"test.lean")
19
+
20
+ with open(temp_file, "w") as f:
21
+ f.write(item['cmd'])
22
+
23
+ # Rest of the function code...
24
+ # Process the item using the temporary file
25
+ # ...
26
+
27
+ # Clean up the temporary file
28
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
29
+ command = 'echo \'%s\' | lake exe repl' % data
30
+
31
+ try:
32
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
33
+ stdout = result.stdout.decode('utf-8')
34
+ stderr = result.stderr.decode('utf-8')
35
+ # stdout = result.stdout.decode('utf-8')
36
+ json_stdout = json.loads(stdout)
37
+ if "messages" not in json_stdout.keys():
38
+ passed += 1
39
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
41
+ except subprocess.CalledProcessError as e:
42
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
43
+ results.append({ 'error': str(e), 'status': 'nopass'})
44
+ total += 1
45
+
46
+ pass_rate = passed / (passed + total) * 100
47
+
48
+
49
+ return {'results': results, 'pass_rate': pass_rate}
50
+
51
+ # Set the directory where your .lean files are located
52
+
53
+ # Get a list of all .lean files in the directory
54
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
55
+ # lean_files = ["test/file.lean"]
56
+ def single(command_list, args):
57
+ results = []
58
+ passed = 0
59
+ total = 0
60
+ for item in tqdm(command_list):
61
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
62
+ f.write(item['cmd'])
63
+ data = '{"path": "test/test.lean", "allTactics": true}'
64
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
65
+ command = 'echo \'%s\' | lake exe repl' % data
66
+ try:
67
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
68
+ # stderr=subprocess.PIPE)
69
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
70
+ # stdout = stdout.decode('utf-8')
71
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
72
+ stdout = result.stdout.decode('utf-8')
73
+ json_stdout = json.loads(stdout)
74
+ if "messages" not in json_stdout.keys():
75
+ passed += 1
76
+ stderr = result.stderr.decode('utf-8')
77
+ results.append({
78
+ # 'item': item['content'],
79
+ 'stdout': stdout,
80
+ 'stderr': stderr,
81
+ 'status': 'pass'
82
+ })
83
+ except subprocess.CalledProcessError as e:
84
+ results.append({
85
+ # 'item': item['content'],
86
+ 'error': str(e),
87
+ 'status': 'nopass'
88
+ })
89
+ total += 1
90
+
91
+ # Calculate pass rate
92
+ pass_rate = passed / total * 100
93
+ print(pass_rate)
94
+
95
+ # Save results to a JSON file
96
+ with open('results.json', 'w') as f:
97
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
98
+
99
+
100
+ def multi(command_list, output_path, k ):
101
+ results = []
102
+ passed = 0
103
+ total = 0
104
+ def execute_command(item, index):
105
+ temp_dir = '/opt/jianqiao'
106
+ def filter_json(json_data):
107
+ filtered_data = {}
108
+ for key in json_data.keys():
109
+ if key in ['question', 'answer', 'total output', 'results']:
110
+ filtered_data[key] = json_data[key]
111
+ return filtered_data
112
+ # result_dict = filter_json(item)
113
+ result_dict = item
114
+ result_dict['results'] = []
115
+
116
+ for i, cmd in enumerate(item['cmd']):
117
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
118
+ with open(temp_file, "w") as f:
119
+ f.write(cmd)
120
+
121
+ data = '{"path": "%s", "allTactics": true}' % temp_file
122
+ command = f'echo \'{data}\' | lake exe repl'
123
+
124
+ try:
125
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
126
+ stdout = json.loads(result.stdout.decode('utf-8'))
127
+ stderr = result.stderr.decode('utf-8')
128
+
129
+ except subprocess.TimeoutExpired as e:
130
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
131
+
132
+ except subprocess.CalledProcessError as e:
133
+ result_item = {'error': str(e), 'status': 'nopass_error'}
134
+
135
+ else:
136
+ if "messages" not in stdout and not len(stderr):
137
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
138
+ elif not len(stderr) and "messages" in stdout:
139
+ flag = 0
140
+ for me in stdout['messages']:
141
+ if me['severity'] == 'error':
142
+ flag = 1
143
+ start_line = me['pos']['line'] - 1
144
+ current_column =me['pos']['column'] -1
145
+ for line_n in range(start_line - 1, 0 , -1):
146
+ line_len = len(cmd.split('\n')[line_n])
147
+ current_column += line_len + 1
148
+ if not line_len:
149
+ break
150
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
151
+ break
152
+ if not flag :
153
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
154
+ else:
155
+ assert len(stderr)
156
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
157
+
158
+ result_dict['results'].append(result_item)
159
+ return result_dict
160
+
161
+
162
+ total = len(command_list)
163
+
164
+ with ThreadPoolExecutor(max_workers=128) as executor:
165
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
166
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
167
+ result = future.result()
168
+ results.append(result)
169
+ # if result['status'] == 'pass':
170
+ # passed += 1
171
+
172
+ def calculate_pass(result_list, k):
173
+ pass_1_count = 0
174
+ pass_k_count = 0
175
+
176
+ for result in result_list:
177
+ results = result.get('results', [])
178
+ if results:
179
+ for j in range(min(1, len(results))):
180
+ if results[j].get('status') == 'pass':
181
+ pass_1_count += 1
182
+ break
183
+
184
+ for j in range(min(k, len(results))):
185
+ if results[j].get('status') == 'pass':
186
+ pass_k_count += 1
187
+ break
188
+
189
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
190
+ pass_k = pass_k_count / len(result_list) if result_list else 0
191
+
192
+ return pass_1, pass_k
193
+
194
+ pass_1, pass_k = calculate_pass(results, k)
195
+ print('total len:', len(results))
196
+ print("Pass@1:", pass_1)
197
+ print(f"Pass@{k}:", pass_k)
198
+
199
+ # pass_rate = (passed / total) * 100
200
+ # print(f"total test: {total}")
201
+ # print(f"Pass rate: {pass_rate}%")
202
+
203
+ output_file = f"pass_rate_results/{output_path}"
204
+ # Create the directory if it doesn't exist
205
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
206
+
207
+ with open(f"{output_file}", 'w') as f:
208
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
209
+
210
+ import re
211
+ def remove_simp_pattern_from_end(s):
212
+ pattern = r'@\[simp\s*.*?\]$'
213
+ return re.sub(pattern, '', s)
214
+
215
+
216
+
217
+
218
+ def main(args):
219
+ import pdb
220
+ command_list = []
221
+ # json_filename = 'data/notlean_dependency.json'
222
+ json_filename = 'data/basic_working.json'
223
+
224
+ json_item = json.load(open(json_filename, encoding='utf-8'))
225
+ working_env = json_item['working_file']
226
+ file_pattern = os.path.join(args.input_path, '[0-9]*.json')
227
+ for file_path in glob.glob(file_pattern):
228
+ with open(file_path, 'r', encoding='utf-8') as rf:
229
+ for line in rf.readlines():
230
+ try:
231
+ json_item = json.loads(line)
232
+ json_item['cmd'] = []
233
+ for output in json_item['total output'][:min(args.k, len(json_item['total output']))]:
234
+ if "llemma" in args.input_path:
235
+ output = output.split('###')[0]
236
+ statement = get_lean(output.strip(), args.input_path)
237
+ json_item['cmd'].append('\n\n'.join([working_env, statement]))
238
+ json_item['answer'] = json_item['content']['answer']
239
+ except:
240
+ import pdb
241
+ pdb.set_trace()
242
+ command_list.append(json_item)
243
+ command_list = command_list
244
+
245
+ multi(command_list, args.output_path, args.k)
246
+
247
+
248
+
249
+
250
+ if __name__ == '__main__':
251
+ arg_parser = ArgumentParser()
252
+ arg_parser.add_argument('--data_path', type=str,
253
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
254
+ arg_parser.add_argument('--input_path', type=str, default='')
255
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
256
+ arg_parser.add_argument('--k', type=int, default=1)
257
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
258
+ arg_parser.add_argument('--generate_method', type=str,
259
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
260
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
261
+ args = arg_parser.parse_args()
262
+ main(args)
263
+
264
+
265
+
openllm_pass_rate_new_test.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+ import random
10
+
11
+ def wrapped_function(item):
12
+ results = []
13
+ passed = 0
14
+ total = 0
15
+
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_file = os.path.join(temp_dir, f"test.lean")
18
+
19
+ with open(temp_file, "w") as f:
20
+ f.write(item['cmd'])
21
+
22
+ # Rest of the function code...
23
+ # Process the item using the temporary file
24
+ # ...
25
+
26
+ # Clean up the temporary file
27
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
28
+ command = 'echo \'%s\' | lake exe repl' % data
29
+
30
+ try:
31
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32
+ stdout = result.stdout.decode('utf-8')
33
+ stderr = result.stderr.decode('utf-8')
34
+ # stdout = result.stdout.decode('utf-8')
35
+ json_stdout = json.loads(stdout)
36
+ if "messages" not in json_stdout.keys():
37
+ passed += 1
38
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ except subprocess.CalledProcessError as e:
41
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
42
+ results.append({ 'error': str(e), 'status': 'nopass'})
43
+ total += 1
44
+
45
+ pass_rate = passed / (passed + total) * 100
46
+
47
+
48
+ return {'results': results, 'pass_rate': pass_rate}
49
+
50
+ # Set the directory where your .lean files are located
51
+
52
+ # Get a list of all .lean files in the directory
53
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
54
+ # lean_files = ["test/file.lean"]
55
+ def single(command_list, args):
56
+ results = []
57
+ passed = 0
58
+ total = 0
59
+ for item in tqdm(command_list):
60
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
61
+ f.write(item['cmd'])
62
+ data = '{"path": "test/test.lean", "allTactics": true}'
63
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
64
+ command = 'echo \'%s\' | lake exe repl' % data
65
+ try:
66
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
67
+ # stderr=subprocess.PIPE)
68
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
69
+ # stdout = stdout.decode('utf-8')
70
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
71
+ stdout = result.stdout.decode('utf-8')
72
+ json_stdout = json.loads(stdout)
73
+ if "messages" not in json_stdout.keys():
74
+ passed += 1
75
+ stderr = result.stderr.decode('utf-8')
76
+ results.append({
77
+ # 'item': item['content'],
78
+ 'stdout': stdout,
79
+ 'stderr': stderr,
80
+ 'status': 'pass'
81
+ })
82
+ except subprocess.CalledProcessError as e:
83
+ results.append({
84
+ # 'item': item['content'],
85
+ 'error': str(e),
86
+ 'status': 'nopass'
87
+ })
88
+ total += 1
89
+
90
+ # Calculate pass rate
91
+ pass_rate = passed / total * 100
92
+ print(pass_rate)
93
+
94
+ # Save results to a JSON file
95
+ with open('results.json', 'w') as f:
96
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
97
+
98
+
99
+ def multi(command_list, output_path, k ):
100
+ results = []
101
+ passed = 0
102
+ total = 0
103
+ def execute_command(item, index):
104
+ temp_dir = '/opt/jianqiao'
105
+ def filter_json(json_data):
106
+ filtered_data = {}
107
+ for key in json_data.keys():
108
+ if key in ['question', 'answer', 'total output', 'results']:
109
+ filtered_data[key] = json_data[key]
110
+ return filtered_data
111
+ result_dict = filter_json(item)
112
+ result_dict['results'] = []
113
+
114
+ for i, cmd in enumerate(item['cmd']):
115
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
116
+ with open(temp_file, "w") as f:
117
+ f.write(cmd)
118
+
119
+ data = '{"path": "%s", "allTactics": true}' % temp_file
120
+ command = f'echo \'{data}\' | lake exe repl'
121
+
122
+ try:
123
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
124
+ stdout = json.loads(result.stdout.decode('utf-8'))
125
+ stderr = result.stderr.decode('utf-8')
126
+
127
+ except subprocess.TimeoutExpired as e:
128
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
129
+
130
+ except subprocess.CalledProcessError as e:
131
+ result_item = {'error': str(e), 'status': 'nopass_error'}
132
+
133
+ else:
134
+ if "messages" not in stdout and not len(stderr):
135
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
136
+ elif not len(stderr) and "messages" in stdout:
137
+ flag = 0
138
+ for me in stdout['messages']:
139
+ if me['severity'] == 'error':
140
+ flag = 1
141
+ start_line = me['pos']['line'] - 1
142
+ current_column =me['pos']['column'] -1
143
+ for line_n in range(start_line - 1, 0 , -1):
144
+ line_len = len(cmd.split('\n')[line_n])
145
+ current_column += line_len + 1
146
+ if not line_len:
147
+ break
148
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
149
+ break
150
+ if not flag :
151
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
152
+ else:
153
+ assert len(stderr)
154
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
155
+
156
+ result_dict['results'].append(result_item)
157
+ return result_dict
158
+
159
+
160
+ total = len(command_list)
161
+
162
+ with ThreadPoolExecutor(max_workers=128) as executor:
163
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
164
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
165
+ result = future.result()
166
+ results.append(result)
167
+ # if result['status'] == 'pass':
168
+ # passed += 1
169
+
170
+ def calculate_pass(result_list, k):
171
+ pass_1_count = 0
172
+ pass_k_count = 0
173
+
174
+ for result in result_list:
175
+ results = result.get('results', [])
176
+ if results:
177
+ for j in range(min(1, len(results))):
178
+ if results[j].get('status') == 'pass':
179
+ pass_1_count += 1
180
+ break
181
+
182
+ for j in range(min(k, len(results))):
183
+ if results[j].get('status') == 'pass':
184
+ pass_k_count += 1
185
+ break
186
+
187
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
188
+ pass_k = pass_k_count / len(result_list) if result_list else 0
189
+
190
+ return pass_1, pass_k
191
+
192
+ pass_1, pass_k = calculate_pass(results, k)
193
+ print('total len:', len(results))
194
+ print("Pass@1:", pass_1)
195
+ print(f"Pass@{k}:", pass_k)
196
+
197
+ # pass_rate = (passed / total) * 100
198
+ # print(f"total test: {total}")
199
+ # print(f"Pass rate: {pass_rate}%")
200
+
201
+ output_file = f"pass_rate_results/{output_path}"
202
+ # Create the directory if it doesn't exist
203
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
204
+
205
+ with open(f"{output_file}", 'w') as f:
206
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
207
+
208
+ import re
209
+ def remove_simp_pattern_from_end(s):
210
+ pattern = r'@\[simp\s*.*?\]$'
211
+ return re.sub(pattern, '', s)
212
+
213
+
214
+
215
+ def get_lean(text, input_path):
216
+
217
+ if any(x in input_path for x in ["deepseek-math-7b-instruct", "deepseek-math-7b-base", "llemma_34b", "llemma_7b"]):
218
+ try:
219
+ code_block_pattern = r"```lean4\s*\n(.*?)\n```"
220
+ code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
221
+ content = "\n\n".join(code_blocks)
222
+ # import pdb
223
+ # pdb.set_trace()
224
+ except:
225
+ matches = re.findall(r'```(.*?)```', text, re.DOTALL)
226
+ if len(matches):
227
+ content = "\n\n".join(matches)
228
+ elif any(x in input_path for x in ["internlm2-math"]):
229
+ import pdb
230
+
231
+ try:
232
+ code_block_pattern = r"```\nlean\s*\n(.*?)\n```"
233
+ code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
234
+ content = "\n\n".join(code_blocks)
235
+ # import pdb
236
+ # pdb.set_trace()
237
+ except:
238
+ matches = re.findall(r'```(.*?)```', text, re.DOTALL)
239
+ if len(matches):
240
+ content = "\n\n".join(matches)
241
+ # import pdb
242
+ # pdb.set_trace()
243
+ elif any(x in input_path for x in ["Mistral-7B-Instruct-v0.2"]):
244
+ try:
245
+ code_block_pattern = r"```lean\s*\n(.*?)\n```"
246
+ code_blocks = re.findall(code_block_pattern, text, re.DOTALL)
247
+ content = "\n\n".join(code_blocks)
248
+ # import pdb
249
+ # pdb.set_trace()
250
+ except:
251
+ matches = re.findall(r'```(.*?)```', text, re.DOTALL)
252
+ if len(matches):
253
+ content = "\n\n".join(matches)
254
+ # import pdb
255
+ # pdb.set_trace()
256
+ else:
257
+ raise NotImplementedError("not implmemented")
258
+
259
+ if not len(content.strip()):
260
+ content = "theorem h : f + g = 39 := by exact rfl"
261
+ return content
262
+
263
+ def main(args):
264
+ command_list = []
265
+ file_pattern = os.path.join(args.input_path, '[0-9]*.json')
266
+ for file_path in glob.glob(file_pattern):
267
+ with open(file_path, 'r', encoding='utf-8') as rf:
268
+ for line in rf.readlines():
269
+ try:
270
+ json_item = json.loads(line)
271
+ working_env = json_item['content']['working_file']
272
+ # pdb.set_trace()
273
+ # statement = json_item['total output'][0]
274
+ json_item['cmd'] = []
275
+ for output in json_item['total output'][:min(args.k, len(json_item['total output']))]:
276
+ if "llemma" in args.input_path:
277
+ output = output.split('###')[0]
278
+ statement = get_lean(output.strip(), args.input_path)
279
+ json_item['cmd'].append('\n\n'.join([working_env, statement]))
280
+ json_item['answer'] = json_item['content']['statement_poof']
281
+ # assert len(statement) > 0
282
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
283
+ except:
284
+ import pdb
285
+ pdb.set_trace()
286
+ # import pdb
287
+ # pdb.set_trace()
288
+ command_list.append(json_item)
289
+ multi(command_list, args.output_path, args.k)
290
+
291
+ if __name__ == '__main__':
292
+ arg_parser = ArgumentParser()
293
+ arg_parser.add_argument('--data_path', type=str,
294
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
295
+ arg_parser.add_argument('--input_path', type=str, default='')
296
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
297
+ arg_parser.add_argument('--k', type=int, default=1)
298
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
299
+ arg_parser.add_argument('--generate_method', type=str,
300
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
301
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
302
+ args = arg_parser.parse_args()
303
+ main(args)
304
+
305
+
306
+
pass_rate.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import tempfile
8
+
9
+ def wrapped_function(item):
10
+ results = []
11
+ passed = 0
12
+ total = 0
13
+
14
+ temp_dir = tempfile.gettempdir()
15
+ temp_file = os.path.join(temp_dir, f"test.lean")
16
+
17
+ with open(temp_file, "w") as f:
18
+ f.write(item['cmd'])
19
+
20
+ # Rest of the function code...
21
+ # Process the item using the temporary file
22
+ # ...
23
+
24
+ # Clean up the temporary file
25
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
26
+ command = 'echo \'%s\' | lake exe repl' % data
27
+
28
+ try:
29
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
30
+ stdout = result.stdout.decode('utf-8')
31
+ stderr = result.stderr.decode('utf-8')
32
+ # stdout = result.stdout.decode('utf-8')
33
+ json_stdout = json.loads(stdout)
34
+ if "messages" not in json_stdout.keys():
35
+ passed += 1
36
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
37
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
38
+ except subprocess.CalledProcessError as e:
39
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
40
+ results.append({ 'error': str(e), 'status': 'nopass'})
41
+ total += 1
42
+
43
+ pass_rate = passed / (passed + total) * 100
44
+
45
+
46
+ return {'results': results, 'pass_rate': pass_rate}
47
+
48
+ # Set the directory where your .lean files are located
49
+
50
+ # Get a list of all .lean files in the directory
51
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
52
+ # lean_files = ["test/file.lean"]
53
+ def single(command_list):
54
+ results = []
55
+ passed = 0
56
+ total = 0
57
+ for item in tqdm(command_list):
58
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
59
+ f.write(item['cmd'])
60
+ data = '{"path": "test/test.lean", "allTactics": true}'
61
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
62
+ command = 'echo \'%s\' | lake exe repl' % data
63
+ try:
64
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
65
+ # stderr=subprocess.PIPE)
66
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
67
+ # stdout = stdout.decode('utf-8')
68
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
69
+ stdout = result.stdout.decode('utf-8')
70
+ json_stdout = json.loads(stdout)
71
+ if "messages" not in json_stdout.keys():
72
+ passed += 1
73
+ stderr = result.stderr.decode('utf-8')
74
+ results.append({
75
+ # 'item': item['content'],
76
+ 'stdout': stdout,
77
+ 'stderr': stderr,
78
+ 'status': 'pass'
79
+ })
80
+ except subprocess.CalledProcessError as e:
81
+ results.append({
82
+ # 'item': item['content'],
83
+ 'error': str(e),
84
+ 'status': 'nopass'
85
+ })
86
+ total += 1
87
+
88
+ # Calculate pass rate
89
+ pass_rate = passed / total * 100
90
+ print(pass_rate)
91
+
92
+ # Save results to a JSON file
93
+ with open('results.json', 'w') as f:
94
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
95
+
96
+
97
+
98
+
99
+ def multi(command_list):
100
+ results = []
101
+ passed = 0
102
+ total = 0
103
+ def execute_command(item):
104
+ temp_dir = '/data/tmp'
105
+ temp_file = os.path.join(temp_dir, f"test_{item['index']}.lean") # Ensure unique filenames
106
+ with open(temp_file, "w") as f:
107
+ f.write(item['cmd'])
108
+
109
+ data = '{"path": "%s", "allTactics": true}' % temp_file
110
+ command = f'echo \'{data}\' | lake exe repl'
111
+
112
+ try:
113
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
114
+ stdout = result.stdout.decode('utf-8')
115
+ stderr = result.stderr.decode('utf-8')
116
+
117
+ if "messages" not in json.loads(stdout):
118
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
119
+ else:
120
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'nopass'}
121
+
122
+ except subprocess.CalledProcessError as e:
123
+ return {'error': str(e), 'status': 'nopass'}
124
+
125
+ os.remove(temp_file)
126
+
127
+ total = len(command_list)
128
+
129
+ with ThreadPoolExecutor(max_workers=32) as executor:
130
+ futures = [executor.submit(execute_command, {'index': i, 'cmd': cmd['cmd']}) for i, cmd in enumerate(command_list)]
131
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
132
+ result = future.result()
133
+ results.append(result)
134
+ if result['status'] == 'pass':
135
+ passed += 1
136
+
137
+ pass_rate = (passed / total) * 100
138
+ print(f"Pass rate: {pass_rate}%")
139
+
140
+ with open('results.json', 'w') as f:
141
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
142
+
143
+ import re
144
+ def remove_simp_pattern_from_end(s):
145
+ pattern = r'@\[simp\s*.*?\]$'
146
+ return re.sub(pattern, '', s)
147
+
148
+ def main(args):
149
+ command_list = []
150
+ for i in range(4):
151
+ with open(f"{args.input_path}/{i}.json", 'r', encoding='utf-8') as rf:
152
+ for line in rf.readlines():
153
+ try:
154
+ json_item = json.loads(line)
155
+ # json_item['content']['statement_poof']
156
+ # json_item['cmd'] = '\n'.join([json_item['content']['working_file'] , json_item['total output'][0]])
157
+ working_env = json_item['content']['working_file'].split('\n')
158
+ for loc in range(len(working_env) - 3, 0, -1):
159
+ if not len(working_env[loc].strip()):
160
+ break
161
+
162
+ working_env = '\n'.join(working_env[:loc] + ['\n'])
163
+ # statement = json_item['content']['statement_poof'].split('\n')
164
+ statement = json_item['total output'][0].split('\n')
165
+ for loc in range(len(statement)):
166
+ if not len(statement[loc].strip()):
167
+ break
168
+ statement = '\n'.join(statement[:loc] + ['\n'])
169
+ json_item['cmd'] = '\n'.join([working_env, statement])
170
+ assert len(statement) > 0
171
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
172
+ except:
173
+ import pdb
174
+ pdb.set_trace()
175
+ command_list.append(json_item)
176
+ command_list = command_list
177
+ results = []
178
+ passed = 0
179
+ total = 0
180
+ single(command_list[:1])
181
+
182
+ if __name__ == '__main__':
183
+ arg_parser = ArgumentParser()
184
+ arg_parser.add_argument('--data_path', type=str,
185
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
186
+ arg_parser.add_argument('--input_path', type=str, default='')
187
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
188
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
189
+ arg_parser.add_argument('--generate_method', type=str,
190
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
191
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
192
+ args = arg_parser.parse_args()
193
+ main(args)
194
+
pass_rate_atp_pass.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import subprocess
3
+ import re
4
+
5
+ # Output file
6
+ output_file = "pass_rate_output.txt"
7
+
8
+ # Clearing the output file before appending new content
9
+ with open(output_file, "w") as file:
10
+ file.write("")
11
+
12
+ # List of input paths
13
+ input_path_lists = [
14
+ "test/zero_shot/wild_test/generation/lean4_random_15k_all/2/1/",
15
+ ]
16
+
17
+ def get_output(input_string, k):
18
+ pattern = r"zero_shot/(\w+)/(.+?)/(\w+)"
19
+ match = re.search(pattern, input_string)
20
+ if match:
21
+ part1 = match.group(1)
22
+ part2 = match.group(3) + f"pass{k}.jsonl"
23
+ result = "/".join([part1, part2])
24
+ print(result)
25
+ else:
26
+ print("No match found.")
27
+ assert True
28
+ return result
29
+
30
+ # List of input paths
31
+ input_path_lists = [
32
+ # "../auto-info/generate_result/zero_shot/gsm8k_train/generation/lean4_random_15k_all/2/10/",
33
+ # "../auto-info/generate_result/zero_shot/math_train/generation/lean4_random_15k_all/2/10/",
34
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/1/1",
35
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/2/1",
36
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/3/1",
37
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/1/1",
38
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/2/1",
39
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/3/1",
40
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/1/1",
41
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/2/1",
42
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/3/1",
43
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_rft/1/1/",
44
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_rft/2/1/",
45
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_rft/3/1/",
46
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_rft/1/1/",
47
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_rft/2/1/",
48
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_rft/3/1/",
49
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier/1/1/",
50
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier/2/1/",
51
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier/3/1/",
52
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier/1/1/",
53
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier/2/1/",
54
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier/3/1/",
55
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier_rft/1/1/",
56
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier_rft/2/1/",
57
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier_rft/3/1/",
58
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier_rft/1/1/",
59
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier_rft/2/1/",
60
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier_rft/3/1/",
61
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/1/1/",
62
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/2/1/",
63
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/3/1/",
64
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/1/1/",
65
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/2/1/",
66
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/3/1/",
67
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/1/1/",
68
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/2/1/",
69
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/3/1/",
70
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_15k_train/generation/lean4_random_15k_all/2/20/",
71
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/2/5/",
72
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all/2/5/",
73
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_5k/2/1/",
74
+ # "test/zero_shot/lean4_random_test/generation/lean4_random_15k_all/3/1/",
75
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/2/1/",
76
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all/2/1/",
77
+ # "test/zero_shot/lean4_random_test/generation/lean4_random_15k_all/3/1/",
78
+ # "test/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/3/1/",
79
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/1/1/",
80
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/1/1/",
81
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/1/1/",
82
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/2/1/",
83
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/2/1/",
84
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/2/1/",
85
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/3/1/",
86
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/3/1/",
87
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/3/1/",
88
+ # "/opt/tiger/auto-info/generate_result/zero_shot/gsm8k_train/generation/lean4_random_15k_all_mathrft/2/10/",
89
+ # "/opt/tiger/auto-info/generate_result/zero_shot/math_train/generation/lean4_random_15k_all_mathrft/2/10/",
90
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_15k_train/generation/lean4_random_15k_all_mathrft/2/10/",
91
+ # Add more input paths as needed
92
+ "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/2/5/",
93
+ "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/2/5/",
94
+ "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/2/5/",
95
+ ]
96
+
97
+ # Iterate through the input paths and run the command
98
+ for input_path in input_path_lists:
99
+ k = 5
100
+ if "wild_test" in input_path or "gsm8k_train" in input_path or "math_train" in input_path:
101
+ print(f"wild")
102
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
103
+ command = f"python3 pass_rate_notlean_test.py --input_path {input_path} --output_path {get_output(input_path,k)} --k {k}"
104
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
105
+ print("\n\n",file=open(output_file, "a"))
106
+
107
+ else:
108
+ print(f"lean")
109
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
110
+ command = f"python3 pass_rate_new_test.py --input_path {input_path} --output_path {get_output(input_path, k)} --k {k}"
111
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
112
+ print("\n\n",file=open(output_file, "a"))
pass_rate_atp_test.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+ import random
10
+
11
+ def wrapped_function(item):
12
+ results = []
13
+ passed = 0
14
+ total = 0
15
+
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_file = os.path.join(temp_dir, f"test.lean")
18
+
19
+ with open(temp_file, "w") as f:
20
+ f.write(item['cmd'])
21
+
22
+ # Rest of the function code...
23
+ # Process the item using the temporary file
24
+ # ...
25
+
26
+ # Clean up the temporary file
27
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
28
+ command = 'echo \'%s\' | lake exe repl' % data
29
+
30
+ try:
31
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32
+ stdout = result.stdout.decode('utf-8')
33
+ stderr = result.stderr.decode('utf-8')
34
+ # stdout = result.stdout.decode('utf-8')
35
+ json_stdout = json.loads(stdout)
36
+ if "messages" not in json_stdout.keys():
37
+ passed += 1
38
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ except subprocess.CalledProcessError as e:
41
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
42
+ results.append({ 'error': str(e), 'status': 'nopass'})
43
+ total += 1
44
+
45
+ pass_rate = passed / (passed + total) * 100
46
+
47
+
48
+ return {'results': results, 'pass_rate': pass_rate}
49
+
50
+ # Set the directory where your .lean files are located
51
+
52
+ # Get a list of all .lean files in the directory
53
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
54
+ # lean_files = ["test/file.lean"]
55
+ def single(command_list, args):
56
+ results = []
57
+ passed = 0
58
+ total = 0
59
+ for item in tqdm(command_list):
60
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
61
+ f.write(item['cmd'])
62
+ data = '{"path": "test/test.lean", "allTactics": true}'
63
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
64
+ command = 'echo \'%s\' | lake exe repl' % data
65
+ try:
66
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
67
+ # stderr=subprocess.PIPE)
68
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
69
+ # stdout = stdout.decode('utf-8')
70
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
71
+ stdout = result.stdout.decode('utf-8')
72
+ json_stdout = json.loads(stdout)
73
+ if "messages" not in json_stdout.keys():
74
+ passed += 1
75
+ stderr = result.stderr.decode('utf-8')
76
+ results.append({
77
+ # 'item': item['content'],
78
+ 'stdout': stdout,
79
+ 'stderr': stderr,
80
+ 'status': 'pass'
81
+ })
82
+ except subprocess.CalledProcessError as e:
83
+ results.append({
84
+ # 'item': item['content'],
85
+ 'error': str(e),
86
+ 'status': 'nopass'
87
+ })
88
+ total += 1
89
+
90
+ # Calculate pass rate
91
+ pass_rate = passed / total * 100
92
+ print(pass_rate)
93
+
94
+ # Save results to a JSON file
95
+ with open('results.json', 'w') as f:
96
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
97
+
98
+
99
+ def multi(command_list, output_path, k ):
100
+ results = []
101
+ passed = 0
102
+ total = 0
103
+ def execute_command(item, index):
104
+ temp_dir = '/opt/jianqiao'
105
+ def filter_json(json_data):
106
+ filtered_data = {}
107
+ for key in json_data.keys():
108
+ if key in ['question', 'answer', 'total output', 'results']:
109
+ filtered_data[key] = json_data[key]
110
+ return filtered_data
111
+ result_dict = filter_json(item)
112
+ result_dict['results'] = []
113
+
114
+ for i, cmd in enumerate(item['cmd']):
115
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
116
+ with open(temp_file, "w") as f:
117
+ f.write(cmd)
118
+
119
+ data = '{"path": "%s", "allTactics": true}' % temp_file
120
+ command = f'echo \'{data}\' | lake exe repl'
121
+
122
+ try:
123
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
124
+ stdout = json.loads(result.stdout.decode('utf-8'))
125
+ stderr = result.stderr.decode('utf-8')
126
+
127
+ except subprocess.TimeoutExpired as e:
128
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
129
+
130
+ except subprocess.CalledProcessError as e:
131
+ result_item = {'error': str(e), 'status': 'nopass_error'}
132
+
133
+ else:
134
+ if "messages" not in stdout and not len(stderr):
135
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
136
+ elif not len(stderr) and "messages" in stdout:
137
+ flag = 0
138
+ for me in stdout['messages']:
139
+ if me['severity'] == 'error':
140
+ flag = 1
141
+ # start_line = me['pos']['line'] - 1
142
+ # current_column =me['pos']['column'] -1
143
+ # for line_n in range(start_line - 1, 0 , -1):
144
+ # line_len = len(cmd.split('\n')[line_n])
145
+ # current_column += line_len + 1
146
+ # if not line_len:
147
+ # break
148
+ current_column = -1
149
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
150
+ break
151
+ if not flag :
152
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
153
+ else:
154
+ assert len(stderr)
155
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
156
+
157
+ result_dict['results'].append(result_item)
158
+ return result_dict
159
+
160
+
161
+ total = len(command_list)
162
+
163
+ with ThreadPoolExecutor(max_workers=128) as executor:
164
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
165
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
166
+ result = future.result()
167
+ results.append(result)
168
+ # if result['status'] == 'pass':
169
+ # passed += 1
170
+
171
+ def calculate_pass(result_list, k):
172
+ pass_1_count = 0
173
+ pass_k_count = 0
174
+
175
+ for result in result_list:
176
+ results = result.get('results', [])
177
+ if results:
178
+ for j in range(min(1, len(results))):
179
+ if results[j].get('status') == 'pass':
180
+ pass_1_count += 1
181
+ break
182
+
183
+ for j in range(min(k, len(results))):
184
+ if results[j].get('status') == 'pass':
185
+ pass_k_count += 1
186
+ break
187
+
188
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
189
+ pass_k = pass_k_count / len(result_list) if result_list else 0
190
+
191
+ return pass_1, pass_k
192
+
193
+ pass_1, pass_k = calculate_pass(results, k)
194
+ print("Pass@1:", pass_1)
195
+ print(f"Pass@{k}:", pass_k)
196
+
197
+ # pass_rate = (passed / total) * 100
198
+ # print(f"total test: {total}")
199
+ # print(f"Pass rate: {pass_rate}%")
200
+
201
+ output_file = f"pass_rate_results/{output_path}"
202
+ # Create the directory if it doesn't exist
203
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
204
+
205
+ with open(f"{output_file}", 'w') as f:
206
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
207
+
208
+ import re
209
+ def remove_simp_pattern_from_end(s):
210
+ pattern = r'@\[simp\s*.*?\]$'
211
+ return re.sub(pattern, '', s)
212
+
213
+ def main(args):
214
+ command_list = []
215
+ file_pattern = os.path.join(args.input_path, '[0-1]*.json')
216
+ # head_list = ["import MiniF2F.Minif2fImport", "import MiniF2F.Valid", "import MiniF2F.Test"]
217
+ head_list = ''
218
+ with open("MiniF2F/Minif2fImport.lean", 'r', encoding='utf8') as rf:
219
+ for line in rf.readlines():
220
+ head_list += line
221
+ for file_path in glob.glob(file_pattern):
222
+ with open(file_path, 'r', encoding='utf-8') as rf:
223
+ for line in rf.readlines():
224
+ try:
225
+ json_item = json.loads(line)
226
+ working_env = json_item['content']['header']
227
+ # pdb.set_trace()
228
+ # statement = json_item['total output'][0]
229
+ json_item['cmd'] = []
230
+ for output in json_item['total output'][:min(args.k, len(json_item['total output']))]:
231
+ proof = output.split("#align")[0]
232
+ atp = json_item['content']['formal_statement'].split(":=")[0] + ":=" + proof
233
+
234
+ json_item['cmd'].append('\n\n'.join( [head_list, working_env, atp]))
235
+ # print(json_item['cmd'][0])
236
+ # import pdb
237
+ # pdb.set_trace()
238
+
239
+ # json_item['answer'] = json_item['content']['statement_poof']
240
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
241
+ except:
242
+ import pdb
243
+ pdb.set_trace()
244
+ # import pdb
245
+ # pdb.set_trace()
246
+ command_list.append(json_item)
247
+ multi(command_list, args.output_path, args.k)
248
+
249
+ if __name__ == '__main__':
250
+ arg_parser = ArgumentParser()
251
+ arg_parser.add_argument('--data_path', type=str,
252
+ default='/opt/tiger/CLIP/theorem_proving/generate_result/zero_shot/minif2f_test/generation/lean5_random_15k_all_mathrft/2/5/')
253
+ arg_parser.add_argument('--input_path', type=str, default='/opt/tiger/CLIP/theorem_proving/generate_result/zero_shot/minif2f_test/generation/lean5_random_15k_all_mathrft/2/5/')
254
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
255
+ arg_parser.add_argument('--k', type=int, default=5)
256
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
257
+ arg_parser.add_argument('--generate_method', type=str,
258
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
259
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
260
+ args = arg_parser.parse_args()
261
+ main(args)
262
+
263
+
264
+
pass_rate_found_item.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import tempfile
8
+
9
+ def wrapped_function(item):
10
+ results = []
11
+ passed = 0
12
+ total = 0
13
+
14
+ temp_dir = tempfile.gettempdir()
15
+ temp_file = os.path.join(temp_dir, f"test.lean")
16
+
17
+ with open(temp_file, "w") as f:
18
+ f.write(item['cmd'])
19
+
20
+ # Rest of the function code...
21
+ # Process the item using the temporary file
22
+ # ...
23
+
24
+ # Clean up the temporary file
25
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
26
+ command = 'echo \'%s\' | lake exe repl' % data
27
+
28
+ try:
29
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
30
+ stdout = result.stdout.decode('utf-8')
31
+ stderr = result.stderr.decode('utf-8')
32
+ # stdout = result.stdout.decode('utf-8')
33
+ json_stdout = json.loads(stdout)
34
+ if "messages" not in json_stdout.keys():
35
+ passed += 1
36
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
37
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
38
+ except subprocess.CalledProcessError as e:
39
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
40
+ results.append({ 'error': str(e), 'status': 'nopass'})
41
+ total += 1
42
+
43
+ pass_rate = passed / (passed + total) * 100
44
+
45
+
46
+ return {'results': results, 'pass_rate': pass_rate}
47
+
48
+ # Set the directory where your .lean files are located
49
+
50
+ # Get a list of all .lean files in the directory
51
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
52
+ # lean_files = ["test/file.lean"]
53
+ def single(command_list):
54
+ results = []
55
+ passed = 0
56
+ total = 0
57
+ for item in tqdm(command_list):
58
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
59
+ f.write(item['cmd'])
60
+ data = '{"path": "test/test.lean", "allTactics": true}'
61
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
62
+ command = 'echo \'%s\' | lake exe repl' % data
63
+ try:
64
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
65
+ # stderr=subprocess.PIPE)
66
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
67
+ # stdout = stdout.decode('utf-8')
68
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
69
+ stdout = result.stdout.decode('utf-8')
70
+ json_stdout = json.loads(stdout)
71
+ if "messages" not in json_stdout.keys():
72
+ passed += 1
73
+ stderr = result.stderr.decode('utf-8')
74
+ results.append({
75
+ # 'item': item['content'],
76
+ 'stdout': stdout,
77
+ 'stderr': stderr,
78
+ 'status': 'pass'
79
+ })
80
+ except subprocess.CalledProcessError as e:
81
+ results.append({
82
+ # 'item': item['content'],
83
+ 'error': str(e),
84
+ 'status': 'nopass'
85
+ })
86
+ total += 1
87
+
88
+ # Calculate pass rate
89
+ pass_rate = passed / total * 100
90
+ print(pass_rate)
91
+
92
+ # Save results to a JSON file
93
+ with open('results.json', 'w') as f:
94
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
95
+
96
+
97
+
98
+
99
+ def multi(command_list):
100
+ results = []
101
+ passed = 0
102
+ total = 0
103
+ def execute_command(item):
104
+ temp_dir = '/data/tmp'
105
+ temp_file = os.path.join(temp_dir, f"test_{item['index']}.lean") # Ensure unique filenames
106
+ with open(temp_file, "w") as f:
107
+ f.write(item['cmd'])
108
+
109
+ data = '{"path": "%s", "allTactics": true}' % temp_file
110
+ command = f'echo \'{data}\' | lake exe repl'
111
+
112
+ try:
113
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
114
+ stdout = result.stdout.decode('utf-8')
115
+ stderr = result.stderr.decode('utf-8')
116
+
117
+ if "messages" not in json.loads(stdout):
118
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
119
+ else:
120
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'nopass'}
121
+
122
+ except subprocess.CalledProcessError as e:
123
+ return {'error': str(e), 'status': 'nopass'}
124
+
125
+ os.remove(temp_file)
126
+
127
+ total = len(command_list)
128
+
129
+ with ThreadPoolExecutor(max_workers=32) as executor:
130
+ futures = [executor.submit(execute_command, {'index': i, 'cmd': cmd['cmd']}) for i, cmd in enumerate(command_list)]
131
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
132
+ result = future.result()
133
+ results.append(result)
134
+ if result['status'] == 'pass':
135
+ passed += 1
136
+
137
+ pass_rate = (passed / total) * 100
138
+ print(f"Pass rate: {pass_rate}%")
139
+
140
+ with open('results.json', 'w') as f:
141
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
142
+
143
+ import re
144
+ def remove_simp_pattern_from_end(s):
145
+ pattern = r'@\[simp\s*.*?\]$'
146
+ return re.sub(pattern, '', s)
147
+
148
+ def main():
149
+ input_file= f"/data/haiming/multilevel_isabelle-main/lean4/repl/self_autoformalization/data/mma_filepath/all_basic.jsonl"
150
+
151
+ command_list = json.load(open(input_file, 'r', encoding = 'utf-8'))
152
+ new_list = []
153
+ for json_item in command_list:
154
+ try:
155
+ working_env = json_item['working_file']
156
+ statement = json_item['statement_poof']
157
+ json_item['cmd'] = '\n'.join([working_env, statement])
158
+ assert len(statement) > 0
159
+ if len(working_env) < 10000:
160
+ new_list.append(json_item)
161
+
162
+ except:
163
+ import pdb
164
+ pdb.set_trace()
165
+
166
+ output_file = "/data/haiming/multilevel_isabelle-main/data/lean4_basic/1k_test.jsonl"
167
+ with open(output_file, 'w', encoding='utf-8') as file:
168
+ json.dump(new_list, file, indent=4, ensure_ascii=False)
169
+
170
+
171
+ # multi(new_list)
172
+
173
+ if __name__ == '__main__':
174
+ main()
175
+
pass_rate_multi.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import subprocess
3
+ import re
4
+
5
+ # Output file
6
+ output_file = "pass_rate_output.txt"
7
+
8
+ # Clearing the output file before appending new content
9
+ with open(output_file, "w") as file:
10
+ file.write("")
11
+
12
+ # List of input paths
13
+ input_path_lists = [
14
+ "test/zero_shot/wild_test/generation/lean4_random_15k_all/2/1/",
15
+ ]
16
+
17
+ def get_output(input_string):
18
+ pattern = r"zero_shot/(\w+)/(.+?)/(\w+)"
19
+ match = re.search(pattern, input_string)
20
+ if match:
21
+ part1 = match.group(1)
22
+ part2 = match.group(3) + ".jsonl"
23
+ result = "/".join([part1, part2])
24
+ print(result)
25
+ else:
26
+ print("No match found.")
27
+ assert True
28
+ return result
29
+
30
+ # List of input paths
31
+ input_path_lists = [
32
+ "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_5k/2/1/",
33
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_5k/2/1/",
34
+ # "test/zero_shot/lean4_random_test/generation/lean4_random_15k_all/3/1/",
35
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/2/1/",
36
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all/2/1/",
37
+ # "test/zero_shot/lean4_random_test/generation/lean4_random_15k_all/3/1/",
38
+ # "test/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/3/1/",
39
+ # Add more input paths as needed
40
+ ]
41
+
42
+ # Iterate through the input paths and run the command
43
+ for input_path in input_path_lists:
44
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
45
+ command = f"python3 pass_rate_new.py --input_path {input_path} --output_path {get_output(input_path)}"
46
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
47
+ print("\n\n",file=open(output_file, "a"))
48
+
pass_rate_multi_notlean.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import re
3
+
4
+ # Output file
5
+ output_file = "pass_rate_output_lean.txt"
6
+
7
+ # Clearing the output file before appending new content
8
+ with open(output_file, "w") as file:
9
+ file.write("")
10
+
11
+ # List of input paths
12
+ input_path_lists = [
13
+ "/opt/tiger/mariana/auto-info/generate_result/zero_shot/math_train/generation/lean4_random_5k/2/1/",
14
+ "/opt/tiger/mariana/auto-info/generate_result/zero_shot/math_train/generation/lean4_random_15k_all/2/1/",
15
+ "/opt/tiger/mariana/auto-info/generate_result/zero_shot/gsm8k_train/generation/lean4_random_5k/2/1/",
16
+ "/opt/tiger/mariana/auto-info/generate_result/zero_shot/gsm8k_train/generation/lean4_random_15k_all/2/1/",
17
+ # "test/zero_shot/wild_test/generation/lean4_random_15k_all/2/1/",
18
+ # "test/zero_shot/math_train/generation/lean4_random_15k_all/2/1/",
19
+ # "test/zero_shot/gsm8k_train/generation/lean4_random_15k_all/2/1/",
20
+ ]
21
+
22
+ def get_output(input_string):
23
+ pattern = r"zero_shot/(\w+)/(.+?)/(\w+)"
24
+ match = re.search(pattern, input_string)
25
+ if match:
26
+ part1 = match.group(1)
27
+ part2 = match.group(3) + ".jsonl"
28
+ result = "/".join([part1, part2])
29
+ print(result)
30
+ else:
31
+ print("No match found.")
32
+ assert True
33
+ return result
34
+
35
+ # Iterate through the input paths and run the command
36
+ for input_path in input_path_lists:
37
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
38
+ command = f"python3 pass_rate_notlean.py --input_path {input_path} --output_path {get_output(input_path)}"
39
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
40
+ print("\n\n",file=open(output_file, "a"))
pass_rate_multi_notlean_pass.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import re
3
+
4
+ # Output file
5
+ output_file = "pass_rate_output.txt"
6
+
7
+ # Clearing the output file before appending new content
8
+ with open(output_file, "w") as file:
9
+ file.write("")
10
+
11
+ # List of input paths
12
+ input_path_lists = [
13
+ "/opt/tiger/auto-info/generate_result/zero_shot/math_train/generation/mistral-qa-gsm8kmath-autoform-forml4-rft-math/1/10",
14
+ "/opt/tiger/auto-info/generate_result/zero_shot/math_train/generation/mistral-qa-gsm8kmath-autoform-forml4-rft-math/2/10",
15
+ "/opt/tiger/auto-info/generate_result/zero_shot/math_train/generation/mistral-qa-gsm8kmath-autoform-forml4-rft-math/3/10",
16
+ # "test/zero_shot/wild_test/generation/lean4_random_15k_all/2/1/",
17
+ # "test/zero_shot/math_train/generation/lean4_random_15k_all/2/1/",
18
+ # "test/zero_shot/gsm8k_train/generation/lean4_random_15k_all/2/1/",
19
+ ]
20
+
21
+ def get_output(input_string, k):
22
+ pattern = r"zero_shot/(\w+)/(.+?)/(\w+)/(\w+)"
23
+ match = re.search(pattern, input_string)
24
+ if match:
25
+ part1 = match.group(1)
26
+ part_model = match.group(2)
27
+ part2 = match.group(3)
28
+ part3 = match.group(4) + f"pass{k}.jsonl"
29
+ result = "/".join([part1, part_model, part2, part3])
30
+ print(result)
31
+ else:
32
+ print("No match found.")
33
+ assert True
34
+ return result
35
+
36
+
37
+ # Iterate through the input paths and run the command
38
+ for input_path in input_path_lists:
39
+ k = 10
40
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
41
+ command = f"python3 pass_rate_notlean_test.py --input_path {input_path} --output_path {get_output(input_path, k)} --k {k}"
42
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
43
+ print("\n\n",file=open(output_file, "a"))
pass_rate_multi_pass.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import subprocess
3
+ import re
4
+
5
+ # Output file
6
+ output_file = "pass_rate_output.txt"
7
+
8
+ # Clearing the output file before appending new content
9
+ with open(output_file, "w") as file:
10
+ file.write("")
11
+
12
+ # List of input paths
13
+ input_path_lists = [
14
+ "test/zero_shot/wild_test/generation/lean4_random_15k_all/2/1/",
15
+ ]
16
+
17
+ def get_output(input_string, k):
18
+ pattern = r"zero_shot/(\w+)/(.+?)/(\w+)"
19
+ match = re.search(pattern, input_string)
20
+ if match:
21
+ part1 = match.group(1)
22
+ part2 = match.group(3) + f"pass{k}.jsonl"
23
+ result = "/".join([part1, part2])
24
+ print(result)
25
+ else:
26
+ print("No match found.")
27
+ assert True
28
+ return result
29
+
30
+ # List of input paths
31
+ input_path_lists = [
32
+ # "../auto-info/generate_result/zero_shot/gsm8k_train/generation/lean4_random_15k_all/2/10/",
33
+ # "../auto-info/generate_result/zero_shot/math_train/generation/lean4_random_15k_all/2/10/",
34
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/1/1",
35
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/2/1",
36
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/3/1",
37
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/1/1",
38
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/2/1",
39
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/3/1",
40
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/1/1",
41
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/2/1",
42
+ # "../auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/3/1",
43
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_rft/1/1/",
44
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_rft/2/1/",
45
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_rft/3/1/",
46
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_rft/1/1/",
47
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_rft/2/1/",
48
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_rft/3/1/",
49
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier/1/1/",
50
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier/2/1/",
51
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier/3/1/",
52
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier/1/1/",
53
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier/2/1/",
54
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier/3/1/",
55
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier_rft/1/1/",
56
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier_rft/2/1/",
57
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_verifier_rft/3/1/",
58
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier_rft/1/1/",
59
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier_rft/2/1/",
60
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_verifier_rft/3/1/",
61
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/1/1/",
62
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/2/1/",
63
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_rft/3/1/",
64
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/1/1/",
65
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/2/1/",
66
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier/3/1/",
67
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/1/1/",
68
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/2/1/",
69
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_verifier_rft/3/1/",
70
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_15k_train/generation/lean4_random_15k_all/2/20/",
71
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/2/5/",
72
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all/2/5/",
73
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_5k/2/1/",
74
+ # "test/zero_shot/lean4_random_test/generation/lean4_random_15k_all/3/1/",
75
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/2/1/",
76
+ # "/opt/tiger/mariana/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all/2/1/",
77
+ # "test/zero_shot/lean4_random_test/generation/lean4_random_15k_all/3/1/",
78
+ # "test/zero_shot/lean4_basic_test/generation/lean4_random_15k_all/3/1/",
79
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/1/1/",
80
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/1/1/",
81
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/1/1/",
82
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/2/1/",
83
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/2/1/",
84
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/2/1/",
85
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/3/1/",
86
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/3/1/",
87
+ # "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/3/1/",
88
+ # "/opt/tiger/auto-info/generate_result/zero_shot/gsm8k_train/generation/lean4_random_15k_all_mathrft/2/10/",
89
+ # "/opt/tiger/auto-info/generate_result/zero_shot/math_train/generation/lean4_random_15k_all_mathrft/2/10/",
90
+ # "/opt/tiger/auto-info/generate_result/zero_shot/lean4_15k_train/generation/lean4_random_15k_all_mathrft/2/10/",
91
+ # Add more input paths as needed
92
+ "/opt/tiger/auto-info/generate_result/zero_shot/lean4_basic_test/generation/lean4_random_15k_all_mathrft/2/5/",
93
+ "/opt/tiger/auto-info/generate_result/zero_shot/lean4_random_test/generation/lean4_random_15k_all_mathrft/2/5/",
94
+ "/opt/tiger/auto-info/generate_result/zero_shot/wild_test/generation/lean4_random_15k_all_mathrft/2/5/",
95
+ ]
96
+
97
+ # Iterate through the input paths and run the command
98
+ for input_path in input_path_lists:
99
+ k = 5
100
+ if "wild_test" in input_path or "gsm8k_train" in input_path or "math_train" in input_path:
101
+ print(f"wild")
102
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
103
+ command = f"python3 pass_rate_notlean_test.py --input_path {input_path} --output_path {get_output(input_path,k)} --k {k}"
104
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
105
+ print("\n\n",file=open(output_file, "a"))
106
+
107
+ else:
108
+ print(f"lean")
109
+ print(f"Running for input path: {input_path}", file=open(output_file, "a"))
110
+ command = f"python3 pass_rate_new_test.py --input_path {input_path} --output_path {get_output(input_path, k)} --k {k}"
111
+ subprocess.run(command, shell=True, stdout=open(output_file, "a"), stderr=subprocess.STDOUT)
112
+ print("\n\n",file=open(output_file, "a"))
pass_rate_new.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+
10
+ def wrapped_function(item):
11
+ results = []
12
+ passed = 0
13
+ total = 0
14
+
15
+ temp_dir = tempfile.gettempdir()
16
+ temp_file = os.path.join(temp_dir, f"test.lean")
17
+
18
+ with open(temp_file, "w") as f:
19
+ f.write(item['cmd'])
20
+
21
+ # Rest of the function code...
22
+ # Process the item using the temporary file
23
+ # ...
24
+
25
+ # Clean up the temporary file
26
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
27
+ command = 'echo \'%s\' | lake exe repl' % data
28
+
29
+ try:
30
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
31
+ stdout = result.stdout.decode('utf-8')
32
+ stderr = result.stderr.decode('utf-8')
33
+ # stdout = result.stdout.decode('utf-8')
34
+ json_stdout = json.loads(stdout)
35
+ if "messages" not in json_stdout.keys():
36
+ passed += 1
37
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
38
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ except subprocess.CalledProcessError as e:
40
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
41
+ results.append({ 'error': str(e), 'status': 'nopass'})
42
+ total += 1
43
+
44
+ pass_rate = passed / (passed + total) * 100
45
+
46
+
47
+ return {'results': results, 'pass_rate': pass_rate}
48
+
49
+ # Set the directory where your .lean files are located
50
+
51
+ # Get a list of all .lean files in the directory
52
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
53
+ # lean_files = ["test/file.lean"]
54
+ def single(command_list, args):
55
+ results = []
56
+ passed = 0
57
+ total = 0
58
+ for item in tqdm(command_list):
59
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
60
+ f.write(item['cmd'])
61
+ data = '{"path": "test/test.lean", "allTactics": true}'
62
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
63
+ command = 'echo \'%s\' | lake exe repl' % data
64
+ try:
65
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
66
+ # stderr=subprocess.PIPE)
67
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
68
+ # stdout = stdout.decode('utf-8')
69
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
70
+ stdout = result.stdout.decode('utf-8')
71
+ json_stdout = json.loads(stdout)
72
+ if "messages" not in json_stdout.keys():
73
+ passed += 1
74
+ stderr = result.stderr.decode('utf-8')
75
+ results.append({
76
+ # 'item': item['content'],
77
+ 'stdout': stdout,
78
+ 'stderr': stderr,
79
+ 'status': 'pass'
80
+ })
81
+ except subprocess.CalledProcessError as e:
82
+ results.append({
83
+ # 'item': item['content'],
84
+ 'error': str(e),
85
+ 'status': 'nopass'
86
+ })
87
+ total += 1
88
+
89
+ # Calculate pass rate
90
+ pass_rate = passed / total * 100
91
+ print(pass_rate)
92
+
93
+ # Save results to a JSON file
94
+ with open('results.json', 'w') as f:
95
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
96
+
97
+
98
+ def multi(command_list, output_path):
99
+ results = []
100
+ passed = 0
101
+ total = 0
102
+ def execute_command(item):
103
+ temp_dir = '/opt/jianqiao'
104
+ temp_file = os.path.join(temp_dir, f"test_{item['index']}.lean") # Ensure unique filenames
105
+ with open(temp_file, "w") as f:
106
+ f.write(item['cmd'])
107
+
108
+ data = '{"path": "%s", "allTactics": true}' % temp_file
109
+ command = f'echo \'{data}\' | lake exe repl'
110
+
111
+ try:
112
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
113
+ stdout = result.stdout.decode('utf-8')
114
+ stderr = result.stderr.decode('utf-8')
115
+
116
+ if "messages" not in json.loads(stdout) and not len(stderr):
117
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'pass' , 'statement':item['statement'], 'content': item['content']}
118
+ else:
119
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'statement':item['statement'] , 'content': item['content']}
120
+
121
+ except subprocess.TimeoutExpired as e:
122
+ return {'error': str(e), 'status': 'nopass_limit', 'statement':item['statement'], 'content': item['content']}
123
+
124
+ except subprocess.CalledProcessError as e:
125
+ return {'error': str(e), 'status': 'nopass_error', 'statement':item['statement'], 'content': item['content']}
126
+
127
+ os.remove(temp_file)
128
+
129
+ total = len(command_list)
130
+
131
+ with ThreadPoolExecutor(max_workers=32) as executor:
132
+ futures = [executor.submit(execute_command, {'index': i, 'cmd': cmd['cmd'], 'statement':cmd['statement'], 'content':cmd['content']}) for i, cmd in enumerate(command_list)]
133
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
134
+ result = future.result()
135
+ results.append(result)
136
+ if result['status'] == 'pass':
137
+ passed += 1
138
+
139
+ pass_rate = (passed / total) * 100
140
+ print(f"total test: {total}")
141
+ print(f"Pass rate: {pass_rate}%")
142
+
143
+ output_file = f"pass_rate_results/{output_path}"
144
+ # Create the directory if it doesn't exist
145
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
146
+
147
+ with open(f"{output_file}", 'w') as f:
148
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
149
+
150
+ import re
151
+ def remove_simp_pattern_from_end(s):
152
+ pattern = r'@\[simp\s*.*?\]$'
153
+ return re.sub(pattern, '', s)
154
+
155
+ def main(args):
156
+ command_list = []
157
+ file_pattern = os.path.join(args.input_path, '[0-9]*.json')
158
+ for file_path in glob.glob(file_pattern):
159
+ with open(file_path, 'r', encoding='utf-8') as rf:
160
+ for line in rf.readlines():
161
+ try:
162
+ json_item = json.loads(line)
163
+ working_env = json_item['content']['working_file']
164
+ # pdb.set_trace()
165
+ # statement = json_item['total output'][0]
166
+
167
+ statement = json_item['total output'][0].split("#align")[0]
168
+ json_item['statement'] = statement
169
+ json_item['cmd'] = '\n\n'.join([working_env, statement])
170
+ assert len(statement) > 0
171
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
172
+ except:
173
+ import pdb
174
+ pdb.set_trace()
175
+ command_list.append(json_item)
176
+ command_list = command_list
177
+ results = []
178
+ passed = 0
179
+ total = 0
180
+ multi(command_list, args.output_path)
181
+
182
+ if __name__ == '__main__':
183
+ arg_parser = ArgumentParser()
184
+ arg_parser.add_argument('--data_path', type=str,
185
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
186
+ arg_parser.add_argument('--input_path', type=str, default='')
187
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
188
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
189
+ arg_parser.add_argument('--generate_method', type=str,
190
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
191
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
192
+ args = arg_parser.parse_args()
193
+ main(args)
194
+
195
+
196
+
pass_rate_new_test.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+ import random
10
+
11
+ def wrapped_function(item):
12
+ results = []
13
+ passed = 0
14
+ total = 0
15
+
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_file = os.path.join(temp_dir, f"test.lean")
18
+
19
+ with open(temp_file, "w") as f:
20
+ f.write(item['cmd'])
21
+
22
+ # Rest of the function code...
23
+ # Process the item using the temporary file
24
+ # ...
25
+
26
+ # Clean up the temporary file
27
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
28
+ command = 'echo \'%s\' | lake exe repl' % data
29
+
30
+ try:
31
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32
+ stdout = result.stdout.decode('utf-8')
33
+ stderr = result.stderr.decode('utf-8')
34
+ # stdout = result.stdout.decode('utf-8')
35
+ json_stdout = json.loads(stdout)
36
+ if "messages" not in json_stdout.keys():
37
+ passed += 1
38
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ except subprocess.CalledProcessError as e:
41
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
42
+ results.append({ 'error': str(e), 'status': 'nopass'})
43
+ total += 1
44
+
45
+ pass_rate = passed / (passed + total) * 100
46
+
47
+
48
+ return {'results': results, 'pass_rate': pass_rate}
49
+
50
+ # Set the directory where your .lean files are located
51
+
52
+ # Get a list of all .lean files in the directory
53
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
54
+ # lean_files = ["test/file.lean"]
55
+ def single(command_list, args):
56
+ results = []
57
+ passed = 0
58
+ total = 0
59
+ for item in tqdm(command_list):
60
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
61
+ f.write(item['cmd'])
62
+ data = '{"path": "test/test.lean", "allTactics": true}'
63
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
64
+ command = 'echo \'%s\' | lake exe repl' % data
65
+ try:
66
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
67
+ # stderr=subprocess.PIPE)
68
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
69
+ # stdout = stdout.decode('utf-8')
70
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
71
+ stdout = result.stdout.decode('utf-8')
72
+ json_stdout = json.loads(stdout)
73
+ if "messages" not in json_stdout.keys():
74
+ passed += 1
75
+ stderr = result.stderr.decode('utf-8')
76
+ results.append({
77
+ # 'item': item['content'],
78
+ 'stdout': stdout,
79
+ 'stderr': stderr,
80
+ 'status': 'pass'
81
+ })
82
+ except subprocess.CalledProcessError as e:
83
+ results.append({
84
+ # 'item': item['content'],
85
+ 'error': str(e),
86
+ 'status': 'nopass'
87
+ })
88
+ total += 1
89
+
90
+ # Calculate pass rate
91
+ pass_rate = passed / total * 100
92
+ print(pass_rate)
93
+
94
+ # Save results to a JSON file
95
+ with open('results.json', 'w') as f:
96
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
97
+
98
+
99
+ def multi(command_list, output_path, k ):
100
+ results = []
101
+ passed = 0
102
+ total = 0
103
+ def execute_command(item, index):
104
+ temp_dir = '/opt/jianqiao'
105
+ def filter_json(json_data):
106
+ filtered_data = {}
107
+ for key in json_data.keys():
108
+ if key in ['question', 'answer', 'total output', 'results']:
109
+ filtered_data[key] = json_data[key]
110
+ return filtered_data
111
+ result_dict = filter_json(item)
112
+ result_dict['results'] = []
113
+
114
+ for i, cmd in enumerate(item['cmd']):
115
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
116
+ with open(temp_file, "w") as f:
117
+ f.write(cmd)
118
+
119
+ data = '{"path": "%s", "allTactics": true}' % temp_file
120
+ command = f'echo \'{data}\' | lake exe repl'
121
+
122
+ try:
123
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
124
+ stdout = json.loads(result.stdout.decode('utf-8'))
125
+ stderr = result.stderr.decode('utf-8')
126
+
127
+ except subprocess.TimeoutExpired as e:
128
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
129
+
130
+ except subprocess.CalledProcessError as e:
131
+ result_item = {'error': str(e), 'status': 'nopass_error'}
132
+
133
+ else:
134
+ if "messages" not in stdout and not len(stderr):
135
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
136
+ elif not len(stderr) and "messages" in stdout:
137
+ flag = 0
138
+ for me in stdout['messages']:
139
+ import pdb
140
+ pdb.set_trace()
141
+ if me['severity'] == 'error':
142
+ flag = 1
143
+ start_line = me['pos']['line'] - 1
144
+ current_column =me['pos']['column'] -1
145
+ for line_n in range(start_line - 1, 0 , -1):
146
+ line_len = len(cmd.split('\n')[line_n])
147
+ current_column += line_len + 1
148
+ if not line_len:
149
+ break
150
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
151
+ break
152
+ if not flag :
153
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
154
+ else:
155
+ assert len(stderr)
156
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
157
+
158
+ result_dict['results'].append(result_item)
159
+ return result_dict
160
+
161
+
162
+ total = len(command_list)
163
+
164
+ with ThreadPoolExecutor(max_workers=1) as executor:
165
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
166
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
167
+ result = future.result()
168
+ results.append(result)
169
+ # if result['status'] == 'pass':
170
+ # passed += 1
171
+
172
+ def calculate_pass(result_list, k):
173
+ pass_1_count = 0
174
+ pass_k_count = 0
175
+
176
+ for result in result_list:
177
+ results = result.get('results', [])
178
+ if results:
179
+ for j in range(min(1, len(results))):
180
+ if results[j].get('status') == 'pass':
181
+ pass_1_count += 1
182
+ break
183
+
184
+ for j in range(min(k, len(results))):
185
+ if results[j].get('status') == 'pass':
186
+ pass_k_count += 1
187
+ break
188
+
189
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
190
+ pass_k = pass_k_count / len(result_list) if result_list else 0
191
+
192
+ return pass_1, pass_k
193
+
194
+ pass_1, pass_k = calculate_pass(results, k)
195
+ print("Pass@1:", pass_1)
196
+ print(f"Pass@{k}:", pass_k)
197
+
198
+ # pass_rate = (passed / total) * 100
199
+ # print(f"total test: {total}")
200
+ # print(f"Pass rate: {pass_rate}%")
201
+
202
+ output_file = f"pass_rate_results/{output_path}"
203
+ # Create the directory if it doesn't exist
204
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
205
+
206
+ with open(f"{output_file}", 'w') as f:
207
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
208
+
209
+ import re
210
+ def remove_simp_pattern_from_end(s):
211
+ pattern = r'@\[simp\s*.*?\]$'
212
+ return re.sub(pattern, '', s)
213
+
214
+ def main(args):
215
+ command_list = []
216
+ file_pattern = os.path.join(args.input_path, '[0-1]*.json')
217
+ for file_path in glob.glob(file_pattern):
218
+ with open(file_path, 'r', encoding='utf-8') as rf:
219
+ for line in rf.readlines():
220
+ try:
221
+ json_item = json.loads(line)
222
+ working_env = json_item['content']['working_file']
223
+ # pdb.set_trace()
224
+ # statement = json_item['total output'][0]
225
+ json_item['cmd'] = []
226
+ for output in json_item['total output'][:min(args.k, len(json_item['total output']))]:
227
+ statement = output.split("#align")[0]
228
+ json_item['cmd'].append('\n\n'.join([working_env, statement]))
229
+ json_item['answer'] = json_item['content']['statement_poof']
230
+ assert len(statement) > 0
231
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
232
+ except:
233
+ import pdb
234
+ pdb.set_trace()
235
+ # import pdb
236
+ # pdb.set_trace()
237
+ command_list.append(json_item)
238
+ multi(command_list, args.output_path, args.k)
239
+
240
+ if __name__ == '__main__':
241
+ arg_parser = ArgumentParser()
242
+ arg_parser.add_argument('--data_path', type=str,
243
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
244
+ arg_parser.add_argument('--input_path', type=str, default='')
245
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
246
+ arg_parser.add_argument('--k', type=int, default=1)
247
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
248
+ arg_parser.add_argument('--generate_method', type=str,
249
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
250
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
251
+ args = arg_parser.parse_args()
252
+ main(args)
253
+
254
+
255
+
pass_rate_new_test_allcontent.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import glob
8
+ import tempfile
9
+ import random
10
+ import random;random.seed(42)
11
+
12
+
13
+ def wrapped_function(item):
14
+ results = []
15
+ passed = 0
16
+ total = 0
17
+
18
+ temp_dir = tempfile.gettempdir()
19
+ temp_file = os.path.join(temp_dir, f"test.lean")
20
+
21
+ with open(temp_file, "w") as f:
22
+ f.write(item['cmd'])
23
+
24
+ # Rest of the function code...
25
+ # Process the item using the temporary file
26
+ # ...
27
+
28
+ # Clean up the temporary file
29
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
30
+ command = 'echo \'%s\' | lake exe repl' % data
31
+
32
+ try:
33
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
34
+ stdout = result.stdout.decode('utf-8')
35
+ stderr = result.stderr.decode('utf-8')
36
+ # stdout = result.stdout.decode('utf-8')
37
+ json_stdout = json.loads(stdout)
38
+ if "messages" not in json_stdout.keys():
39
+ passed += 1
40
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
41
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
42
+ except subprocess.CalledProcessError as e:
43
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
44
+ results.append({ 'error': str(e), 'status': 'nopass'})
45
+ total += 1
46
+
47
+ pass_rate = passed / (passed + total) * 100
48
+
49
+
50
+ return {'results': results, 'pass_rate': pass_rate}
51
+
52
+ # Set the directory where your .lean files are located
53
+
54
+ # Get a list of all .lean files in the directory
55
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
56
+ # lean_files = ["test/file.lean"]
57
+ def single(command_list, args):
58
+ results = []
59
+ passed = 0
60
+ total = 0
61
+ for item in tqdm(command_list):
62
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
63
+ f.write(item['cmd'])
64
+ data = '{"path": "test/test.lean", "allTactics": true}'
65
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
66
+ command = 'echo \'%s\' | lake exe repl' % data
67
+ try:
68
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
69
+ # stderr=subprocess.PIPE)
70
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
71
+ # stdout = stdout.decode('utf-8')
72
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
73
+ stdout = result.stdout.decode('utf-8')
74
+ json_stdout = json.loads(stdout)
75
+ if "messages" not in json_stdout.keys():
76
+ passed += 1
77
+ stderr = result.stderr.decode('utf-8')
78
+ results.append({
79
+ # 'item': item['content'],
80
+ 'stdout': stdout,
81
+ 'stderr': stderr,
82
+ 'status': 'pass'
83
+ })
84
+ except subprocess.CalledProcessError as e:
85
+ results.append({
86
+ # 'item': item['content'],
87
+ 'error': str(e),
88
+ 'status': 'nopass'
89
+ })
90
+ total += 1
91
+
92
+ # Calculate pass rate
93
+ pass_rate = passed / total * 100
94
+ print(pass_rate)
95
+
96
+ # Save results to a JSON file
97
+ with open('results.json', 'w') as f:
98
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
99
+
100
+
101
+ def multi(command_list, output_path, k ):
102
+ results = []
103
+ passed = 0
104
+ total = 0
105
+ def execute_command(item, index):
106
+ temp_dir = '/opt/jianqiao'
107
+ def filter_json(json_data):
108
+ filtered_data = {}
109
+ for key in json_data.keys():
110
+ if key in ['question', 'answer', 'total output', 'results', 'cmd']:
111
+ filtered_data[key] = json_data[key]
112
+ return filtered_data
113
+ result_dict = filter_json(item)
114
+ result_dict['results'] = []
115
+
116
+ for i, cmd in enumerate(item['cmd']):
117
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
118
+ with open(temp_file, "w") as f:
119
+ f.write(cmd)
120
+
121
+ data = '{"path": "%s", "allTactics": true}' % temp_file
122
+ command = f'echo \'{data}\' | lake exe repl'
123
+
124
+ try:
125
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
126
+ stdout = json.loads(result.stdout.decode('utf-8'))
127
+ stderr = result.stderr.decode('utf-8')
128
+
129
+ except subprocess.TimeoutExpired as e:
130
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
131
+
132
+ except subprocess.CalledProcessError as e:
133
+ result_item = {'error': str(e), 'status': 'nopass_error'}
134
+
135
+ else:
136
+ if "messages" not in stdout and not len(stderr):
137
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
138
+ elif not len(stderr) and "messages" in stdout:
139
+ flag = 0
140
+ for me in stdout['messages']:
141
+ if me['severity'] == 'error':
142
+ flag = 1
143
+ start_line = me['pos']['line'] - 1
144
+ current_column =me['pos']['column'] -1
145
+ for line_n in range(start_line - 1, 0 , -1):
146
+ line_len = len(cmd.split('\n')[line_n])
147
+ current_column += line_len + 1
148
+ if not line_len:
149
+ break
150
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
151
+ break
152
+ if not flag :
153
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
154
+ else:
155
+ assert len(stderr)
156
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
157
+
158
+ result_dict['results'].append(result_item)
159
+ return result_dict
160
+
161
+
162
+ total = len(command_list)
163
+
164
+ with ThreadPoolExecutor(max_workers=128) as executor:
165
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
166
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
167
+ result = future.result()
168
+ results.append(result)
169
+ # if result['status'] == 'pass':
170
+ # passed += 1
171
+
172
+ def calculate_pass(result_list, k):
173
+ pass_1_count = 0
174
+ pass_k_count = 0
175
+
176
+ for result in result_list:
177
+ results = result.get('results', [])
178
+ if results:
179
+ for j in range(min(1, len(results))):
180
+ if results[j].get('status') == 'pass':
181
+ pass_1_count += 1
182
+ break
183
+
184
+ for j in range(min(k, len(results))):
185
+ if results[j].get('status') == 'pass':
186
+ pass_k_count += 1
187
+ break
188
+
189
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
190
+ pass_k = pass_k_count / len(result_list) if result_list else 0
191
+
192
+ return pass_1, pass_k
193
+
194
+ pass_1, pass_k = calculate_pass(results, k)
195
+ print("Pass@1:", pass_1)
196
+ print(f"Pass@{k}:", pass_k)
197
+
198
+ # pass_rate = (passed / total) * 100
199
+ # print(f"total test: {total}")
200
+ # print(f"Pass rate: {pass_rate}%")
201
+
202
+ output_file = f"pass_rate_results/{output_path}"
203
+ # Create the directory if it doesn't exist
204
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
205
+
206
+ with open(f"{output_file}", 'w') as f:
207
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
208
+
209
+ import re
210
+ def remove_simp_pattern_from_end(s):
211
+ pattern = r'@\[simp\s*.*?\]$'
212
+ return re.sub(pattern, '', s)
213
+
214
+ def main(args):
215
+ command_list = []
216
+ file_pattern = os.path.join(args.input_path, '[0-9]*.json')
217
+ for file_path in glob.glob(file_pattern):
218
+ with open(file_path, 'r', encoding='utf-8') as rf:
219
+ for line in rf.readlines():
220
+ try:
221
+ json_item = json.loads(line)
222
+ working_env = json_item['content']['working_file']
223
+ # pdb.set_trace()
224
+ # statement = json_item['total output'][0]
225
+ json_item['cmd'] = []
226
+ for output in json_item['total output'][:min(args.k, len(json_item['total output']))]:
227
+ statement = output.split("#align")[0]
228
+ json_item['cmd'].append('\n\n'.join([working_env, statement]))
229
+ json_item['answer'] = json_item['content']['statement_poof']
230
+ assert len(statement) > 0
231
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
232
+ except:
233
+ import pdb
234
+ pdb.set_trace()
235
+ # import pdb
236
+ # pdb.set_trace()
237
+ command_list.append(json_item)
238
+ multi(random.sample(command_list, 1000), args.output_path, args.k)
239
+
240
+ if __name__ == '__main__':
241
+ arg_parser = ArgumentParser()
242
+ arg_parser.add_argument('--data_path', type=str,
243
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
244
+ arg_parser.add_argument('--input_path', type=str, default='')
245
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
246
+ arg_parser.add_argument('--k', type=int, default=1)
247
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
248
+ arg_parser.add_argument('--generate_method', type=str,
249
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
250
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
251
+ args = arg_parser.parse_args()
252
+ main(args)
253
+
254
+
255
+
pass_rate_notlean.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import tempfile
8
+ import glob
9
+ import pdb
10
+
11
+ def wrapped_function(item):
12
+ results = []
13
+ passed = 0
14
+ total = 0
15
+
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_file = os.path.join(temp_dir, f"test.lean")
18
+
19
+ with open(temp_file, "w") as f:
20
+ f.write(item['cmd'])
21
+
22
+ # Rest of the function code...
23
+ # Process the item using the temporary file
24
+ # ...
25
+
26
+ # Clean up the temporary file
27
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
28
+ command = 'echo \'%s\' | lake exe repl' % data
29
+
30
+ try:
31
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32
+ stdout = result.stdout.decode('utf-8')
33
+ stderr = result.stderr.decode('utf-8')
34
+ # stdout = result.stdout.decode('utf-8')
35
+ json_stdout = json.loads(stdout)
36
+ if "messages" not in json_stdout.keys():
37
+ passed += 1
38
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ except subprocess.CalledProcessError as e:
41
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
42
+ results.append({ 'error': str(e), 'status': 'nopass'})
43
+ total += 1
44
+
45
+ pass_rate = passed / (passed + total) * 100
46
+
47
+
48
+ return {'results': results, 'pass_rate': pass_rate}
49
+
50
+ # Set the directory where your .lean files are located
51
+
52
+ # Get a list of all .lean files in the directory
53
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
54
+ # lean_files = ["test/file.lean"]
55
+ def single(command_list, output_path):
56
+ results = []
57
+ passed = 0
58
+ total = 0
59
+ for item in tqdm(command_list):
60
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
61
+ f.write(item['cmd'])
62
+ data = '{"path": "test/test.lean", "allTactics": true}'
63
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
64
+ command = 'echo \'%s\' | lake exe repl' % data
65
+
66
+ try:
67
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
68
+ # stderr=subprocess.PIPE)
69
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
70
+ # stdout = stdout.decode('utf-8')
71
+
72
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
73
+ stdout = result.stdout.decode('utf-8')
74
+ json_stdout = json.loads(stdout)
75
+ if "messages" not in json_stdout.keys():
76
+ passed += 1
77
+ stderr = result.stderr.decode('utf-8')
78
+ results.append({
79
+ # 'item': item['content'],
80
+ 'stdout': stdout,
81
+ 'stderr': stderr,
82
+ 'status': 'pass'
83
+ })
84
+ except subprocess.CalledProcessError as e:
85
+ results.append({
86
+ # 'item': item['content'],
87
+ 'error': str(e),
88
+ 'status': 'nopass'
89
+ })
90
+ total += 1
91
+
92
+ # Calculate pass rate
93
+ pass_rate = passed / total * 100
94
+ print(pass_rate)
95
+
96
+ # Save results to a JSON file
97
+ with open('results.json', 'w') as f:
98
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
99
+
100
+
101
+
102
+
103
+ def multi(command_list, output_path):
104
+ results = []
105
+ passed = 0
106
+ total = 0
107
+ def execute_command(item):
108
+ temp_dir = '/opt/jianqiao'
109
+ temp_file = os.path.join(temp_dir, f"test_{item['index']}.lean") # Ensure unique filenames
110
+ with open(temp_file, "w") as f:
111
+ f.write(item['cmd'])
112
+
113
+ data = '{"path": "%s", "allTactics": true}' % temp_file
114
+ command = f'echo \'{data}\' | lake exe repl'
115
+
116
+ try:
117
+ result = subprocess.run(command, shell=True, check=True,timeout=600, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
118
+ stdout = result.stdout.decode('utf-8')
119
+ stderr = result.stderr.decode('utf-8')
120
+
121
+ if "messages" not in json.loads(stdout) and not len(stderr):
122
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'pass' , 'statement':item['statement'], 'content': item['content']}
123
+ else:
124
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'statement':item['statement'] , 'content': item['content']}
125
+
126
+ except subprocess.TimeoutExpired as e:
127
+ return {'error': str(e), 'status': 'nopass_limit', 'statement':item['statement'], 'content': item['content']}
128
+
129
+ except subprocess.CalledProcessError as e:
130
+ return {'error': str(e), 'status': 'nopass_error', 'statement':item['statement'], 'content': item['content']}
131
+
132
+ os.remove(temp_file)
133
+
134
+ total = len(command_list)
135
+
136
+ with ThreadPoolExecutor(max_workers=32) as executor:
137
+ futures = [executor.submit(execute_command, {'index': i, 'cmd': cmd['cmd'], 'statement':cmd['statement'], 'content':cmd['content']}) for i, cmd in enumerate(command_list)]
138
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
139
+ result = future.result()
140
+ results.append(result)
141
+ if result['status'] == 'pass':
142
+ passed += 1
143
+
144
+ pass_rate = (passed / total) * 100
145
+ print(f"Pass rate: {pass_rate}%")
146
+
147
+ output_file = f"pass_rate_results/{output_path}"
148
+ # Create the directory if it doesn't exist
149
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
150
+
151
+ with open(f"{output_file}", 'w') as f:
152
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
153
+
154
+ import re
155
+ def remove_simp_pattern_from_end(s):
156
+ pattern = r'@\[simp\s*.*?\]$'
157
+ return re.sub(pattern, '', s)
158
+
159
+ def main(args):
160
+ import pdb
161
+ command_list = []
162
+ # json_filename = 'data/notlean_dependency.json'
163
+ json_filename = 'data/basic_working.json'
164
+
165
+ json_item = json.load(open(json_filename, encoding='utf-8'))
166
+ working_env = json_item['working_file']
167
+ file_pattern = os.path.join(args.input_path, '[0-9]*.json')
168
+ for file_path in glob.glob(file_pattern):
169
+ with open(file_path, 'r', encoding='utf-8') as rf:
170
+ for line in rf.readlines():
171
+ try:
172
+ json_item = json.loads(line)
173
+ statement = json_item['total output'][0].split("#align")[0]
174
+ json_item['statement'] = statement
175
+ json_item['cmd'] = '\n\n'.join([working_env, statement])
176
+ assert len(statement) > 0
177
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
178
+ except:
179
+ import pdb
180
+ pdb.set_trace()
181
+ command_list.append(json_item)
182
+ command_list = command_list
183
+ results = []
184
+ passed = 0
185
+ total = 0
186
+ multi( command_list, args.output_path)
187
+
188
+ if __name__ == '__main__':
189
+ arg_parser = ArgumentParser()
190
+ arg_parser.add_argument('--data_path', type=str,
191
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
192
+ arg_parser.add_argument('--input_path', type=str, default='')
193
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
194
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
195
+ arg_parser.add_argument('--generate_method', type=str,
196
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
197
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
198
+ args = arg_parser.parse_args()
199
+ main(args)
200
+
201
+
202
+
pass_rate_notlean_test.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import tempfile
8
+ import glob
9
+ import pdb
10
+
11
+ def wrapped_function(item):
12
+ results = []
13
+ passed = 0
14
+ total = 0
15
+
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_file = os.path.join(temp_dir, f"test.lean")
18
+
19
+ with open(temp_file, "w") as f:
20
+ f.write(item['cmd'])
21
+
22
+ # Rest of the function code...
23
+ # Process the item using the temporary file
24
+ # ...
25
+
26
+ # Clean up the temporary file
27
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
28
+ command = 'echo \'%s\' | lake exe repl' % data
29
+
30
+ try:
31
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32
+ stdout = result.stdout.decode('utf-8')
33
+ stderr = result.stderr.decode('utf-8')
34
+ # stdout = result.stdout.decode('utf-8')
35
+ json_stdout = json.loads(stdout)
36
+ if "messages" not in json_stdout.keys():
37
+ passed += 1
38
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
39
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
40
+ except subprocess.CalledProcessError as e:
41
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
42
+ results.append({ 'error': str(e), 'status': 'nopass'})
43
+ total += 1
44
+
45
+ pass_rate = passed / (passed + total) * 100
46
+
47
+
48
+ return {'results': results, 'pass_rate': pass_rate}
49
+
50
+ # Set the directory where your .lean files are located
51
+
52
+ # Get a list of all .lean files in the directory
53
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
54
+ # lean_files = ["test/file.lean"]
55
+ def single(command_list, output_path):
56
+ results = []
57
+ passed = 0
58
+ total = 0
59
+ for item in tqdm(command_list):
60
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
61
+ f.write(item['cmd'])
62
+ data = '{"path": "test/test.lean", "allTactics": true}'
63
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
64
+ command = 'echo \'%s\' | lake exe repl' % data
65
+
66
+ try:
67
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
68
+ # stderr=subprocess.PIPE)
69
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
70
+ # stdout = stdout.decode('utf-8')
71
+
72
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
73
+ stdout = result.stdout.decode('utf-8')
74
+ json_stdout = json.loads(stdout)
75
+ if "messages" not in json_stdout.keys():
76
+ passed += 1
77
+ stderr = result.stderr.decode('utf-8')
78
+ results.append({
79
+ # 'item': item['content'],
80
+ 'stdout': stdout,
81
+ 'stderr': stderr,
82
+ 'status': 'pass'
83
+ })
84
+ except subprocess.CalledProcessError as e:
85
+ results.append({
86
+ # 'item': item['content'],
87
+ 'error': str(e),
88
+ 'status': 'nopass'
89
+ })
90
+ total += 1
91
+
92
+ # Calculate pass rate
93
+ pass_rate = passed / total * 100
94
+ print(pass_rate)
95
+
96
+ # Save results to a JSON file
97
+ with open('results.json', 'w') as f:
98
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
99
+
100
+
101
+
102
+
103
+
104
+
105
+ def multi(command_list, output_path, k ):
106
+ results = []
107
+ passed = 0
108
+ total = 0
109
+ def execute_command(item, index):
110
+ temp_dir = '/opt/jianqiao'
111
+ def filter_json(json_data):
112
+ filtered_data = {}
113
+ for key in json_data.keys():
114
+ if key in ['question', 'answer', 'total output', 'results']:
115
+ filtered_data[key] = json_data[key]
116
+ return filtered_data
117
+ result_dict = filter_json(item)
118
+ result_dict['results'] = []
119
+
120
+ for i, cmd in enumerate(item['cmd']):
121
+ temp_file = os.path.join(temp_dir,f"{index}_test_{i}.lean") # Ensure unique filenames
122
+ with open(temp_file, "w") as f:
123
+ f.write(cmd)
124
+
125
+ data = '{"path": "%s", "allTactics": true}' % temp_file
126
+ command = f'echo \'{data}\' | lake exe repl'
127
+
128
+ try:
129
+ result = subprocess.run(command, shell=True, check=True,timeout=480, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
130
+ stdout = json.loads(result.stdout.decode('utf-8'))
131
+ stderr = result.stderr.decode('utf-8')
132
+
133
+ except subprocess.TimeoutExpired as e:
134
+ result_item = {'error': str(e), 'status': 'nopass_limit'}
135
+
136
+ except subprocess.CalledProcessError as e:
137
+ result_item = {'error': str(e), 'status': 'nopass_error'}
138
+
139
+ else:
140
+ if "messages" not in stdout and not len(stderr):
141
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass' }
142
+ elif not len(stderr) and "messages" in stdout:
143
+ flag = 0
144
+ for me in stdout['messages']:
145
+ if me['severity'] == 'error':
146
+ flag = 1
147
+ start_line = me['pos']['line'] - 1
148
+ current_column =me['pos']['column'] -1
149
+ for line_n in range(start_line - 1, 0 , -1):
150
+ line_len = len(cmd.split('\n')[line_n])
151
+ current_column += line_len + 1
152
+ if not line_len:
153
+ break
154
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos':current_column}
155
+ break
156
+ if not flag :
157
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
158
+ else:
159
+ assert len(stderr)
160
+ result_item = {'stdout': stdout, 'stderr': stderr, 'status': 'nopass', 'string_pos': 0 }
161
+
162
+ result_dict['results'].append(result_item)
163
+ return result_dict
164
+
165
+
166
+ total = len(command_list)
167
+
168
+ with ThreadPoolExecutor(max_workers=128) as executor:
169
+ futures = [executor.submit(execute_command, cmd, i) for i, cmd in enumerate(command_list)]
170
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
171
+
172
+ result = future.result()
173
+ results.append(result)
174
+ # if result['status'] == 'pass':
175
+ # passed += 1
176
+
177
+ def calculate_pass(result_list, k):
178
+ pass_1_count = 0
179
+ pass_k_count = 0
180
+
181
+ for result in result_list:
182
+ results = result.get('results', [])
183
+ if results:
184
+ for j in range(min(1, len(results))):
185
+ if results[j].get('status') == 'pass':
186
+ pass_1_count += 1
187
+ break
188
+
189
+ for j in range(min(k, len(results))):
190
+ if results[j].get('status') == 'pass':
191
+ pass_k_count += 1
192
+ break
193
+
194
+ pass_1 = pass_1_count / len(result_list) if result_list else 0
195
+ pass_k = pass_k_count / len(result_list) if result_list else 0
196
+
197
+ return pass_1, pass_k
198
+
199
+ pass_1, pass_k = calculate_pass(results, k)
200
+ print("Pass@1:", pass_1)
201
+ print(f"Pass@{k}:", pass_k)
202
+
203
+ # pass_rate = (passed / total) * 100
204
+ # print(f"total test: {total}")
205
+ # print(f"Pass rate: {pass_rate}%")
206
+
207
+ output_file = f"pass_rate_results/{output_path}"
208
+ # Create the directory if it doesn't exist
209
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
210
+
211
+ with open(f"{output_file}", 'w') as f:
212
+ json.dump({'results': results, 'pass_1': pass_1, f"pass_{k}":pass_k}, f, indent=2, ensure_ascii=False)
213
+
214
+ import re
215
+ def remove_simp_pattern_from_end(s):
216
+ pattern = r'@\[simp\s*.*?\]$'
217
+ return re.sub(pattern, '', s)
218
+
219
+ def main(args):
220
+ command_list = []
221
+ # json_filename = 'data/notlean_dependency.json'
222
+ # json_filename = 'data/basic_working.json'
223
+ # json_item = json.load(open(json_filename, encoding='utf-8'))
224
+ # working_env = json_item['working_file']
225
+ working_env = ''
226
+
227
+ json_filename = 'data/leandojo.txt'
228
+ with open(json_filename, 'r') as rf:
229
+ for line in rf.readlines():
230
+ working_env += line
231
+ file_pattern = os.path.join(args.input_path, '[0-9]*.json')
232
+ for file_path in glob.glob(file_pattern):
233
+ with open(file_path, 'r', encoding='utf-8') as rf:
234
+ for line in rf.readlines():
235
+ json_item = json.loads(line)
236
+ json_item['cmd'] = []
237
+ for output in json_item['total output'][:min(args.k, len(json_item['total output']))]:
238
+ statement = output.split("#align")[0]
239
+ json_item['cmd'].append('\n\n'.join([working_env, statement]))
240
+ # json_item['answer'] = json_item['content']['answer']
241
+ command_list.append(json_item)
242
+ command_list = command_list
243
+
244
+ multi(command_list, args.output_path, args.k)
245
+
246
+ if __name__ == '__main__':
247
+ arg_parser = ArgumentParser()
248
+ arg_parser.add_argument('--data_path', type=str,
249
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
250
+ arg_parser.add_argument('--input_path', type=str, default='')
251
+ arg_parser.add_argument('--cuda_num', type=int, default=8)
252
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
253
+ arg_parser.add_argument('--k', type=int, default=1)
254
+ arg_parser.add_argument('--generate_method', type=str,
255
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
256
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
257
+ args = arg_parser.parse_args()
258
+ main(args)
259
+
260
+
261
+