rookiemango
commited on
Upload folder using huggingface_hub
Browse files- all_code.py +66 -0
- auto_backform.py +0 -0
- autobackform_train.py +305 -0
- autoform_train.py +300 -0
- autosolver_train.py +303 -0
- autostatement_train.py +308 -0
- generation_method.py +472 -0
- model_train.py +318 -0
- repl/.lake/packages/mathlib/scripts/align-import.py +61 -0
- repl/.lake/packages/mathlib/scripts/align.py +64 -0
- repl/.lake/packages/mathlib/scripts/bench/accumulate_profile.py +21 -0
- repl/.lake/packages/mathlib/scripts/detect_sha_changes.py +104 -0
- repl/.lake/packages/mathlib/scripts/fix-comments.py +162 -0
- repl/.lake/packages/mathlib/scripts/fix-line-breaks.py +23 -0
- repl/.lake/packages/mathlib/scripts/fix-lints.py +48 -0
- repl/.lake/packages/mathlib/scripts/lint-style.py +451 -0
- repl/.lake/packages/mathlib/scripts/make_port_status.py +218 -0
- repl/.lake/packages/mathlib/scripts/polyrith_sage.py +101 -0
- repl/.lake/packages/mathlib/scripts/polyrith_sage_helper.py +13 -0
- repl/.lake/packages/mathlib/scripts/yaml_check.py +69 -0
- repl/pass_rate.py +195 -0
- whole_generation.py +491 -0
all_code.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
import json
|
3 |
+
def filtered():
|
4 |
+
import json
|
5 |
+
with open("data/lean4_basic/1k_test.jsonl", "r") as f:
|
6 |
+
test_data = json.load(f)
|
7 |
+
test_data = [item['statement_poof'] for item in test_data]
|
8 |
+
|
9 |
+
# Function to filter items based on existence in test data
|
10 |
+
|
11 |
+
|
12 |
+
with open("data/lean4_random/5k_second.json", "r") as f:
|
13 |
+
second_5k = json.load(f)
|
14 |
+
|
15 |
+
def filter_items(data, test_data):
|
16 |
+
filtered_data = [item for item in tqdm.tqdm(data) if item['statement_poof'][:-2] not in test_data]
|
17 |
+
return filtered_data
|
18 |
+
|
19 |
+
# Filter and save filtered data
|
20 |
+
|
21 |
+
filtered_second_5k = filter_items(second_5k, test_data)
|
22 |
+
with open("data/lean4_random/5k_second_filtered.json", "w") as f:
|
23 |
+
json.dump(filtered_second_5k, f, ensure_ascii=False, indent=2)
|
24 |
+
print("Filtered second file length:", len(filtered_second_5k))
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def insert_label_for_autoformalization():
|
30 |
+
input_lists = ["data/lean4_statement_translate/15k_state_problem_translation.json"]
|
31 |
+
for input_file in input_lists:
|
32 |
+
with open(input_file, "r") as f:
|
33 |
+
test_data = json.load(f)
|
34 |
+
for item in test_data:
|
35 |
+
item['task']='statement_form'
|
36 |
+
|
37 |
+
with open("data/lean4_statement_translate/15k_state_problem_translation_statement_form.json", "w") as f:
|
38 |
+
json.dump(test_data, f, ensure_ascii=False, indent=2)
|
39 |
+
|
40 |
+
|
41 |
+
input_lists = ["data/lean4_statement_translate/15k_state_problem_translation.json"]
|
42 |
+
for input_file in input_lists:
|
43 |
+
with open(input_file, "r") as f:
|
44 |
+
test_data = json.load(f)
|
45 |
+
for item in test_data:
|
46 |
+
item['task']='statementproof_inform'
|
47 |
+
with open("data/lean4_statement_translate/15k_state_problem_translation_statementproof_inform.json", "w") as f:
|
48 |
+
json.dump(test_data, f, ensure_ascii=False, indent=2)
|
49 |
+
|
50 |
+
|
51 |
+
input_lists = ["all_theorem.jsonl"]
|
52 |
+
for input_file in input_lists:
|
53 |
+
with open(input_file, "r") as f:
|
54 |
+
test_data = json.load(f)
|
55 |
+
for item in test_data:
|
56 |
+
item['task']='solver'
|
57 |
+
|
58 |
+
with open("data/all_theorem_solver.jsonl", "w") as f:
|
59 |
+
json.dump(test_data, f, ensure_ascii=False, indent=2)
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
insert_label_for_autoformalization()
|
66 |
+
# filtered()
|
auto_backform.py
ADDED
File without changes
|
autobackform_train.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Modified by Zheng Yuan and Hongyi Yuan
|
15 |
+
|
16 |
+
import os
|
17 |
+
import copy
|
18 |
+
import logging
|
19 |
+
from dataclasses import dataclass, field
|
20 |
+
from typing import Optional, Dict, Sequence
|
21 |
+
import io
|
22 |
+
import torch
|
23 |
+
import transformers
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
from transformers import Trainer
|
26 |
+
import argparse
|
27 |
+
import json
|
28 |
+
import random;random.seed(42)
|
29 |
+
|
30 |
+
def _make_r_io_base(f, mode: str):
|
31 |
+
if not isinstance(f, io.IOBase):
|
32 |
+
f = open(f, mode=mode)
|
33 |
+
return f
|
34 |
+
|
35 |
+
def jload(f, mode="r"):
|
36 |
+
"""Load a .json file into a dictionary."""
|
37 |
+
f = _make_r_io_base(f, mode)
|
38 |
+
jdict = json.load(f)
|
39 |
+
f.close()
|
40 |
+
return jdict
|
41 |
+
|
42 |
+
IGNORE_INDEX = -100
|
43 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
44 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
45 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
46 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
47 |
+
PROMPT_DICT = {
|
48 |
+
"lean4": (
|
49 |
+
"Statement and proof in natural language:\n\n"
|
50 |
+
"{statement_text}\n\n"
|
51 |
+
"Translate the statement and proof in natural language to lean4:"
|
52 |
+
),
|
53 |
+
"backform": (
|
54 |
+
"Statement and proof in lean4:\n\n"
|
55 |
+
"{statement_text}\n\n"
|
56 |
+
"Translate the statement and proof in lean4 to natural language:"
|
57 |
+
),
|
58 |
+
"prompt_no_input": (
|
59 |
+
"Below is an instruction that describes a task. "
|
60 |
+
"Write a response that appropriately completes the request.\n\n"
|
61 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
62 |
+
),
|
63 |
+
}
|
64 |
+
#### 28
|
65 |
+
@dataclass
|
66 |
+
class ModelArguments:
|
67 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
68 |
+
|
69 |
+
|
70 |
+
@dataclass
|
71 |
+
class DataArguments:
|
72 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class TrainingArguments(transformers.TrainingArguments):
|
77 |
+
cache_dir: Optional[str] = field(default=None)
|
78 |
+
optim: str = field(default="adamw_torch")
|
79 |
+
model_max_length: int = field(
|
80 |
+
default=2048,
|
81 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
82 |
+
)
|
83 |
+
overwrite_output_dir: bool = field(default=True)
|
84 |
+
|
85 |
+
|
86 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
87 |
+
"""Collects the state dict and dump to disk."""
|
88 |
+
state_dict = trainer.model.state_dict()
|
89 |
+
if trainer.args.should_save:
|
90 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
91 |
+
del state_dict
|
92 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
93 |
+
|
94 |
+
|
95 |
+
def smart_tokenizer_and_embedding_resize(
|
96 |
+
special_tokens_dict: Dict,
|
97 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
98 |
+
model: transformers.PreTrainedModel,
|
99 |
+
):
|
100 |
+
"""Resize tokenizer and embedding.
|
101 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
102 |
+
"""
|
103 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
104 |
+
model.resize_token_embeddings(len(tokenizer))
|
105 |
+
|
106 |
+
if num_new_tokens > 0:
|
107 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
108 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
109 |
+
|
110 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
111 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
112 |
+
|
113 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
114 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
115 |
+
|
116 |
+
|
117 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
118 |
+
"""Tokenize a list of strings."""
|
119 |
+
tokenized_list = [
|
120 |
+
tokenizer(
|
121 |
+
text,
|
122 |
+
return_tensors="pt",
|
123 |
+
padding="longest",
|
124 |
+
max_length=tokenizer.model_max_length,
|
125 |
+
truncation=True,
|
126 |
+
)
|
127 |
+
for text in strings
|
128 |
+
]
|
129 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
130 |
+
input_ids_lens = labels_lens = [
|
131 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
132 |
+
]
|
133 |
+
return dict(
|
134 |
+
input_ids=input_ids,
|
135 |
+
labels=labels,
|
136 |
+
input_ids_lens=input_ids_lens,
|
137 |
+
labels_lens=labels_lens,
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def preprocess(
|
142 |
+
sources: Sequence[str],
|
143 |
+
targets: Sequence[str],
|
144 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
145 |
+
) -> Dict:
|
146 |
+
"""Preprocess the data by tokenizing."""
|
147 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
148 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
149 |
+
input_ids = examples_tokenized["input_ids"]
|
150 |
+
labels = copy.deepcopy(input_ids)
|
151 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
152 |
+
label[:source_len] = IGNORE_INDEX
|
153 |
+
return dict(input_ids=input_ids, labels=labels)
|
154 |
+
|
155 |
+
class SupervisedDataset(Dataset):
|
156 |
+
"""Dataset for supervised fine-tuning."""
|
157 |
+
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
|
158 |
+
super(SupervisedDataset, self).__init__()
|
159 |
+
logging.warning("Loading data...")
|
160 |
+
data_path = data_args.data_path
|
161 |
+
try:
|
162 |
+
data_path = data_path_map[data_path]
|
163 |
+
except:
|
164 |
+
data_path = data_path
|
165 |
+
list_data_dict = []
|
166 |
+
for item in data_path.split(','):
|
167 |
+
try:
|
168 |
+
list_data_dict += jload(item)
|
169 |
+
|
170 |
+
except BaseException:
|
171 |
+
with open(item, 'r') as f:
|
172 |
+
lines = f.readlines()
|
173 |
+
list_data_dict += [json.loads(line.strip()) for line in lines]
|
174 |
+
|
175 |
+
list_data_dict = random.sample(list_data_dict, len(list_data_dict))
|
176 |
+
list_data_dict = list_data_dict[:data_args.data_length]
|
177 |
+
|
178 |
+
logging.warning("Formatting inputs...")
|
179 |
+
prompt_lean4 = PROMPT_DICT["backform"]
|
180 |
+
|
181 |
+
# list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
|
182 |
+
|
183 |
+
list_data_dict = [{'instruction':prompt_lean4.format(statement_text = data['statement_poof']), 'input':'', 'output':data['model_response']} for data in list_data_dict]
|
184 |
+
print(f"len of {len(list_data_dict)}")
|
185 |
+
sources = [example['instruction'] for example in list_data_dict]
|
186 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
187 |
+
# targets = [example['output'] for example in list_data_dict]
|
188 |
+
|
189 |
+
self.sources = sources
|
190 |
+
self.targets = targets
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return len(self.sources)
|
194 |
+
|
195 |
+
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
|
196 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
197 |
+
|
198 |
+
def __getitem__(self, i):
|
199 |
+
return dict(input_ids=self.sources[i], labels=self.targets[i])
|
200 |
+
|
201 |
+
@dataclass
|
202 |
+
class DataCollatorForSupervisedDataset(object):
|
203 |
+
"""Collate examples for supervised fine-tuning."""
|
204 |
+
|
205 |
+
tokenizer: transformers.PreTrainedTokenizer
|
206 |
+
|
207 |
+
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
208 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
209 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
210 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
211 |
+
)
|
212 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
213 |
+
return dict(
|
214 |
+
input_ids=input_ids,
|
215 |
+
labels=labels,
|
216 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
217 |
+
)
|
218 |
+
|
219 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
220 |
+
sources = []
|
221 |
+
targets = []
|
222 |
+
for instance in instances:
|
223 |
+
source = instance['input_ids']
|
224 |
+
target = instance['labels']
|
225 |
+
sources.append(source)
|
226 |
+
targets.append(target)
|
227 |
+
|
228 |
+
data_dict = preprocess(sources, targets, self.tokenizer)
|
229 |
+
input_ids, labels = data_dict['input_ids'], data_dict['labels']
|
230 |
+
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
231 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
232 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
233 |
+
)
|
234 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
235 |
+
return dict(
|
236 |
+
input_ids=input_ids,
|
237 |
+
labels=labels,
|
238 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
239 |
+
)
|
240 |
+
|
241 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
242 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
243 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
|
244 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
245 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
246 |
+
|
247 |
+
|
248 |
+
os.environ["WANDB_PROJECT"] = "auto_backform"
|
249 |
+
|
250 |
+
def train():
|
251 |
+
|
252 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
253 |
+
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
254 |
+
data_args.data_length = int(remaining_args[1])
|
255 |
+
|
256 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
257 |
+
model_args.model_name_or_path,
|
258 |
+
cache_dir=training_args.cache_dir,
|
259 |
+
trust_remote_code=True,
|
260 |
+
torch_dtype=torch.bfloat16,
|
261 |
+
attn_implementation="flash_attention_2",
|
262 |
+
)
|
263 |
+
|
264 |
+
model.config.use_cache = False
|
265 |
+
model.gradient_checkpointing_enable()
|
266 |
+
|
267 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
268 |
+
model_args.model_name_or_path,
|
269 |
+
cache_dir=training_args.cache_dir,
|
270 |
+
model_max_length=training_args.model_max_length,
|
271 |
+
padding_side="right",
|
272 |
+
use_fast=False,
|
273 |
+
)
|
274 |
+
if tokenizer.pad_token is None:
|
275 |
+
smart_tokenizer_and_embedding_resize(
|
276 |
+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
277 |
+
tokenizer=tokenizer,
|
278 |
+
model=model,
|
279 |
+
)
|
280 |
+
if "llama" in model_args.model_name_or_path:
|
281 |
+
tokenizer.add_special_tokens(
|
282 |
+
{
|
283 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
284 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
285 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
286 |
+
}
|
287 |
+
)
|
288 |
+
try:
|
289 |
+
tokenizer.pad_token = tokenizer.unk_token
|
290 |
+
except:
|
291 |
+
pass
|
292 |
+
|
293 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
294 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
295 |
+
trainer.train()
|
296 |
+
model.config.use_cache = True
|
297 |
+
# trainer.save_state()
|
298 |
+
# if os.environ.get('LOCAL_RANK') == '0':
|
299 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
if __name__ == "__main__":
|
305 |
+
train()
|
autoform_train.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Modified by Zheng Yuan and Hongyi Yuan
|
15 |
+
|
16 |
+
import os
|
17 |
+
import copy
|
18 |
+
import logging
|
19 |
+
from dataclasses import dataclass, field
|
20 |
+
from typing import Optional, Dict, Sequence
|
21 |
+
import io
|
22 |
+
import torch
|
23 |
+
import transformers
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
from transformers import Trainer
|
26 |
+
import argparse
|
27 |
+
import json
|
28 |
+
import random;random.seed(42)
|
29 |
+
|
30 |
+
def _make_r_io_base(f, mode: str):
|
31 |
+
if not isinstance(f, io.IOBase):
|
32 |
+
f = open(f, mode=mode)
|
33 |
+
return f
|
34 |
+
|
35 |
+
def jload(f, mode="r"):
|
36 |
+
"""Load a .json file into a dictionary."""
|
37 |
+
f = _make_r_io_base(f, mode)
|
38 |
+
jdict = json.load(f)
|
39 |
+
f.close()
|
40 |
+
return jdict
|
41 |
+
|
42 |
+
IGNORE_INDEX = -100
|
43 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
44 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
45 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
46 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
47 |
+
PROMPT_DICT = {
|
48 |
+
"lean4": (
|
49 |
+
"Statement and proof in natural language:\n\n"
|
50 |
+
"{statement_text}\n\n"
|
51 |
+
"Translate the statement and proof in natural language to lean4:"
|
52 |
+
),
|
53 |
+
"prompt_no_input": (
|
54 |
+
"Below is an instruction that describes a task. "
|
55 |
+
"Write a response that appropriately completes the request.\n\n"
|
56 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
57 |
+
),
|
58 |
+
}
|
59 |
+
#### 28
|
60 |
+
@dataclass
|
61 |
+
class ModelArguments:
|
62 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class DataArguments:
|
67 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
68 |
+
|
69 |
+
|
70 |
+
@dataclass
|
71 |
+
class TrainingArguments(transformers.TrainingArguments):
|
72 |
+
cache_dir: Optional[str] = field(default=None)
|
73 |
+
optim: str = field(default="adamw_torch")
|
74 |
+
model_max_length: int = field(
|
75 |
+
default=2048,
|
76 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
77 |
+
)
|
78 |
+
overwrite_output_dir: bool = field(default=True)
|
79 |
+
|
80 |
+
|
81 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
82 |
+
"""Collects the state dict and dump to disk."""
|
83 |
+
state_dict = trainer.model.state_dict()
|
84 |
+
if trainer.args.should_save:
|
85 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
86 |
+
del state_dict
|
87 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
88 |
+
|
89 |
+
|
90 |
+
def smart_tokenizer_and_embedding_resize(
|
91 |
+
special_tokens_dict: Dict,
|
92 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
93 |
+
model: transformers.PreTrainedModel,
|
94 |
+
):
|
95 |
+
"""Resize tokenizer and embedding.
|
96 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
97 |
+
"""
|
98 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
99 |
+
model.resize_token_embeddings(len(tokenizer))
|
100 |
+
|
101 |
+
if num_new_tokens > 0:
|
102 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
103 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
104 |
+
|
105 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
106 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
107 |
+
|
108 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
109 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
110 |
+
|
111 |
+
|
112 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
113 |
+
"""Tokenize a list of strings."""
|
114 |
+
tokenized_list = [
|
115 |
+
tokenizer(
|
116 |
+
text,
|
117 |
+
return_tensors="pt",
|
118 |
+
padding="longest",
|
119 |
+
max_length=tokenizer.model_max_length,
|
120 |
+
truncation=True,
|
121 |
+
)
|
122 |
+
for text in strings
|
123 |
+
]
|
124 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
125 |
+
input_ids_lens = labels_lens = [
|
126 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
127 |
+
]
|
128 |
+
return dict(
|
129 |
+
input_ids=input_ids,
|
130 |
+
labels=labels,
|
131 |
+
input_ids_lens=input_ids_lens,
|
132 |
+
labels_lens=labels_lens,
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
def preprocess(
|
137 |
+
sources: Sequence[str],
|
138 |
+
targets: Sequence[str],
|
139 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
140 |
+
) -> Dict:
|
141 |
+
"""Preprocess the data by tokenizing."""
|
142 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
143 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
144 |
+
input_ids = examples_tokenized["input_ids"]
|
145 |
+
labels = copy.deepcopy(input_ids)
|
146 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
147 |
+
label[:source_len] = IGNORE_INDEX
|
148 |
+
return dict(input_ids=input_ids, labels=labels)
|
149 |
+
|
150 |
+
class SupervisedDataset(Dataset):
|
151 |
+
"""Dataset for supervised fine-tuning."""
|
152 |
+
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
|
153 |
+
super(SupervisedDataset, self).__init__()
|
154 |
+
logging.warning("Loading data...")
|
155 |
+
data_path = data_args.data_path
|
156 |
+
try:
|
157 |
+
data_path = data_path_map[data_path]
|
158 |
+
except:
|
159 |
+
data_path = data_path
|
160 |
+
list_data_dict = []
|
161 |
+
for item in data_path.split(','):
|
162 |
+
try:
|
163 |
+
list_data_dict += jload(item)
|
164 |
+
|
165 |
+
except BaseException:
|
166 |
+
with open(item, 'r') as f:
|
167 |
+
lines = f.readlines()
|
168 |
+
list_data_dict += [json.loads(line.strip()) for line in lines]
|
169 |
+
|
170 |
+
list_data_dict = random.sample(list_data_dict, len(list_data_dict))
|
171 |
+
list_data_dict = list_data_dict[:data_args.data_length]
|
172 |
+
|
173 |
+
logging.warning("Formatting inputs...")
|
174 |
+
prompt_lean4 = PROMPT_DICT["lean4"]
|
175 |
+
|
176 |
+
# list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
|
177 |
+
|
178 |
+
list_data_dict = [{'instruction':prompt_lean4.format(statement_text = data['model_response']), 'input':'', 'output':data['statement_poof']} for data in list_data_dict]
|
179 |
+
print(f"len of {len(list_data_dict)}")
|
180 |
+
sources = [example['instruction'] for example in list_data_dict]
|
181 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
182 |
+
# targets = [example['output'] for example in list_data_dict]
|
183 |
+
|
184 |
+
self.sources = sources
|
185 |
+
self.targets = targets
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return len(self.sources)
|
189 |
+
|
190 |
+
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
|
191 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
192 |
+
|
193 |
+
def __getitem__(self, i):
|
194 |
+
return dict(input_ids=self.sources[i], labels=self.targets[i])
|
195 |
+
|
196 |
+
@dataclass
|
197 |
+
class DataCollatorForSupervisedDataset(object):
|
198 |
+
"""Collate examples for supervised fine-tuning."""
|
199 |
+
|
200 |
+
tokenizer: transformers.PreTrainedTokenizer
|
201 |
+
|
202 |
+
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
203 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
204 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
205 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
206 |
+
)
|
207 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
208 |
+
return dict(
|
209 |
+
input_ids=input_ids,
|
210 |
+
labels=labels,
|
211 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
212 |
+
)
|
213 |
+
|
214 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
215 |
+
sources = []
|
216 |
+
targets = []
|
217 |
+
for instance in instances:
|
218 |
+
source = instance['input_ids']
|
219 |
+
target = instance['labels']
|
220 |
+
sources.append(source)
|
221 |
+
targets.append(target)
|
222 |
+
|
223 |
+
data_dict = preprocess(sources, targets, self.tokenizer)
|
224 |
+
input_ids, labels = data_dict['input_ids'], data_dict['labels']
|
225 |
+
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
226 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
227 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
228 |
+
)
|
229 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
230 |
+
return dict(
|
231 |
+
input_ids=input_ids,
|
232 |
+
labels=labels,
|
233 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
234 |
+
)
|
235 |
+
|
236 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
237 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
238 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
|
239 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
240 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
241 |
+
|
242 |
+
|
243 |
+
os.environ["WANDB_PROJECT"] = "auto_form"
|
244 |
+
|
245 |
+
def train():
|
246 |
+
|
247 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
248 |
+
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
249 |
+
data_args.data_length = int(remaining_args[1])
|
250 |
+
|
251 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
252 |
+
model_args.model_name_or_path,
|
253 |
+
cache_dir=training_args.cache_dir,
|
254 |
+
trust_remote_code=True,
|
255 |
+
torch_dtype=torch.bfloat16,
|
256 |
+
attn_implementation="flash_attention_2",
|
257 |
+
)
|
258 |
+
|
259 |
+
model.config.use_cache = False
|
260 |
+
model.gradient_checkpointing_enable()
|
261 |
+
|
262 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
263 |
+
model_args.model_name_or_path,
|
264 |
+
cache_dir=training_args.cache_dir,
|
265 |
+
model_max_length=training_args.model_max_length,
|
266 |
+
padding_side="right",
|
267 |
+
use_fast=False,
|
268 |
+
)
|
269 |
+
if tokenizer.pad_token is None:
|
270 |
+
smart_tokenizer_and_embedding_resize(
|
271 |
+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
272 |
+
tokenizer=tokenizer,
|
273 |
+
model=model,
|
274 |
+
)
|
275 |
+
if "llama" in model_args.model_name_or_path:
|
276 |
+
tokenizer.add_special_tokens(
|
277 |
+
{
|
278 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
279 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
280 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
281 |
+
}
|
282 |
+
)
|
283 |
+
try:
|
284 |
+
tokenizer.pad_token = tokenizer.unk_token
|
285 |
+
except:
|
286 |
+
pass
|
287 |
+
|
288 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
289 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
290 |
+
trainer.train()
|
291 |
+
model.config.use_cache = True
|
292 |
+
# trainer.save_state()
|
293 |
+
# if os.environ.get('LOCAL_RANK') == '0':
|
294 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
train()
|
autosolver_train.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Modified by Zheng Yuan and Hongyi Yuan
|
15 |
+
|
16 |
+
import os
|
17 |
+
import copy
|
18 |
+
import logging
|
19 |
+
from dataclasses import dataclass, field
|
20 |
+
from typing import Optional, Dict, Sequence
|
21 |
+
import io
|
22 |
+
import torch
|
23 |
+
import transformers
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
from transformers import Trainer
|
26 |
+
import argparse
|
27 |
+
import json
|
28 |
+
import random;random.seed(42)
|
29 |
+
|
30 |
+
def _make_r_io_base(f, mode: str):
|
31 |
+
if not isinstance(f, io.IOBase):
|
32 |
+
f = open(f, mode=mode)
|
33 |
+
return f
|
34 |
+
|
35 |
+
def jload(f, mode="r"):
|
36 |
+
"""Load a .json file into a dictionary."""
|
37 |
+
f = _make_r_io_base(f, mode)
|
38 |
+
jdict = json.load(f)
|
39 |
+
f.close()
|
40 |
+
return jdict
|
41 |
+
|
42 |
+
IGNORE_INDEX = -100
|
43 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
44 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
45 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
46 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
47 |
+
PROMPT_DICT = {
|
48 |
+
"lean4": (
|
49 |
+
"Statement and proof in natural language:\n\n"
|
50 |
+
"{statement_text}\n\n"
|
51 |
+
"Translate the statement and proof in natural language to lean4:"
|
52 |
+
),
|
53 |
+
"plain": (
|
54 |
+
"{statement_text}"
|
55 |
+
),
|
56 |
+
"prompt_no_input": (
|
57 |
+
"Below is an instruction that describes a task. "
|
58 |
+
"Write a response that appropriately completes the request.\n\n"
|
59 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
60 |
+
),
|
61 |
+
}
|
62 |
+
#### 28
|
63 |
+
@dataclass
|
64 |
+
class ModelArguments:
|
65 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class DataArguments:
|
70 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class TrainingArguments(transformers.TrainingArguments):
|
75 |
+
cache_dir: Optional[str] = field(default=None)
|
76 |
+
optim: str = field(default="adamw_torch")
|
77 |
+
model_max_length: int = field(
|
78 |
+
default=2048,
|
79 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
80 |
+
)
|
81 |
+
overwrite_output_dir: bool = field(default=True)
|
82 |
+
|
83 |
+
|
84 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
85 |
+
"""Collects the state dict and dump to disk."""
|
86 |
+
state_dict = trainer.model.state_dict()
|
87 |
+
if trainer.args.should_save:
|
88 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
89 |
+
del state_dict
|
90 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
91 |
+
|
92 |
+
|
93 |
+
def smart_tokenizer_and_embedding_resize(
|
94 |
+
special_tokens_dict: Dict,
|
95 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
96 |
+
model: transformers.PreTrainedModel,
|
97 |
+
):
|
98 |
+
"""Resize tokenizer and embedding.
|
99 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
100 |
+
"""
|
101 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
102 |
+
model.resize_token_embeddings(len(tokenizer))
|
103 |
+
|
104 |
+
if num_new_tokens > 0:
|
105 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
106 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
107 |
+
|
108 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
109 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
110 |
+
|
111 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
112 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
113 |
+
|
114 |
+
|
115 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
116 |
+
"""Tokenize a list of strings."""
|
117 |
+
tokenized_list = [
|
118 |
+
tokenizer(
|
119 |
+
text,
|
120 |
+
return_tensors="pt",
|
121 |
+
padding="longest",
|
122 |
+
max_length=tokenizer.model_max_length,
|
123 |
+
truncation=True,
|
124 |
+
)
|
125 |
+
for text in strings
|
126 |
+
]
|
127 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
128 |
+
input_ids_lens = labels_lens = [
|
129 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
130 |
+
]
|
131 |
+
return dict(
|
132 |
+
input_ids=input_ids,
|
133 |
+
labels=labels,
|
134 |
+
input_ids_lens=input_ids_lens,
|
135 |
+
labels_lens=labels_lens,
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
def preprocess(
|
140 |
+
sources: Sequence[str],
|
141 |
+
targets: Sequence[str],
|
142 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
143 |
+
) -> Dict:
|
144 |
+
"""Preprocess the data by tokenizing."""
|
145 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
146 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
147 |
+
input_ids = examples_tokenized["input_ids"]
|
148 |
+
labels = copy.deepcopy(input_ids)
|
149 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
150 |
+
label[:source_len] = IGNORE_INDEX
|
151 |
+
return dict(input_ids=input_ids, labels=labels)
|
152 |
+
|
153 |
+
class SupervisedDataset(Dataset):
|
154 |
+
"""Dataset for supervised fine-tuning."""
|
155 |
+
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
|
156 |
+
super(SupervisedDataset, self).__init__()
|
157 |
+
logging.warning("Loading data...")
|
158 |
+
data_path = data_args.data_path
|
159 |
+
try:
|
160 |
+
data_path = data_path_map[data_path]
|
161 |
+
except:
|
162 |
+
data_path = data_path
|
163 |
+
list_data_dict = []
|
164 |
+
for item in data_path.split(','):
|
165 |
+
try:
|
166 |
+
list_data_dict += jload(item)
|
167 |
+
|
168 |
+
except BaseException:
|
169 |
+
with open(item, 'r') as f:
|
170 |
+
lines = f.readlines()
|
171 |
+
list_data_dict += [json.loads(line.strip()) for line in lines]
|
172 |
+
|
173 |
+
list_data_dict = random.sample(list_data_dict, len(list_data_dict))
|
174 |
+
list_data_dict = list_data_dict[:data_args.data_length]
|
175 |
+
|
176 |
+
logging.warning("Formatting inputs...")
|
177 |
+
prompt_lean4 = PROMPT_DICT["plain"]
|
178 |
+
|
179 |
+
# list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
|
180 |
+
|
181 |
+
list_data_dict = [{'instruction':prompt_lean4.format(statement_text = data['statement']), 'input':'', 'output':data['proof']} for data in list_data_dict]
|
182 |
+
print(f"len of {len(list_data_dict)}")
|
183 |
+
sources = [example['instruction'] for example in list_data_dict]
|
184 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
185 |
+
# targets = [example['output'] for example in list_data_dict]
|
186 |
+
|
187 |
+
self.sources = sources
|
188 |
+
self.targets = targets
|
189 |
+
|
190 |
+
def __len__(self):
|
191 |
+
return len(self.sources)
|
192 |
+
|
193 |
+
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
|
194 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
195 |
+
|
196 |
+
def __getitem__(self, i):
|
197 |
+
return dict(input_ids=self.sources[i], labels=self.targets[i])
|
198 |
+
|
199 |
+
@dataclass
|
200 |
+
class DataCollatorForSupervisedDataset(object):
|
201 |
+
"""Collate examples for supervised fine-tuning."""
|
202 |
+
|
203 |
+
tokenizer: transformers.PreTrainedTokenizer
|
204 |
+
|
205 |
+
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
206 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
207 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
208 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
209 |
+
)
|
210 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
211 |
+
return dict(
|
212 |
+
input_ids=input_ids,
|
213 |
+
labels=labels,
|
214 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
215 |
+
)
|
216 |
+
|
217 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
218 |
+
sources = []
|
219 |
+
targets = []
|
220 |
+
for instance in instances:
|
221 |
+
source = instance['input_ids']
|
222 |
+
target = instance['labels']
|
223 |
+
sources.append(source)
|
224 |
+
targets.append(target)
|
225 |
+
|
226 |
+
data_dict = preprocess(sources, targets, self.tokenizer)
|
227 |
+
input_ids, labels = data_dict['input_ids'], data_dict['labels']
|
228 |
+
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
229 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
230 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
231 |
+
)
|
232 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
233 |
+
return dict(
|
234 |
+
input_ids=input_ids,
|
235 |
+
labels=labels,
|
236 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
237 |
+
)
|
238 |
+
|
239 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
240 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
241 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
|
242 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
243 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
244 |
+
|
245 |
+
|
246 |
+
os.environ["WANDB_PROJECT"] = "auto_solver"
|
247 |
+
|
248 |
+
def train():
|
249 |
+
|
250 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
251 |
+
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
252 |
+
data_args.data_length = int(remaining_args[1])
|
253 |
+
|
254 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
255 |
+
model_args.model_name_or_path,
|
256 |
+
cache_dir=training_args.cache_dir,
|
257 |
+
trust_remote_code=True,
|
258 |
+
torch_dtype=torch.bfloat16,
|
259 |
+
attn_implementation="flash_attention_2",
|
260 |
+
)
|
261 |
+
|
262 |
+
model.config.use_cache = False
|
263 |
+
model.gradient_checkpointing_enable()
|
264 |
+
|
265 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
266 |
+
model_args.model_name_or_path,
|
267 |
+
cache_dir=training_args.cache_dir,
|
268 |
+
model_max_length=training_args.model_max_length,
|
269 |
+
padding_side="right",
|
270 |
+
use_fast=False,
|
271 |
+
)
|
272 |
+
if tokenizer.pad_token is None:
|
273 |
+
smart_tokenizer_and_embedding_resize(
|
274 |
+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
275 |
+
tokenizer=tokenizer,
|
276 |
+
model=model,
|
277 |
+
)
|
278 |
+
if "llama" in model_args.model_name_or_path:
|
279 |
+
tokenizer.add_special_tokens(
|
280 |
+
{
|
281 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
282 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
283 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
284 |
+
}
|
285 |
+
)
|
286 |
+
try:
|
287 |
+
tokenizer.pad_token = tokenizer.unk_token
|
288 |
+
except:
|
289 |
+
pass
|
290 |
+
|
291 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
292 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
293 |
+
trainer.train()
|
294 |
+
model.config.use_cache = True
|
295 |
+
# trainer.save_state()
|
296 |
+
# if os.environ.get('LOCAL_RANK') == '0':
|
297 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
if __name__ == "__main__":
|
303 |
+
train()
|
autostatement_train.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Modified by Zheng Yuan and Hongyi Yuan
|
15 |
+
|
16 |
+
import os
|
17 |
+
import copy
|
18 |
+
import logging
|
19 |
+
from dataclasses import dataclass, field
|
20 |
+
from typing import Optional, Dict, Sequence
|
21 |
+
import io
|
22 |
+
import torch
|
23 |
+
import transformers
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
from transformers import Trainer
|
26 |
+
import argparse
|
27 |
+
import json
|
28 |
+
import random;random.seed(42)
|
29 |
+
|
30 |
+
def _make_r_io_base(f, mode: str):
|
31 |
+
if not isinstance(f, io.IOBase):
|
32 |
+
f = open(f, mode=mode)
|
33 |
+
return f
|
34 |
+
|
35 |
+
def jload(f, mode="r"):
|
36 |
+
"""Load a .json file into a dictionary."""
|
37 |
+
f = _make_r_io_base(f, mode)
|
38 |
+
jdict = json.load(f)
|
39 |
+
f.close()
|
40 |
+
return jdict
|
41 |
+
|
42 |
+
IGNORE_INDEX = -100
|
43 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
44 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
45 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
46 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
47 |
+
PROMPT_DICT = {
|
48 |
+
"lean4": (
|
49 |
+
"Statement and proof in natural language:\n\n"
|
50 |
+
"{statement_text}\n\n"
|
51 |
+
"Translate the statement and proof in natural language to lean4:"
|
52 |
+
),
|
53 |
+
"plain": (
|
54 |
+
"{statement_text}"
|
55 |
+
),
|
56 |
+
"statement": (
|
57 |
+
"Statement in natural language:\n"
|
58 |
+
"{problem}\n"
|
59 |
+
"Translate the statement in natural language to Lean4:"
|
60 |
+
),
|
61 |
+
"prompt_no_input": (
|
62 |
+
"Below is an instruction that describes a task. "
|
63 |
+
"Write a response that appropriately completes the request.\n\n"
|
64 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
65 |
+
),
|
66 |
+
}
|
67 |
+
#### 28
|
68 |
+
@dataclass
|
69 |
+
class ModelArguments:
|
70 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class DataArguments:
|
75 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class TrainingArguments(transformers.TrainingArguments):
|
80 |
+
cache_dir: Optional[str] = field(default=None)
|
81 |
+
optim: str = field(default="adamw_torch")
|
82 |
+
model_max_length: int = field(
|
83 |
+
default=2048,
|
84 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
85 |
+
)
|
86 |
+
overwrite_output_dir: bool = field(default=True)
|
87 |
+
|
88 |
+
|
89 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
90 |
+
"""Collects the state dict and dump to disk."""
|
91 |
+
state_dict = trainer.model.state_dict()
|
92 |
+
if trainer.args.should_save:
|
93 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
94 |
+
del state_dict
|
95 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
96 |
+
|
97 |
+
|
98 |
+
def smart_tokenizer_and_embedding_resize(
|
99 |
+
special_tokens_dict: Dict,
|
100 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
101 |
+
model: transformers.PreTrainedModel,
|
102 |
+
):
|
103 |
+
"""Resize tokenizer and embedding.
|
104 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
105 |
+
"""
|
106 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
107 |
+
model.resize_token_embeddings(len(tokenizer))
|
108 |
+
|
109 |
+
if num_new_tokens > 0:
|
110 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
111 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
112 |
+
|
113 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
114 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
115 |
+
|
116 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
117 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
118 |
+
|
119 |
+
|
120 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
121 |
+
"""Tokenize a list of strings."""
|
122 |
+
tokenized_list = [
|
123 |
+
tokenizer(
|
124 |
+
text,
|
125 |
+
return_tensors="pt",
|
126 |
+
padding="longest",
|
127 |
+
max_length=tokenizer.model_max_length,
|
128 |
+
truncation=True,
|
129 |
+
)
|
130 |
+
for text in strings
|
131 |
+
]
|
132 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
133 |
+
input_ids_lens = labels_lens = [
|
134 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
135 |
+
]
|
136 |
+
return dict(
|
137 |
+
input_ids=input_ids,
|
138 |
+
labels=labels,
|
139 |
+
input_ids_lens=input_ids_lens,
|
140 |
+
labels_lens=labels_lens,
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
def preprocess(
|
145 |
+
sources: Sequence[str],
|
146 |
+
targets: Sequence[str],
|
147 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
148 |
+
) -> Dict:
|
149 |
+
"""Preprocess the data by tokenizing."""
|
150 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
151 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
152 |
+
input_ids = examples_tokenized["input_ids"]
|
153 |
+
labels = copy.deepcopy(input_ids)
|
154 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
155 |
+
label[:source_len] = IGNORE_INDEX
|
156 |
+
return dict(input_ids=input_ids, labels=labels)
|
157 |
+
|
158 |
+
class SupervisedDataset(Dataset):
|
159 |
+
"""Dataset for supervised fine-tuning."""
|
160 |
+
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
|
161 |
+
super(SupervisedDataset, self).__init__()
|
162 |
+
logging.warning("Loading data...")
|
163 |
+
data_path = data_args.data_path
|
164 |
+
try:
|
165 |
+
data_path = data_path_map[data_path]
|
166 |
+
except:
|
167 |
+
data_path = data_path
|
168 |
+
list_data_dict = []
|
169 |
+
for item in data_path.split(','):
|
170 |
+
try:
|
171 |
+
list_data_dict += jload(item)
|
172 |
+
|
173 |
+
except BaseException:
|
174 |
+
with open(item, 'r') as f:
|
175 |
+
lines = f.readlines()
|
176 |
+
list_data_dict += [json.loads(line.strip()) for line in lines]
|
177 |
+
|
178 |
+
list_data_dict = random.sample(list_data_dict, len(list_data_dict))
|
179 |
+
list_data_dict = list_data_dict[:data_args.data_length]
|
180 |
+
|
181 |
+
logging.warning("Formatting inputs...")
|
182 |
+
prompt_lean4 = PROMPT_DICT["statement"]
|
183 |
+
|
184 |
+
# list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
|
185 |
+
|
186 |
+
list_data_dict = [{'instruction':prompt_lean4.format(problem= data['problem']), 'input':'', 'output':data['statement']} for data in list_data_dict]
|
187 |
+
print(f"len of {len(list_data_dict)}")
|
188 |
+
sources = [example['instruction'] for example in list_data_dict]
|
189 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
190 |
+
# targets = [example['output'] for example in list_data_dict]
|
191 |
+
|
192 |
+
self.sources = sources
|
193 |
+
self.targets = targets
|
194 |
+
|
195 |
+
def __len__(self):
|
196 |
+
return len(self.sources)
|
197 |
+
|
198 |
+
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
|
199 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
200 |
+
|
201 |
+
def __getitem__(self, i):
|
202 |
+
return dict(input_ids=self.sources[i], labels=self.targets[i])
|
203 |
+
|
204 |
+
@dataclass
|
205 |
+
class DataCollatorForSupervisedDataset(object):
|
206 |
+
"""Collate examples for supervised fine-tuning."""
|
207 |
+
|
208 |
+
tokenizer: transformers.PreTrainedTokenizer
|
209 |
+
|
210 |
+
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
211 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
212 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
213 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
214 |
+
)
|
215 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
216 |
+
return dict(
|
217 |
+
input_ids=input_ids,
|
218 |
+
labels=labels,
|
219 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
220 |
+
)
|
221 |
+
|
222 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
223 |
+
sources = []
|
224 |
+
targets = []
|
225 |
+
for instance in instances:
|
226 |
+
source = instance['input_ids']
|
227 |
+
target = instance['labels']
|
228 |
+
sources.append(source)
|
229 |
+
targets.append(target)
|
230 |
+
|
231 |
+
data_dict = preprocess(sources, targets, self.tokenizer)
|
232 |
+
input_ids, labels = data_dict['input_ids'], data_dict['labels']
|
233 |
+
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
234 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
235 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
236 |
+
)
|
237 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
238 |
+
return dict(
|
239 |
+
input_ids=input_ids,
|
240 |
+
labels=labels,
|
241 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
242 |
+
)
|
243 |
+
|
244 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
245 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
246 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
|
247 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
248 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
249 |
+
|
250 |
+
|
251 |
+
os.environ["WANDB_PROJECT"] = "auto_statement"
|
252 |
+
|
253 |
+
def train():
|
254 |
+
|
255 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
256 |
+
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
257 |
+
data_args.data_length = int(remaining_args[1])
|
258 |
+
|
259 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
260 |
+
model_args.model_name_or_path,
|
261 |
+
cache_dir=training_args.cache_dir,
|
262 |
+
trust_remote_code=True,
|
263 |
+
torch_dtype=torch.bfloat16,
|
264 |
+
attn_implementation="flash_attention_2",
|
265 |
+
)
|
266 |
+
|
267 |
+
model.config.use_cache = False
|
268 |
+
model.gradient_checkpointing_enable()
|
269 |
+
|
270 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
271 |
+
model_args.model_name_or_path,
|
272 |
+
cache_dir=training_args.cache_dir,
|
273 |
+
model_max_length=training_args.model_max_length,
|
274 |
+
padding_side="right",
|
275 |
+
use_fast=False,
|
276 |
+
)
|
277 |
+
if tokenizer.pad_token is None:
|
278 |
+
smart_tokenizer_and_embedding_resize(
|
279 |
+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
280 |
+
tokenizer=tokenizer,
|
281 |
+
model=model,
|
282 |
+
)
|
283 |
+
if "llama" in model_args.model_name_or_path:
|
284 |
+
tokenizer.add_special_tokens(
|
285 |
+
{
|
286 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
287 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
288 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
289 |
+
}
|
290 |
+
)
|
291 |
+
try:
|
292 |
+
tokenizer.pad_token = tokenizer.unk_token
|
293 |
+
except:
|
294 |
+
pass
|
295 |
+
|
296 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
297 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
298 |
+
trainer.train()
|
299 |
+
model.config.use_cache = True
|
300 |
+
# trainer.save_state()
|
301 |
+
# if os.environ.get('LOCAL_RANK') == '0':
|
302 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == "__main__":
|
308 |
+
train()
|
generation_method.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def generate_few_shot(prompt):
|
13 |
+
base_gsm8k_list = [
|
14 |
+
{
|
15 |
+
'question': "John and his best friend Steve bought 12 cupcakes together. Each cupcake cost $1.50. If they split the costs evenly, how much did each person pay?",
|
16 |
+
'answer': "The total cost of cupcakes was 1.5*12=$<<1.5*12=18>>18\\nSo they each paid 18/2=$<<18/2=9>>9.",
|
17 |
+
'direct_answer': "9"
|
18 |
+
},
|
19 |
+
{
|
20 |
+
'question': "Lizzy has to ship 540 pounds of fish that are packed into 30-pound crates. If the shipping cost of each crate is $1.5, how much will Lizzy pay for the shipment?",
|
21 |
+
'answer': "There are 540 pounds / 30 pounds/crate = <<540/30=18>>18 crates of fish needed.\\nHence, the total cost for the shipment is $1.5/crate x 18 crates = $<<1.5*18=27>>27.",
|
22 |
+
'direct_answer': "27"
|
23 |
+
},
|
24 |
+
{
|
25 |
+
'question': "Tom, Tim, and Paul are collecting photos of cars. Paul has 10 photos more than Tim. Tim has one hundred photos less than the total amount of photos which is 152. How many photos does Tom have?",
|
26 |
+
'answer': "Tim has 152 photos - 100 photos = <<152-100=52>>52 photos.\\nWhen Tim has 52 photos, then Paul has 52 + 10 photos = <<52+10=62>>62 photos.\\nTim and Paul have together 52 photos + 62 photos = <<52+62=114>>114 photos.\\nThat leaves Tom with 152 photos - 114 photos = <<152-114=38>>38 photos.",
|
27 |
+
'direct_answer': "38"
|
28 |
+
},
|
29 |
+
|
30 |
+
]
|
31 |
+
index_list = list(range(len(base_gsm8k_list)))
|
32 |
+
random.shuffle(index_list)
|
33 |
+
few_shot_example = ""
|
34 |
+
for i in index_list:
|
35 |
+
item = base_gsm8k_list[i]
|
36 |
+
few_shot_example += "Q: " + item['question'] + "\n" + "A: "+ item['answer'] + "\nThe answer is " + item['direct_answer'] + "\n"
|
37 |
+
|
38 |
+
few_shot_example += "Q: " + prompt + "A: "
|
39 |
+
return few_shot_example
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def generate_prompt_generation(args, question):
|
44 |
+
if args.evaluation_mode == 'generation':
|
45 |
+
if args.method == 'zero_shot_cot':
|
46 |
+
content = question + " Let's think step by step."
|
47 |
+
elif args.method == 'zero_shot':
|
48 |
+
content = question
|
49 |
+
elif args.method == 'few_shot':
|
50 |
+
content = generate_few_shot(question)
|
51 |
+
else:
|
52 |
+
raise ValueError("we do not method for such model type yet")
|
53 |
+
|
54 |
+
if "generator" not in args.model_type:
|
55 |
+
MODEL_DICT = {
|
56 |
+
"llama": (
|
57 |
+
"[INST] \n{content}\n [/INST]"
|
58 |
+
),
|
59 |
+
"mistral": (
|
60 |
+
"<s>[INST] {content} [/INST]"
|
61 |
+
),
|
62 |
+
"chatglm": (
|
63 |
+
"<|user|> \n{content}\n <|assistant|>"
|
64 |
+
),
|
65 |
+
"qianwen": (
|
66 |
+
"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
|
67 |
+
),
|
68 |
+
"baichuan": (
|
69 |
+
"<reserved_106>{content}<reserved_107>"
|
70 |
+
)
|
71 |
+
}
|
72 |
+
|
73 |
+
if args.model_type in ["qianwen", "qianwen-13b", "qianwen-70b"]:
|
74 |
+
content = MODEL_DICT['qianwen'].format_map(
|
75 |
+
{'content': content}
|
76 |
+
)
|
77 |
+
|
78 |
+
elif args.model_type in ["chatglm"]:
|
79 |
+
pass
|
80 |
+
|
81 |
+
|
82 |
+
elif args.model_type in ['llama2-7b-chat']:
|
83 |
+
content = MODEL_DICT['llama'].format_map(
|
84 |
+
{'content': content}
|
85 |
+
)
|
86 |
+
|
87 |
+
elif args.model_type in ["mistral", 'mixtral']:
|
88 |
+
content = MODEL_DICT['mistral'].format_map(
|
89 |
+
{'content': content}
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
return content
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
few_shot_list = [
|
99 |
+
{
|
100 |
+
'question': "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
|
101 |
+
'answer': "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
|
102 |
+
'direct_answer': "6"
|
103 |
+
},
|
104 |
+
{
|
105 |
+
'question': "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
|
106 |
+
'answer': "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
|
107 |
+
'direct_answer': "5",
|
108 |
+
},
|
109 |
+
{
|
110 |
+
'question': "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
|
111 |
+
'answer': "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
|
112 |
+
'direct_answer': "39",
|
113 |
+
},
|
114 |
+
{
|
115 |
+
'question': "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
|
116 |
+
'answer': "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.",
|
117 |
+
'direct_answer': "8",
|
118 |
+
},
|
119 |
+
{
|
120 |
+
'question': "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
|
121 |
+
'answer': "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.",
|
122 |
+
'direct_answer': "9",
|
123 |
+
},
|
124 |
+
{
|
125 |
+
'question': "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
|
126 |
+
'answer': "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.",
|
127 |
+
'direct_answer': "29",
|
128 |
+
},
|
129 |
+
{
|
130 |
+
'question': "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
|
131 |
+
'answer': "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.",
|
132 |
+
'direct_answer': "33",
|
133 |
+
},
|
134 |
+
{
|
135 |
+
'question': "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
|
136 |
+
'answer': "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.",
|
137 |
+
'direct_answer': "8",
|
138 |
+
},
|
139 |
+
]
|
140 |
+
import json
|
141 |
+
|
142 |
+
from collections import Counter
|
143 |
+
|
144 |
+
|
145 |
+
def self_consistency(pairs):
|
146 |
+
val_counts = Counter(value for key, value in pairs)
|
147 |
+
most = val_counts.most_common(1)[0][0]
|
148 |
+
for key, value in pairs:
|
149 |
+
if value == most:
|
150 |
+
return key
|
151 |
+
|
152 |
+
|
153 |
+
#
|
154 |
+
def find_feedback(content):
|
155 |
+
match = re.search(r'Judgement: (.+)', content)
|
156 |
+
if match:
|
157 |
+
judgement = match.group(1)
|
158 |
+
else:
|
159 |
+
judgement = "None"
|
160 |
+
return judgement
|
161 |
+
|
162 |
+
|
163 |
+
def str2bool(s):
|
164 |
+
s = s.lower()
|
165 |
+
if s == 'true':
|
166 |
+
return True
|
167 |
+
elif s == 'false':
|
168 |
+
return False
|
169 |
+
else:
|
170 |
+
raise ValueError('invalid value: {}, must be true or false'.format(s))
|
171 |
+
|
172 |
+
|
173 |
+
def parse_arguments():
|
174 |
+
parser = argparse.ArgumentParser(description="Zero-shot-CoT")
|
175 |
+
|
176 |
+
# parser.add_argument(
|
177 |
+
# "--dataset", type=str, default="plan",
|
178 |
+
# choices=["plan", 'tool_use_awareness', 'tool_selection', 'tool_selection_harder', 'tool_creation_awareness',
|
179 |
+
# 'tool_creation_awareness_harder', 'tool_creation',
|
180 |
+
# 'arguments_filling'], help="dataset used for experiment")
|
181 |
+
parser.add_argument(
|
182 |
+
"--cot_trigger_no", type=int, default=1,
|
183 |
+
help="A trigger sentence that elicits a model to execute chain of thought"
|
184 |
+
)
|
185 |
+
parser.add_argument("--dataset", type=str, default="")
|
186 |
+
parser.add_argument("--data_path", type=str, default="")
|
187 |
+
parser.add_argument("--evaluation_mode", type=str, default="")
|
188 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
189 |
+
parser.add_argument("--eval_method", type=str, default="")
|
190 |
+
|
191 |
+
parser.add_argument("--model_path", type=str, default="")
|
192 |
+
|
193 |
+
parser.add_argument("--model_type", type=str, default="chatglm")
|
194 |
+
|
195 |
+
parser.add_argument("--output_dir", type=str, default="generation_test")
|
196 |
+
|
197 |
+
parser.add_argument("--lora_path", type=str, default="")
|
198 |
+
|
199 |
+
parser.add_argument("--iter_num", type=int, default=1)
|
200 |
+
parser.add_argument("--method", type=str, default="few_shot_cot")
|
201 |
+
parser.add_argument("--data_question_key", type=str, default="question")
|
202 |
+
parser.add_argument("--data_answer_key", type=str, default="answer")
|
203 |
+
|
204 |
+
parser.add_argument("--sample_num", type=int, default=1)
|
205 |
+
|
206 |
+
parser.add_argument("--cuda_ind", type=int, default=0)
|
207 |
+
parser.add_argument("--tensor_parallel", type=int, default=1)
|
208 |
+
parser.add_argument("--cuda_start", type=int, default=0)
|
209 |
+
parser.add_argument("--cuda_num", type=int, default=8)
|
210 |
+
|
211 |
+
parser.add_argument("--load_in_8bit", type=str2bool, default=False)
|
212 |
+
parser.add_argument("--rewrite", type=str2bool, default=True)
|
213 |
+
parser.add_argument("--notlean", type=str2bool, default=True)
|
214 |
+
|
215 |
+
parser.add_argument("--use_typewriter", type=int, default=0)
|
216 |
+
|
217 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
218 |
+
parser.add_argument("--top_p", type=float, default=1)
|
219 |
+
parser.add_argument("--iter_max_new_tokens", type=int, default=512)
|
220 |
+
parser.add_argument("--init_max_new_tokens", type=int, default=2048)
|
221 |
+
parser.add_argument("--min_new_tokens", type=int, default=1)
|
222 |
+
parser.add_argument("--correct_response_format", type=str, default="The correct response is:")
|
223 |
+
|
224 |
+
args = parser.parse_args()
|
225 |
+
if args.evaluation_mode == 'generation':
|
226 |
+
if "lean" in args.dataset:
|
227 |
+
args.data_question_key = 'model_response'
|
228 |
+
args.data_answer_key = 'statement_poof'
|
229 |
+
|
230 |
+
if args.dataset == "lean4_5k_test":
|
231 |
+
args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
|
232 |
+
|
233 |
+
elif args.dataset == "math_train":
|
234 |
+
args.data_path = "data/test/math/train.jsonl"
|
235 |
+
|
236 |
+
elif args.dataset == "gsm8k_train":
|
237 |
+
args.data_path = "data/test/gsm8k/train.jsonl"
|
238 |
+
|
239 |
+
elif args.dataset == "wild_test":
|
240 |
+
args.data_path = "/hpc2hdd/home/zyang398/data_2/wild_sample1k.jsonl"
|
241 |
+
|
242 |
+
elif args.dataset == "lean4_basic_test":
|
243 |
+
args.data_path = "data/lean4_basic/1k_test.jsonl"
|
244 |
+
elif args.dataset == "lean4_random_test":
|
245 |
+
args.data_path = "data/lean4_random/1k_test.json"
|
246 |
+
elif args.dataset == "lean4_random_first_train":
|
247 |
+
args.data_path = "data/lean4_random/5k_first.json"
|
248 |
+
elif args.dataset == "lean4_random_second_train":
|
249 |
+
args.data_path = "data/lean4_random/5k_second.json"
|
250 |
+
elif args.dataset == "lean4_random_third_train":
|
251 |
+
args.data_path = "data/lean4_random/5k_third.json"
|
252 |
+
|
253 |
+
if args.model_type == 'mistral_generator':
|
254 |
+
args.model_path = 'models/gsm8k/generators/mistral-ep2/'
|
255 |
+
elif args.model_type == 'mistral_generator_original':
|
256 |
+
args.model_path = '/data/OVM-Mistral-7b/mistral7b-ep2/'
|
257 |
+
elif args.model_type == 'gemma_generator':
|
258 |
+
args.model_path = 'models/gsm8k/generators/gemma2b2-ep2/'
|
259 |
+
elif args.model_type == 'phi2_generator':
|
260 |
+
args.model_path = 'models/gsm8k/generators/phi2b-ep2/'
|
261 |
+
|
262 |
+
elif args.model_type == 'mixtral':
|
263 |
+
args.model_path = '/data/Mixtral-8x7B-Instruct-v0.1'
|
264 |
+
|
265 |
+
elif args.model_type == 'mistral':
|
266 |
+
args.model_path = '/data/mistral-instruct'
|
267 |
+
|
268 |
+
elif args.model_type == 'qianwen-70b':
|
269 |
+
args.model_path = '/data/Qwen-72B-Chat'
|
270 |
+
|
271 |
+
|
272 |
+
elif args.model_type == 'llama2-7b-chat':
|
273 |
+
args.model_path = '/data/Llama-2-7b-chat/'
|
274 |
+
|
275 |
+
if args.cot_trigger_no == 1:
|
276 |
+
args.cot_trigger = "Let's think step by step."
|
277 |
+
|
278 |
+
return args
|
279 |
+
|
280 |
+
|
281 |
+
def create_demo_text(args, cot_flag, index_list):
|
282 |
+
# Concatenate demonstration examples ...
|
283 |
+
demo_text = ""
|
284 |
+
for i in index_list:
|
285 |
+
item = few_shot_list[i]
|
286 |
+
if cot_flag:
|
287 |
+
demo_text += "Q: " + item['question'] + "\nA: " + item['answer'] + " " + \
|
288 |
+
args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
|
289 |
+
else:
|
290 |
+
demo_text += "Q: " + item['question'] + "\nA: " + \
|
291 |
+
args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
|
292 |
+
|
293 |
+
return demo_text
|
294 |
+
|
295 |
+
|
296 |
+
def str2bool(s):
|
297 |
+
s = s.lower()
|
298 |
+
if s == 'true':
|
299 |
+
return True
|
300 |
+
elif s == 'false':
|
301 |
+
return False
|
302 |
+
else:
|
303 |
+
raise ValueError('invalid value: {}, must be true or false'.format(s))
|
304 |
+
|
305 |
+
|
306 |
+
def batchify(pairs, batch_size):
|
307 |
+
|
308 |
+
"""将列表分成指定大小的批次"""
|
309 |
+
for i in range(0, len(pairs), batch_size):
|
310 |
+
yield pairs[i:i + batch_size]
|
311 |
+
|
312 |
+
|
313 |
+
def generate_prompts(questions, args):
|
314 |
+
"""为每个问题生成提示"""
|
315 |
+
prompts = [generate_prompt_generation(args, question) for question in questions]
|
316 |
+
return prompts
|
317 |
+
|
318 |
+
PROMPT_DICT = {
|
319 |
+
"wild": (
|
320 |
+
"Statement and proof in natural language:\n\n"
|
321 |
+
"# Problem:\n{question}\n\n"
|
322 |
+
"# Proof:\n{answer}\n\n"
|
323 |
+
"Translate the statement and proof in natural language to lean4:"
|
324 |
+
),
|
325 |
+
"lean4": (
|
326 |
+
"Statement and proof in natural language:\n\n"
|
327 |
+
"{statement_text}\n\n"
|
328 |
+
"Translate the statement and proof in natural language to lean4:"
|
329 |
+
),
|
330 |
+
"prompt_no_input": (
|
331 |
+
"Below is an instruction that describes a task. "
|
332 |
+
"Write a response that appropriately completes the request.\n\n"
|
333 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
334 |
+
),
|
335 |
+
}
|
336 |
+
|
337 |
+
def get_question_answer(args):
|
338 |
+
allfilepath = args.data_path
|
339 |
+
questions = []
|
340 |
+
answers = []
|
341 |
+
|
342 |
+
# Attempt to read the file as a regular JSON file
|
343 |
+
for filepath in allfilepath.split(','):
|
344 |
+
try:
|
345 |
+
with open(filepath, 'r') as file:
|
346 |
+
data = json.load(file)
|
347 |
+
# If the data is a list, assume it's an array of objects
|
348 |
+
if isinstance(data, list):
|
349 |
+
for json_item in data:
|
350 |
+
questions.append(json_item[args.data_question_key])
|
351 |
+
answers.append(json_item)
|
352 |
+
# If the data is a dict, assume it's a single object (or adjust logic as needed)
|
353 |
+
elif isinstance(data, dict):
|
354 |
+
questions.append(data[args.data_question_key])
|
355 |
+
answers.append(json_item)
|
356 |
+
|
357 |
+
except ValueError:
|
358 |
+
# If it fails, assume the file is in JSON Lines format
|
359 |
+
with open(filepath, 'r') as file:
|
360 |
+
for line in file:
|
361 |
+
json_item = json.loads(line)
|
362 |
+
questions.append(json_item[args.data_question_key])
|
363 |
+
answers.append(json_item)
|
364 |
+
|
365 |
+
if args.notlean :
|
366 |
+
questions = [ PROMPT_DICT['wild'].format(question= questions[id], answer =answers[id][args.data_answer_key] ) for id in range(len(questions))]
|
367 |
+
|
368 |
+
else:
|
369 |
+
questions = [ PROMPT_DICT['lean4'].format(statement_text = item) for item in questions]
|
370 |
+
|
371 |
+
|
372 |
+
return questions, answers
|
373 |
+
|
374 |
+
|
375 |
+
def main3(args):
|
376 |
+
from vllm import LLM, SamplingParams
|
377 |
+
import torch
|
378 |
+
|
379 |
+
model = LLM(model=args.model_path, dtype="bfloat16", trust_remote_code=True,
|
380 |
+
tensor_parallel_size=args.tensor_parallel, gpu_memory_utilization = 0.95)
|
381 |
+
print(args.model_path)
|
382 |
+
|
383 |
+
if "qianwen" in args.model_type:
|
384 |
+
model.llm_engine.tokenizer.eos_token_id = 151645
|
385 |
+
# model.llm_engine.tokenizer.pad_token_id = 151645
|
386 |
+
model.llm_engine.tokenizer.pad_token_id = None
|
387 |
+
# model.llm_engine.tokenizer.eos_token_id = None
|
388 |
+
|
389 |
+
|
390 |
+
print("load data")
|
391 |
+
|
392 |
+
|
393 |
+
questions, answers = get_question_answer(args)
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
question_exist_list = []
|
398 |
+
write_pattern = 'w' if args.rewrite else "a+"
|
399 |
+
if os.path.exists(args.output_dir) and not args.rewrite :
|
400 |
+
# 如果文件存在,从文件中读取数据加载到response_list
|
401 |
+
# Loop through each file that matches the pattern
|
402 |
+
file_pattern = os.path.join(args.output_dir, '[0-9]*.json')
|
403 |
+
for file_path in glob.glob(file_pattern):
|
404 |
+
# Open and read the JSON file
|
405 |
+
with open(file_path, 'r') as fp:
|
406 |
+
# Extract the 'question' field from each line and add it to the list
|
407 |
+
for line in fp.readlines():
|
408 |
+
question_exist_list.append(json.loads(line)['question'])
|
409 |
+
else:
|
410 |
+
try:
|
411 |
+
os.mkdir(args.output_dir)
|
412 |
+
except:
|
413 |
+
pass
|
414 |
+
qa_pairs = [(questions[idx], answers[idx]) for idx in range(len(questions)) if questions[idx] not in question_exist_list ]
|
415 |
+
cuda_pieces = np.array_split(range(len(qa_pairs)), args.cuda_num // args.tensor_parallel)
|
416 |
+
print(f"fitered {len(questions) - len(qa_pairs)} already")
|
417 |
+
|
418 |
+
with open(f"{args.output_dir}/{args.cuda_ind // args.tensor_parallel + args.cuda_start}.json", write_pattern,
|
419 |
+
encoding='utf-8') as wf:
|
420 |
+
start = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][0]
|
421 |
+
end = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][-1] + 1
|
422 |
+
subset_length = end - start
|
423 |
+
total_batches = (subset_length + args.batch_size - 1) // args.batch_size # Calculate the total number of batches
|
424 |
+
for batch in tqdm(batchify(qa_pairs[start:end], args.batch_size), total=total_batches):
|
425 |
+
questions, answers = zip(*batch) # 解压问题和答案
|
426 |
+
prompts = generate_prompts(questions, args)
|
427 |
+
|
428 |
+
with torch.no_grad():
|
429 |
+
output_all = []
|
430 |
+
try:
|
431 |
+
for i in range(args.sample_num):
|
432 |
+
sample_list = []
|
433 |
+
sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p,
|
434 |
+
max_tokens=args.init_max_new_tokens)
|
435 |
+
generations = model.generate(prompts, sampling_params, use_tqdm=False)
|
436 |
+
for generation_output in generations:
|
437 |
+
output = generation_output.outputs[0].text
|
438 |
+
sample_list.append(output)
|
439 |
+
output_all.append(sample_list)
|
440 |
+
|
441 |
+
output_all = list(map(list, zip(*output_all)))
|
442 |
+
except Exception as e:
|
443 |
+
print(str(e))
|
444 |
+
exit
|
445 |
+
dicts = []
|
446 |
+
for question, answer, output, prompt in zip(questions, answers, output_all, prompts):
|
447 |
+
dicts.append({
|
448 |
+
"question": question,
|
449 |
+
"prompt": prompt,
|
450 |
+
"content": answer,
|
451 |
+
"total output": output,
|
452 |
+
})
|
453 |
+
|
454 |
+
for dict in dicts:
|
455 |
+
wf.writelines(json.dumps(dict, ensure_ascii=False) + '\n')
|
456 |
+
|
457 |
+
wf.flush()
|
458 |
+
|
459 |
+
|
460 |
+
def main(argv=None):
|
461 |
+
args = parse_arguments()
|
462 |
+
print('*****************************')
|
463 |
+
print(args)
|
464 |
+
print('*****************************')
|
465 |
+
if args.evaluation_mode == 'generation':
|
466 |
+
main3(args)
|
467 |
+
else:
|
468 |
+
raise ValueError("we do not yet inplement")
|
469 |
+
|
470 |
+
|
471 |
+
if __name__ == "__main__":
|
472 |
+
main()
|
model_train.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# Modified by Zheng Yuan and Hongyi Yuan
|
15 |
+
|
16 |
+
import os
|
17 |
+
import copy
|
18 |
+
import logging
|
19 |
+
from dataclasses import dataclass, field
|
20 |
+
from typing import Optional, Dict, Sequence
|
21 |
+
import io
|
22 |
+
import torch
|
23 |
+
import transformers
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
from transformers import Trainer
|
26 |
+
import argparse
|
27 |
+
import json
|
28 |
+
import random;
|
29 |
+
|
30 |
+
random.seed(42)
|
31 |
+
|
32 |
+
|
33 |
+
def _make_r_io_base(f, mode: str):
|
34 |
+
if not isinstance(f, io.IOBase):
|
35 |
+
f = open(f, mode=mode)
|
36 |
+
return f
|
37 |
+
|
38 |
+
|
39 |
+
def jload(f, mode="r"):
|
40 |
+
"""Load a .json file into a dictionary."""
|
41 |
+
f = _make_r_io_base(f, mode)
|
42 |
+
jdict = json.load(f)
|
43 |
+
f.close()
|
44 |
+
return jdict
|
45 |
+
|
46 |
+
|
47 |
+
IGNORE_INDEX = -100
|
48 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
49 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
50 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
51 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
52 |
+
PROMPT_DICT = {
|
53 |
+
"statement_form": (
|
54 |
+
"Statement in natural language:\n"
|
55 |
+
"{problem}\n"
|
56 |
+
"Translate the statement in natural language to Lean4:"
|
57 |
+
),
|
58 |
+
"solver": (
|
59 |
+
"{statement_text}"
|
60 |
+
),
|
61 |
+
"statementproof_inform": (
|
62 |
+
"Statement and proof in lean4:\n\n"
|
63 |
+
"{statement_text}\n\n"
|
64 |
+
"Translate the statement and proof in lean4 to natural language:"
|
65 |
+
),
|
66 |
+
}
|
67 |
+
KEY_DICT = {
|
68 |
+
"statement_form" : ["problem", "statement"],
|
69 |
+
"solver": ["statement", "proof"],
|
70 |
+
"statementproof_inform": ["statement_poof", "model_response"]
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
#### 28
|
75 |
+
@dataclass
|
76 |
+
class ModelArguments:
|
77 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass
|
81 |
+
class DataArguments:
|
82 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
83 |
+
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class TrainingArguments(transformers.TrainingArguments):
|
87 |
+
cache_dir: Optional[str] = field(default=None)
|
88 |
+
optim: str = field(default="adamw_torch")
|
89 |
+
model_max_length: int = field(
|
90 |
+
default=2048,
|
91 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
92 |
+
)
|
93 |
+
overwrite_output_dir: bool = field(default=True)
|
94 |
+
|
95 |
+
|
96 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
97 |
+
"""Collects the state dict and dump to disk."""
|
98 |
+
state_dict = trainer.model.state_dict()
|
99 |
+
if trainer.args.should_save:
|
100 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
101 |
+
del state_dict
|
102 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
103 |
+
|
104 |
+
|
105 |
+
def smart_tokenizer_and_embedding_resize(
|
106 |
+
special_tokens_dict: Dict,
|
107 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
108 |
+
model: transformers.PreTrainedModel,
|
109 |
+
):
|
110 |
+
"""Resize tokenizer and embedding.
|
111 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
112 |
+
"""
|
113 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
114 |
+
model.resize_token_embeddings(len(tokenizer))
|
115 |
+
|
116 |
+
if num_new_tokens > 0:
|
117 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
118 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
119 |
+
|
120 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
121 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
122 |
+
|
123 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
124 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
125 |
+
|
126 |
+
|
127 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
128 |
+
"""Tokenize a list of strings."""
|
129 |
+
tokenized_list = [
|
130 |
+
tokenizer(
|
131 |
+
text,
|
132 |
+
return_tensors="pt",
|
133 |
+
padding="longest",
|
134 |
+
max_length=tokenizer.model_max_length,
|
135 |
+
truncation=True,
|
136 |
+
)
|
137 |
+
for text in strings
|
138 |
+
]
|
139 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
140 |
+
input_ids_lens = labels_lens = [
|
141 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
142 |
+
]
|
143 |
+
return dict(
|
144 |
+
input_ids=input_ids,
|
145 |
+
labels=labels,
|
146 |
+
input_ids_lens=input_ids_lens,
|
147 |
+
labels_lens=labels_lens,
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def preprocess(
|
152 |
+
sources: Sequence[str],
|
153 |
+
targets: Sequence[str],
|
154 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
155 |
+
) -> Dict:
|
156 |
+
"""Preprocess the data by tokenizing."""
|
157 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
158 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
159 |
+
input_ids = examples_tokenized["input_ids"]
|
160 |
+
labels = copy.deepcopy(input_ids)
|
161 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
162 |
+
label[:source_len] = IGNORE_INDEX
|
163 |
+
return dict(input_ids=input_ids, labels=labels)
|
164 |
+
|
165 |
+
|
166 |
+
class SupervisedDataset(Dataset):
|
167 |
+
"""Dataset for supervised fine-tuning."""
|
168 |
+
|
169 |
+
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
|
170 |
+
super(SupervisedDataset, self).__init__()
|
171 |
+
logging.warning("Loading data...")
|
172 |
+
data_path = data_args.data_path
|
173 |
+
try:
|
174 |
+
data_path = data_path_map[data_path]
|
175 |
+
except:
|
176 |
+
data_path = data_path
|
177 |
+
list_data_dict = []
|
178 |
+
for item in data_path.split(','):
|
179 |
+
try:
|
180 |
+
list_data_dict += jload(item)
|
181 |
+
|
182 |
+
except BaseException:
|
183 |
+
with open(item, 'r') as f:
|
184 |
+
lines = f.readlines()
|
185 |
+
list_data_dict += [json.loads(line.strip()) for line in lines]
|
186 |
+
|
187 |
+
list_data_dict = random.sample(list_data_dict, len(list_data_dict))
|
188 |
+
list_data_dict = list_data_dict[:data_args.data_length]
|
189 |
+
|
190 |
+
logging.warning("Formatting inputs...")
|
191 |
+
|
192 |
+
# list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
|
193 |
+
|
194 |
+
list_data_dict = [{'instruction': PROMPT_DICT[data['task']].format(statement_text=data[KEY_DICT[data['task']][0]]), 'input': '',
|
195 |
+
'output':data[KEY_DICT[data['task']][1]] } for data in list_data_dict]
|
196 |
+
print(f"len of {len(list_data_dict)}")
|
197 |
+
sources = [example['instruction'] for example in list_data_dict]
|
198 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
199 |
+
# targets = [example['output'] for example in list_data_dict]
|
200 |
+
|
201 |
+
self.sources = sources
|
202 |
+
self.targets = targets
|
203 |
+
|
204 |
+
def __len__(self):
|
205 |
+
return len(self.sources)
|
206 |
+
|
207 |
+
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
|
208 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
209 |
+
|
210 |
+
def __getitem__(self, i):
|
211 |
+
return dict(input_ids=self.sources[i], labels=self.targets[i])
|
212 |
+
|
213 |
+
|
214 |
+
@dataclass
|
215 |
+
class DataCollatorForSupervisedDataset(object):
|
216 |
+
"""Collate examples for supervised fine-tuning."""
|
217 |
+
|
218 |
+
tokenizer: transformers.PreTrainedTokenizer
|
219 |
+
|
220 |
+
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
221 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
222 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
223 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
224 |
+
)
|
225 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
226 |
+
return dict(
|
227 |
+
input_ids=input_ids,
|
228 |
+
labels=labels,
|
229 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
230 |
+
)
|
231 |
+
|
232 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
233 |
+
sources = []
|
234 |
+
targets = []
|
235 |
+
for instance in instances:
|
236 |
+
source = instance['input_ids']
|
237 |
+
target = instance['labels']
|
238 |
+
sources.append(source)
|
239 |
+
targets.append(target)
|
240 |
+
|
241 |
+
data_dict = preprocess(sources, targets, self.tokenizer)
|
242 |
+
input_ids, labels = data_dict['input_ids'], data_dict['labels']
|
243 |
+
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
244 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
245 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
246 |
+
)
|
247 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
248 |
+
return dict(
|
249 |
+
input_ids=input_ids,
|
250 |
+
labels=labels,
|
251 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
256 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
257 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
|
258 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
259 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
260 |
+
|
261 |
+
|
262 |
+
os.environ["WANDB_PROJECT"] = "train_in_one_model"
|
263 |
+
|
264 |
+
|
265 |
+
def train():
|
266 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
267 |
+
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(
|
268 |
+
return_remaining_strings=True)
|
269 |
+
data_args.data_length = int(remaining_args[1])
|
270 |
+
|
271 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
272 |
+
model_args.model_name_or_path,
|
273 |
+
cache_dir=training_args.cache_dir,
|
274 |
+
trust_remote_code=True,
|
275 |
+
torch_dtype=torch.bfloat16,
|
276 |
+
attn_implementation="flash_attention_2",
|
277 |
+
)
|
278 |
+
|
279 |
+
model.config.use_cache = False
|
280 |
+
model.gradient_checkpointing_enable()
|
281 |
+
|
282 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
283 |
+
model_args.model_name_or_path,
|
284 |
+
cache_dir=training_args.cache_dir,
|
285 |
+
model_max_length=training_args.model_max_length,
|
286 |
+
padding_side="right",
|
287 |
+
use_fast=False,
|
288 |
+
)
|
289 |
+
if tokenizer.pad_token is None:
|
290 |
+
smart_tokenizer_and_embedding_resize(
|
291 |
+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
292 |
+
tokenizer=tokenizer,
|
293 |
+
model=model,
|
294 |
+
)
|
295 |
+
if "llama" in model_args.model_name_or_path:
|
296 |
+
tokenizer.add_special_tokens(
|
297 |
+
{
|
298 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
299 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
300 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
301 |
+
}
|
302 |
+
)
|
303 |
+
try:
|
304 |
+
tokenizer.pad_token = tokenizer.unk_token
|
305 |
+
except:
|
306 |
+
pass
|
307 |
+
|
308 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
309 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
310 |
+
trainer.train()
|
311 |
+
model.config.use_cache = True
|
312 |
+
# trainer.save_state()
|
313 |
+
# if os.environ.get('LOCAL_RANK') == '0':
|
314 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
315 |
+
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
train()
|
repl/.lake/packages/mathlib/scripts/align-import.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# This script was written by ChatGPT.
|
4 |
+
# https://chat.openai.com/share/e0363ebf-ed6f-4fd8-9b76-ebf422ed9f62
|
5 |
+
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
|
9 |
+
def update_file_header(file_path):
|
10 |
+
with open(file_path, 'r') as f:
|
11 |
+
lines = f.readlines()
|
12 |
+
|
13 |
+
# Initialize variables
|
14 |
+
end_of_header_index = 0
|
15 |
+
source_module = ""
|
16 |
+
repo_ref = ""
|
17 |
+
commit_id = ""
|
18 |
+
|
19 |
+
# Lines to delete
|
20 |
+
delete_indices = []
|
21 |
+
|
22 |
+
for i, line in enumerate(lines):
|
23 |
+
# Check for the end of the "import" lines
|
24 |
+
if line.startswith('import'):
|
25 |
+
end_of_header_index = i
|
26 |
+
elif end_of_header_index != 0 and not line.startswith('import'):
|
27 |
+
break
|
28 |
+
|
29 |
+
# Extract the necessary info for the align import line and mark lines for deletion
|
30 |
+
if line.startswith('! This file was ported from'):
|
31 |
+
source_module = line.split()[-1]
|
32 |
+
delete_indices.append(i)
|
33 |
+
elif line.startswith('!') and 'commit' in line and commit_id == "":
|
34 |
+
split_line = line.split()
|
35 |
+
repo_ref = split_line[1]
|
36 |
+
commit_id = split_line[-1]
|
37 |
+
delete_indices.append(i)
|
38 |
+
elif line.startswith('!'):
|
39 |
+
delete_indices.append(i)
|
40 |
+
elif line == "\n" and lines[i+1].startswith("!"):
|
41 |
+
delete_indices.append(i)
|
42 |
+
|
43 |
+
# Only proceed if we have found the necessary info for the align import line
|
44 |
+
if source_module and repo_ref and commit_id:
|
45 |
+
# Generate the new line
|
46 |
+
new_line = f'\n#align_import {source_module} from "{repo_ref}"@"{commit_id}"\n'
|
47 |
+
|
48 |
+
# Delete the marked lines
|
49 |
+
for index in sorted(delete_indices, reverse=True):
|
50 |
+
del lines[index]
|
51 |
+
|
52 |
+
# Insert the new line after the "import" lines
|
53 |
+
lines.insert(end_of_header_index - len(delete_indices) + 1, new_line)
|
54 |
+
|
55 |
+
# Write the updated lines back to the file
|
56 |
+
with open(file_path, 'w') as f:
|
57 |
+
f.writelines(lines)
|
58 |
+
|
59 |
+
# The first command line argument is the file path
|
60 |
+
file_path = sys.argv[1]
|
61 |
+
update_file_header(file_path)
|
repl/.lake/packages/mathlib/scripts/align.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Tool to add source headers to ported theory files,
|
4 |
+
# archived for historical purposes.
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
import re
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
excepts = {
|
11 |
+
'categorytheory.category.rel': 'categorytheory.category.relcat',
|
12 |
+
'categorytheory.isomorphism': 'categorytheory.iso',
|
13 |
+
'categorytheory.naturalisomorphism': 'categorytheory.natiso',
|
14 |
+
'categorytheory.naturaltransformation': 'categorytheory.nattrans',
|
15 |
+
'leancore.data.vector': 'data.vector',
|
16 |
+
'order.monovary': 'order.monotone.monovary'
|
17 |
+
}
|
18 |
+
|
19 |
+
def condense(s):
|
20 |
+
if s.startswith('Mathlib/'):
|
21 |
+
s = s[len('Mathlib/'):]
|
22 |
+
if s.endswith('.lean'):
|
23 |
+
s = s[:-5]
|
24 |
+
s = s.lower()
|
25 |
+
s = s.replace('/', '.')
|
26 |
+
s = s.replace('_', '')
|
27 |
+
if s in excepts:
|
28 |
+
s = excepts[s]
|
29 |
+
return s
|
30 |
+
|
31 |
+
port_status = yaml.load(open("mathlib4-port-status.yaml").read())
|
32 |
+
|
33 |
+
# map from condensed names to mathlib4 paths
|
34 |
+
map = {}
|
35 |
+
for path in Path('Mathlib').glob('**/*.lean'):
|
36 |
+
path = str(path)
|
37 |
+
map[condense(path)] = path
|
38 |
+
|
39 |
+
count = 0
|
40 |
+
for key, val in port_status.items():
|
41 |
+
if val.startswith('Yes'):
|
42 |
+
sha = val.split()[2]
|
43 |
+
mathlib3 = key
|
44 |
+
mathlib4 = map[condense(key)]
|
45 |
+
|
46 |
+
place = '(\n-/\n\n?import )'
|
47 |
+
blob = "\n\n! This file was ported from Lean 3 source module " + mathlib3 + "\n" + \
|
48 |
+
"! leanprover-community/mathlib commit " + sha + "\n" + \
|
49 |
+
"! Please do not edit these lines, except to modify the commit id\n" + \
|
50 |
+
"! if you have ported upstream changes."
|
51 |
+
old = open(mathlib4).read()
|
52 |
+
|
53 |
+
if blob[1:] in old: # match even without leading newline
|
54 |
+
print(f'{mathlib4} already has header')
|
55 |
+
elif "! leanprover-community/mathlib commit " in old:
|
56 |
+
m = re.search("^! leanprover-community/mathlib commit (.*)$", old, flags=re.MULTILINE)
|
57 |
+
print(f'file says {m.groups()[0]} but we want {sha}')
|
58 |
+
assert(False)
|
59 |
+
else:
|
60 |
+
new = re.sub(place, blob + '\\1', old, flags=re.MULTILINE)
|
61 |
+
open(mathlib4, 'w').write(new)
|
62 |
+
count += 1
|
63 |
+
|
64 |
+
print(count)
|
repl/.lake/packages/mathlib/scripts/bench/accumulate_profile.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# sum up times of lines a la `elaboration 100ms`
|
3 |
+
|
4 |
+
import collections
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
|
8 |
+
cats = collections.defaultdict(lambda: 0)
|
9 |
+
for line in sys.stdin:
|
10 |
+
sys.stderr.write(line)
|
11 |
+
if m := re.match("(.+?) ([\d.]+)(m?)s$", line):
|
12 |
+
cats[m[1].strip()] += float(m[2]) * (1e-3 if m[3] else 1)
|
13 |
+
|
14 |
+
for cat in sorted(cats.keys()):
|
15 |
+
cat2 = cat
|
16 |
+
if len(sys.argv) > 1:
|
17 |
+
cat2 = f"{sys.argv[1]} {cat}"
|
18 |
+
# default unit to `s`
|
19 |
+
if "|" not in cat2:
|
20 |
+
cat2 += "|s"
|
21 |
+
print(f"{cat2!r}: {cats[cat]:f}")
|
repl/.lake/packages/mathlib/scripts/detect_sha_changes.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script is called by a github action to verify that commit SHAs used in porting are valid.
|
3 |
+
It also produces links to the port-status webpage.
|
4 |
+
|
5 |
+
Note that only the first 10 annotations created with this action are guaranteed to appear, so we
|
6 |
+
produce the errors first.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import dataclasses
|
10 |
+
import re
|
11 |
+
import sys
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
import git
|
15 |
+
|
16 |
+
# upstream bug
|
17 |
+
git.Git.CatFileContentStream.__next__ = git.Git.CatFileContentStream.next
|
18 |
+
|
19 |
+
align_import_re = re.compile(
|
20 |
+
r'^#align_import ([^ ]*) from "(leanprover-community/[a-z]*)" ?@ ?"([0-9a-f]*)"')
|
21 |
+
|
22 |
+
@dataclasses.dataclass(eq=True, frozen=True)
|
23 |
+
class VersionInfo:
|
24 |
+
module: str
|
25 |
+
repo: Optional[str]
|
26 |
+
commit: Optional[str]
|
27 |
+
commit_line_no: Optional[int] = dataclasses.field(compare=False)
|
28 |
+
|
29 |
+
def to_commit(self):
|
30 |
+
try:
|
31 |
+
repo = git.Repo('port-repos/' + self.repo)
|
32 |
+
except git.exc.NoSuchPathError:
|
33 |
+
raise ValueError(f"Repo {self.repo} not recognized")
|
34 |
+
try:
|
35 |
+
repo.remotes.origin.fetch(self.commit, depth=1)
|
36 |
+
except Exception:
|
37 |
+
pass
|
38 |
+
return repo.commit(self.commit)
|
39 |
+
|
40 |
+
def get_mathlib4_module_commit_infos(contents):
|
41 |
+
for i, line in enumerate(contents, 1):
|
42 |
+
m = align_import_re.match(line)
|
43 |
+
if m:
|
44 |
+
module = m.group(1)
|
45 |
+
repo = m.group(2)
|
46 |
+
commit = m.group(3)
|
47 |
+
yield VersionInfo(module, repo, commit, i)
|
48 |
+
|
49 |
+
def get_mathlib4_module_commit_info_from_blob(blob: Optional[git.Blob]):
|
50 |
+
if blob is None:
|
51 |
+
return
|
52 |
+
yield from get_mathlib4_module_commit_infos(
|
53 |
+
l.decode('utf8') for l in blob.data_stream.stream)
|
54 |
+
|
55 |
+
def encode_msg_text_for_github(msg):
|
56 |
+
# even though this is probably url quoting, we match the implementation at
|
57 |
+
# https://github.com/actions/toolkit/blob/af821474235d3c5e1f49cee7c6cf636abb0874c4/packages/core/src/command.ts#L36-L94
|
58 |
+
return msg.replace('%', '%25').replace('\r', '%0D').replace('\n', '%0A')
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
repo = git.Repo('.')
|
62 |
+
base = repo.commit(sys.argv[1])
|
63 |
+
head = repo.commit(sys.argv[2])
|
64 |
+
any_errors = False
|
65 |
+
|
66 |
+
diff_infos = []
|
67 |
+
for diff in base.diff(head, paths=['Mathlib']):
|
68 |
+
a_info = set(get_mathlib4_module_commit_info_from_blob(diff.a_blob))
|
69 |
+
b_info = set(get_mathlib4_module_commit_info_from_blob(diff.b_blob))
|
70 |
+
if b_info <= a_info:
|
71 |
+
continue
|
72 |
+
diff_infos.append((diff, a_info, b_info))
|
73 |
+
|
74 |
+
all_refs = {}
|
75 |
+
|
76 |
+
# produce errors first
|
77 |
+
for diff, a_infos, b_infos in diff_infos:
|
78 |
+
for b_info in b_infos:
|
79 |
+
try:
|
80 |
+
b_info.to_commit()
|
81 |
+
except Exception as e:
|
82 |
+
print(f"::error file={diff.b_blob.path},line={b_info.commit_line_no},title=Invalid header::{encode_msg_text_for_github(str(e))}")
|
83 |
+
any_errors = True
|
84 |
+
continue
|
85 |
+
|
86 |
+
for diff, a_info, b_info in diff_infos:
|
87 |
+
same = a_info.intersection(b_info)
|
88 |
+
a_info -= same
|
89 |
+
b_info -= same
|
90 |
+
if a_info != {} and b_info != {}:
|
91 |
+
a_info_by_mod = {a.module: a for a in a_info}
|
92 |
+
b_info_by_mod = {b.module: b for b in b_info}
|
93 |
+
for k in set(a_info_by_mod.keys()) | set(b_info_by_mod.keys()):
|
94 |
+
a_info = a_info_by_mod.get(k, None)
|
95 |
+
b_info = b_info_by_mod.get(k, None)
|
96 |
+
if a_info is None or b_info is None:
|
97 |
+
pass
|
98 |
+
elif a_info.module == b_info.module:
|
99 |
+
mod_path = a_info.module.replace('.', '/')
|
100 |
+
msg = f"See review instructions and diff at\nhttps://leanprover-community.github.io/mathlib-port-status/file/{mod_path}?range={a_info.commit}..{b_info.commit}"
|
101 |
+
print(f"::notice file={diff.b_blob.path},line={b_info.commit_line_no},title=Synchronization::{encode_msg_text_for_github(msg)}")
|
102 |
+
|
103 |
+
if any_errors:
|
104 |
+
raise SystemExit("Setting a failure due to errors above")
|
repl/.lake/packages/mathlib/scripts/fix-comments.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import subprocess
|
7 |
+
import re
|
8 |
+
|
9 |
+
if len(sys.argv) != 2 or not sys.argv[1].endswith('.lean'):
|
10 |
+
print("usage: fix-comments.py X.lean")
|
11 |
+
sys.exit(1)
|
12 |
+
|
13 |
+
leanfile = sys.argv[1]
|
14 |
+
|
15 |
+
is_clean = subprocess.run(
|
16 |
+
['git', 'status', '--untracked-files=no', '--porcelain'],
|
17 |
+
capture_output=True,
|
18 |
+
check=True,
|
19 |
+
encoding='utf-8').stdout.rstrip()
|
20 |
+
|
21 |
+
if is_clean != "":
|
22 |
+
print("Certain files tracked by git have uncommitted changes.\n")
|
23 |
+
os.system("git status --untracked-files=no")
|
24 |
+
print("\n")
|
25 |
+
s = input("Type y to continue. ")
|
26 |
+
if s != 'y':
|
27 |
+
sys.exit(1)
|
28 |
+
|
29 |
+
root_dir = subprocess.run(
|
30 |
+
['git', 'rev-parse', '--show-toplevel'],
|
31 |
+
capture_output=True,
|
32 |
+
check=True,
|
33 |
+
encoding='utf-8').stdout.rstrip()
|
34 |
+
|
35 |
+
align_files = subprocess.run(
|
36 |
+
['git', 'grep', '-l', '^#align'],
|
37 |
+
cwd=root_dir,
|
38 |
+
capture_output=True,
|
39 |
+
check=True,
|
40 |
+
encoding='utf-8')
|
41 |
+
|
42 |
+
name_map = dict()
|
43 |
+
for f in align_files.stdout.splitlines():
|
44 |
+
with open(os.path.join(root_dir, f), encoding="utf-8") as fh:
|
45 |
+
contents = fh.read()
|
46 |
+
for p in contents.split(sep='\n#align')[1:]:
|
47 |
+
n3, n4, *_ = p.split(maxsplit=2)
|
48 |
+
name_map[n3] = n4
|
49 |
+
|
50 |
+
def replace_names(s):
|
51 |
+
# Terrible hack to treat `.` as a word character
|
52 |
+
# (to match qualified names)
|
53 |
+
s = s.replace('.', 'Ᾰ')
|
54 |
+
# re.DOTALL means that `.` can also match a newline.
|
55 |
+
# `\A` and `\Z` match only at the start/end of the string respectively.
|
56 |
+
w = re.findall(r'(?:\b|\A).+?(?:\b|\Z)', s, flags=re.DOTALL)
|
57 |
+
for i in range(len(w)):
|
58 |
+
w[i] = w[i].replace('Ᾰ', '.')
|
59 |
+
w[i] = name_map.get(w[i], w[i])
|
60 |
+
return ''.join(w)
|
61 |
+
|
62 |
+
def process_backticked_names(s):
|
63 |
+
w = s.split(sep='`')
|
64 |
+
for i in range(len(w)):
|
65 |
+
if i % 2 == 1:
|
66 |
+
w[i] = replace_names(w[i])
|
67 |
+
return '`'.join(w)
|
68 |
+
|
69 |
+
rewritten_contents = ''
|
70 |
+
|
71 |
+
in_block_comment = False
|
72 |
+
in_line_comment = False
|
73 |
+
prev_char = None
|
74 |
+
comment_so_far = None # contains end marker but not begin marker
|
75 |
+
|
76 |
+
def finish_comment():
|
77 |
+
global rewritten_contents
|
78 |
+
global in_block_comment
|
79 |
+
global in_line_comment
|
80 |
+
global comment_so_far
|
81 |
+
if comment_so_far is not None:
|
82 |
+
rewritten_contents += process_backticked_names(comment_so_far)
|
83 |
+
in_block_comment = False
|
84 |
+
in_line_comment = False
|
85 |
+
comment_so_far = None
|
86 |
+
|
87 |
+
with open(leanfile, encoding="utf-8") as F:
|
88 |
+
while 1:
|
89 |
+
char = F.read(1)
|
90 |
+
if not char:
|
91 |
+
finish_comment()
|
92 |
+
break
|
93 |
+
|
94 |
+
if in_block_comment or in_line_comment:
|
95 |
+
comment_so_far = comment_so_far + char
|
96 |
+
else:
|
97 |
+
rewritten_contents += char
|
98 |
+
|
99 |
+
if in_block_comment and prev_char == '-' and char == '/':
|
100 |
+
finish_comment()
|
101 |
+
|
102 |
+
if in_line_comment and char == '\n':
|
103 |
+
finish_comment()
|
104 |
+
|
105 |
+
if comment_so_far is None and prev_char == '/' and char == '-':
|
106 |
+
in_block_comment = True
|
107 |
+
comment_so_far = ''
|
108 |
+
|
109 |
+
if comment_so_far is None and prev_char == '-' and char == '-':
|
110 |
+
in_line_comment = True
|
111 |
+
comment_so_far = ''
|
112 |
+
|
113 |
+
prev_char = char
|
114 |
+
|
115 |
+
def mktree(path, sha, tree=True):
|
116 |
+
if path == Path('.'):
|
117 |
+
return sha
|
118 |
+
if tree:
|
119 |
+
inp = f"040000 tree {sha}\t{path.name}"
|
120 |
+
else:
|
121 |
+
inp = f"100644 blob {sha}\t{path.name}"
|
122 |
+
tree_sha = subprocess.run(
|
123 |
+
['git', 'mktree'],
|
124 |
+
cwd=root_dir,
|
125 |
+
input=inp,
|
126 |
+
capture_output=True,
|
127 |
+
check=True,
|
128 |
+
encoding='utf8').stdout.rstrip()
|
129 |
+
return mktree(path.parent, tree_sha)
|
130 |
+
|
131 |
+
path = Path(subprocess.run(
|
132 |
+
['git', 'ls-files', '--full-name', leanfile],
|
133 |
+
capture_output=True,
|
134 |
+
check=True,
|
135 |
+
encoding='utf-8').stdout.rstrip())
|
136 |
+
|
137 |
+
blob_sha = subprocess.run(
|
138 |
+
['git', 'hash-object', '-w', '--stdin'],
|
139 |
+
input=rewritten_contents,
|
140 |
+
cwd=root_dir,
|
141 |
+
capture_output=True,
|
142 |
+
check=True,
|
143 |
+
encoding='utf-8').stdout.rstrip()
|
144 |
+
|
145 |
+
tree_sha = mktree(path, blob_sha, tree=False)
|
146 |
+
|
147 |
+
print(f"The script will now interactively suggest changes to {leanfile}.\n")
|
148 |
+
s = input("Type y to continue. ")
|
149 |
+
if s != 'y':
|
150 |
+
sys.exit(1)
|
151 |
+
|
152 |
+
subprocess.run(['git', 'restore', '--patch', '--source=' + tree_sha, '--', leanfile], check=True)
|
153 |
+
|
154 |
+
r = subprocess.run(['git', 'diff', '--quiet', leanfile])
|
155 |
+
if r.returncode == 0:
|
156 |
+
pass
|
157 |
+
elif r.returncode == 1: # file was changed
|
158 |
+
print("\nPerhaps you would now like to run:")
|
159 |
+
print(f"git add {leanfile} && git commit -m 'auto: naming'")
|
160 |
+
else:
|
161 |
+
# something went wrong
|
162 |
+
r.check_returncode()
|
repl/.lake/packages/mathlib/scripts/fix-line-breaks.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import sys
|
3 |
+
from collections import deque
|
4 |
+
|
5 |
+
lns = deque([], 2)
|
6 |
+
with open(sys.argv[1], "r", encoding="utf-8", newline="\n") as f, \
|
7 |
+
open(sys.argv[2], "w", encoding="utf-8", newline="\n") as g:
|
8 |
+
for ln_raw in f:
|
9 |
+
ln = ln_raw.strip("\n")
|
10 |
+
lns.append(ln)
|
11 |
+
if len(lns) <= 1:
|
12 |
+
continue
|
13 |
+
if lns[1].lstrip() == "by" and len(lns[0]) < 98 and not lns[0].lstrip().startswith("--"):
|
14 |
+
lns.pop()
|
15 |
+
lns[0] += " by"
|
16 |
+
elif lns[1].lstrip() == "where" and len(lns[0]) < 95 and not lns[0].lstrip().startswith("--"):
|
17 |
+
lns.pop()
|
18 |
+
lns[0] += " where"
|
19 |
+
else:
|
20 |
+
print(lns[0], file=g)
|
21 |
+
lns.popleft()
|
22 |
+
for ln in lns:
|
23 |
+
print(ln, file=g)
|
repl/.lake/packages/mathlib/scripts/fix-lints.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
def getpos(line):
|
9 |
+
_, line, col, _ = line.split(sep=':', maxsplit=3)
|
10 |
+
return int(line), int(col)
|
11 |
+
|
12 |
+
if len(sys.argv) != 2 or not sys.argv[1].endswith('.lean'):
|
13 |
+
print("usage: fix-lints.py Mathlib/A/B/C.lean")
|
14 |
+
sys.exit(1)
|
15 |
+
|
16 |
+
leanfile = sys.argv[1]
|
17 |
+
leanmodule = leanfile[:-5].replace('/', '.')
|
18 |
+
|
19 |
+
# try to build
|
20 |
+
log = subprocess.run(
|
21 |
+
['lake', 'build', leanmodule],
|
22 |
+
capture_output=True, encoding='utf8')
|
23 |
+
if log.returncode == 0:
|
24 |
+
print("no errors 🎉")
|
25 |
+
exit(0)
|
26 |
+
|
27 |
+
shutil.copyfile(leanfile, leanfile + '.bak')
|
28 |
+
|
29 |
+
with open(leanfile + '.bak', encoding='utf8') as fp:
|
30 |
+
f = list(fp)
|
31 |
+
count = 0
|
32 |
+
for l in reversed(log.stderr.splitlines()):
|
33 |
+
if 'linter.unusedVariables' in l:
|
34 |
+
line, col = getpos(l)
|
35 |
+
f[line-1] = f[line-1][0:col] + '_' + f[line-1][col:]
|
36 |
+
count += 1
|
37 |
+
elif 'linter.unnecessarySeqFocus' in l:
|
38 |
+
line, col = getpos(l)
|
39 |
+
f[line-1] = f[line-1][0:col].rstrip() + ';' + f[line-1][col+3:]
|
40 |
+
count += 1
|
41 |
+
else:
|
42 |
+
print(l, file=sys.stderr)
|
43 |
+
|
44 |
+
print(f'Fixed {count} warnings', file=sys.stderr)
|
45 |
+
|
46 |
+
with open(leanfile, 'w', encoding='utf8') as fp:
|
47 |
+
fp.write(''.join(f))
|
48 |
+
os.remove(leanfile + '.bak')
|
repl/.lake/packages/mathlib/scripts/lint-style.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Lint a file or files from mathlib for style.
|
4 |
+
|
5 |
+
Sample usage:
|
6 |
+
|
7 |
+
$ ./scripts/lint-style.py $(find Mathlib -name '*.lean')
|
8 |
+
|
9 |
+
which will lint all of the Lean files in the specified directories.
|
10 |
+
|
11 |
+
The resulting error output will contain one line for each style error
|
12 |
+
encountered that isn't in the list of allowed / ignored style exceptions.
|
13 |
+
|
14 |
+
Paths with no errors will not appear in the output, and the script will
|
15 |
+
exit with successful return code if there are no errors encountered in
|
16 |
+
any provided paths.
|
17 |
+
|
18 |
+
Paths emitted in the output will match the paths provided on the
|
19 |
+
command line for any files containing errors -- in particular, linting
|
20 |
+
a relative path (like ``Mathlib/Foo/Bar.lean``) will produce errors
|
21 |
+
that contain the relative path, whilst linting absolute paths (like
|
22 |
+
``/root/mathlib4/Mathlib/Foo/Bar.lean``) will produce errors with the
|
23 |
+
absolute path.
|
24 |
+
|
25 |
+
This script can also be used to regenerate the list of allowed / ignored style
|
26 |
+
exceptions by redirecting the output to ``style-exceptions.txt``. Use:
|
27 |
+
|
28 |
+
$ ./scripts/update-style-exceptions.sh
|
29 |
+
|
30 |
+
to perform this update.
|
31 |
+
"""
|
32 |
+
|
33 |
+
# TODO: This is adapted from the linter for mathlib3. It should be rewritten in Lean.
|
34 |
+
|
35 |
+
from pathlib import Path
|
36 |
+
import sys
|
37 |
+
import re
|
38 |
+
import shutil
|
39 |
+
|
40 |
+
ERR_COP = 0 # copyright header
|
41 |
+
ERR_MOD = 2 # module docstring
|
42 |
+
ERR_LIN = 3 # line length
|
43 |
+
ERR_OPT = 6 # set_option
|
44 |
+
ERR_AUT = 7 # malformed authors list
|
45 |
+
ERR_TAC = 9 # imported Mathlib.Tactic
|
46 |
+
ERR_IBY = 11 # isolated by
|
47 |
+
ERR_DOT = 12 # isolated or low focusing dot
|
48 |
+
ERR_SEM = 13 # the substring " ;"
|
49 |
+
ERR_WIN = 14 # Windows line endings "\r\n"
|
50 |
+
ERR_TWS = 15 # trailing whitespace
|
51 |
+
ERR_CLN = 16 # line starts with a colon
|
52 |
+
ERR_IND = 17 # second line not correctly indented
|
53 |
+
ERR_ARR = 18 # space after "←"
|
54 |
+
ERR_NUM_LIN = 19 # file is too large
|
55 |
+
ERR_NSP = 20 # non-terminal simp
|
56 |
+
|
57 |
+
exceptions = []
|
58 |
+
|
59 |
+
SCRIPTS_DIR = Path(__file__).parent.resolve()
|
60 |
+
ROOT_DIR = SCRIPTS_DIR.parent
|
61 |
+
|
62 |
+
|
63 |
+
with SCRIPTS_DIR.joinpath("style-exceptions.txt").open(encoding="utf-8") as f:
|
64 |
+
for exline in f:
|
65 |
+
filename, _, _, _, _, errno, *extra = exline.split()
|
66 |
+
path = ROOT_DIR / filename
|
67 |
+
if errno == "ERR_COP":
|
68 |
+
exceptions += [(ERR_COP, path, None)]
|
69 |
+
elif errno == "ERR_MOD":
|
70 |
+
exceptions += [(ERR_MOD, path, None)]
|
71 |
+
elif errno == "ERR_LIN":
|
72 |
+
exceptions += [(ERR_LIN, path, None)]
|
73 |
+
elif errno == "ERR_OPT":
|
74 |
+
exceptions += [(ERR_OPT, path, None)]
|
75 |
+
elif errno == "ERR_AUT":
|
76 |
+
exceptions += [(ERR_AUT, path, None)]
|
77 |
+
elif errno == "ERR_TAC":
|
78 |
+
exceptions += [(ERR_TAC, path, None)]
|
79 |
+
elif errno == "ERR_NUM_LIN":
|
80 |
+
exceptions += [(ERR_NUM_LIN, path, extra[1])]
|
81 |
+
else:
|
82 |
+
print(f"Error: unexpected errno in style-exceptions.txt: {errno}")
|
83 |
+
sys.exit(1)
|
84 |
+
|
85 |
+
new_exceptions = False
|
86 |
+
|
87 |
+
def annotate_comments(enumerate_lines):
|
88 |
+
"""
|
89 |
+
Take a list of tuples of enumerated lines of the form
|
90 |
+
(line_number, line, ...)
|
91 |
+
and return a list of
|
92 |
+
(line_number, line, ..., True/False)
|
93 |
+
where lines have True attached when they are in comments.
|
94 |
+
"""
|
95 |
+
nesting_depth = 0 # We're in a comment when `nesting_depth > 0`.
|
96 |
+
starts_in_comment = False # Whether we're in a comment when starting the line.
|
97 |
+
for line_nr, line, *rem in enumerate_lines:
|
98 |
+
# We assume multiline comments do not begin or end within single-line comments.
|
99 |
+
if line == "\n" or line.lstrip().startswith("--"):
|
100 |
+
yield line_nr, line, *rem, True
|
101 |
+
continue
|
102 |
+
# We assume that "/-/" and "-/-" never occur outside of "--" comments.
|
103 |
+
# We assume that we do not encounter "... -/ <term> /- ...".
|
104 |
+
# We also don't account for "/-" and "-/" appearing in strings.
|
105 |
+
starts_in_comment = (nesting_depth > 0)
|
106 |
+
nesting_depth = nesting_depth + line.count("/-") - line.count("-/")
|
107 |
+
in_comment = (starts_in_comment or line.lstrip().startswith("/-")) and \
|
108 |
+
(nesting_depth > 0 or line.rstrip().endswith("-/"))
|
109 |
+
yield line_nr, line, *rem, in_comment
|
110 |
+
|
111 |
+
def annotate_strings(enumerate_lines):
|
112 |
+
"""
|
113 |
+
Take a list of tuples of enumerated lines of the form
|
114 |
+
(line_number, line, ...)
|
115 |
+
and return a list of
|
116 |
+
(line_number, line, ..., True/False)
|
117 |
+
where lines have True attached when they are in strings.
|
118 |
+
"""
|
119 |
+
in_string = False
|
120 |
+
in_comment = False
|
121 |
+
for line_nr, line, *rem in enumerate_lines:
|
122 |
+
# ignore comment markers inside string literals
|
123 |
+
if not in_string:
|
124 |
+
if "/-" in line:
|
125 |
+
in_comment = True
|
126 |
+
if "-/" in line:
|
127 |
+
in_comment = False
|
128 |
+
# ignore quotes inside comments
|
129 |
+
if not in_comment:
|
130 |
+
# crude heuristic: if the number of non-escaped quote signs is odd,
|
131 |
+
# we're starting / ending a string literal
|
132 |
+
if line.count("\"") - line.count("\\\"") % 2 == 1:
|
133 |
+
in_string = not in_string
|
134 |
+
# if there are quote signs in this line,
|
135 |
+
# a string literal probably begins and / or ends here,
|
136 |
+
# so we skip this line
|
137 |
+
if line.count("\"") > 0:
|
138 |
+
yield line_nr, line, *rem, True
|
139 |
+
continue
|
140 |
+
if in_string:
|
141 |
+
yield line_nr, line, *rem, True
|
142 |
+
continue
|
143 |
+
yield line_nr, line, *rem, False
|
144 |
+
|
145 |
+
def set_option_check(lines, path):
|
146 |
+
errors = []
|
147 |
+
newlines = []
|
148 |
+
for line_nr, line, in_comment, in_string in annotate_strings(annotate_comments(lines)):
|
149 |
+
if line.strip().startswith('set_option') and not in_comment and not in_string:
|
150 |
+
option_prefix = line.strip().split(' ', 2)[1].split('.', 1)[0]
|
151 |
+
# forbidden options: pp, profiler, trace
|
152 |
+
if option_prefix in {'pp', 'profiler', 'trace'}:
|
153 |
+
errors += [(ERR_OPT, line_nr, path)]
|
154 |
+
# skip adding this line to newlines so that we suggest removal
|
155 |
+
continue
|
156 |
+
newlines.append((line_nr, line))
|
157 |
+
return errors, newlines
|
158 |
+
|
159 |
+
def line_endings_check(lines, path):
|
160 |
+
errors = []
|
161 |
+
newlines = []
|
162 |
+
for line_nr, line in lines:
|
163 |
+
if "\r\n" in line:
|
164 |
+
errors += [(ERR_WIN, line_nr, path)]
|
165 |
+
line = line.replace("\r\n", "\n")
|
166 |
+
if line.endswith(" \n"):
|
167 |
+
errors += [(ERR_TWS, line_nr, path)]
|
168 |
+
line = line.rstrip() + "\n"
|
169 |
+
newlines.append((line_nr, line))
|
170 |
+
return errors, newlines
|
171 |
+
|
172 |
+
def four_spaces_in_second_line(lines, path):
|
173 |
+
# TODO: also fix the space for all lines before ":=", right now we only fix the line after
|
174 |
+
# the first line break
|
175 |
+
errors = []
|
176 |
+
# We never alter the first line, as it does not occur as next_line in the iteration over the
|
177 |
+
# zipped lines below, hence we add it here
|
178 |
+
newlines = [lines[0]]
|
179 |
+
annotated_lines = list(annotate_comments(lines))
|
180 |
+
for (_, line, is_comment), (next_line_nr, next_line, _) in zip(annotated_lines,
|
181 |
+
annotated_lines[1:]):
|
182 |
+
# Check if the current line matches "(lemma|theorem) .* :"
|
183 |
+
new_next_line = next_line
|
184 |
+
if (not is_comment) and re.search(r"^(protected )?(def|lemma|theorem) (?!.*:=).*(where)?$",
|
185 |
+
line):
|
186 |
+
# Calculate the number of spaces before the first non-space character in the next line
|
187 |
+
stripped_next_line = next_line.lstrip()
|
188 |
+
if not (next_line == '\n' or next_line.startswith("#") or stripped_next_line.startswith("--")):
|
189 |
+
num_spaces = len(next_line) - len(stripped_next_line)
|
190 |
+
# The match with "| " could potentially match with a different usage of the same
|
191 |
+
# symbol, e.g. some sort of norm. In that case a space is not necessary, so
|
192 |
+
# looking for "| " should be enough.
|
193 |
+
if stripped_next_line.startswith("| ") or line.endswith("where\n"):
|
194 |
+
# Check and fix if the number of leading space is not 2
|
195 |
+
if num_spaces != 2:
|
196 |
+
errors += [(ERR_IND, next_line_nr, path)]
|
197 |
+
new_next_line = ' ' * 2 + stripped_next_line
|
198 |
+
# Check and fix if the number of leading spaces is not 4
|
199 |
+
else:
|
200 |
+
if num_spaces != 4:
|
201 |
+
errors += [(ERR_IND, next_line_nr, path)]
|
202 |
+
new_next_line = ' ' * 4 + stripped_next_line
|
203 |
+
newlines.append((next_line_nr, new_next_line))
|
204 |
+
return errors, newlines
|
205 |
+
|
206 |
+
def nonterminal_simp_check(lines, path):
|
207 |
+
errors = []
|
208 |
+
newlines = []
|
209 |
+
annotated_lines = list(annotate_comments(lines))
|
210 |
+
for (line_nr, line, is_comment), (_, next_line, _) in zip(annotated_lines,
|
211 |
+
annotated_lines[1:]):
|
212 |
+
# Check if the current line matches whitespace followed by "simp"
|
213 |
+
new_line = line
|
214 |
+
# TODO it would be better to use a regex like r"^\s*simp( \[.*\])?( at .*)?$" and thereby
|
215 |
+
# catch all possible simp invocations. Adding this will require more initial cleanup or
|
216 |
+
# nolint.
|
217 |
+
if (not is_comment) and re.search(r"^\s*simp$", line):
|
218 |
+
# Calculate the number of spaces before the first non-space character in the line
|
219 |
+
num_spaces = len(line) - len(line.lstrip())
|
220 |
+
# Calculate the number of spaces before the first non-space character in the next line
|
221 |
+
stripped_next_line = next_line.lstrip()
|
222 |
+
if not (next_line == '\n' or next_line.startswith("#") or stripped_next_line.startswith("--") or "rfl" in next_line):
|
223 |
+
num_next_spaces = len(next_line) - len(stripped_next_line)
|
224 |
+
# Check if the number of leading spaces is the same
|
225 |
+
if num_spaces == num_next_spaces:
|
226 |
+
# If so, the simp is nonterminal
|
227 |
+
errors += [(ERR_NSP, line_nr, path)]
|
228 |
+
new_line = line.replace("simp", "simp?")
|
229 |
+
newlines.append((line_nr, new_line))
|
230 |
+
newlines.append(lines[-1])
|
231 |
+
return errors, newlines
|
232 |
+
|
233 |
+
def long_lines_check(lines, path):
|
234 |
+
errors = []
|
235 |
+
# TODO: find a good way to break long lines
|
236 |
+
# TODO: some string literals (in e.g. tactic output messages) can be excepted from this rule
|
237 |
+
for line_nr, line in lines:
|
238 |
+
if "http" in line or "#align" in line:
|
239 |
+
continue
|
240 |
+
if len(line) > 101:
|
241 |
+
errors += [(ERR_LIN, line_nr, path)]
|
242 |
+
return errors, lines
|
243 |
+
|
244 |
+
def import_only_check(lines, path):
|
245 |
+
for _, line, is_comment in annotate_comments(lines):
|
246 |
+
if is_comment:
|
247 |
+
continue
|
248 |
+
imports = line.split()
|
249 |
+
if imports[0] == "#align_import":
|
250 |
+
continue
|
251 |
+
if imports[0] != "import":
|
252 |
+
return False
|
253 |
+
return True
|
254 |
+
|
255 |
+
def regular_check(lines, path):
|
256 |
+
errors = []
|
257 |
+
copy_started = False
|
258 |
+
copy_done = False
|
259 |
+
copy_start_line_nr = 1
|
260 |
+
copy_lines = ""
|
261 |
+
for line_nr, line in lines:
|
262 |
+
if not copy_started and line == "\n":
|
263 |
+
errors += [(ERR_COP, copy_start_line_nr, path)]
|
264 |
+
continue
|
265 |
+
if not copy_started and line == "/-\n":
|
266 |
+
copy_started = True
|
267 |
+
copy_start_line_nr = line_nr
|
268 |
+
continue
|
269 |
+
if not copy_started:
|
270 |
+
errors += [(ERR_COP, line_nr, path)]
|
271 |
+
if copy_started and not copy_done:
|
272 |
+
copy_lines += line
|
273 |
+
if "Author" in line:
|
274 |
+
# Validating names is not a reasonable thing to do,
|
275 |
+
# so we just look for the two common variations:
|
276 |
+
# using ' and ' between names, and a '.' at the end of line.
|
277 |
+
if ((not line.startswith("Authors: ")) or
|
278 |
+
(" " in line) or
|
279 |
+
(" and " in line) or
|
280 |
+
(line[-2] == '.')):
|
281 |
+
errors += [(ERR_AUT, line_nr, path)]
|
282 |
+
if line == "-/\n":
|
283 |
+
if ((not "Copyright" in copy_lines) or
|
284 |
+
(not "Apache" in copy_lines) or
|
285 |
+
(not "Authors: " in copy_lines)):
|
286 |
+
errors += [(ERR_COP, copy_start_line_nr, path)]
|
287 |
+
copy_done = True
|
288 |
+
continue
|
289 |
+
if copy_done and line == "\n":
|
290 |
+
continue
|
291 |
+
words = line.split()
|
292 |
+
if words[0] != "import" and words[0] != "--" and words[0] != "/-!" and words[0] != "#align_import":
|
293 |
+
errors += [(ERR_MOD, line_nr, path)]
|
294 |
+
break
|
295 |
+
if words[0] == "/-!":
|
296 |
+
break
|
297 |
+
return errors, lines
|
298 |
+
|
299 |
+
def banned_import_check(lines, path):
|
300 |
+
errors = []
|
301 |
+
for line_nr, line, is_comment in annotate_comments(lines):
|
302 |
+
if is_comment:
|
303 |
+
continue
|
304 |
+
imports = line.split()
|
305 |
+
if imports[0] != "import":
|
306 |
+
break
|
307 |
+
if imports[1] in ["Mathlib.Tactic"]:
|
308 |
+
errors += [(ERR_TAC, line_nr, path)]
|
309 |
+
return errors, lines
|
310 |
+
|
311 |
+
def isolated_by_dot_semicolon_check(lines, path):
|
312 |
+
errors = []
|
313 |
+
newlines = []
|
314 |
+
for line_nr, line in lines:
|
315 |
+
if line.strip() == "by":
|
316 |
+
# We excuse those "by"s following a comma or ", fun ... =>", since generally hanging "by"s
|
317 |
+
# should not be used in the second or later arguments of a tuple/anonymous constructor
|
318 |
+
# See https://github.com/leanprover-community/mathlib4/pull/3825#discussion_r1186702599
|
319 |
+
prev_line = lines[line_nr - 2][1].rstrip()
|
320 |
+
if not prev_line.endswith(",") and not re.search(", fun [^,]* (=>|↦)$", prev_line):
|
321 |
+
errors += [(ERR_IBY, line_nr, path)]
|
322 |
+
if line.lstrip().startswith(". "):
|
323 |
+
errors += [(ERR_DOT, line_nr, path)]
|
324 |
+
line = line.replace(". ", "· ", 1)
|
325 |
+
if line.strip() in (".", "·"):
|
326 |
+
errors += [(ERR_DOT, line_nr, path)]
|
327 |
+
if " ;" in line:
|
328 |
+
errors += [(ERR_SEM, line_nr, path)]
|
329 |
+
line = line.replace(" ;", ";")
|
330 |
+
if line.lstrip().startswith(":"):
|
331 |
+
errors += [(ERR_CLN, line_nr, path)]
|
332 |
+
newlines.append((line_nr, line))
|
333 |
+
return errors, newlines
|
334 |
+
|
335 |
+
def left_arrow_check(lines, path):
|
336 |
+
errors = []
|
337 |
+
newlines = []
|
338 |
+
for line_nr, line, is_comment, in_string in annotate_strings(annotate_comments(lines)):
|
339 |
+
if is_comment or in_string:
|
340 |
+
newlines.append((line_nr, line))
|
341 |
+
continue
|
342 |
+
# Allow "←" to be followed by "%" or "`", but not by "`(" or "``(" (since "`()" and "``()"
|
343 |
+
# are used for syntax quotations). Otherwise, insert a space after "←".
|
344 |
+
new_line = re.sub(r'←(?:(?=``?\()|(?![%`]))(\S)', r'← \1', line)
|
345 |
+
if new_line != line:
|
346 |
+
errors += [(ERR_ARR, line_nr, path)]
|
347 |
+
newlines.append((line_nr, new_line))
|
348 |
+
return errors, newlines
|
349 |
+
|
350 |
+
def output_message(path, line_nr, code, msg):
|
351 |
+
if len(exceptions) == 0:
|
352 |
+
# we are generating a new exceptions file
|
353 |
+
# filename first, then line so that we can call "sort" on the output
|
354 |
+
print(f"{path} : line {line_nr} : {code} : {msg}")
|
355 |
+
else:
|
356 |
+
if code.startswith("ERR"):
|
357 |
+
msg_type = "error"
|
358 |
+
if code.startswith("WRN"):
|
359 |
+
msg_type = "warning"
|
360 |
+
# We are outputting for github. We duplicate path, line_nr and code,
|
361 |
+
# so that they are also visible in the plaintext output.
|
362 |
+
print(f"::{msg_type} file={path},line={line_nr},code={code}::{path}#L{line_nr}: {code}: {msg}")
|
363 |
+
|
364 |
+
def format_errors(errors):
|
365 |
+
global new_exceptions
|
366 |
+
for errno, line_nr, path in errors:
|
367 |
+
if (errno, path.resolve(), None) in exceptions:
|
368 |
+
continue
|
369 |
+
new_exceptions = True
|
370 |
+
if errno == ERR_COP:
|
371 |
+
output_message(path, line_nr, "ERR_COP", "Malformed or missing copyright header")
|
372 |
+
if errno == ERR_MOD:
|
373 |
+
output_message(path, line_nr, "ERR_MOD", "Module docstring missing, or too late")
|
374 |
+
if errno == ERR_LIN:
|
375 |
+
output_message(path, line_nr, "ERR_LIN", "Line has more than 100 characters")
|
376 |
+
if errno == ERR_OPT:
|
377 |
+
output_message(path, line_nr, "ERR_OPT", "Forbidden set_option command")
|
378 |
+
if errno == ERR_AUT:
|
379 |
+
output_message(path, line_nr, "ERR_AUT", "Authors line should look like: 'Authors: Jean Dupont, Иван Иванович Иванов'")
|
380 |
+
if errno == ERR_TAC:
|
381 |
+
output_message(path, line_nr, "ERR_TAC", "Files in mathlib cannot import the whole tactic folder")
|
382 |
+
if errno == ERR_IBY:
|
383 |
+
output_message(path, line_nr, "ERR_IBY", "Line is an isolated 'by'")
|
384 |
+
if errno == ERR_DOT:
|
385 |
+
output_message(path, line_nr, "ERR_DOT", "Line is an isolated focusing dot or uses . instead of ·")
|
386 |
+
if errno == ERR_SEM:
|
387 |
+
output_message(path, line_nr, "ERR_SEM", "Line contains a space before a semicolon")
|
388 |
+
if errno == ERR_WIN:
|
389 |
+
output_message(path, line_nr, "ERR_WIN", "Windows line endings (\\r\\n) detected")
|
390 |
+
if errno == ERR_TWS:
|
391 |
+
output_message(path, line_nr, "ERR_TWS", "Trailing whitespace detected on line")
|
392 |
+
if errno == ERR_CLN:
|
393 |
+
output_message(path, line_nr, "ERR_CLN", "Put : and := before line breaks, not after")
|
394 |
+
if errno == ERR_IND:
|
395 |
+
output_message(path, line_nr, "ERR_IND", "If the theorem/def statement requires multiple lines, indent it correctly (4 spaces or 2 for `|`)")
|
396 |
+
if errno == ERR_ARR:
|
397 |
+
output_message(path, line_nr, "ERR_ARR", "Missing space after '←'.")
|
398 |
+
if errno == ERR_NSP:
|
399 |
+
output_message(path, line_nr, "ERR_NSP", "Non-terminal simp. Replace with `simp?` and use the suggested output")
|
400 |
+
|
401 |
+
def lint(path, fix=False):
|
402 |
+
global new_exceptions
|
403 |
+
with path.open(encoding="utf-8", newline="") as f:
|
404 |
+
# We enumerate the lines so that we can report line numbers in the error messages correctly
|
405 |
+
# we will modify lines as we go, so we need to keep track of the original line numbers
|
406 |
+
lines = f.readlines()
|
407 |
+
enum_lines = enumerate(lines, 1)
|
408 |
+
newlines = enum_lines
|
409 |
+
for error_check in [line_endings_check,
|
410 |
+
four_spaces_in_second_line,
|
411 |
+
long_lines_check,
|
412 |
+
isolated_by_dot_semicolon_check,
|
413 |
+
set_option_check,
|
414 |
+
left_arrow_check,
|
415 |
+
nonterminal_simp_check]:
|
416 |
+
errs, newlines = error_check(newlines, path)
|
417 |
+
format_errors(errs)
|
418 |
+
|
419 |
+
if not import_only_check(newlines, path):
|
420 |
+
# Check for too long files: either longer than 1500 lines, or not covered by an exception.
|
421 |
+
# Each exception contains a "watermark". If the file is longer than that, we also complain.
|
422 |
+
if len(lines) > 1500:
|
423 |
+
ex = [e for e in exceptions if e[1] == path.resolve()]
|
424 |
+
if ex:
|
425 |
+
(_ERR_NUM, _path, watermark) = list(ex)[0]
|
426 |
+
assert int(watermark) > 500 # protect against parse error
|
427 |
+
is_too_long = len(lines) > int(watermark)
|
428 |
+
else:
|
429 |
+
is_too_long = True
|
430 |
+
if is_too_long:
|
431 |
+
new_exceptions = True
|
432 |
+
# add up to 200 lines of slack, so simple PRs don't trigger this right away
|
433 |
+
watermark = len(lines) // 100 * 100 + 200
|
434 |
+
output_message(path, 1, "ERR_NUM_LIN", f"{watermark} file contains {len(lines)} lines, try to split it up")
|
435 |
+
errs, newlines = regular_check(newlines, path)
|
436 |
+
format_errors(errs)
|
437 |
+
errs, newlines = banned_import_check(newlines, path)
|
438 |
+
format_errors(errs)
|
439 |
+
# if we haven't been asked to fix errors, or there are no errors or no fixes, we're done
|
440 |
+
if fix and new_exceptions and enum_lines != newlines:
|
441 |
+
path.with_name(path.name + '.bak').write_text("".join(l for _,l in newlines), encoding = "utf8")
|
442 |
+
shutil.move(path.with_name(path.name + '.bak'), path)
|
443 |
+
|
444 |
+
fix = "--fix" in sys.argv
|
445 |
+
argv = (arg for arg in sys.argv[1:] if arg != "--fix")
|
446 |
+
|
447 |
+
for filename in argv:
|
448 |
+
lint(Path(filename), fix=fix)
|
449 |
+
|
450 |
+
if new_exceptions:
|
451 |
+
exit(1)
|
repl/.lake/packages/mathlib/scripts/make_port_status.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import pytz
|
4 |
+
import datetime
|
5 |
+
import github
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import requests
|
9 |
+
import subprocess
|
10 |
+
import sys
|
11 |
+
import yaml
|
12 |
+
import networkx as nx
|
13 |
+
from collections import defaultdict
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
# Must run from root of mathlib4 directory.
|
17 |
+
|
18 |
+
if not os.path.exists('port-repos/mathlib'):
|
19 |
+
print("Make sure you are in the root of the mathlib4 directory")
|
20 |
+
print("and have checked out mathlib under port-repos/mathlib.")
|
21 |
+
sys.exit(1)
|
22 |
+
|
23 |
+
GITHUB_TOKEN_FILE = 'port-repos/github-token'
|
24 |
+
github_token = open(GITHUB_TOKEN_FILE).read().strip()
|
25 |
+
|
26 |
+
mathlib3_root = 'port-repos/mathlib'
|
27 |
+
mathlib4_root = './'
|
28 |
+
|
29 |
+
source_module_re = re.compile(r"^! .*source module (.*)$")
|
30 |
+
commit_re = re.compile(r"^! (leanprover-community/[a-z]*) commit ([0-9a-f]*)")
|
31 |
+
import_re = re.compile(r"^import ([^ ]*)")
|
32 |
+
|
33 |
+
align_import_re = re.compile(
|
34 |
+
r'^#align_import ([^ ]*) from "(leanprover-community/[a-z]*)" ?@ ?"([0-9a-f]*)"')
|
35 |
+
|
36 |
+
def mk_label(path: Path) -> str:
|
37 |
+
rel = path.relative_to(Path(mathlib3_root))
|
38 |
+
rel = Path(*rel.parts[1:])
|
39 |
+
return str(rel.with_suffix('')).replace(os.sep, '.')
|
40 |
+
|
41 |
+
paths = []
|
42 |
+
for path in Path(mathlib3_root).glob('**/*.lean'):
|
43 |
+
if path.relative_to(mathlib3_root).parts[0] not in ['src', 'archive', 'counterexamples']:
|
44 |
+
continue
|
45 |
+
if path.relative_to(mathlib3_root).parts[1] in ['tactic', 'meta']:
|
46 |
+
continue
|
47 |
+
paths.append(path)
|
48 |
+
|
49 |
+
graph = nx.DiGraph()
|
50 |
+
for path in paths:
|
51 |
+
graph.add_node(mk_label(path))
|
52 |
+
|
53 |
+
for path in paths:
|
54 |
+
label = mk_label(path)
|
55 |
+
for line in path.read_text().split('\n'):
|
56 |
+
m = import_re.match(line)
|
57 |
+
if m:
|
58 |
+
imported = m.group(1)
|
59 |
+
if imported.startswith('tactic.') or imported.startswith('meta.') or imported.startswith('.'):
|
60 |
+
continue
|
61 |
+
if imported not in graph.nodes:
|
62 |
+
if imported + '.default' in graph.nodes:
|
63 |
+
imported = imported + '.default'
|
64 |
+
else:
|
65 |
+
imported = imported
|
66 |
+
graph.add_edge(imported, label)
|
67 |
+
|
68 |
+
def get_mathlib4_module_commit_info(contents):
|
69 |
+
module = repo = commit = None
|
70 |
+
for line in contents.split('\n'):
|
71 |
+
m = align_import_re.match(line)
|
72 |
+
if m:
|
73 |
+
module = m.group(1)
|
74 |
+
repo = m.group(2)
|
75 |
+
commit = m.group(3)
|
76 |
+
break
|
77 |
+
m = source_module_re.match(line)
|
78 |
+
if m:
|
79 |
+
module = m.group(1)
|
80 |
+
m = commit_re.match(line)
|
81 |
+
if m:
|
82 |
+
repo = m.group(1)
|
83 |
+
commit = m.group(2)
|
84 |
+
return module, repo, commit
|
85 |
+
|
86 |
+
# contains ported files
|
87 |
+
# lean 3 module name -> { mathlib4_file, mathlib3_hash }
|
88 |
+
data = dict()
|
89 |
+
for path4 in Path(mathlib4_root).glob('**/*.lean'):
|
90 |
+
# we definitely do not want to look in `port-repos` here!
|
91 |
+
if path4.relative_to(mathlib4_root).parts[0] not in ('Mathlib', 'Archive', 'Counterexamples'):
|
92 |
+
continue
|
93 |
+
module, repo, commit = get_mathlib4_module_commit_info(path4.read_text())
|
94 |
+
if module is None:
|
95 |
+
continue
|
96 |
+
|
97 |
+
if commit is None:
|
98 |
+
print(f"Commit is None for module: {module}")
|
99 |
+
continue
|
100 |
+
|
101 |
+
log = subprocess.run(
|
102 |
+
['git', 'log', '--oneline', str(path4)],
|
103 |
+
capture_output=True)
|
104 |
+
pr_matches = re.search(r'#([0-9]+)\)$', log.stdout.decode().splitlines()[-1])
|
105 |
+
if pr_matches:
|
106 |
+
mathlib4_pr = int(pr_matches.groups()[0])
|
107 |
+
else:
|
108 |
+
mathlib4_pr = None
|
109 |
+
|
110 |
+
data[module] = {
|
111 |
+
'mathlib4_file': str(path4.relative_to(mathlib4_root)),
|
112 |
+
'mathlib4_pr': mathlib4_pr,
|
113 |
+
'source': dict(repo=repo, commit=commit)
|
114 |
+
}
|
115 |
+
|
116 |
+
graph.add_node(module)
|
117 |
+
|
118 |
+
prs = {}
|
119 |
+
fetch_args = ['git', 'fetch', 'origin']
|
120 |
+
nums = []
|
121 |
+
sync_prs = defaultdict(set)
|
122 |
+
mathlib4repo = github.Github(github_token).get_repo("leanprover-community/mathlib4")
|
123 |
+
for pr in mathlib4repo.get_pulls(state='open'):
|
124 |
+
if pr.created_at < datetime.datetime(2022, 12, 1, 0, 0, 0, tzinfo=pytz.UTC):
|
125 |
+
continue
|
126 |
+
if 'no-source-header' in (l.name for l in pr.labels):
|
127 |
+
continue
|
128 |
+
if 'mathlib3-pair' in (l.name for l in pr.labels):
|
129 |
+
for file in (f.filename for f in pr.get_files()):
|
130 |
+
sync_prs[file].add(pr.number)
|
131 |
+
num = pr.number
|
132 |
+
nums.append(num)
|
133 |
+
prs[num] = pr
|
134 |
+
fetch_args.append(f'pull/{num}/head:port-status-pull/{num}')
|
135 |
+
|
136 |
+
os.system("git branch -D $(git branch --list 'port-status-pull/*')")
|
137 |
+
subprocess.run(fetch_args)
|
138 |
+
|
139 |
+
prs_of_import = {}
|
140 |
+
for num in nums:
|
141 |
+
p = subprocess.run(
|
142 |
+
['git', 'diff', '--name-only', '--diff-filter=A',
|
143 |
+
f'origin/master...port-status-pull/{num}'],
|
144 |
+
capture_output=True)
|
145 |
+
for l in p.stdout.decode().splitlines():
|
146 |
+
f = subprocess.run(
|
147 |
+
['git', 'cat-file', 'blob', f'port-status-pull/{num}:{l}'],
|
148 |
+
capture_output=True)
|
149 |
+
import_, repo, commit = get_mathlib4_module_commit_info(f.stdout.decode(encoding='utf8', errors='replace'))
|
150 |
+
prs_of_import.setdefault(import_, []).append({'pr': num, 'repo': repo, 'commit': commit, 'fname': l})
|
151 |
+
|
152 |
+
COMMENTS_URL = "https://raw.githubusercontent.com/wiki/leanprover-community/mathlib4/port-comments.md"
|
153 |
+
comments_dict = yaml.safe_load(requests.get(COMMENTS_URL).content.replace(b"```", b""))
|
154 |
+
|
155 |
+
yaml_dict = {}
|
156 |
+
new_yaml_dict = {}
|
157 |
+
for node in sorted(graph.nodes):
|
158 |
+
if node in data:
|
159 |
+
new_status = dict(
|
160 |
+
ported=True,
|
161 |
+
mathlib4_file=data[node]['mathlib4_file'],
|
162 |
+
mathlib4_pr=data[node]['mathlib4_pr'],
|
163 |
+
source=data[node]['source']
|
164 |
+
)
|
165 |
+
_sync_prs = [
|
166 |
+
dict(
|
167 |
+
num=sync_pr_num,
|
168 |
+
labels=[dict(name=l.name, color=l.color) for l in prs[sync_pr_num].labels]
|
169 |
+
)
|
170 |
+
for sync_pr_num in sync_prs[data[node]['mathlib4_file']]
|
171 |
+
]
|
172 |
+
if _sync_prs:
|
173 |
+
new_status.update(mathlib4_sync_prs=_sync_prs)
|
174 |
+
pr_status = f"mathlib4#{data[node]['mathlib4_pr']}" if data[node]['mathlib4_pr'] is not None else "_"
|
175 |
+
sha = data[node]['source']['commit'] if data[node]['source']['repo'] == 'leanprover-community/mathlib' else "_"
|
176 |
+
|
177 |
+
status = f"Yes {pr_status} {sha}"
|
178 |
+
else:
|
179 |
+
new_status = dict(ported=False)
|
180 |
+
status = f'No'
|
181 |
+
if node in prs_of_import:
|
182 |
+
pr_info = prs_of_import[node][0]
|
183 |
+
if pr_info['commit'] is None:
|
184 |
+
print('PR seems to be missing a source header', node, pr_info)
|
185 |
+
assert(False)
|
186 |
+
new_status.update(
|
187 |
+
mathlib4_pr=pr_info['pr'],
|
188 |
+
mathlib4_file=pr_info['fname'],
|
189 |
+
source=dict(repo=pr_info['repo'], commit=pr_info['commit']))
|
190 |
+
labels = [{'name': l.name, 'color': l.color} for l in prs[pr_info['pr']].labels]
|
191 |
+
if labels:
|
192 |
+
new_status.update(labels=labels)
|
193 |
+
sha = pr_info['commit'] if pr_info['repo'] == 'leanprover-community/mathlib' else "_"
|
194 |
+
status += f" mathlib4#{pr_info['pr']} {sha}"
|
195 |
+
try:
|
196 |
+
comment_data = comments_dict[node]
|
197 |
+
except KeyError:
|
198 |
+
pass
|
199 |
+
else:
|
200 |
+
if isinstance(comment_data, str):
|
201 |
+
# old comment format
|
202 |
+
comment_data = dict(message=comment_data)
|
203 |
+
# new comment format
|
204 |
+
status += ' ' + comment_data['message']
|
205 |
+
new_status.update(comment=comment_data)
|
206 |
+
yaml_dict[node] = status
|
207 |
+
new_yaml_dict[node] = new_status
|
208 |
+
|
209 |
+
DO_NOT_EDIT_MESSAGE = """
|
210 |
+
# Do not edit this file.
|
211 |
+
# If you want to add free-form comments about files that don't have PRs yet,
|
212 |
+
# edit https://github.com/leanprover-community/mathlib4/wiki/port-comments/_edit instead.
|
213 |
+
""" + ("\n" * 37)
|
214 |
+
|
215 |
+
with open('port_status.yaml', 'w') as f:
|
216 |
+
f.write(DO_NOT_EDIT_MESSAGE + "```\n" + yaml.dump(yaml_dict) + "```\n")
|
217 |
+
with open('port_status_new.yaml', 'w') as f:
|
218 |
+
f.write(DO_NOT_EDIT_MESSAGE + "```\n" + yaml.dump(new_yaml_dict) + "```\n")
|
repl/.lake/packages/mathlib/scripts/polyrith_sage.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is part of the `polyrith` tactic in `src/tactic/polyrith.lean`.
|
2 |
+
# It interfaces between Lean and the Sage web interface.
|
3 |
+
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
import sys
|
7 |
+
from os.path import join, dirname
|
8 |
+
|
9 |
+
# These functions are used to format the output of Sage for parsing in Lean.
|
10 |
+
# They are stored here as a string since they are passed to Sage via the web API.
|
11 |
+
with open(join(dirname(__file__), "polyrith_sage_helper.py"), encoding='utf8') as f:
|
12 |
+
polynomial_formatting_functions = f.read()
|
13 |
+
|
14 |
+
# future extensions may change behavior depending on the base type
|
15 |
+
def type_str(type):
|
16 |
+
return "QQ"
|
17 |
+
|
18 |
+
def create_query(type: str, n_vars: int, eq_list, goal_type):
|
19 |
+
""" Create a query to invoke Sage's `MPolynomial_libsingular.lift`. See
|
20 |
+
https://github.com/sagemath/sage/blob/f8df80820dc7321dc9b18c9644c3b8315999670b/src/sage/rings/polynomial/multi_polynomial_libsingular.pyx#L4472-L4518
|
21 |
+
for a description of this method. """
|
22 |
+
var_list = [f"var{i}" for i in range(n_vars)] + ['aux']
|
23 |
+
query = f'''
|
24 |
+
if {n_vars!r} != 0:
|
25 |
+
P = PolynomialRing({type_str(type)}, {var_list})
|
26 |
+
[{", ".join(var_list)}] = P.gens()
|
27 |
+
p = P({goal_type})
|
28 |
+
gens = {eq_list} + [1 - p*aux]
|
29 |
+
I = P.ideal(gens)
|
30 |
+
coeffs = P(1).lift(I)
|
31 |
+
power = max(cf.degree(aux) for cf in coeffs)
|
32 |
+
coeffs = [P(cf.subs(aux = 1/p)*p^power) for cf in coeffs[:int(-1)]]
|
33 |
+
print(str(power)+';'+serialize_polynomials(coeffs))
|
34 |
+
else:
|
35 |
+
# workaround for a Sage shortcoming with `n_vars = 0`,
|
36 |
+
# `TypeError: no conversion of this ring to a Singular ring defined`
|
37 |
+
# In this case, there is no need to look for membership in the *radical*;
|
38 |
+
# we just check for membership in the ideal, and return exponent 1
|
39 |
+
# if coefficients are found.
|
40 |
+
P = PolynomialRing({type_str(type)}, 'var', 1)
|
41 |
+
p = P({goal_type})
|
42 |
+
I = P.ideal({eq_list})
|
43 |
+
coeffs = p.lift(I)
|
44 |
+
print('1;'+serialize_polynomials(coeffs))
|
45 |
+
'''
|
46 |
+
return query
|
47 |
+
|
48 |
+
class EvaluationError(Exception):
|
49 |
+
def __init__(self, ename, evalue, message='Error in Sage communication'):
|
50 |
+
self.ename = ename
|
51 |
+
self.evalue = evalue
|
52 |
+
self.message = message
|
53 |
+
super().__init__(self.message)
|
54 |
+
|
55 |
+
def parse_response(resp: str) -> str:
|
56 |
+
exp, data = resp.split(';', 1)
|
57 |
+
return dict(power=int(exp), coeffs=json.loads(data))
|
58 |
+
|
59 |
+
|
60 |
+
def evaluate_in_sage(query: str) -> str:
|
61 |
+
data = {'code': query}
|
62 |
+
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
63 |
+
response = requests.post('https://sagecell.sagemath.org/service', data, headers=headers).json()
|
64 |
+
if response['success']:
|
65 |
+
return parse_response(response.get('stdout'))
|
66 |
+
elif 'execute_reply' in response and 'ename' in response['execute_reply'] and 'evalue' in response['execute_reply']:
|
67 |
+
raise EvaluationError(response['execute_reply']['ename'], response['execute_reply']['evalue'])
|
68 |
+
else:
|
69 |
+
raise Exception(response)
|
70 |
+
|
71 |
+
def main():
|
72 |
+
'''The system args contain the following:
|
73 |
+
0 - the path to this python file
|
74 |
+
1 - a string containing "true" or "false" depending on whether polyrith was called with trace enabled
|
75 |
+
2 - a string representing the base type of the target
|
76 |
+
3 - the number of variables used
|
77 |
+
4 - a list of the polynomial hypotheses/proof terms in terms of the variables
|
78 |
+
5 - a single polynomial representing the target
|
79 |
+
|
80 |
+
This returns a json object with format:
|
81 |
+
```
|
82 |
+
{ success: bool,
|
83 |
+
data: Optional[list[str]],
|
84 |
+
trace: Optional[str],
|
85 |
+
name: Optional[str],
|
86 |
+
value: Optional[str] }
|
87 |
+
```
|
88 |
+
'''
|
89 |
+
command = create_query(sys.argv[2], int(sys.argv[3]), sys.argv[4], sys.argv[5])
|
90 |
+
final_query = polynomial_formatting_functions + "\n" + command
|
91 |
+
if sys.argv[1] == 'true': # trace dry run enabled
|
92 |
+
output = dict(success=True, trace=command)
|
93 |
+
else:
|
94 |
+
try:
|
95 |
+
output = dict(success=True, data=evaluate_in_sage(final_query))
|
96 |
+
except EvaluationError as e:
|
97 |
+
output = dict(success=False, name=e.ename, value=e.evalue)
|
98 |
+
print(json.dumps(output))
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
main()
|
repl/.lake/packages/mathlib/scripts/polyrith_sage_helper.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this file will be run by the remote sage server, so should not import local files.
|
2 |
+
from typing import Iterable
|
3 |
+
|
4 |
+
def q_arr(coeff: QQ) -> str:
|
5 |
+
return "[" + str(coeff.numerator()) + "," + str(coeff.denominator()) + "]"
|
6 |
+
|
7 |
+
def arr(args: Iterable[str]) -> str:
|
8 |
+
return "[" + ",".join(args) + "]"
|
9 |
+
|
10 |
+
def serialize_polynomials(coeffs) -> str:
|
11 |
+
return arr(
|
12 |
+
arr(arr([arr(arr([str(t[0]), str(t[1])]) for t in etuple.sparse_iter()), q_arr(coeff)])
|
13 |
+
for etuple, coeff in c.dict().items()) for c in coeffs)
|
repl/.lake/packages/mathlib/scripts/yaml_check.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file is copied from the mathlib3 file of the same name.
|
3 |
+
It reads in the three yaml files, and translates them to simpler json files that are easier to
|
4 |
+
process in Lean.
|
5 |
+
"""
|
6 |
+
from typing import Dict, Optional, Union, Tuple, List
|
7 |
+
import yaml
|
8 |
+
import json
|
9 |
+
import sys
|
10 |
+
|
11 |
+
TieredDict = Dict[str, Union[Optional[str], 'TieredDict']]
|
12 |
+
|
13 |
+
def tiered_extract(db: TieredDict) -> List[Tuple[List[str], str]]:
|
14 |
+
"""From a nested dictionary, return a list of (key_path, values)
|
15 |
+
of the deepest level."""
|
16 |
+
out = []
|
17 |
+
for name, entry in db.items():
|
18 |
+
if isinstance(entry, dict):
|
19 |
+
for subname, value in tiered_extract(entry):
|
20 |
+
out.append(([name] + subname, value))
|
21 |
+
else:
|
22 |
+
if entry and '/' not in entry:
|
23 |
+
out.append(([name], entry))
|
24 |
+
return out
|
25 |
+
|
26 |
+
def flatten_names(data: List[Tuple[List[str], str]]) -> List[Tuple[str, str]]:
|
27 |
+
return [(' :: '.join(id), v) for id, v in data]
|
28 |
+
|
29 |
+
def print_list(fn: str, pairs: List[Tuple[str, str]]) -> None:
|
30 |
+
with open(fn, 'w', encoding='utf8') as out:
|
31 |
+
for (id, val) in pairs:
|
32 |
+
out.write(f'{id}\n{val.strip()}\n\n')
|
33 |
+
|
34 |
+
hundred_yaml = sys.argv[1]
|
35 |
+
overview_yaml = sys.argv[2]
|
36 |
+
undergrad_yaml = sys.argv[3]
|
37 |
+
|
38 |
+
with open(hundred_yaml, 'r', encoding='utf8') as hy:
|
39 |
+
hundred = yaml.safe_load(hy)
|
40 |
+
with open(overview_yaml, 'r', encoding='utf8') as hy:
|
41 |
+
overview = yaml.safe_load(hy)
|
42 |
+
with open(undergrad_yaml, 'r', encoding='utf8') as hy:
|
43 |
+
undergrad = yaml.safe_load(hy)
|
44 |
+
|
45 |
+
hundred_decls:List[Tuple[str, str]] = []
|
46 |
+
|
47 |
+
for index, entry in hundred.items():
|
48 |
+
title = entry['title']
|
49 |
+
if 'decl' in entry:
|
50 |
+
hundred_decls.append((f'{index} {title}', entry['decl']))
|
51 |
+
elif 'decls' in entry:
|
52 |
+
if not isinstance(entry['decls'], list):
|
53 |
+
raise ValueError(f"For key {index} ({title}): did you mean `decl` instead of `decls`?")
|
54 |
+
hundred_decls = hundred_decls + [(f'{index} {title}', d) for d in entry['decls']]
|
55 |
+
|
56 |
+
overview_decls = tiered_extract(overview)
|
57 |
+
assert all(len(n) == 3 for n, _ in overview_decls)
|
58 |
+
overview_decls = flatten_names(overview_decls)
|
59 |
+
|
60 |
+
undergrad_decls = tiered_extract(undergrad)
|
61 |
+
assert all(len(n) >= 3 for n, _ in undergrad_decls)
|
62 |
+
undergrad_decls = flatten_names(undergrad_decls)
|
63 |
+
|
64 |
+
with open('100.json', 'w', encoding='utf8') as f:
|
65 |
+
json.dump(hundred_decls, f)
|
66 |
+
with open('overview.json', 'w', encoding='utf8') as f:
|
67 |
+
json.dump(overview_decls, f)
|
68 |
+
with open('undergrad.json', 'w', encoding='utf8') as f:
|
69 |
+
json.dump(undergrad_decls, f)
|
repl/pass_rate.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
import json
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import tempfile
|
8 |
+
|
9 |
+
def wrapped_function(item):
|
10 |
+
results = []
|
11 |
+
passed = 0
|
12 |
+
total = 0
|
13 |
+
|
14 |
+
temp_dir = tempfile.gettempdir()
|
15 |
+
temp_file = os.path.join(temp_dir, f"test.lean")
|
16 |
+
|
17 |
+
with open(temp_file, "w") as f:
|
18 |
+
f.write(item['cmd'])
|
19 |
+
|
20 |
+
# Rest of the function code...
|
21 |
+
# Process the item using the temporary file
|
22 |
+
# ...
|
23 |
+
|
24 |
+
# Clean up the temporary file
|
25 |
+
data = '{"path": "%s", "allTactics": true}' %(temp_file)
|
26 |
+
command = 'echo \'%s\' | lake exe repl' % data
|
27 |
+
|
28 |
+
try:
|
29 |
+
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
30 |
+
stdout = result.stdout.decode('utf-8')
|
31 |
+
stderr = result.stderr.decode('utf-8')
|
32 |
+
# stdout = result.stdout.decode('utf-8')
|
33 |
+
json_stdout = json.loads(stdout)
|
34 |
+
if "messages" not in json_stdout.keys():
|
35 |
+
passed += 1
|
36 |
+
# results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
|
37 |
+
results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
|
38 |
+
except subprocess.CalledProcessError as e:
|
39 |
+
# results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
|
40 |
+
results.append({ 'error': str(e), 'status': 'nopass'})
|
41 |
+
total += 1
|
42 |
+
|
43 |
+
pass_rate = passed / (passed + total) * 100
|
44 |
+
|
45 |
+
|
46 |
+
return {'results': results, 'pass_rate': pass_rate}
|
47 |
+
|
48 |
+
# Set the directory where your .lean files are located
|
49 |
+
|
50 |
+
# Get a list of all .lean files in the directory
|
51 |
+
# lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
|
52 |
+
# lean_files = ["test/file.lean"]
|
53 |
+
def single(command_list):
|
54 |
+
results = []
|
55 |
+
passed = 0
|
56 |
+
total = 0
|
57 |
+
for item in tqdm(command_list):
|
58 |
+
with open("test/test.lean", "w", encoding = 'utf-8') as f:
|
59 |
+
f.write(item['cmd'])
|
60 |
+
data = '{"path": "test/test.lean", "allTactics": true}'
|
61 |
+
# data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
|
62 |
+
command = 'echo \'%s\' | lake exe repl' % data
|
63 |
+
try:
|
64 |
+
# process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
65 |
+
# stderr=subprocess.PIPE)
|
66 |
+
# stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
|
67 |
+
# stdout = stdout.decode('utf-8')
|
68 |
+
import pdb
|
69 |
+
pdb.set_trace()
|
70 |
+
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
71 |
+
stdout = result.stdout.decode('utf-8')
|
72 |
+
json_stdout = json.loads(stdout)
|
73 |
+
if "messages" not in json_stdout.keys():
|
74 |
+
passed += 1
|
75 |
+
stderr = result.stderr.decode('utf-8')
|
76 |
+
results.append({
|
77 |
+
# 'item': item['content'],
|
78 |
+
'stdout': stdout,
|
79 |
+
'stderr': stderr,
|
80 |
+
'status': 'pass'
|
81 |
+
})
|
82 |
+
except subprocess.CalledProcessError as e:
|
83 |
+
results.append({
|
84 |
+
# 'item': item['content'],
|
85 |
+
'error': str(e),
|
86 |
+
'status': 'nopass'
|
87 |
+
})
|
88 |
+
total += 1
|
89 |
+
|
90 |
+
# Calculate pass rate
|
91 |
+
pass_rate = passed / total * 100
|
92 |
+
print(pass_rate)
|
93 |
+
|
94 |
+
# Save results to a JSON file
|
95 |
+
with open('results.json', 'w') as f:
|
96 |
+
json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
def multi(command_list):
|
106 |
+
results = []
|
107 |
+
passed = 0
|
108 |
+
total = 0
|
109 |
+
def execute_command(item):
|
110 |
+
temp_dir = '/hpc2hdd/home/zyang398/lujianqiao/lean4/repl/tmp'
|
111 |
+
temp_file = os.path.join(temp_dir, f"test_{item['index']}.lean") # Ensure unique filenames
|
112 |
+
with open(temp_file, "w") as f:
|
113 |
+
f.write(item['cmd'])
|
114 |
+
|
115 |
+
data = '{"path": "%s", "allTactics": true}' % temp_file
|
116 |
+
command = f'echo \'{data}\' | lake exe repl'
|
117 |
+
|
118 |
+
try:
|
119 |
+
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
120 |
+
stdout = result.stdout.decode('utf-8')
|
121 |
+
stderr = result.stderr.decode('utf-8')
|
122 |
+
|
123 |
+
if "messages" not in json.loads(stdout):
|
124 |
+
return {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
|
125 |
+
else:
|
126 |
+
return {'stdout': stdout, 'stderr': stderr, 'status': 'nopass'}
|
127 |
+
|
128 |
+
except subprocess.CalledProcessError as e:
|
129 |
+
return {'error': str(e), 'status': 'nopass'}
|
130 |
+
|
131 |
+
os.remove(temp_file)
|
132 |
+
|
133 |
+
total = len(command_list)
|
134 |
+
|
135 |
+
with ThreadPoolExecutor(max_workers=32) as executor:
|
136 |
+
futures = [executor.submit(execute_command, {'index': i, 'cmd': cmd['cmd']}) for i, cmd in enumerate(command_list)]
|
137 |
+
for future in tqdm(futures, total=total, desc="Processing Commands"):
|
138 |
+
result = future.result()
|
139 |
+
results.append(result)
|
140 |
+
if result['status'] == 'pass':
|
141 |
+
passed += 1
|
142 |
+
|
143 |
+
pass_rate = (passed / total) * 100
|
144 |
+
print(f"Pass rate: {pass_rate}%")
|
145 |
+
|
146 |
+
with open('results.json', 'w') as f:
|
147 |
+
json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
|
148 |
+
|
149 |
+
import re
|
150 |
+
def remove_simp_pattern_from_end(s):
|
151 |
+
pattern = r'@\[simp\s*.*?\]$'
|
152 |
+
return re.sub(pattern, '', s)
|
153 |
+
|
154 |
+
def main(args):
|
155 |
+
command_list = []
|
156 |
+
for i in range(args.cuda_num):
|
157 |
+
with open(f"{args.input_path}/{i}.json", 'r', encoding='utf-8') as rf:
|
158 |
+
for line in rf.readlines():
|
159 |
+
try:
|
160 |
+
json_item = json.loads(line)
|
161 |
+
# json_item['content']['statement_poof']
|
162 |
+
# json_item['cmd'] = '\n'.join([json_item['content']['working_file'] , json_item['total output'][0]])
|
163 |
+
working_env = json_item['content']['working_file']
|
164 |
+
|
165 |
+
# statement = json_item['content']['statement_poof'].split('\n')
|
166 |
+
statement = json_item['total output'][0]
|
167 |
+
|
168 |
+
json_item['cmd'] = '\n'.join([working_env, statement])
|
169 |
+
# print(json_item['cmd'])
|
170 |
+
assert len(statement) > 0
|
171 |
+
# json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
|
172 |
+
except:
|
173 |
+
import pdb
|
174 |
+
pdb.set_trace()
|
175 |
+
command_list.append(json_item)
|
176 |
+
command_list = command_list
|
177 |
+
results = []
|
178 |
+
passed = 0
|
179 |
+
total = 0
|
180 |
+
single(command_list)
|
181 |
+
|
182 |
+
if __name__ == '__main__':
|
183 |
+
arg_parser = ArgumentParser()
|
184 |
+
arg_parser.add_argument('--data_path', type=str,
|
185 |
+
default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
|
186 |
+
arg_parser.add_argument('--input_path', type=str, default='')
|
187 |
+
arg_parser.add_argument('--cuda_num', type=int, default=4)
|
188 |
+
arg_parser.add_argument('--output_path', type=str, default='total.json')
|
189 |
+
arg_parser.add_argument('--generate_method', type=str,
|
190 |
+
choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
|
191 |
+
arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
|
192 |
+
args = arg_parser.parse_args()
|
193 |
+
main(args)
|
194 |
+
|
195 |
+
|
whole_generation.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
PROMPT_DICT = {
|
12 |
+
"lean4": (
|
13 |
+
"Statement and proof in natural language:\n\n"
|
14 |
+
"{statement_text}\n\n"
|
15 |
+
"Translate the statement and proof in natural language to lean4:"
|
16 |
+
),
|
17 |
+
"plain": (
|
18 |
+
"{statement_text}"
|
19 |
+
),
|
20 |
+
"statement": (
|
21 |
+
"Statement in natural language:\n"
|
22 |
+
"{problem}\n"
|
23 |
+
"Translate the statement in natural language to Lean4:"
|
24 |
+
),
|
25 |
+
"prompt_no_input": (
|
26 |
+
"Below is an instruction that describes a task. "
|
27 |
+
"Write a response that appropriately completes the request.\n\n"
|
28 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
29 |
+
),
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def generate_few_shot(prompt):
|
34 |
+
base_gsm8k_list = [
|
35 |
+
{
|
36 |
+
'question': "John and his best friend Steve bought 12 cupcakes together. Each cupcake cost $1.50. If they split the costs evenly, how much did each person pay?",
|
37 |
+
'answer': "The total cost of cupcakes was 1.5*12=$<<1.5*12=18>>18\\nSo they each paid 18/2=$<<18/2=9>>9.",
|
38 |
+
'direct_answer': "9"
|
39 |
+
},
|
40 |
+
{
|
41 |
+
'question': "Lizzy has to ship 540 pounds of fish that are packed into 30-pound crates. If the shipping cost of each crate is $1.5, how much will Lizzy pay for the shipment?",
|
42 |
+
'answer': "There are 540 pounds / 30 pounds/crate = <<540/30=18>>18 crates of fish needed.\\nHence, the total cost for the shipment is $1.5/crate x 18 crates = $<<1.5*18=27>>27.",
|
43 |
+
'direct_answer': "27"
|
44 |
+
},
|
45 |
+
{
|
46 |
+
'question': "Tom, Tim, and Paul are collecting photos of cars. Paul has 10 photos more than Tim. Tim has one hundred photos less than the total amount of photos which is 152. How many photos does Tom have?",
|
47 |
+
'answer': "Tim has 152 photos - 100 photos = <<152-100=52>>52 photos.\\nWhen Tim has 52 photos, then Paul has 52 + 10 photos = <<52+10=62>>62 photos.\\nTim and Paul have together 52 photos + 62 photos = <<52+62=114>>114 photos.\\nThat leaves Tom with 152 photos - 114 photos = <<152-114=38>>38 photos.",
|
48 |
+
'direct_answer': "38"
|
49 |
+
},
|
50 |
+
|
51 |
+
]
|
52 |
+
index_list = list(range(len(base_gsm8k_list)))
|
53 |
+
random.shuffle(index_list)
|
54 |
+
few_shot_example = ""
|
55 |
+
for i in index_list:
|
56 |
+
item = base_gsm8k_list[i]
|
57 |
+
few_shot_example += "Q: " + item['question'] + "\n" + "A: "+ item['answer'] + "\nThe answer is " + item['direct_answer'] + "\n"
|
58 |
+
|
59 |
+
few_shot_example += "Q: " + prompt + "A: "
|
60 |
+
return few_shot_example
|
61 |
+
|
62 |
+
|
63 |
+
def generate_prompt_translate(args, question):
|
64 |
+
return PROMPT_DICT['statement'].format(problem= question)
|
65 |
+
|
66 |
+
def generate_prompt_solver(args, question):
|
67 |
+
return PROMPT_DICT['plain'].format(statement_text= question)
|
68 |
+
|
69 |
+
|
70 |
+
def generate_prompt_generation(args, question):
|
71 |
+
if args.evaluation_mode == 'generation':
|
72 |
+
if args.method == 'zero_shot_cot':
|
73 |
+
content = question + " Let's think step by step."
|
74 |
+
elif args.method == 'zero_shot':
|
75 |
+
content = question
|
76 |
+
elif args.method == 'few_shot':
|
77 |
+
content = generate_few_shot(question)
|
78 |
+
else:
|
79 |
+
raise ValueError("we do not method for such model type yet")
|
80 |
+
|
81 |
+
if "generator" not in args.model_type:
|
82 |
+
MODEL_DICT = {
|
83 |
+
"llama": (
|
84 |
+
"[INST] \n{content}\n [/INST]"
|
85 |
+
),
|
86 |
+
"mistral": (
|
87 |
+
"<s>[INST] {content} [/INST]"
|
88 |
+
),
|
89 |
+
"chatglm": (
|
90 |
+
"<|user|> \n{content}\n <|assistant|>"
|
91 |
+
),
|
92 |
+
"qianwen": (
|
93 |
+
"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
|
94 |
+
),
|
95 |
+
"baichuan": (
|
96 |
+
"<reserved_106>{content}<reserved_107>"
|
97 |
+
)
|
98 |
+
}
|
99 |
+
|
100 |
+
if args.model_type in ["qianwen", "qianwen-13b", "qianwen-70b"]:
|
101 |
+
content = MODEL_DICT['qianwen'].format_map(
|
102 |
+
{'content': content}
|
103 |
+
)
|
104 |
+
|
105 |
+
elif args.model_type in ["chatglm"]:
|
106 |
+
pass
|
107 |
+
|
108 |
+
|
109 |
+
elif args.model_type in ['llama2-7b-chat']:
|
110 |
+
content = MODEL_DICT['llama'].format_map(
|
111 |
+
{'content': content}
|
112 |
+
)
|
113 |
+
|
114 |
+
elif args.model_type in ["mistral", 'mixtral']:
|
115 |
+
content = MODEL_DICT['mistral'].format_map(
|
116 |
+
{'content': content}
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
return content
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
few_shot_list = [
|
126 |
+
{
|
127 |
+
'question': "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
|
128 |
+
'answer': "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
|
129 |
+
'direct_answer': "6"
|
130 |
+
},
|
131 |
+
{
|
132 |
+
'question': "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
|
133 |
+
'answer': "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
|
134 |
+
'direct_answer': "5",
|
135 |
+
},
|
136 |
+
{
|
137 |
+
'question': "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
|
138 |
+
'answer': "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
|
139 |
+
'direct_answer': "39",
|
140 |
+
},
|
141 |
+
{
|
142 |
+
'question': "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
|
143 |
+
'answer': "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.",
|
144 |
+
'direct_answer': "8",
|
145 |
+
},
|
146 |
+
{
|
147 |
+
'question': "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
|
148 |
+
'answer': "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.",
|
149 |
+
'direct_answer': "9",
|
150 |
+
},
|
151 |
+
{
|
152 |
+
'question': "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
|
153 |
+
'answer': "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.",
|
154 |
+
'direct_answer': "29",
|
155 |
+
},
|
156 |
+
{
|
157 |
+
'question': "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
|
158 |
+
'answer': "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.",
|
159 |
+
'direct_answer': "33",
|
160 |
+
},
|
161 |
+
{
|
162 |
+
'question': "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
|
163 |
+
'answer': "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.",
|
164 |
+
'direct_answer': "8",
|
165 |
+
},
|
166 |
+
]
|
167 |
+
import json
|
168 |
+
|
169 |
+
from collections import Counter
|
170 |
+
|
171 |
+
|
172 |
+
def self_consistency(pairs):
|
173 |
+
val_counts = Counter(value for key, value in pairs)
|
174 |
+
most = val_counts.most_common(1)[0][0]
|
175 |
+
for key, value in pairs:
|
176 |
+
if value == most:
|
177 |
+
return key
|
178 |
+
|
179 |
+
|
180 |
+
#
|
181 |
+
def find_feedback(content):
|
182 |
+
match = re.search(r'Judgement: (.+)', content)
|
183 |
+
if match:
|
184 |
+
judgement = match.group(1)
|
185 |
+
else:
|
186 |
+
judgement = "None"
|
187 |
+
return judgement
|
188 |
+
|
189 |
+
|
190 |
+
def str2bool(s):
|
191 |
+
s = s.lower()
|
192 |
+
if s == 'true':
|
193 |
+
return True
|
194 |
+
elif s == 'false':
|
195 |
+
return False
|
196 |
+
else:
|
197 |
+
raise ValueError('invalid value: {}, must be true or false'.format(s))
|
198 |
+
|
199 |
+
|
200 |
+
def parse_arguments():
|
201 |
+
parser = argparse.ArgumentParser(description="Zero-shot-CoT")
|
202 |
+
|
203 |
+
# parser.add_argument(
|
204 |
+
# "--dataset", type=str, default="plan",
|
205 |
+
# choices=["plan", 'tool_use_awareness', 'tool_selection', 'tool_selection_harder', 'tool_creation_awareness',
|
206 |
+
# 'tool_creation_awareness_harder', 'tool_creation',
|
207 |
+
# 'arguments_filling'], help="dataset used for experiment")
|
208 |
+
parser.add_argument(
|
209 |
+
"--cot_trigger_no", type=int, default=1,
|
210 |
+
help="A trigger sentence that elicits a model to execute chain of thought"
|
211 |
+
)
|
212 |
+
parser.add_argument("--dataset", type=str, default="")
|
213 |
+
parser.add_argument("--data_path", type=str, default="")
|
214 |
+
parser.add_argument("--evaluation_mode", type=str, default="")
|
215 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
216 |
+
parser.add_argument("--eval_method", type=str, default="")
|
217 |
+
|
218 |
+
parser.add_argument("--model_path", type=str, default="")
|
219 |
+
|
220 |
+
parser.add_argument("--model_type", type=str, default="chatglm")
|
221 |
+
|
222 |
+
parser.add_argument("--output_dir", type=str, default="generation_test")
|
223 |
+
|
224 |
+
parser.add_argument("--lora_path", type=str, default="")
|
225 |
+
|
226 |
+
parser.add_argument("--iter_num", type=int, default=1)
|
227 |
+
parser.add_argument("--method", type=str, default="few_shot_cot")
|
228 |
+
parser.add_argument("--data_question_key", type=str, default="question")
|
229 |
+
parser.add_argument("--data_answer_key", type=str, default="answer")
|
230 |
+
|
231 |
+
parser.add_argument("--sample_num", type=int, default=1)
|
232 |
+
|
233 |
+
parser.add_argument("--cuda_ind", type=int, default=0)
|
234 |
+
parser.add_argument("--tensor_parallel", type=int, default=1)
|
235 |
+
parser.add_argument("--cuda_start", type=int, default=0)
|
236 |
+
parser.add_argument("--cuda_num", type=int, default=8)
|
237 |
+
|
238 |
+
parser.add_argument("--load_in_8bit", type=str2bool, default=False)
|
239 |
+
parser.add_argument("--rewrite", type=str2bool, default=True)
|
240 |
+
|
241 |
+
parser.add_argument("--use_typewriter", type=int, default=0)
|
242 |
+
|
243 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
244 |
+
parser.add_argument("--top_p", type=float, default=1)
|
245 |
+
parser.add_argument("--iter_max_new_tokens", type=int, default=512)
|
246 |
+
parser.add_argument("--init_max_new_tokens", type=int, default=2048)
|
247 |
+
parser.add_argument("--min_new_tokens", type=int, default=1)
|
248 |
+
parser.add_argument("--correct_response_format", type=str, default="The correct response is:")
|
249 |
+
|
250 |
+
args = parser.parse_args()
|
251 |
+
if args.evaluation_mode == 'generation':
|
252 |
+
if "lean" in args.dataset:
|
253 |
+
args.data_question_key = 'model_response'
|
254 |
+
args.data_answer_key = 'statement_poof'
|
255 |
+
|
256 |
+
if args.dataset == "lean4_5k_test":
|
257 |
+
args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
|
258 |
+
elif args.dataset == "lean4_basic_test":
|
259 |
+
args.data_path = "data/lean4_basic/1k_test.jsonl"
|
260 |
+
elif args.dataset == "lean4_random_test":
|
261 |
+
args.data_path = "data/lean4_random/1k_test.json"
|
262 |
+
elif args.dataset == "lean4_random_first_train":
|
263 |
+
args.data_path = "data/lean4_random/5k_first.json"
|
264 |
+
elif args.dataset == "lean4_random_second_train":
|
265 |
+
args.data_path = "data/lean4_random/5k_second.json"
|
266 |
+
elif args.dataset == "lean4_random_third_train":
|
267 |
+
args.data_path = "data/lean4_random/5k_third.json"
|
268 |
+
|
269 |
+
if args.model_type == 'mistral_generator':
|
270 |
+
args.model_path = 'models/gsm8k/generators/mistral-ep2/'
|
271 |
+
elif args.model_type == 'mistral_generator_original':
|
272 |
+
args.model_path = '/data/OVM-Mistral-7b/mistral7b-ep2/'
|
273 |
+
elif args.model_type == 'gemma_generator':
|
274 |
+
args.model_path = 'models/gsm8k/generators/gemma2b2-ep2/'
|
275 |
+
elif args.model_type == 'phi2_generator':
|
276 |
+
args.model_path = 'models/gsm8k/generators/phi2b-ep2/'
|
277 |
+
|
278 |
+
elif args.model_type == 'mixtral':
|
279 |
+
args.model_path = '/data/Mixtral-8x7B-Instruct-v0.1'
|
280 |
+
|
281 |
+
elif args.model_type == 'mistral':
|
282 |
+
args.model_path = '/data/mistral-instruct'
|
283 |
+
|
284 |
+
elif args.model_type == 'qianwen-70b':
|
285 |
+
args.model_path = '/data/Qwen-72B-Chat'
|
286 |
+
|
287 |
+
|
288 |
+
elif args.model_type == 'llama2-7b-chat':
|
289 |
+
args.model_path = '/data/Llama-2-7b-chat/'
|
290 |
+
|
291 |
+
if args.cot_trigger_no == 1:
|
292 |
+
args.cot_trigger = "Let's think step by step."
|
293 |
+
|
294 |
+
return args
|
295 |
+
|
296 |
+
|
297 |
+
def create_demo_text(args, cot_flag, index_list):
|
298 |
+
# Concatenate demonstration examples ...
|
299 |
+
demo_text = ""
|
300 |
+
for i in index_list:
|
301 |
+
item = few_shot_list[i]
|
302 |
+
if cot_flag:
|
303 |
+
demo_text += "Q: " + item['question'] + "\nA: " + item['answer'] + " " + \
|
304 |
+
args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
|
305 |
+
else:
|
306 |
+
demo_text += "Q: " + item['question'] + "\nA: " + \
|
307 |
+
args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
|
308 |
+
|
309 |
+
return demo_text
|
310 |
+
|
311 |
+
|
312 |
+
def str2bool(s):
|
313 |
+
s = s.lower()
|
314 |
+
if s == 'true':
|
315 |
+
return True
|
316 |
+
elif s == 'false':
|
317 |
+
return False
|
318 |
+
else:
|
319 |
+
raise ValueError('invalid value: {}, must be true or false'.format(s))
|
320 |
+
|
321 |
+
|
322 |
+
def batchify(pairs, batch_size):
|
323 |
+
|
324 |
+
"""将列表分成指定大小的批次"""
|
325 |
+
for i in range(0, len(pairs), batch_size):
|
326 |
+
yield pairs[i:i + batch_size]
|
327 |
+
|
328 |
+
|
329 |
+
def generate_prompts(questions, func, args):
|
330 |
+
"""为每个问题生成提示"""
|
331 |
+
prompts = [func(args, question) for question in questions]
|
332 |
+
return prompts
|
333 |
+
|
334 |
+
|
335 |
+
def get_question_answer(args):
|
336 |
+
allfilepath = args.data_path
|
337 |
+
questions = []
|
338 |
+
answers = []
|
339 |
+
|
340 |
+
# Attempt to read the file as a regular JSON file
|
341 |
+
for filepath in allfilepath.split(','):
|
342 |
+
try:
|
343 |
+
with open(filepath, 'r') as file:
|
344 |
+
data = json.load(file)
|
345 |
+
# If the data is a list, assume it's an array of objects
|
346 |
+
if isinstance(data, list):
|
347 |
+
for json_item in data:
|
348 |
+
questions.append(json_item[args.data_question_key])
|
349 |
+
answers.append(json_item)
|
350 |
+
# If the data is a dict, assume it's a single object (or adjust logic as needed)
|
351 |
+
elif isinstance(data, dict):
|
352 |
+
questions.append(data[args.data_question_key])
|
353 |
+
answers.append(json_item)
|
354 |
+
|
355 |
+
except ValueError:
|
356 |
+
# If it fails, assume the file is in JSON Lines format
|
357 |
+
with open(filepath, 'r') as file:
|
358 |
+
for line in file:
|
359 |
+
json_item = json.loads(line)
|
360 |
+
questions.append(json_item[args.data_question_key])
|
361 |
+
answers.append(json_item)
|
362 |
+
|
363 |
+
questions = [ PROMPT_DICT['lean4'].format(statement_text = item) for item in questions]
|
364 |
+
|
365 |
+
return questions, answers
|
366 |
+
|
367 |
+
|
368 |
+
def main3(args):
|
369 |
+
from vllm import LLM, SamplingParams
|
370 |
+
import torch
|
371 |
+
|
372 |
+
|
373 |
+
|
374 |
+
print("load data")
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
questions, answers = get_question_answer(args)
|
379 |
+
|
380 |
+
|
381 |
+
|
382 |
+
question_exist_list = []
|
383 |
+
write_pattern = 'w' if args.rewrite else "a+"
|
384 |
+
if os.path.exists(args.output_dir) and not args.rewrite :
|
385 |
+
# 如果文件存在,从文件中读取数据加载到response_list
|
386 |
+
# Loop through each file that matches the pattern
|
387 |
+
file_pattern = os.path.join(args.output_dir, '[0-9]*.json')
|
388 |
+
for file_path in glob.glob(file_pattern):
|
389 |
+
# Open and read the JSON file
|
390 |
+
with open(file_path, 'r') as fp:
|
391 |
+
# Extract the 'question' field from each line and add it to the list
|
392 |
+
for line in fp.readlines():
|
393 |
+
question_exist_list.append(json.loads(line)['question'])
|
394 |
+
else:
|
395 |
+
try:
|
396 |
+
os.mkdir(args.output_dir)
|
397 |
+
except:
|
398 |
+
pass
|
399 |
+
qa_pairs = [(questions[idx], answers[idx]) for idx in range(len(questions)) if questions[idx] not in question_exist_list ]
|
400 |
+
cuda_pieces = np.array_split(range(len(qa_pairs)), args.cuda_num // args.tensor_parallel)
|
401 |
+
print(f"fitered {len(questions) - len(qa_pairs)} already")
|
402 |
+
|
403 |
+
with open(f"{args.output_dir}/{args.cuda_ind // args.tensor_parallel + args.cuda_start}.json", write_pattern,
|
404 |
+
encoding='utf-8') as wf:
|
405 |
+
start = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][0]
|
406 |
+
end = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][-1] + 1
|
407 |
+
subset_length = end - start
|
408 |
+
total_batches = (subset_length + args.batch_size - 1) // args.batch_size # Calculate the total number of batches
|
409 |
+
for batch in tqdm(batchify(qa_pairs[start:end], args.batch_size), total=total_batches):
|
410 |
+
|
411 |
+
|
412 |
+
questions, answers = zip(*batch) # 解压问题和答案
|
413 |
+
with torch.no_grad():
|
414 |
+
model = LLM(model=args.translate_model_path, dtype="bfloat16", trust_remote_code=True,
|
415 |
+
tensor_parallel_size=args.tensor_parallel, gpu_memory_utilization=0.95)
|
416 |
+
|
417 |
+
translate_prompts = generate_prompts(questions, generate_prompt_translate, args)
|
418 |
+
translate_output_all = []
|
419 |
+
try:
|
420 |
+
for i in range(args.sample_num):
|
421 |
+
sample_list = []
|
422 |
+
sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p,
|
423 |
+
max_tokens=args.init_max_new_tokens)
|
424 |
+
generations = model.generate(translate_prompts, sampling_params, use_tqdm=False)
|
425 |
+
for generation_output in generations:
|
426 |
+
output = generation_output.outputs[0].text
|
427 |
+
sample_list.append(output)
|
428 |
+
translate_output_all.append(sample_list)
|
429 |
+
|
430 |
+
translate_output_all = list(map(list, zip(*translate_output_all)))
|
431 |
+
except Exception as e:
|
432 |
+
print(str(e))
|
433 |
+
exit
|
434 |
+
|
435 |
+
for batch in tqdm(batchify(qa_pairs[start:end], args.batch_size), total=total_batches):
|
436 |
+
questions, answers = zip(*batch) # 解压问题和答案
|
437 |
+
with torch.no_grad():
|
438 |
+
|
439 |
+
model = LLM(model=args.solver_model_path, dtype="bfloat16", trust_remote_code=True,
|
440 |
+
tensor_parallel_size=args.tensor_parallel, gpu_memory_utilization=0.95)
|
441 |
+
solver_prompts = generate_prompts(translate_output_all, generate_prompt_solver ,args)
|
442 |
+
solver_output_all = []
|
443 |
+
|
444 |
+
try:
|
445 |
+
for i in range(args.sample_num):
|
446 |
+
solver_sample_list = []
|
447 |
+
sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p,
|
448 |
+
max_tokens=args.init_max_new_tokens)
|
449 |
+
generations = model.generate(solver_prompts, sampling_params, use_tqdm=False)
|
450 |
+
for generation_output in generations:
|
451 |
+
output = generation_output.outputs[0].text
|
452 |
+
solver_sample_list.append(output)
|
453 |
+
translate_output_all.append(solver_sample_list)
|
454 |
+
|
455 |
+
translate_output_all= list(map(list, zip(*translate_output_all)))
|
456 |
+
|
457 |
+
except Exception as e:
|
458 |
+
print(str(e))
|
459 |
+
exit
|
460 |
+
dicts = []
|
461 |
+
|
462 |
+
for question, answer,translate_output, translate_prompt, solver_output, solver_prompt in zip(questions, answers, translate_output_all, translate_prompts, solver_output_all, solver_prompts):
|
463 |
+
dicts.append({
|
464 |
+
"question": question,
|
465 |
+
"translate output": translate_output,
|
466 |
+
"translate prompt": translate_prompt,
|
467 |
+
"soler output": solver_output,
|
468 |
+
"soler prompt": solver_prompt,
|
469 |
+
"answer": answer,
|
470 |
+
})
|
471 |
+
|
472 |
+
for dict in dicts:
|
473 |
+
wf.writelines(json.dumps(dict, ensure_ascii=False) + '\n')
|
474 |
+
|
475 |
+
wf.flush()
|
476 |
+
|
477 |
+
|
478 |
+
def main(argv=None):
|
479 |
+
args = parse_arguments()
|
480 |
+
print('*****************************')
|
481 |
+
print(args)
|
482 |
+
print('*****************************')
|
483 |
+
if args.evaluation_mode == 'generation':
|
484 |
+
main3(args)
|
485 |
+
else:
|
486 |
+
raise ValueError("we do not yet inplement")
|
487 |
+
|
488 |
+
|
489 |
+
if __name__ == "__main__":
|
490 |
+
main()
|
491 |
+
|