unknown
commited on
Commit
·
ac3312e
1
Parent(s):
3e0611d
Initial
Browse files- Scripts/UnixCoder/model_gen.py +0 -31
- Scripts/UnixCoder/run_one_model.py +1 -101
- run_fine_tuning.sh +1 -1
Scripts/UnixCoder/model_gen.py
CHANGED
@@ -56,9 +56,6 @@ class Seq2Seq(nn.Module):
|
|
56 |
mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
|
57 |
encoder_output = self.encoder(
|
58 |
source_ids, attention_mask=mask, use_cache=True)
|
59 |
-
# print("source_ids:", source_ids.size()) # torch.Size([56, 510])
|
60 |
-
# print("exist:", exist.size()) # torch.Size([56, 1])
|
61 |
-
# print("target_ids:", target_ids.size()) # torch.Size([56, 240])
|
62 |
ids = torch.cat((source_ids, target_ids), -1)
|
63 |
|
64 |
mask = self.bias[:,
|
@@ -68,33 +65,15 @@ class Seq2Seq(nn.Module):
|
|
68 |
out = self.decoder(target_ids, attention_mask=mask,
|
69 |
past_key_values=encoder_output.past_key_values).last_hidden_state
|
70 |
|
71 |
-
# 先concat 再池化
|
72 |
-
# print("out:", out.size()) # torch.Size([56, 240, 768])
|
73 |
-
|
74 |
lm_logits = self.lm_head(out[..., 1:, :])
|
75 |
-
# print("lm_logits:", lm_logits.size()) # torch.Size([56, 239, 51416])
|
76 |
-
|
77 |
# Shift so that tokens < n predict n
|
78 |
active_loss = target_ids[..., 2:].ne(1).view(-1)
|
79 |
-
# print("active_loss:", active_loss.size()) # torch.Size([13328])
|
80 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
81 |
-
# print("shift_logits:", shift_logits.size()) # torch.Size([56, 238, 51416])
|
82 |
-
|
83 |
shift_labels = target_ids[..., 2:].contiguous()
|
84 |
-
# print("shift_labels:", shift_labels.size()) # torch.Size([56, 238])
|
85 |
|
86 |
exist_labels = exist.contiguous()
|
87 |
-
# print("exist_labels:", exist_labels.size()) # torch.Size([56, 1])
|
88 |
-
|
89 |
-
# print("shift_logits.size:", shift_logits.size(-1)) # 51416
|
90 |
-
# print("shift_logits.view(-1, shift_logits.size(-1)):", shift_logits.view(-1, shift_logits.size(-1))[active_loss].size()) # torch.Size([614, 51416])
|
91 |
-
# print("shift_labels.view(-1):", shift_labels.view(-1)[active_loss].size()) # torch.Size([614])
|
92 |
-
|
93 |
pred_out = out[..., 0, :]
|
94 |
-
# print("pred_out:", pred_out.size()) # torch.Size([56, 768])
|
95 |
pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
|
96 |
-
# print("pred_sigmoid:", pred_sigmoid.size()) # torch.Size([56, 1])
|
97 |
-
|
98 |
# Flatten the tokens
|
99 |
loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
|
100 |
loss_fct_pred = nn.MSELoss(reduction="mean")
|
@@ -103,8 +82,6 @@ class Seq2Seq(nn.Module):
|
|
103 |
|
104 |
loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
|
105 |
loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight
|
106 |
-
# loss = loss.to(torch.float32)
|
107 |
-
# loss = loss_pred
|
108 |
|
109 |
outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
|
110 |
return outputs
|
@@ -135,10 +112,7 @@ class Seq2Seq(nn.Module):
|
|
135 |
mask = mask & ids[:, None, :].ne(1)
|
136 |
out = self.decoder(input_ids, attention_mask=mask,
|
137 |
past_key_values=context).last_hidden_state
|
138 |
-
# print("out:", out.size())
|
139 |
-
# concat 池化 out
|
140 |
hidden_states = out[:, -1, :]
|
141 |
-
# print("hidden_states:", hidden_states.size())
|
142 |
if out.size(1) == 1:
|
143 |
pred_sigmoid = self.sigmoid(self.pred_dense(
|
144 |
hidden_states.view(-1, 1, hidden_states.size(-1))))
|
@@ -155,14 +129,9 @@ class Seq2Seq(nn.Module):
|
|
155 |
pred = [torch.cat([x.view(-1) for x in p] + [zero] *
|
156 |
(self.max_length-len(p))).view(1, -1) for p in pred]
|
157 |
predicates.append(predicate[0][0])# ZM modified
|
158 |
-
#print("ZM-Model_Debug_P_Each_Itr: %d, %d, %d" % (len(predicate), len(predicate[0]), len(predicate[0][0])))
|
159 |
preds.append(torch.cat(pred, 0).unsqueeze(0))
|
160 |
-
#print("ZM-Model_Debug_Predicate_Shape: %d" % (len(predicates)))
|
161 |
-
#print("ZM-Model_Debug_Codes_BeforeCat: %d, %d, %d, %d" % (len(preds), len(preds[0]), len(preds[0][0]), len(preds[0][0][0])))
|
162 |
preds = torch.cat(preds, 0)
|
163 |
predicates = torch.tensor(predicates, device="cuda")# ZM modified
|
164 |
-
# predicates = torch.cat(predicates, 0).unsqueeze(0)
|
165 |
-
#print("ZM-Model_Debug_Codes_AfterCat: %d, %d, %d" % (len(preds), len(preds[0]), len(preds[0][0])))
|
166 |
return preds, predicates
|
167 |
|
168 |
|
|
|
56 |
mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
|
57 |
encoder_output = self.encoder(
|
58 |
source_ids, attention_mask=mask, use_cache=True)
|
|
|
|
|
|
|
59 |
ids = torch.cat((source_ids, target_ids), -1)
|
60 |
|
61 |
mask = self.bias[:,
|
|
|
65 |
out = self.decoder(target_ids, attention_mask=mask,
|
66 |
past_key_values=encoder_output.past_key_values).last_hidden_state
|
67 |
|
|
|
|
|
|
|
68 |
lm_logits = self.lm_head(out[..., 1:, :])
|
|
|
|
|
69 |
# Shift so that tokens < n predict n
|
70 |
active_loss = target_ids[..., 2:].ne(1).view(-1)
|
|
|
71 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
|
|
|
72 |
shift_labels = target_ids[..., 2:].contiguous()
|
|
|
73 |
|
74 |
exist_labels = exist.contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
pred_out = out[..., 0, :]
|
|
|
76 |
pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
|
|
|
|
|
77 |
# Flatten the tokens
|
78 |
loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
|
79 |
loss_fct_pred = nn.MSELoss(reduction="mean")
|
|
|
82 |
|
83 |
loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
|
84 |
loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight
|
|
|
|
|
85 |
|
86 |
outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
|
87 |
return outputs
|
|
|
112 |
mask = mask & ids[:, None, :].ne(1)
|
113 |
out = self.decoder(input_ids, attention_mask=mask,
|
114 |
past_key_values=context).last_hidden_state
|
|
|
|
|
115 |
hidden_states = out[:, -1, :]
|
|
|
116 |
if out.size(1) == 1:
|
117 |
pred_sigmoid = self.sigmoid(self.pred_dense(
|
118 |
hidden_states.view(-1, 1, hidden_states.size(-1))))
|
|
|
129 |
pred = [torch.cat([x.view(-1) for x in p] + [zero] *
|
130 |
(self.max_length-len(p))).view(1, -1) for p in pred]
|
131 |
predicates.append(predicate[0][0])# ZM modified
|
|
|
132 |
preds.append(torch.cat(pred, 0).unsqueeze(0))
|
|
|
|
|
133 |
preds = torch.cat(preds, 0)
|
134 |
predicates = torch.tensor(predicates, device="cuda")# ZM modified
|
|
|
|
|
135 |
return preds, predicates
|
136 |
|
137 |
|
Scripts/UnixCoder/run_one_model.py
CHANGED
@@ -53,7 +53,6 @@ class Example(object):
|
|
53 |
vec,
|
54 |
exist,
|
55 |
module
|
56 |
-
# propertyposition,
|
57 |
):
|
58 |
self.idx = idx
|
59 |
self.source = source
|
@@ -77,8 +76,6 @@ def read_examples_no_bracket(filename, is_function_test):
|
|
77 |
break
|
78 |
line = line.strip()
|
79 |
js = json.loads(line)
|
80 |
-
if idx > 1000:
|
81 |
-
break
|
82 |
if js["Stmt"].strip()[0] == "}":
|
83 |
continue
|
84 |
if js["Value"].strip().lower() == "nothing" and '#' in js['FIR']:
|
@@ -119,11 +116,6 @@ def read_examples_no_bracket(filename, is_function_test):
|
|
119 |
mod = ""
|
120 |
if "Module" in js.keys():
|
121 |
mod = js["Module"]
|
122 |
-
# propos = ' '.join(js['pp'])
|
123 |
-
# propos = ' '.join(propos.strip().split(','))
|
124 |
-
# print(code)
|
125 |
-
# print(nl)
|
126 |
-
# print(pro)
|
127 |
examples.append(
|
128 |
Example(
|
129 |
idx=idx,
|
@@ -152,8 +144,6 @@ def read_examples(filename, is_function_test):
|
|
152 |
break
|
153 |
line = line.strip()
|
154 |
js = json.loads(line)
|
155 |
-
if idx > 3000:
|
156 |
-
break
|
157 |
if 'idx' not in js:
|
158 |
js['idx'] = idx
|
159 |
code = ' '.join(js['FIR_token']).replace('\n', ' ')
|
@@ -188,11 +178,6 @@ def read_examples(filename, is_function_test):
|
|
188 |
mod = ""
|
189 |
if "Module" in js.keys():
|
190 |
mod = js["Module"]
|
191 |
-
# propos = ' '.join(js['pp'])
|
192 |
-
# propos = ' '.join(propos.strip().split(','))
|
193 |
-
# print(code)
|
194 |
-
# print(nl)
|
195 |
-
# print(pro)
|
196 |
examples.append(
|
197 |
Example(
|
198 |
idx=idx,
|
@@ -233,7 +218,7 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
|
|
233 |
# source
|
234 |
func_tokens = tokenizer.tokenize(example.funcname)
|
235 |
source_tokens = tokenizer.tokenize(
|
236 |
-
example.source)
|
237 |
pro_tokens = tokenizer.tokenize(example.property)
|
238 |
vec_tokens = example.vec
|
239 |
source_tokens = [tokenizer.cls_token, "<encoder-decoder>", tokenizer.sep_token, "<mask0>"] + func_tokens + [tokenizer.cls_token] + \
|
@@ -243,8 +228,6 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
|
|
243 |
padding_length = args.max_source_length - len(source_ids)
|
244 |
source_ids += [tokenizer.pad_token_id] * padding_length
|
245 |
|
246 |
-
# target
|
247 |
-
# if stage=="test":
|
248 |
target_tokens = tokenizer.tokenize(example.target)
|
249 |
exist = [example.exist]
|
250 |
target_tokens = [tokenizer.cls_token, "<mask0>"] + \
|
@@ -252,13 +235,6 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
|
|
252 |
target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
|
253 |
padding_length = args.max_target_length - len(target_ids)
|
254 |
target_ids += [tokenizer.pad_token_id] * padding_length
|
255 |
-
# else:
|
256 |
-
# target_tokens = tokenizer.tokenize(example.target)
|
257 |
-
# exist_tokens = tokenizer.tokenize(example.exist)
|
258 |
-
# target_tokens = ["<mask0>"] + exist_tokens + [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
|
259 |
-
# target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
|
260 |
-
# padding_length = args.max_target_length - len(target_ids)
|
261 |
-
# target_ids += [tokenizer.pad_token_id] * padding_length
|
262 |
|
263 |
features.append(
|
264 |
InputFeatures(
|
@@ -470,14 +446,7 @@ def vega_train_main():
|
|
470 |
total_eval_all = len(eval_examples_all)
|
471 |
patience, best_acc, losses, dev_dataset = 0, 0, [], {}
|
472 |
for epoch in tqdm(range(args.num_train_epochs)):
|
473 |
-
# print(args.num_train_epochs)
|
474 |
-
|
475 |
for idx, batch in enumerate(train_dataloader):
|
476 |
-
# print("##########Debug################")
|
477 |
-
# print(idx)
|
478 |
-
# print("###############Debug###########")
|
479 |
-
# if idx > 100:
|
480 |
-
# break
|
481 |
batch = tuple(t.to(device) for t in batch)
|
482 |
source_ids, exist, target_ids = batch
|
483 |
loss, _, _, mse_loss, ce_loss = model(
|
@@ -572,9 +541,7 @@ def vega_train_main():
|
|
572 |
# convert ids to text
|
573 |
for pred, predicate in zip(preds, predicates):
|
574 |
t = pred[0].cpu().numpy()
|
575 |
-
#p = predicate[0].cpu().numpy()
|
576 |
p = predicate.float().item()
|
577 |
-
#print("ZM_Debug -- ppp: " + str(p))
|
578 |
t = list(t)
|
579 |
#p = list(p)
|
580 |
tem_i = 0
|
@@ -608,7 +575,6 @@ def vega_train_main():
|
|
608 |
cnt_iteration += 1
|
609 |
pred = ref[0].strip()
|
610 |
predicate = ref[1]
|
611 |
-
#print("ZM_Debug -- predicate: " + str(predicate))
|
612 |
if gold.property.strip().lower() != "nothing":
|
613 |
predicate = 1.0
|
614 |
else:
|
@@ -626,7 +592,6 @@ def vega_train_main():
|
|
626 |
|
627 |
|
628 |
if pred == gt_pred and int(round(predicate)) == int(round(gt_predicate)):
|
629 |
-
#print("Total correct, Inside this place")
|
630 |
EM = EM + 1.0
|
631 |
EM_V = EM_V + 1.0
|
632 |
EM_P = EM_P + 1.0
|
@@ -646,43 +611,16 @@ def vega_train_main():
|
|
646 |
|
647 |
model_predicate.append(predicate)
|
648 |
groundtruth_predicate.append(gt_predicate)
|
649 |
-
# if len(pred.split(tokenizer.cls_token)) >= 2:
|
650 |
-
# if pred.split(tokenizer.cls_token)[0].strip() == gt_pred.split(tokenizer.cls_token)[0].strip():
|
651 |
-
# EM_P += 1
|
652 |
-
# if pred.split(tokenizer.cls_token)[1].strip() == gt_pred.split(tokenizer.cls_token)[1].strip():
|
653 |
-
# EM_V += 1
|
654 |
-
# MAE_P = mean_absolute_error(
|
655 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
656 |
-
# MSE_P = mean_squared_error(
|
657 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
658 |
-
# RMSE_P = np.sqrt(MSE_P)
|
659 |
dev_acc = round((100*EM/total), 2)
|
660 |
dev_acc_v = round((100*EM_V/total), 2)
|
661 |
dev_acc_p = round((100*EM_P/total), 2)
|
662 |
logger.info(" %s = %s " % ("Current Acc", str(dev_acc)))
|
663 |
-
#logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
664 |
logger.info(" "+"*"*20)
|
665 |
logger.info(" %s = %s " % ("Current Acc V", str(dev_acc_v)))
|
666 |
-
#logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
667 |
logger.info(" "+"*"*20)
|
668 |
logger.info(" %s = %s " % ("Current Acc P", str(dev_acc_p)))
|
669 |
-
#logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
670 |
logger.info(" "+"*"*20)
|
671 |
-
# logger.info(" %s = %s " %
|
672 |
-
# ("Current MAE P", str(round(MAE_P, 2))))
|
673 |
-
# #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
674 |
-
# logger.info(" "+"*"*20)
|
675 |
-
# logger.info(" %s = %s " %
|
676 |
-
# ("Current MSE P", str(round(MSE_P, 2))))
|
677 |
-
# #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
678 |
-
# logger.info(" "+"*"*20)
|
679 |
-
# logger.info(" %s = %s " %
|
680 |
-
# ("Current RMSE P", str(round(RMSE_P, 2))))
|
681 |
-
# #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
|
682 |
-
# logger.info(" "+"*"*20)
|
683 |
if dev_acc > best_acc:
|
684 |
-
#logger.info(" Best acc:%s",dev_acc)
|
685 |
-
#logger.info(" "+"*"*20)
|
686 |
best_acc = dev_acc
|
687 |
# Save best checkpoint for best bleu
|
688 |
output_dir = os.path.join(
|
@@ -694,15 +632,6 @@ def vega_train_main():
|
|
694 |
output_model_file = os.path.join(
|
695 |
output_dir, "pytorch_model.bin")
|
696 |
torch.save(model_to_save.state_dict(), output_model_file)
|
697 |
-
# with open(args.output_dir+"/p_valid_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
698 |
-
# writer = csv.writer(fcsv2)
|
699 |
-
# for wl in p_wrong_list:
|
700 |
-
# writer.writerow(wl)
|
701 |
-
# with open(args.output_dir+"/v_valid_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
702 |
-
# writer = csv.writer(fcsv2)
|
703 |
-
# for wl in v_wrong_list:
|
704 |
-
# writer.writerow(wl)
|
705 |
-
#print("ZM Debug--cnt_err_v: " + str(cnt_v))
|
706 |
logger.info(" Best acc:%s", best_acc)
|
707 |
logger.info(" " + "*" * 20)
|
708 |
|
@@ -753,9 +682,7 @@ def vega_train_main():
|
|
753 |
# convert ids to text
|
754 |
for pred, predicate in zip(preds, predicates):
|
755 |
t = pred[0].cpu().numpy()
|
756 |
-
#p = predicate[0].cpu().numpy()
|
757 |
p = predicate.float().item()
|
758 |
-
#print("ZM_Debug -- ppp: " + str(p))
|
759 |
t = list(t)
|
760 |
tem_i = 0
|
761 |
if 0 in t:
|
@@ -802,7 +729,6 @@ def vega_train_main():
|
|
802 |
predicate = 0.0
|
803 |
if 1 in gold.vec[-97:]:
|
804 |
predicate = 1.0
|
805 |
-
#my_cls = tokenizer.decode([tokenizer.cls_token_id],clean_up_tokenization_spaces=False)
|
806 |
gt_pred = gold.target.strip()
|
807 |
gt_predicate = gold.exist
|
808 |
is_re = False
|
@@ -840,30 +766,13 @@ def vega_train_main():
|
|
840 |
|
841 |
if pred == gt_pred:
|
842 |
EM_V += 1
|
843 |
-
# else:
|
844 |
-
# print("TEST Wrong pred:", pred, " gt_pred:", gt_pred)
|
845 |
if round(predicate) == gt_predicate:
|
846 |
EM_P += 1
|
847 |
model_predicate.append(predicate)
|
848 |
groundtruth_predicate.append(gt_predicate)
|
849 |
-
|
850 |
-
# MAE_P = mean_absolute_error(
|
851 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
852 |
-
# MSE_P = mean_squared_error(
|
853 |
-
# np.array(model_predicate), np.array(groundtruth_predicate))
|
854 |
-
# RMSE_P = np.sqrt(MSE_P)
|
855 |
-
|
856 |
dev_acc = round((100 * EM / total), 2)
|
857 |
dev_acc_v = round((100 * EM_V / total), 2)
|
858 |
dev_acc_p = round((100 * EM_P / total), 2)
|
859 |
-
# logger.info(" %s = %s " % ("Test Acc", str(dev_acc)))
|
860 |
-
# logger.info(" %s = %s " % ("Test Acc V", str(dev_acc_v)))
|
861 |
-
# logger.info(" %s = %s " % ("Test Acc P", str(dev_acc_p)))
|
862 |
-
# logger.info(" %s = %s "%("Test Edit sim",str(round(edit_sim/total, 2))))
|
863 |
-
# logger.info(" %s = %s " % ("Test MAE P", str(round(MAE_P, 2))))
|
864 |
-
# logger.info(" %s = %s " % ("Test MSE P", str(round(MSE_P, 2))))
|
865 |
-
# logger.info(" %s = %s " % ("Test RMSE P", str(round(RMSE_P, 2))))
|
866 |
-
# logger.info(" " + "*" * 20)
|
867 |
predictions = []
|
868 |
|
869 |
|
@@ -897,15 +806,6 @@ def vega_train_main():
|
|
897 |
json.dump(dic, f2)
|
898 |
f2.write('\n')
|
899 |
|
900 |
-
|
901 |
-
# with open(args.output_dir+"/p_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
902 |
-
# writer = csv.writer(fcsv2)
|
903 |
-
# for wl in p_wrong_list:
|
904 |
-
# writer.writerow(wl)
|
905 |
-
# with open(args.output_dir+"/v_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
|
906 |
-
# writer = csv.writer(fcsv2)
|
907 |
-
# for wl in v_wrong_list:
|
908 |
-
# writer.writerow(wl)
|
909 |
|
910 |
|
911 |
if __name__ == "__main__":
|
|
|
53 |
vec,
|
54 |
exist,
|
55 |
module
|
|
|
56 |
):
|
57 |
self.idx = idx
|
58 |
self.source = source
|
|
|
76 |
break
|
77 |
line = line.strip()
|
78 |
js = json.loads(line)
|
|
|
|
|
79 |
if js["Stmt"].strip()[0] == "}":
|
80 |
continue
|
81 |
if js["Value"].strip().lower() == "nothing" and '#' in js['FIR']:
|
|
|
116 |
mod = ""
|
117 |
if "Module" in js.keys():
|
118 |
mod = js["Module"]
|
|
|
|
|
|
|
|
|
|
|
119 |
examples.append(
|
120 |
Example(
|
121 |
idx=idx,
|
|
|
144 |
break
|
145 |
line = line.strip()
|
146 |
js = json.loads(line)
|
|
|
|
|
147 |
if 'idx' not in js:
|
148 |
js['idx'] = idx
|
149 |
code = ' '.join(js['FIR_token']).replace('\n', ' ')
|
|
|
178 |
mod = ""
|
179 |
if "Module" in js.keys():
|
180 |
mod = js["Module"]
|
|
|
|
|
|
|
|
|
|
|
181 |
examples.append(
|
182 |
Example(
|
183 |
idx=idx,
|
|
|
218 |
# source
|
219 |
func_tokens = tokenizer.tokenize(example.funcname)
|
220 |
source_tokens = tokenizer.tokenize(
|
221 |
+
example.source)
|
222 |
pro_tokens = tokenizer.tokenize(example.property)
|
223 |
vec_tokens = example.vec
|
224 |
source_tokens = [tokenizer.cls_token, "<encoder-decoder>", tokenizer.sep_token, "<mask0>"] + func_tokens + [tokenizer.cls_token] + \
|
|
|
228 |
padding_length = args.max_source_length - len(source_ids)
|
229 |
source_ids += [tokenizer.pad_token_id] * padding_length
|
230 |
|
|
|
|
|
231 |
target_tokens = tokenizer.tokenize(example.target)
|
232 |
exist = [example.exist]
|
233 |
target_tokens = [tokenizer.cls_token, "<mask0>"] + \
|
|
|
235 |
target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
|
236 |
padding_length = args.max_target_length - len(target_ids)
|
237 |
target_ids += [tokenizer.pad_token_id] * padding_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
features.append(
|
240 |
InputFeatures(
|
|
|
446 |
total_eval_all = len(eval_examples_all)
|
447 |
patience, best_acc, losses, dev_dataset = 0, 0, [], {}
|
448 |
for epoch in tqdm(range(args.num_train_epochs)):
|
|
|
|
|
449 |
for idx, batch in enumerate(train_dataloader):
|
|
|
|
|
|
|
|
|
|
|
450 |
batch = tuple(t.to(device) for t in batch)
|
451 |
source_ids, exist, target_ids = batch
|
452 |
loss, _, _, mse_loss, ce_loss = model(
|
|
|
541 |
# convert ids to text
|
542 |
for pred, predicate in zip(preds, predicates):
|
543 |
t = pred[0].cpu().numpy()
|
|
|
544 |
p = predicate.float().item()
|
|
|
545 |
t = list(t)
|
546 |
#p = list(p)
|
547 |
tem_i = 0
|
|
|
575 |
cnt_iteration += 1
|
576 |
pred = ref[0].strip()
|
577 |
predicate = ref[1]
|
|
|
578 |
if gold.property.strip().lower() != "nothing":
|
579 |
predicate = 1.0
|
580 |
else:
|
|
|
592 |
|
593 |
|
594 |
if pred == gt_pred and int(round(predicate)) == int(round(gt_predicate)):
|
|
|
595 |
EM = EM + 1.0
|
596 |
EM_V = EM_V + 1.0
|
597 |
EM_P = EM_P + 1.0
|
|
|
611 |
|
612 |
model_predicate.append(predicate)
|
613 |
groundtruth_predicate.append(gt_predicate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
dev_acc = round((100*EM/total), 2)
|
615 |
dev_acc_v = round((100*EM_V/total), 2)
|
616 |
dev_acc_p = round((100*EM_P/total), 2)
|
617 |
logger.info(" %s = %s " % ("Current Acc", str(dev_acc)))
|
|
|
618 |
logger.info(" "+"*"*20)
|
619 |
logger.info(" %s = %s " % ("Current Acc V", str(dev_acc_v)))
|
|
|
620 |
logger.info(" "+"*"*20)
|
621 |
logger.info(" %s = %s " % ("Current Acc P", str(dev_acc_p)))
|
|
|
622 |
logger.info(" "+"*"*20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
if dev_acc > best_acc:
|
|
|
|
|
624 |
best_acc = dev_acc
|
625 |
# Save best checkpoint for best bleu
|
626 |
output_dir = os.path.join(
|
|
|
632 |
output_model_file = os.path.join(
|
633 |
output_dir, "pytorch_model.bin")
|
634 |
torch.save(model_to_save.state_dict(), output_model_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
logger.info(" Best acc:%s", best_acc)
|
636 |
logger.info(" " + "*" * 20)
|
637 |
|
|
|
682 |
# convert ids to text
|
683 |
for pred, predicate in zip(preds, predicates):
|
684 |
t = pred[0].cpu().numpy()
|
|
|
685 |
p = predicate.float().item()
|
|
|
686 |
t = list(t)
|
687 |
tem_i = 0
|
688 |
if 0 in t:
|
|
|
729 |
predicate = 0.0
|
730 |
if 1 in gold.vec[-97:]:
|
731 |
predicate = 1.0
|
|
|
732 |
gt_pred = gold.target.strip()
|
733 |
gt_predicate = gold.exist
|
734 |
is_re = False
|
|
|
766 |
|
767 |
if pred == gt_pred:
|
768 |
EM_V += 1
|
|
|
|
|
769 |
if round(predicate) == gt_predicate:
|
770 |
EM_P += 1
|
771 |
model_predicate.append(predicate)
|
772 |
groundtruth_predicate.append(gt_predicate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
773 |
dev_acc = round((100 * EM / total), 2)
|
774 |
dev_acc_v = round((100 * EM_V / total), 2)
|
775 |
dev_acc_p = round((100 * EM_P / total), 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
predictions = []
|
777 |
|
778 |
|
|
|
806 |
json.dump(dic, f2)
|
807 |
f2.write('\n')
|
808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
|
810 |
|
811 |
if __name__ == "__main__":
|
run_fine_tuning.sh
CHANGED
@@ -10,6 +10,6 @@ python ./Scripts/UnixCoder/run_one_model.py \
|
|
10 |
--train_batch_size 64 \
|
11 |
--eval_batch_size 48 \
|
12 |
--learning_rate 6e-5 \
|
13 |
-
--num_train_epochs
|
14 |
--mse_loss_weight 0.9 \
|
15 |
--ce_loss_weight 0.1
|
|
|
10 |
--train_batch_size 64 \
|
11 |
--eval_batch_size 48 \
|
12 |
--learning_rate 6e-5 \
|
13 |
+
--num_train_epochs 50 \
|
14 |
--mse_loss_weight 0.9 \
|
15 |
--ce_loss_weight 0.1
|