shekkari21 commited on
Commit
2a21e9f
·
1 Parent(s): 5f767e8

Deploy to HF Space

Browse files
__pycache__/configs.cpython-39.pyc ADDED
Binary file (5.38 kB). View file
 
__pycache__/fastapi_app.cpython-39.pyc ADDED
Binary file (4.81 kB). View file
 
__pycache__/models.cpython-39.pyc ADDED
Binary file (6.67 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (27.8 kB). View file
 
clear.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clearml import Model
2
+ import torch
3
+ import os
4
+ # Import needed classes for local loading and LoRA construction
5
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
+ from peft import LoraConfig, get_peft_model
7
+
8
+ # 1. Download the LoRA checkpoint artifact from ClearML
9
+ CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e"
10
+ model_obj = Model(model_id=CLEARML_MODEL_ID)
11
+ checkpoint_path = model_obj.get_local_copy()
12
+ adapter_dir = os.path.dirname(checkpoint_path)
13
+ print(f"LoRA checkpoint downloaded to: {checkpoint_path}")
14
+
15
+ # 2. Load the base pretrained CodeT5 model and tokenizer from local config.json directory
16
+ BASE_MODEL_PATH = "microsoft/codereviewer"
17
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
18
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_PATH)
19
+
20
+ # Print all base model parameters and their shapes
21
+ print("\nBase model parameters:")
22
+ for name, param in base_model.named_parameters():
23
+ print(f"{name}: {tuple(param.shape)}")
24
+
25
+ # 3. Reconstruct and attach LoRA adapters
26
+ lora_config = LoraConfig(
27
+ r=64,
28
+ lora_alpha=128,
29
+ target_modules=["q", "k", "v", "o", "wi", "wo"],
30
+ lora_dropout=0.05,
31
+ bias="none",
32
+ task_type="SEQ_2_SEQ_LM"
33
+ )
34
+ model = get_peft_model(base_model, lora_config)
35
+
36
+ # 4. Load LoRA adapter weights from ClearML checkpoint
37
+ adapter_state = torch.load(checkpoint_path, map_location="cpu")
38
+ model.load_state_dict(adapter_state, strict=False)
39
+
40
+ # 5. Move to CPU and set evaluation mode
41
+ model.to("cpu").eval()
42
+
43
+ print("Model with LoRA adapters loaded and ready for inference.")
44
+
45
+ # Print out all LoRA adapter parameter names and shapes as before
46
+ print("\nFinetuned (LoRA adapter) parameters:")
47
+ for name, param in model.named_parameters():
48
+ if "lora_" in name:
49
+ print(f"{name}: {tuple(param.shape)}")
configs.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import logging
4
+ import multiprocessing
5
+ import numpy as np
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def add_args(parser):
11
+ parser.add_argument(
12
+ "--task",
13
+ type=str,
14
+ required=False,
15
+ choices=[
16
+ "review",
17
+ ],
18
+ )
19
+ parser.add_argument(
20
+ "--model_type",
21
+ default="codet5",
22
+ type=str,
23
+ choices=["roberta", "t5", "bart", "codet5", "scratch"],
24
+ )
25
+ parser.add_argument("--add_lang_ids", action="store_true")
26
+ parser.add_argument("--from_scratch", action="store_true")
27
+ parser.add_argument("--debug", action="store_true")
28
+ parser.add_argument("--start_epoch", default=0, type=int)
29
+ parser.add_argument("--train_epochs", default=10, type=int)
30
+ parser.add_argument("--tokenizer_path", type=str, required=False)
31
+
32
+ parser.add_argument(
33
+ "--output_dir",
34
+ default=None,
35
+ type=str,
36
+ required=False,
37
+ help="The output directory where the model predictions and checkpoints will be written.",
38
+ )
39
+ parser.add_argument(
40
+ "--load_model_path",
41
+ default=None,
42
+ type=str,
43
+ required=False
44
+ )
45
+ parser.add_argument(
46
+ "--model_name_or_path",
47
+ default=None,
48
+ type=str,
49
+ help="Path to trained model: Should contain the .bin files",
50
+ )
51
+ ## Other parameters
52
+ parser.add_argument(
53
+ "--train_path",
54
+ default=None,
55
+ type=str,
56
+ help="The pretrain files path. Should contain the .jsonl files for this task.",
57
+ )
58
+ parser.add_argument(
59
+ "--eval_chunkname",
60
+ default=None,
61
+ type=str,
62
+ help="The eval file name.",
63
+ )
64
+ parser.add_argument(
65
+ "--train_filename",
66
+ default=None,
67
+ type=str,
68
+ help="The train filename. Should contain the .jsonl files for this task.",
69
+ )
70
+ parser.add_argument(
71
+ "--dev_filename",
72
+ default=None,
73
+ type=str,
74
+ help="The dev filename. Should contain the .jsonl files for this task.",
75
+ )
76
+ parser.add_argument(
77
+ "--test_filename",
78
+ default=None,
79
+ type=str,
80
+ help="The test filename. Should contain the .jsonl files for this task.",
81
+ )
82
+ parser.add_argument(
83
+ "--gold_filename",
84
+ default=None,
85
+ type=str,
86
+ help="The gold filename. Should contain the .jsonl files for this task.",
87
+ )
88
+ parser.add_argument(
89
+ "--config_name",
90
+ default="Salesforce/codet5-base",
91
+ type=str,
92
+ help="Pretrained config name or path if not the same as model_name",
93
+ )
94
+ parser.add_argument(
95
+ "--max_source_length",
96
+ default=64,
97
+ type=int,
98
+ help="The maximum total source sequence length after tokenization. Sequences longer "
99
+ "than this will be truncated, sequences shorter will be padded.",
100
+ )
101
+ parser.add_argument(
102
+ "--max_target_length",
103
+ default=32,
104
+ type=int,
105
+ help="The maximum total target sequence length after tokenization. Sequences longer "
106
+ "than this will be truncated, sequences shorter will be padded.",
107
+ )
108
+ parser.add_argument(
109
+ "--do_train", action="store_true", help="Whether to run eval on the train set."
110
+ )
111
+ parser.add_argument(
112
+ "--do_eval", action="store_true", help="Whether to run eval on the dev set."
113
+ )
114
+ parser.add_argument(
115
+ "--do_test", action="store_true", help="Whether to run eval on the dev set."
116
+ )
117
+ parser.add_argument(
118
+ "--raw_input", action="store_true", help="Whether to use simple input format (set for baselines)."
119
+ )
120
+ parser.add_argument(
121
+ "--do_lower_case",
122
+ action="store_true",
123
+ help="Set this flag if you are using an uncased model.",
124
+ )
125
+ parser.add_argument(
126
+ "--no_cuda", action="store_true", help="Avoid using CUDA when available"
127
+ )
128
+ parser.add_argument(
129
+ "--train_batch_size",
130
+ default=8,
131
+ type=int,
132
+ help="Batch size per GPU/CPU for training.",
133
+ )
134
+ parser.add_argument(
135
+ "--eval_batch_size",
136
+ default=8,
137
+ type=int,
138
+ help="Batch size per GPU/CPU for evaluation.",
139
+ )
140
+ parser.add_argument(
141
+ "--gradient_accumulation_steps",
142
+ type=int,
143
+ default=1,
144
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
145
+ )
146
+ parser.add_argument(
147
+ "--learning_rate",
148
+ default=5e-5,
149
+ type=float,
150
+ help="The initial learning rate for Adam.",
151
+ )
152
+ parser.add_argument(
153
+ "--mask_rate", default=0.15, type=float, help="The masked rate of input lines.",
154
+ )
155
+ parser.add_argument(
156
+ "--beam_size", default=6, type=int, help="beam size for beam search"
157
+ )
158
+ parser.add_argument(
159
+ "--weight_decay", default=0.0, type=float, help="Weight deay if we apply some."
160
+ )
161
+ parser.add_argument(
162
+ "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
163
+ )
164
+ parser.add_argument(
165
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
166
+ )
167
+ parser.add_argument(
168
+ "--save_steps", default=-1, type=int,
169
+ )
170
+ parser.add_argument(
171
+ "--log_steps", default=-1, type=int,
172
+ )
173
+ parser.add_argument("--eval_steps", default=-1, type=int, help="")
174
+ parser.add_argument("--eval_file", default="", type=str)
175
+ parser.add_argument("--out_file", default="", type=str)
176
+ parser.add_argument("--break_cnt", default=-1, type=int)
177
+ parser.add_argument("--train_steps", default=-1, type=int, help="")
178
+ parser.add_argument(
179
+ "--warmup_steps", default=100, type=int, help="Linear warmup over warmup_steps."
180
+ )
181
+ parser.add_argument(
182
+ "--gpu_per_node",
183
+ type=int,
184
+ default=4,
185
+ help="gpus per node",
186
+ )
187
+ parser.add_argument(
188
+ "--node_index",
189
+ type=int,
190
+ default=0,
191
+ help="For distributed training: node_index",
192
+ )
193
+ parser.add_argument(
194
+ "--local_rank",
195
+ type=int,
196
+ default=-1,
197
+ help="For distributed training: local_rank",
198
+ )
199
+ parser.add_argument(
200
+ "--seed", type=int, default=2233, help="random seed for initialization"
201
+ ) # previous one 42
202
+ # Or in configs.py if add_args is defined there
203
+
204
+ parser.add_argument(
205
+ "--clearml_train_dataset_id",
206
+ type=str,
207
+ default=None,
208
+ help="ClearML Dataset ID to fetch training data from. Overrides train_filename if provided.",
209
+ )
210
+ parser.add_argument(
211
+ "--clearml_valid_dataset_id",
212
+ type=str,
213
+ default=None,
214
+ help="ClearML Dataset ID to fetch validation data from. Overrides dev_filename if provided.",
215
+ )
216
+ args = parser.parse_args()
217
+ return args
218
+
219
+
220
+ def set_dist(args):
221
+ # Setup CUDA, GPU & distributed training
222
+ if args.local_rank == -1 or args.no_cuda:
223
+ device = torch.device(
224
+ "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
225
+ )
226
+ args.n_gpu = torch.cuda.device_count()
227
+ else:
228
+ # Setup for distributed data parallel
229
+ torch.cuda.set_device(args.local_rank)
230
+ device = torch.device("cuda", args.local_rank)
231
+ torch.distributed.init_process_group(backend="nccl")
232
+ args.n_gpu = 1
233
+ cpu_count = multiprocessing.cpu_count()
234
+ logger.warning(
235
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, cpu count: %d",
236
+ args.local_rank,
237
+ device,
238
+ args.n_gpu,
239
+ bool(args.local_rank != -1),
240
+ cpu_count,
241
+ )
242
+ args.device = device
243
+ args.cpu_count = cpu_count
244
+
245
+
246
+ def set_seed(args):
247
+ """set random seed."""
248
+ random.seed(args.seed)
249
+ np.random.seed(args.seed)
250
+ torch.manual_seed(args.seed)
251
+ # if args.n_gpu > 0:
252
+ torch.cuda.manual_seed_all(args.seed)
extract_pr_comment.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import requests
5
+ from github import Github
6
+
7
+ # Path to the GitHub Actions event payload
8
+ event_path = os.environ.get("GITHUB_EVENT_PATH")
9
+ if not event_path or not os.path.exists(event_path):
10
+ print("No event payload found.")
11
+ exit(1)
12
+
13
+ with open(event_path, "r") as f:
14
+ event = json.load(f)
15
+
16
+ # Only proceed if this is a PR comment event
17
+ if "pull_request" not in event.get("issue", {}):
18
+ print("Not a PR comment event.")
19
+ exit(0)
20
+
21
+ pr_number = event["issue"]["number"]
22
+ comment_body = event["comment"]["body"]
23
+ repo_full_name = event["repository"]["full_name"]
24
+ token = os.environ.get("GITHUB_TOKEN")
25
+
26
+ if not token:
27
+ print("No GITHUB_TOKEN found in environment.")
28
+ exit(1)
29
+
30
+ gh = Github(token)
31
+ repo = gh.get_repo(repo_full_name)
32
+ pr = repo.get_pull(pr_number)
33
+
34
+ files = []
35
+ for file in pr.get_files():
36
+ cf = repo.get_contents(file.filename, ref=pr.head.sha)
37
+ content = base64.b64decode(cf.content).decode("utf-8")
38
+ files.append({"filename": file.filename, "content": content})
39
+
40
+ fastapi_url = "http://127.0.0.1:8000/pr-comments"
41
+ payload = {
42
+ "comment": comment_body,
43
+ "files": files
44
+ }
45
+ response = requests.post(fastapi_url, json=payload)
46
+ print(f"FastAPI response: {response.status_code} {response.text}")
fastapi_app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Form
2
+ from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
3
+ from pydantic import BaseModel
4
+ from typing import List
5
+ from clearml import Model
6
+ import torch
7
+ from configs import add_args
8
+ from models import build_or_load_gen_model
9
+ import argparse
10
+ from argparse import Namespace
11
+ import os
12
+ from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
13
+
14
+ MAX_SOURCE_LENGTH = 512
15
+
16
+ def pad_assert(tokenizer, source_ids):
17
+ source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
18
+ source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
19
+ pad_len = MAX_SOURCE_LENGTH - len(source_ids)
20
+ source_ids += [tokenizer.pad_id] * pad_len
21
+ assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
22
+ return source_ids
23
+
24
+ # Encode code content and comment into model input
25
+ def encode_diff(tokenizer, code, comment):
26
+ # Tokenize code file content
27
+ code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
28
+ # Tokenize comment
29
+ comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
30
+ # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment
31
+ source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id]
32
+ source_ids += [tokenizer.msg_id] + comment_ids
33
+ # Pad/truncate to fixed length
34
+ source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
35
+ source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
36
+ pad_len = MAX_SOURCE_LENGTH - len(source_ids)
37
+ source_ids += [tokenizer.pad_id] * pad_len
38
+ assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
39
+ return source_ids
40
+
41
+ # Load base model architecture and tokenizer from HuggingFace
42
+ BASE_MODEL_NAME = "microsoft/codereviewer"
43
+ args = Namespace(
44
+ model_name_or_path=BASE_MODEL_NAME,
45
+ load_model_path=None,
46
+ # Add other necessary default arguments if build_or_load_gen_model requires them
47
+ )
48
+ print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}")
49
+ config, base_model, tokenizer = build_or_load_gen_model(args)
50
+ print("Base model architecture and tokenizer loaded.")
51
+
52
+ # Download the fine-tuned weights from ClearML
53
+ CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e"
54
+ model_obj = Model(model_id=CLEARML_MODEL_ID)
55
+ finetuned_weights_path = model_obj.get_local_copy()
56
+ adapter_dir = os.path.dirname(finetuned_weights_path)
57
+
58
+ print(f"Fine-tuned adapter weights downloaded to directory: {adapter_dir}")
59
+
60
+ # Create LoRA configuration matching the fine-tuned checkpoint
61
+ lora_cfg = LoraConfig(
62
+ r=64,
63
+ lora_alpha=128,
64
+ target_modules=["q", "wo", "wi", "v", "o", "k"],
65
+ lora_dropout=0.05,
66
+ bias="none",
67
+ task_type="SEQ_2_SEQ_LM"
68
+ )
69
+ # Wrap base model with PEFT LoRA
70
+ peft_model = get_peft_model(base_model, lora_cfg)
71
+ # Load adapter-only weights and merge into base
72
+ adapter_state = torch.load(finetuned_weights_path, map_location="cpu")
73
+ peft_model.load_state_dict(adapter_state, strict=False)
74
+ model = peft_model.merge_and_unload()
75
+ print("Merged base model with LoRA adapters.")
76
+
77
+ model.to("cpu")
78
+ model.eval()
79
+ print("Model ready for inference.")
80
+
81
+ app = FastAPI()
82
+
83
+ last_payload = {"comment": "", "files": []}
84
+ last_infer_result = {"generated_code": ""}
85
+
86
+ class FileContent(BaseModel):
87
+ filename: str
88
+ content: str
89
+
90
+ class PRPayload(BaseModel):
91
+ comment: str
92
+ files: List[FileContent]
93
+
94
+ class InferenceRequest(BaseModel):
95
+ comment: str
96
+ files: List[FileContent]
97
+
98
+
99
+ @app.get("/")
100
+ def root():
101
+ return {"message": "FastAPI PR comment service is running"}
102
+
103
+ @app.post("/pr-comments")
104
+ async def receive_pr_comment(payload: PRPayload):
105
+ global last_payload
106
+ last_payload = payload.dict()
107
+ # Return the received payload as JSON and also redirect to /show
108
+ return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"})
109
+
110
+ @app.get("/show", response_class=HTMLResponse)
111
+ def show_last_comment():
112
+ html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>"
113
+ for file in last_payload["files"]:
114
+ html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>"
115
+ return html
116
+
117
+ @app.post("/infer")
118
+ async def infer(request: InferenceRequest):
119
+ global last_infer_result
120
+ print("[DEBUG] Received /infer request with:", request.dict())
121
+
122
+ code = request.files[0].content if request.files else ""
123
+ source_ids = encode_diff(tokenizer, code, request.comment)
124
+ # print("[DEBUG] source_ids:", source_ids)
125
+ #tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids]
126
+ #print("[DEBUG] tokens:", tokens)
127
+ inputs = torch.tensor([source_ids], dtype=torch.long)
128
+ inputs_mask = inputs.ne(tokenizer.pad_id)
129
+
130
+ preds = model.generate(
131
+ inputs,
132
+ attention_mask=inputs_mask,
133
+ use_cache=True,
134
+ num_beams=5,
135
+ early_stopping=True,
136
+ max_length=100,
137
+ num_return_sequences=1
138
+ )
139
+
140
+ pred = preds[0].cpu().numpy()
141
+ pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
142
+ last_infer_result = {"generated_code": pred_nl}
143
+ return last_infer_result
144
+
145
+ @app.get("/show-infer", response_class=HTMLResponse)
146
+ def show_infer_result():
147
+ html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>"
148
+ return html
149
+
150
+ if __name__ == "__main__":
151
+ # Place any CLI/training logic here if needed
152
+ # This block is NOT executed when running with uvicorn
153
+ pass
models.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch.nn as nn
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
6
+ import numpy as np
7
+ from utils import MyTokenizer
8
+ from transformers import (
9
+ RobertaConfig,
10
+ RobertaModel,
11
+ RobertaTokenizer,
12
+ BartConfig,
13
+ BartForConditionalGeneration,
14
+ BartTokenizer,
15
+ T5Config,
16
+ T5ForConditionalGeneration,
17
+ T5Tokenizer,
18
+ )
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ReviewerModel(T5ForConditionalGeneration):
25
+
26
+ def __init__(self, config):
27
+ super().__init__(config)
28
+ self.cls_head = nn.Linear(self.config.d_model, 2, bias=True)
29
+ self.init()
30
+
31
+ def init(self):
32
+ nn.init.xavier_uniform_(self.lm_head.weight)
33
+ factor = self.config.initializer_factor
34
+ self.cls_head.weight.data.normal_(mean=0.0, \
35
+ std=factor * ((self.config.d_model) ** -0.5))
36
+ self.cls_head.bias.data.zero_()
37
+
38
+ def forward(
39
+ self, *argv, **kwargs
40
+ ):
41
+ r"""
42
+ Doc from Huggingface transformers:
43
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
44
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
45
+ config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
46
+ labels in ``[0, ..., config.vocab_size]``
47
+ Returns:
48
+ Examples::
49
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
50
+ >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
51
+ >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
52
+ >>> # training
53
+ >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
54
+ >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
55
+ >>> outputs = model(input_ids=input_ids, labels=labels)
56
+ >>> loss = outputs.loss
57
+ >>> logits = outputs.logits
58
+ >>> # inference
59
+ >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
60
+ >>> outputs = model.generate(input_ids)
61
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
62
+ >>> # studies have shown that owning a dog is good for you.
63
+ """
64
+ if "cls" in kwargs:
65
+ assert (
66
+ "input_ids" in kwargs and \
67
+ "labels" in kwargs and \
68
+ "attention_mask" in kwargs
69
+ )
70
+ return self.cls(
71
+ input_ids=kwargs["input_ids"],
72
+ labels=kwargs["labels"],
73
+ attention_mask=kwargs["attention_mask"],
74
+ )
75
+ if "input_labels" in kwargs:
76
+ assert (
77
+ "input_ids" in kwargs and \
78
+ "input_labels" in kwargs and \
79
+ "decoder_input_ids" in kwargs and \
80
+ "attention_mask" in kwargs and \
81
+ "decoder_attention_mask" in kwargs
82
+ ), "Please give these arg keys."
83
+ input_ids = kwargs["input_ids"]
84
+ input_labels = kwargs["input_labels"]
85
+ decoder_input_ids = kwargs["decoder_input_ids"]
86
+ attention_mask = kwargs["attention_mask"]
87
+ decoder_attention_mask = kwargs["decoder_attention_mask"]
88
+ if "encoder_loss" not in kwargs:
89
+ encoder_loss = True
90
+ else:
91
+ encoder_loss = kwargs["encoder_loss"]
92
+ return self.review_forward(input_ids, input_labels, decoder_input_ids, attention_mask, decoder_attention_mask, encoder_loss)
93
+ return super().forward(*argv, **kwargs)
94
+
95
+ def cls(
96
+ self,
97
+ input_ids,
98
+ labels,
99
+ attention_mask,
100
+ ):
101
+ encoder_outputs = self.encoder( \
102
+ input_ids=input_ids,
103
+ attention_mask=attention_mask,
104
+ output_attentions=False,
105
+ return_dict=False
106
+ )
107
+ hidden_states = encoder_outputs[0]
108
+ first_hidden = hidden_states[:, 0, :]
109
+ first_hidden = nn.Dropout(0.3)(first_hidden)
110
+ logits = self.cls_head(first_hidden)
111
+ loss_fct = CrossEntropyLoss()
112
+ if labels != None:
113
+ loss = loss_fct(logits, labels)
114
+ return loss
115
+ return logits
116
+
117
+ def review_forward(
118
+ self,
119
+ input_ids,
120
+ input_labels,
121
+ decoder_input_ids,
122
+ attention_mask,
123
+ decoder_attention_mask,
124
+ encoder_loss=True
125
+ ):
126
+ encoder_outputs = self.encoder( \
127
+ input_ids=input_ids,
128
+ attention_mask=attention_mask,
129
+ output_attentions=False,
130
+ return_dict=False
131
+ )
132
+ hidden_states = encoder_outputs[0]
133
+ decoder_inputs = self._shift_right(decoder_input_ids)
134
+ # Decode
135
+ decoder_outputs = self.decoder(
136
+ input_ids=decoder_inputs,
137
+ attention_mask=decoder_attention_mask,
138
+ encoder_hidden_states=hidden_states,
139
+ encoder_attention_mask=attention_mask,
140
+ output_attentions=False,
141
+ return_dict=False
142
+ )
143
+ sequence_output = decoder_outputs[0]
144
+ if self.config.tie_word_embeddings: # this is True default
145
+ sequence_output = sequence_output * (self.model_dim ** -0.5)
146
+ if encoder_loss:
147
+ # print(self.encoder.get_input_embeddings().weight.shape)
148
+ cls_logits = nn.functional.linear(hidden_states, self.encoder.get_input_embeddings().weight)
149
+ # cls_logits = self.cls_head(hidden_states)
150
+ lm_logits = self.lm_head(sequence_output)
151
+ if decoder_input_ids is not None:
152
+ lm_loss_fct = CrossEntropyLoss(ignore_index=0) # Warning: PAD_ID should be 0
153
+ loss = lm_loss_fct(lm_logits.view(-1, lm_logits.size(-1)), decoder_input_ids.view(-1))
154
+ if encoder_loss and input_labels is not None:
155
+ cls_loss_fct = CrossEntropyLoss(ignore_index=-100)
156
+ loss += cls_loss_fct(cls_logits.view(-1, cls_logits.size(-1)), input_labels.view(-1))
157
+ return loss
158
+ return cls_logits, lm_logits
159
+
160
+ def get_model_size(model):
161
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
162
+ model_size = sum([np.prod(p.size()) for p in model_parameters])
163
+ return "{}M".format(round(model_size / 1e6))
164
+
165
+
166
+ def build_or_load_gen_model(args):
167
+ config_class, model_class, tokenizer_class = T5Config, ReviewerModel, RobertaTokenizer
168
+
169
+ config = config_class.from_pretrained(args.model_name_or_path)
170
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
171
+ model = model_class.from_pretrained(args.model_name_or_path, config=config)
172
+
173
+ tokenizer.special_dict = {
174
+ f"<e{i}>" : tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1)
175
+ }
176
+
177
+ tokenizer.mask_id = tokenizer.get_vocab()["<mask>"]
178
+ tokenizer.bos_id = tokenizer.get_vocab()["<s>"]
179
+ tokenizer.pad_id = tokenizer.get_vocab()["<pad>"]
180
+ tokenizer.eos_id = tokenizer.get_vocab()["</s>"]
181
+ tokenizer.msg_id = tokenizer.get_vocab()["<msg>"]
182
+ tokenizer.keep_id = tokenizer.get_vocab()["<keep>"]
183
+ tokenizer.add_id = tokenizer.get_vocab()["<add>"]
184
+ tokenizer.del_id = tokenizer.get_vocab()["<del>"]
185
+ tokenizer.start_id = tokenizer.get_vocab()["<start>"]
186
+ tokenizer.end_id = tokenizer.get_vocab()["<end>"]
187
+
188
+ logger.info(
189
+ "Finish loading model [%s] from %s",
190
+ get_model_size(model),
191
+ args.model_name_or_path,
192
+ )
193
+
194
+ if args.load_model_path is not None:
195
+ model_path = os.path.join(args.load_model_path, "pytorch_model.bin")
196
+ logger.info("Reload model from {}".format(model_path))
197
+ try:
198
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
199
+ except RuntimeError:
200
+ saved = model.cls_head
201
+ model.cls_head = None
202
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
203
+ model.cls_head = saved
204
+ model.to(args.local_rank)
205
+
206
+ return config, model, tokenizer
207
+
208
+
test.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64 # For decoding Base64 content
2
+ import requests # For HTTP GET on raw_url
3
+ from github import Github # PyGithub
4
+
5
+ # ==== CHANGE VALUES BELOW ====================================
6
+ TOKEN = "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" # <-- Change: Your GitHub PAT
7
+ OWNER = "Habil7" # <-- Change: Repo owner
8
+ REPO_NAME = "git-demo" # <-- Change: Repo name
9
+ PR_NUMBER = 4 # <-- Change: Pull request number
10
+ # =============================================================
11
+
12
+ gh = Github(TOKEN)
13
+ repo = gh.get_repo(f"{OWNER}/{REPO_NAME}")
14
+ pr = repo.get_pull(PR_NUMBER)
15
+ print(pr)
16
+
17
+ # Print PR comments
18
+ print("\n--- PR Comments ---")
19
+ for comment in pr.get_issue_comments():
20
+ print(f"{comment.user.login}: {comment.body}")
21
+
22
+ print(f"Number of files in PR: {pr.get_files().totalCount}")
23
+
24
+ for file in pr.get_files():
25
+ print(f"\n=== {file.filename} ===")
26
+ # Fetch and decode via PyGithub get_contents
27
+ cf = repo.get_contents(file.filename, ref=pr.head.sha)
28
+ content_via_api = base64.b64decode(cf.content).decode("utf-8")
29
+ print(content_via_api)
utils.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, json
2
+ import os, random
3
+ import torch, logging
4
+ from copy import deepcopy as cp
5
+ from torch.utils.data import Dataset
6
+ from tokenizers import ByteLevelBPETokenizer
7
+ from transformers import T5Tokenizer, RobertaTokenizer
8
+ import nltk
9
+
10
+ logging.basicConfig(
11
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
12
+ datefmt="%m/%d/%Y %H:%M:%S",
13
+ level=logging.INFO,
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+
19
+ class MyTokenizer(object):
20
+ """
21
+ Wrapper for ByteLevelBPETokenizer
22
+ """
23
+ def __init__(self, vocab=None, merges=None, **kwargs):
24
+ self.tokenizer = ByteLevelBPETokenizer(vocab, merges, **kwargs)
25
+ self.update_id2token()
26
+
27
+ @staticmethod
28
+ def from_pretrained(path):
29
+ vocabp = os.path.join(path, "vocab.json")
30
+ mergesp = os.path.join(path, "merges.txt")
31
+ mytoken = MyTokenizer(vocabp, mergesp)
32
+ return mytoken
33
+
34
+ def update_id2token(self):
35
+ vocab = self.tokenizer.get_vocab()
36
+ self.id2token = {vocab[token]: token for token in vocab}
37
+
38
+ def add_special_tokens(self, dic):
39
+ for values in dic.values():
40
+ self.tokenizer.add_special_tokens(values)
41
+ self.update_id2token()
42
+
43
+ def convert_ids_to_tokens(self, ids):
44
+ vocab = self.id2token
45
+ return [vocab[i] for i in ids]
46
+
47
+ def decode(self, ids, **kwargs): ##### to be update
48
+ tokens = self.convert_ids_to_tokens(ids)
49
+ return " ".join(tokens)
50
+
51
+ def encode(self, text, **kwargs):
52
+ text = text.encode("ascii", errors="ignore").decode("ascii")
53
+ return self.tokenizer.encode(text).ids
54
+
55
+ def get_vocab(self):
56
+ return self.tokenizer.get_vocab()
57
+
58
+ def __len__(self):
59
+ return len(self.tokenizer.get_vocab())
60
+
61
+
62
+ class RefineFeatures(object):
63
+ def __init__(self, example_id, source_ids, target_ids):
64
+ self.example_id = example_id
65
+ self.source_ids = source_ids
66
+ self.target_ids = target_ids
67
+
68
+ class RefineDataset(Dataset):
69
+ def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
70
+ self.tokenizer = tokenizer
71
+ self.args = args
72
+ logger.info("Reading examples from {}".format(file_path))
73
+ examples = [json.loads(line) for line in open(file_path)]
74
+ for i in range(len(examples)):
75
+ if "id" not in examples[i]:
76
+ examples[i]["id"] = i
77
+ if samplenum > 0:
78
+ examples = examples[:samplenum]
79
+ logger.info(f"Tokenize examples: {file_path}")
80
+ self.feats = pool.map(self.tokenize, \
81
+ [(example, tokenizer, args) for example in examples])
82
+
83
+ def tokenize(self, item):
84
+ example, tokenizer, args = item
85
+ oldlines = example["old"].split("\n")
86
+ newlines = example["new"].split("\n")
87
+ oldlines = [line[1:].strip() for line in oldlines]
88
+ newlines = [line[1:].strip() for line in newlines]
89
+ oldlines = "\n".join(oldlines)
90
+ newlines = "\n".join(newlines)
91
+ oldlines = "<add>" + oldlines.replace("\n", "<add>")
92
+ newlines = "<add>" + newlines.replace("\n", "<add>")
93
+ comment = example["comment"]
94
+ srcids = self.encode_remove(tokenizer, oldlines, args)
95
+ srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args)
96
+ tgtids = self.encode_remove(tokenizer, newlines, args)
97
+ srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer)
98
+ return RefineFeatures(example["id"], srcids, tgtids)
99
+
100
+ @staticmethod
101
+ def process_pred_gold(pred, gold):
102
+ gold = gold.split("\n")
103
+ gold = [line[1:].strip() for line in gold]
104
+ gold = " ".join(gold)
105
+ pred = " ".join(pred.split())
106
+ pred = pred.replace("<add> ", "")
107
+ return pred, gold
108
+
109
+ def pad_assert(self, source_ids, target_ids, args, tokenizer):
110
+ source_ids = source_ids[:args.max_source_length - 2]
111
+ source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
112
+ pad_len = args.max_source_length - len(source_ids)
113
+ source_ids += [tokenizer.pad_id] * pad_len
114
+ target_ids = target_ids[:args.max_target_length - 2]
115
+ target_ids = [tokenizer.bos_id] + target_ids + [tokenizer.eos_id]
116
+ pad_len = args.max_target_length - len(target_ids)
117
+ target_ids += [tokenizer.pad_id] * pad_len
118
+ assert len(source_ids) == args.max_source_length, "Not equal length."
119
+ assert len(target_ids) == args.max_target_length, "Not equal length."
120
+ return source_ids, target_ids
121
+
122
+ def encode_remove(self, tokenizer, text, args):
123
+ text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True)
124
+ if type(tokenizer) == T5Tokenizer:
125
+ return text[:-1]
126
+ elif type(tokenizer) == RobertaTokenizer:
127
+ return text[1:-1]
128
+ elif type(tokenizer) == MyTokenizer:
129
+ return text
130
+ else:
131
+ raise NotImplementedError
132
+
133
+ def __len__(self):
134
+ return len(self.feats)
135
+
136
+ def __getitem__(self, i):
137
+ return self.feats[i]
138
+
139
+ class SimpleRefineDataset(RefineDataset):
140
+ def tokenize(self, item):
141
+ example, tokenizer, args = item
142
+ oldlines = example["old"].split("\n")
143
+ newlines = example["new"].split("\n")
144
+ oldlines = [line[1:].strip() for line in oldlines]
145
+ newlines = [line[1:].strip() for line in newlines]
146
+ oldlines = " ".join(oldlines)
147
+ newlines = " ".join(newlines)
148
+ comment = example["comment"]
149
+ srcids = self.encode_remove(tokenizer, oldlines, args)
150
+ srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args)
151
+ tgtids = self.encode_remove(tokenizer, newlines, args)
152
+ srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer)
153
+ return RefineFeatures(example["id"], srcids, tgtids)
154
+
155
+ @staticmethod
156
+ def process_pred_gold(pred, gold):
157
+ gold = gold.split("\n")
158
+ gold = [line[1:].strip() for line in gold]
159
+ gold = " ".join(gold)
160
+ pred = " ".join(pred.split())
161
+ return pred, gold
162
+
163
+
164
+ class Seq2SeqDataset(RefineDataset):
165
+ def tokenize(self, item):
166
+ example, tokenizer, args = item
167
+ inputs, outputs = example["old"], example["new"]
168
+ inputs = " ".join(inputs.split())
169
+ outputs = " ".join(outputs.split())
170
+ srcids = self.encode_remove(tokenizer, inputs, args)
171
+ tgtids = self.encode_remove(tokenizer, outputs, args)
172
+ srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer)
173
+ return RefineFeatures(example["id"], srcids, tgtids)
174
+
175
+ @staticmethod
176
+ def process_pred_gold(pred, gold):
177
+ gold = " ".join(gold.split())
178
+ pred = " ".join(pred.split())
179
+ return pred, gold
180
+
181
+
182
+ class TextDataset(Dataset):
183
+ def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
184
+ self.cnt = 0
185
+ self.tokenizer = tokenizer
186
+ self.args = args
187
+ if isinstance(tokenizer, MyTokenizer):
188
+ tokenizer_type = "mytok"
189
+ elif isinstance(tokenizer, T5Tokenizer):
190
+ tokenizer_type = ""
191
+ elif isinstance(tokenizer, RobertaTokenizer):
192
+ tokenizer_type = "rb"
193
+ else:
194
+ tokenizer_type = "unk"
195
+ savep = file_path.replace(".jsonl", tokenizer_type + ".exps")
196
+ # savep = "/home/v-zhuoli1/lzzz/processed/chunk_25.exps"
197
+ if os.path.exists(savep):
198
+ logger.info("Loading examples from {}".format(savep))
199
+ examples = torch.load(savep)
200
+ else:
201
+ logger.info("Reading examples from {}".format(file_path))
202
+ examples = read_review_examples(file_path, samplenum, tokenizer)
203
+ logger.info(f"Tokenize examples: {file_path}")
204
+ examples = pool.map(self.tokenize, \
205
+ [(example, tokenizer, args) for example in examples])
206
+ torch.save(examples, savep)
207
+ logger.info("Convert examples to features...")
208
+ self.set_start_end_ids(examples)
209
+ self.featss = pool.map(self.convert_examples_to_features, \
210
+ [(example, tokenizer, args) for example in examples])
211
+ self.feats = [feat for feats in self.featss for feat in feats] # expand the lists
212
+
213
+ def __len__(self):
214
+ return len(self.feats)
215
+
216
+ def __getitem__(self, i):
217
+ return self.feats[i]
218
+
219
+ def reset_len(self, data_len):
220
+ assert len(self.feats) >= data_len
221
+ self.feats = self.feats[:data_len]
222
+
223
+ def set_start_end_ids(self, examples):
224
+ for example in examples:
225
+ labels = example.labels
226
+ start_id = 0
227
+ end_id = len(labels) - 1
228
+ for i, label in enumerate(labels):
229
+ if label != -100: # find the first label
230
+ start_id = i
231
+ break
232
+ for i in range(len(labels) - 1, -1, -1):
233
+ label = labels[i]
234
+ if label != -100:
235
+ end_id = i
236
+ break
237
+ example.start_id = start_id
238
+ example.end_id = end_id
239
+
240
+ def tokenize(self, item):
241
+ example, tokenizer, args = item
242
+ example.input = self.encode_remove(tokenizer, example.input, args)
243
+ e0id = tokenizer.special_dict["<e0>"]
244
+ inputs = " ".join(str(id) for id in example.input)
245
+ lines = inputs.split(" " + str(e0id) + " ")
246
+ lines = [
247
+ [int(v) for v in line.split(" ") if len(v) > 0] for line in lines
248
+ ]
249
+ lens = [len(line) for line in lines]
250
+ # if 0 in lens:
251
+ # logger.info("Warning: empty line in an example.")
252
+ lens = list(map(len, lines))
253
+ curlen = len(lens) + sum(lens)
254
+ left, right = 0, len(lines)
255
+ while curlen > args.max_source_length - 2:
256
+ if left % 2 == 0:
257
+ curlen -= 1 + len(lines[left])
258
+ left += 1
259
+ else:
260
+ right -= 1
261
+ curlen -= 1 + len(lines[right])
262
+ lines = lines[left:right]
263
+ labels = example.labels[left:right]
264
+ assert len(lines) + sum(map(len, lines)) <= args.max_source_length - 2, "Too long inputs in TextDataset.tokenize."
265
+ if len(lines) != len(labels):
266
+ logger.info("Not equal length in TextDataset.tokenize.")
267
+ lines = lines[:len(labels)]
268
+ labels = labels[:len(lines)]
269
+ example.lines = lines
270
+ example.labels = labels
271
+ example.msg = self.encode_remove(tokenizer, example.msg, args)
272
+ return example
273
+
274
+ def convert_examples_to_features(self, item):
275
+ example, _, _ = item
276
+ if len(example.msg) > 0:
277
+ exs = []
278
+ for _ in range(3): # up sampling
279
+ if random.random() < 0.5:
280
+ exs.append(self.genmsg_example(item))
281
+ else:
282
+ exs.append(self.daemsg_example(item))
283
+ return exs
284
+ if random.random() < 0.5:
285
+ return [self.encoder_example(item)]
286
+ return [self.decoder_example(item)]
287
+
288
+ def encoder_example(self, item):
289
+ example, tokenizer, args = item
290
+ lines = example.lines
291
+ labels = example.labels
292
+ target_ids = [tokenizer.pad_id] * args.max_target_length
293
+ source_ids, input_labels = [], []
294
+ for i, (line, label) in enumerate(zip(lines, labels)):
295
+ if i == example.start_id:
296
+ source_ids.append(tokenizer.start_id)
297
+ input_labels.append(-100)
298
+ if label != -100: # only insert special tokens at diffs, not context
299
+ source_ids.append(tokenizer.mask_id)
300
+ input_labels.append(label)
301
+ source_ids.extend(line)
302
+ input_labels.extend([-100] * len(line))
303
+ if i == example.end_id:
304
+ source_ids.append(tokenizer.end_id)
305
+ input_labels.append(-100)
306
+ assert len(input_labels) == len(source_ids), "Not equal length."
307
+ assert len(input_labels) <= args.max_source_length, f"Too long inputs: {len(input_labels)}."
308
+ source_ids = source_ids[:args.max_source_length - 2]
309
+ input_labels = input_labels[:args.max_source_length - 2]
310
+ source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
311
+ input_labels = [-100] + input_labels + [-100]
312
+ pad_len = args.max_source_length - len(source_ids)
313
+ source_ids += [tokenizer.pad_id] * pad_len
314
+ input_labels += [-100] * pad_len
315
+
316
+ new_input_labels = []
317
+ map_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id}
318
+ for label in input_labels:
319
+ if label == -100:
320
+ new_input_labels.append(-100)
321
+ else:
322
+ new_input_labels.append(map_dict[label])
323
+ input_labels = new_input_labels
324
+ assert len(source_ids) == args.max_source_length, "Not equal length."
325
+ assert len(input_labels) == args.max_source_length, "Not equal length."
326
+ return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="label")
327
+
328
+ def decoder_example(self, item):
329
+ example, tokenizer, args = item
330
+ lines = example.lines
331
+ labels = example.labels
332
+
333
+ input_labels = [-100] * args.max_source_length
334
+ source_ids, target_ids = [], []
335
+ SPECIAL_ID = 0
336
+ mask_idxs = random.choices(range(len(lines)), k=int(len(lines) * args.mask_rate))
337
+ id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id}
338
+ for i, (line, label) in enumerate(zip(lines, labels)):
339
+ if i == example.start_id:
340
+ source_ids.append(tokenizer.start_id)
341
+ if label in id_dict:
342
+ source_ids.append(id_dict[label])
343
+ if i in mask_idxs:
344
+ source_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
345
+ target_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
346
+ target_ids.extend(line)
347
+ if SPECIAL_ID < 99: # only 0-99 ids in vocab
348
+ SPECIAL_ID += 1
349
+ else:
350
+ source_ids.extend(line)
351
+ if i == example.end_id:
352
+ source_ids.append(tokenizer.end_id)
353
+ source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
354
+ return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="line")
355
+
356
+ def genmsg_example(self, item):
357
+ example, tokenizer, args = item
358
+ lines = example.lines
359
+ labels = example.labels
360
+ input_labels = [-100] * args.max_source_length
361
+ source_ids, target_ids = [], []
362
+ id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id}
363
+ for i, (line, label) in enumerate(zip(lines, labels)):
364
+ if i == example.start_id:
365
+ source_ids.append(tokenizer.start_id)
366
+ if label != -100:
367
+ source_ids.append(id_dict[label])
368
+ source_ids.extend(line)
369
+ if i == example.end_id:
370
+ source_ids.append(tokenizer.end_id)
371
+ target_ids.append(tokenizer.msg_id)
372
+ target_ids.extend(example.msg)
373
+ assert len(source_ids) <= args.max_source_length, f"Too long inputs: {len(source_ids)}."
374
+ source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
375
+ return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="genmsg")
376
+
377
+ def daemsg_example(self, item):
378
+ example, tokenizer, args = item
379
+ input_labels = [-100] * args.max_source_length
380
+ source_ids, target_ids = [], []
381
+ msg_ids = cp(example.msg)
382
+ masks = [random.random() < 0.20 for _ in range(len(msg_ids))]
383
+ if sum(masks) == 0:
384
+ idx = random.choice(range(len(msg_ids)))
385
+ masks[idx] = True
386
+ source_ids, target_ids = [], []
387
+ i = 0
388
+ SPECIAL_ID = 0
389
+ while i < len(masks):
390
+ j = i
391
+ while j < len(masks) and not masks[j]:
392
+ source_ids.append(msg_ids[j])
393
+ j += 1
394
+ if j == len(masks):
395
+ break
396
+ source_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
397
+ target_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
398
+ while j < len(masks) and masks[j]:
399
+ target_ids.append(msg_ids[j])
400
+ j += 1
401
+ if SPECIAL_ID < 99: # only 0-99 ids in vocab
402
+ SPECIAL_ID += 1
403
+ i = j
404
+ source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
405
+ return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="daemsg")
406
+
407
+ def pad_assert(self, source_ids, target_ids, args, tokenizer):
408
+ source_ids = source_ids[:args.max_source_length - 2]
409
+ source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
410
+ pad_len = args.max_source_length - len(source_ids)
411
+ source_ids += [tokenizer.pad_id] * pad_len
412
+ target_ids = target_ids[:args.max_target_length - 1]
413
+ target_ids = target_ids + [tokenizer.eos_id]
414
+ pad_len = args.max_target_length - len(target_ids)
415
+ target_ids += [tokenizer.pad_id] * pad_len
416
+ assert len(source_ids) == args.max_source_length, "Not equal length."
417
+ assert len(target_ids) == args.max_target_length, "Not equal length."
418
+ return source_ids, target_ids
419
+
420
+ def encode_remove(self, tokenizer, text, args):
421
+ text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True)
422
+ if type(tokenizer) == T5Tokenizer:
423
+ return text[:-1]
424
+ elif type(tokenizer) == RobertaTokenizer:
425
+ return text[1:-1]
426
+ elif type(tokenizer) == MyTokenizer:
427
+ return text
428
+ else:
429
+ raise NotImplementedError
430
+
431
+
432
+ class CommentGenDataset(TextDataset):
433
+ def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
434
+ self.tokenizer = tokenizer
435
+ if isinstance(tokenizer, MyTokenizer):
436
+ tokenizer_type = "mytok"
437
+ elif isinstance(tokenizer, T5Tokenizer):
438
+ tokenizer_type = ""
439
+ elif isinstance(tokenizer, RobertaTokenizer):
440
+ tokenizer_type = "rb"
441
+ else:
442
+ tokenizer_type = "unk"
443
+ savep = file_path.replace(".jsonl", tokenizer_type + ".exps")
444
+ if os.path.exists(savep):
445
+ logger.info("Loading examples from {}".format(savep))
446
+ examples = torch.load(savep)
447
+ else:
448
+ logger.info("Reading examples from {}".format(file_path))
449
+ examples = read_review_examples(file_path, samplenum, tokenizer)
450
+ # for i in range(len(examples)):
451
+ # examples[i].msg = " ".join(nltk.word_tokenize(examples[i].msg))
452
+ logger.info(f"Tokenize examples: {file_path}")
453
+ examples = pool.map(self.tokenize, \
454
+ [(example, tokenizer, args) for example in examples])
455
+ torch.save(examples, savep)
456
+ logger.info("Convert examples to features...")
457
+ self.set_start_end_ids(examples)
458
+ self.feats = pool.map(self.convert_examples_to_features, \
459
+ [(example, tokenizer, args) for example in examples])
460
+ self.feats = [feat for feat in self.feats if feat is not None]
461
+
462
+ def convert_examples_to_features(self, item):
463
+ example, tokenizer, args = item
464
+ if len(example.msg) == 0:
465
+ return None
466
+ return self.genmsg_example(item)
467
+
468
+
469
+ class CommentClsDataset(TextDataset):
470
+ def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
471
+ self.tokenizer = tokenizer
472
+ if isinstance(tokenizer, MyTokenizer):
473
+ tokenizer_type = "mytok"
474
+ elif isinstance(tokenizer, T5Tokenizer):
475
+ tokenizer_type = ""
476
+ elif isinstance(tokenizer, RobertaTokenizer):
477
+ tokenizer_type = "rb"
478
+ else:
479
+ tokenizer_type = "unk"
480
+ savep = file_path.replace(".jsonl", tokenizer_type + ".exps")
481
+ if os.path.exists(savep):
482
+ logger.info("Loading examples from {}".format(savep))
483
+ examples = torch.load(savep)
484
+ else:
485
+ logger.info("Reading examples from {}".format(file_path))
486
+ examples = read_review_examples(file_path, samplenum, tokenizer)
487
+ logger.info(f"Tokenize examples: {file_path}")
488
+ examples = pool.map(self.tokenize, \
489
+ [(example, tokenizer, args) for example in examples])
490
+ torch.save(examples, savep)
491
+ logger.info("Convert examples to features...")
492
+ self.set_start_end_ids(examples)
493
+ self.feats = pool.map(self.convert_examples_to_features, \
494
+ [(example, tokenizer, args) for example in examples])
495
+
496
+ def convert_examples_to_features(self, item):
497
+ example, tokenizer, args = item
498
+ tmpfeature = self.genmsg_example(item)
499
+ return ClsFeatures(tmpfeature.example_id, tmpfeature.source_ids, example.y)
500
+
501
+
502
+ class SimpleClsDataset(TextDataset):
503
+ def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
504
+ self.tokenizer = tokenizer
505
+ if isinstance(tokenizer, MyTokenizer):
506
+ tokenizer_type = "mytok"
507
+ elif isinstance(tokenizer, T5Tokenizer):
508
+ tokenizer_type = ""
509
+ elif isinstance(tokenizer, RobertaTokenizer):
510
+ tokenizer_type = "rb"
511
+ else:
512
+ tokenizer_type = "unk"
513
+ savep = file_path.replace(".jsonl", tokenizer_type + ".simpexps")
514
+ if os.path.exists(savep):
515
+ logger.info("Loading examples from {}".format(savep))
516
+ self.feats = torch.load(savep)
517
+ else:
518
+ logger.info("Reading examples from {}".format(file_path))
519
+ examples = read_review_examples(file_path, samplenum, tokenizer)
520
+ logger.info(f"Tokenize examples: {file_path}")
521
+ self.feats = pool.map(self.convert_examples_to_features, \
522
+ [(example, tokenizer, args) for example in examples])
523
+ torch.save(self.feats, savep)
524
+
525
+ def convert_examples_to_features(self, item):
526
+ example, tokenizer, args = item
527
+ example.input_lines = example.input.split("<e0>")
528
+ labels_l = len(example.labels)
529
+ example.input_lines = example.input_lines[:labels_l]
530
+ for i in range(len(example.input_lines)):
531
+ if example.labels[i] == 1:
532
+ example.input_lines[i] = "+ " + example.input_lines[i]
533
+ elif example.labels[i] == 0:
534
+ example.input_lines[i] = "- " + example.input_lines[i]
535
+ example.input = " ".join(example.input_lines)
536
+ input_ids = self.encode_remove(tokenizer, example.input, args)
537
+ exceed_l = len(input_ids) - args.max_source_length + 2
538
+ if exceed_l > 0:
539
+ halfexl = (exceed_l + 1) // 2
540
+ input_ids = input_ids[halfexl:-halfexl]
541
+ source_ids = input_ids[:args.max_source_length - 2]
542
+ source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
543
+ pad_len = args.max_source_length - len(source_ids)
544
+ source_ids += [tokenizer.pad_id] * pad_len
545
+ example_id = example.idx
546
+ y = example.y
547
+ return ClsFeatures(example_id, source_ids, y)
548
+
549
+
550
+ class SimpleGenDataset(TextDataset):
551
+ def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
552
+ self.tokenizer = tokenizer
553
+ if isinstance(tokenizer, MyTokenizer):
554
+ tokenizer_type = "mytok"
555
+ elif isinstance(tokenizer, T5Tokenizer):
556
+ tokenizer_type = ""
557
+ elif isinstance(tokenizer, RobertaTokenizer):
558
+ tokenizer_type = "rb"
559
+ else:
560
+ tokenizer_type = "unk"
561
+ savep = file_path.replace(".jsonl", tokenizer_type + ".simpgenexps")
562
+ if os.path.exists(savep):
563
+ logger.info("Loading examples from {}".format(savep))
564
+ self.feats = torch.load(savep)
565
+ else:
566
+ logger.info("Reading examples from {}".format(file_path))
567
+ data = read_jsonl(file_path)
568
+ # data = [dic for dic in data if len(dic["patch"].split("\n")) <= 20]
569
+ for i in range(len(data)):
570
+ data[i]["idx"] = i
571
+ logger.info(f"Tokenize examples: {file_path}")
572
+ # self.feats = pool.map(self.convert_examples_to_features, \
573
+ # [(dic, tokenizer, args) for dic in data])
574
+ self.feats = [self.convert_examples_to_features((dic, tokenizer, args)) for dic in data]
575
+ torch.save(self.feats, savep)
576
+
577
+ def convert_examples_to_features(self, item):
578
+ dic, tokenizer, args = item
579
+ diff, msg = dic["patch"], dic["msg"]
580
+ difflines = diff.split("\n")[1:] # remove start @@
581
+ difflines = [line for line in difflines if len(line.strip()) > 0]
582
+ map_dic = {"-": 0, "+": 1, " ": 2}
583
+ def f(s):
584
+ if s in map_dic:
585
+ return map_dic[s]
586
+ else:
587
+ return 2
588
+ labels = [f(line[0]) for line in difflines]
589
+ difflines = [line[1:].strip() for line in difflines]
590
+ inputstr = ""
591
+ for label, line in zip(labels, difflines):
592
+ if label == 1:
593
+ inputstr += "<add>" + line
594
+ elif label == 0:
595
+ inputstr += "<del>" + line
596
+ else:
597
+ inputstr += "<keep>" + line
598
+ source_ids = self.encode_remove(tokenizer, inputstr, args)
599
+ target_ids = []
600
+ target_ids.append(tokenizer.msg_id)
601
+ msg = self.encode_remove(tokenizer, dic["msg"], args)
602
+ target_ids.extend(msg)
603
+ source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
604
+ input_labels = [-100] * len(source_ids)
605
+ return ReviewFeatures(dic["idx"], source_ids, input_labels, target_ids, type="genmsg")
606
+
607
+
608
+ class InputFeatures(object):
609
+ """A single training/test features for a example."""
610
+
611
+ def __init__(self, example_id, source_ids, target_ids, url=None):
612
+ self.example_id = example_id
613
+ self.source_ids = source_ids
614
+ self.target_ids = target_ids
615
+ self.url = url
616
+
617
+
618
+ class ReviewFeatures(object):
619
+ def __init__(self, example_id, source_ids, source_labels, target_ids, type):
620
+ self.example_id = example_id
621
+ self.source_ids = source_ids
622
+ self.source_labels = source_labels
623
+ self.target_ids = target_ids
624
+ assert type in ("label", "line", "genmsg", "daemsg")
625
+ self.type = type
626
+
627
+ class ClsFeatures(object):
628
+ def __init__(self, example_id, source_ids, y):
629
+ self.example_id = example_id
630
+ self.source_ids = source_ids
631
+ self.y = y
632
+
633
+ class ReviewExample(object):
634
+ """A single training/test example."""
635
+
636
+ def __init__(
637
+ self, idx, oldf, diff, msg, cmtid, max_len, y
638
+ ):
639
+ self.idx = idx # idx is useless yet
640
+ self.oldf = oldf
641
+ self.diff = diff
642
+ self.msg = msg
643
+ self.cmtid = cmtid
644
+ self.max_len = max_len
645
+ self.y = y
646
+ self.prevlines = []
647
+ self.afterlines = []
648
+ self.lines = []
649
+ self.labels = []
650
+ self.avail = False
651
+ self.input = ""
652
+ self.align_and_clean()
653
+ self.postprocess()
654
+
655
+ def postprocess(self):
656
+ if not self.avail:
657
+ return
658
+ # Warning: lines is not self.lines
659
+ # lines for rough length estimation
660
+ lines = [source_str.split() for source_str in self.lines]
661
+ inputl = len(lines) # line tag
662
+ inputl += sum(map(len, lines))
663
+ left, right = 0, len(lines)
664
+ while inputl > self.max_len:
665
+ if left % 2 == 0:
666
+ inputl -= len(lines[left]) + 1
667
+ left += 1
668
+ else:
669
+ right -= 1
670
+ inputl -= len(lines[right]) + 1
671
+ lines = lines[left:right]
672
+ self.lines = self.lines[left:right]
673
+ self.labels = self.labels[left:right]
674
+ prevlines = self.prevlines
675
+ afterlines = self.afterlines
676
+ prev_after_len = max(len(prevlines), len(afterlines))
677
+ i = 0
678
+ while inputl < self.max_len and i < prev_after_len:
679
+ if i < len(prevlines):
680
+ newl = inputl + len(prevlines[-1-i].split()) + 1
681
+ if newl > self.max_len:
682
+ break
683
+ self.lines.insert(0, prevlines[-1-i])
684
+ self.labels.insert(0, -100)
685
+ inputl = newl # tag
686
+ if i < len(afterlines):
687
+ newl = inputl + len(afterlines[i].split()) + 1
688
+ if newl > self.max_len:
689
+ break
690
+ self.lines.append(afterlines[i])
691
+ self.labels.append(-100)
692
+ inputl = newl # tag
693
+ i += 1
694
+ assert inputl <= self.max_len, "Too long inputs."
695
+ assert len(self.lines) == len(self.labels), "Not equal length."
696
+ self.input = "<e0>".join(self.lines)
697
+ self.prevlines, self.lines, self.afterlines = [], [], []
698
+
699
+ def remove_space_clean(self, line):
700
+ """
701
+ Remove start and end empty chars.
702
+ """
703
+ rep = " \t\r"
704
+ totallen = len(line)
705
+ i = 0
706
+ while i < totallen and line[i] in rep:
707
+ i += 1
708
+ j = totallen - 1
709
+ while j >= 0 and line[j] in rep:
710
+ j -= 1
711
+ line = line[i : j + 1]
712
+ return line
713
+
714
+ def align_and_clean(self):
715
+ oldflines = self.oldf.split("\n")
716
+ difflines = self.diff.split("\n")
717
+ first_line = difflines[0]
718
+ difflines = difflines[1:]
719
+ difflines = [line for line in difflines if line != r""]
720
+ regex = r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@"
721
+ matchres = re.match(regex, first_line)
722
+ if matchres:
723
+ startline, rangelen, startpos, endpos = matchres.groups()
724
+ self.avail = True
725
+ else:
726
+ self.avail = False
727
+ return
728
+ startline, rangelen = int(startline) - 1, int(rangelen)
729
+ endline = startline + rangelen
730
+ self.prevlines = oldflines[:startline]
731
+ self.afterlines = oldflines[endline:]
732
+ for line in difflines:
733
+ if line.startswith("-"):
734
+ self.lines.append(line[1:])
735
+ self.labels.append(0)
736
+ elif line.startswith("+"):
737
+ self.lines.append(line[1:])
738
+ self.labels.append(1)
739
+ else:
740
+ self.lines.append(line)
741
+ self.labels.append(2)
742
+ self.prevlines = [self.remove_space_clean(line) for line in self.prevlines]
743
+ self.afterlines = [self.remove_space_clean(line) for line in self.afterlines]
744
+ self.lines = [self.remove_space_clean(line) for line in self.lines]
745
+ self.msg = self.remove_space_clean(self.msg)
746
+ self.prevlines = [line for line in self.prevlines if len(line) > 0]
747
+ self.afterlines = [line for line in self.afterlines if len(line) > 0]
748
+ # print("\n".join(self.prevlines))
749
+ # print("\n\n\n\n")
750
+ # print("\n".join(self.lines))
751
+ # print("\n\n\n\n")
752
+ # print("\n".join(self.afterlines))
753
+ # print("\n\n\n\n")
754
+ assert len(self.lines) == len(self.labels), "Not equal length in align."
755
+ topack = list(
756
+ zip(
757
+ *[
758
+ (line, label)
759
+ for line, label in zip(self.lines, self.labels)
760
+ if len(line) > 0
761
+ ]
762
+ )
763
+ )
764
+ if topack == []:
765
+ self.avail = False
766
+ return
767
+ else:
768
+ self.lines, self.labels = topack
769
+ # tuple->list, convenient for later operation
770
+ self.lines = list(self.lines)
771
+ self.labels = list(self.labels)
772
+
773
+
774
+ def read_review_examples(filename, data_num=-1, tokenizer=None):
775
+ """Read examples from filename."""
776
+ examples = []
777
+ idx = 0
778
+ with open(filename) as f:
779
+ for line in f:
780
+ try:
781
+ js = json.loads(line.strip())
782
+ except:
783
+ print("Error during reading json data.")
784
+ continue
785
+ maxl = 200
786
+ if "y" not in js:
787
+ js["y"] = 0
788
+ if "msg" in js and len(js["msg"]) > 0:
789
+ js["y"] = 1
790
+ example = ReviewExample(
791
+ idx=idx,
792
+ oldf=js["oldf"],
793
+ diff=js["patch"],
794
+ msg=js["msg"] if "msg" in js else "",
795
+ cmtid=js["cmtid"] if "cmtid" in js else "",
796
+ max_len=maxl,
797
+ y=js["y"]
798
+ )
799
+ if example.avail:
800
+ examples.append(example)
801
+ idx += 1
802
+ if idx == data_num:
803
+ break
804
+ else:
805
+ # print(f"Passing {idx} because of invalid diff.")
806
+ idx += 1
807
+ if idx == data_num:
808
+ break
809
+
810
+ return examples
811
+
812
+
813
+ def read_jsonl(path):
814
+ data = []
815
+ with open(path) as f:
816
+ for line in f:
817
+ try:
818
+ js = json.loads(line.strip())
819
+ except:
820
+ print("Error during reading json data.")
821
+ continue
822
+ data.append(js)
823
+ return data