unknown commited on
Commit
ac3312e
·
1 Parent(s): 3e0611d
Scripts/UnixCoder/model_gen.py CHANGED
@@ -56,9 +56,6 @@ class Seq2Seq(nn.Module):
56
  mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
57
  encoder_output = self.encoder(
58
  source_ids, attention_mask=mask, use_cache=True)
59
- # print("source_ids:", source_ids.size()) # torch.Size([56, 510])
60
- # print("exist:", exist.size()) # torch.Size([56, 1])
61
- # print("target_ids:", target_ids.size()) # torch.Size([56, 240])
62
  ids = torch.cat((source_ids, target_ids), -1)
63
 
64
  mask = self.bias[:,
@@ -68,33 +65,15 @@ class Seq2Seq(nn.Module):
68
  out = self.decoder(target_ids, attention_mask=mask,
69
  past_key_values=encoder_output.past_key_values).last_hidden_state
70
 
71
- # 先concat 再池化
72
- # print("out:", out.size()) # torch.Size([56, 240, 768])
73
-
74
  lm_logits = self.lm_head(out[..., 1:, :])
75
- # print("lm_logits:", lm_logits.size()) # torch.Size([56, 239, 51416])
76
-
77
  # Shift so that tokens < n predict n
78
  active_loss = target_ids[..., 2:].ne(1).view(-1)
79
- # print("active_loss:", active_loss.size()) # torch.Size([13328])
80
  shift_logits = lm_logits[..., :-1, :].contiguous()
81
- # print("shift_logits:", shift_logits.size()) # torch.Size([56, 238, 51416])
82
-
83
  shift_labels = target_ids[..., 2:].contiguous()
84
- # print("shift_labels:", shift_labels.size()) # torch.Size([56, 238])
85
 
86
  exist_labels = exist.contiguous()
87
- # print("exist_labels:", exist_labels.size()) # torch.Size([56, 1])
88
-
89
- # print("shift_logits.size:", shift_logits.size(-1)) # 51416
90
- # print("shift_logits.view(-1, shift_logits.size(-1)):", shift_logits.view(-1, shift_logits.size(-1))[active_loss].size()) # torch.Size([614, 51416])
91
- # print("shift_labels.view(-1):", shift_labels.view(-1)[active_loss].size()) # torch.Size([614])
92
-
93
  pred_out = out[..., 0, :]
94
- # print("pred_out:", pred_out.size()) # torch.Size([56, 768])
95
  pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
96
- # print("pred_sigmoid:", pred_sigmoid.size()) # torch.Size([56, 1])
97
-
98
  # Flatten the tokens
99
  loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
100
  loss_fct_pred = nn.MSELoss(reduction="mean")
@@ -103,8 +82,6 @@ class Seq2Seq(nn.Module):
103
 
104
  loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
105
  loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight
106
- # loss = loss.to(torch.float32)
107
- # loss = loss_pred
108
 
109
  outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
110
  return outputs
@@ -135,10 +112,7 @@ class Seq2Seq(nn.Module):
135
  mask = mask & ids[:, None, :].ne(1)
136
  out = self.decoder(input_ids, attention_mask=mask,
137
  past_key_values=context).last_hidden_state
138
- # print("out:", out.size())
139
- # concat 池化 out
140
  hidden_states = out[:, -1, :]
141
- # print("hidden_states:", hidden_states.size())
142
  if out.size(1) == 1:
143
  pred_sigmoid = self.sigmoid(self.pred_dense(
144
  hidden_states.view(-1, 1, hidden_states.size(-1))))
@@ -155,14 +129,9 @@ class Seq2Seq(nn.Module):
155
  pred = [torch.cat([x.view(-1) for x in p] + [zero] *
156
  (self.max_length-len(p))).view(1, -1) for p in pred]
157
  predicates.append(predicate[0][0])# ZM modified
158
- #print("ZM-Model_Debug_P_Each_Itr: %d, %d, %d" % (len(predicate), len(predicate[0]), len(predicate[0][0])))
159
  preds.append(torch.cat(pred, 0).unsqueeze(0))
160
- #print("ZM-Model_Debug_Predicate_Shape: %d" % (len(predicates)))
161
- #print("ZM-Model_Debug_Codes_BeforeCat: %d, %d, %d, %d" % (len(preds), len(preds[0]), len(preds[0][0]), len(preds[0][0][0])))
162
  preds = torch.cat(preds, 0)
163
  predicates = torch.tensor(predicates, device="cuda")# ZM modified
164
- # predicates = torch.cat(predicates, 0).unsqueeze(0)
165
- #print("ZM-Model_Debug_Codes_AfterCat: %d, %d, %d" % (len(preds), len(preds[0]), len(preds[0][0])))
166
  return preds, predicates
167
 
168
 
 
56
  mask = source_ids.ne(1)[:, None, :]*source_ids.ne(1)[:, :, None]
57
  encoder_output = self.encoder(
58
  source_ids, attention_mask=mask, use_cache=True)
 
 
 
59
  ids = torch.cat((source_ids, target_ids), -1)
60
 
61
  mask = self.bias[:,
 
65
  out = self.decoder(target_ids, attention_mask=mask,
66
  past_key_values=encoder_output.past_key_values).last_hidden_state
67
 
 
 
 
68
  lm_logits = self.lm_head(out[..., 1:, :])
 
 
69
  # Shift so that tokens < n predict n
70
  active_loss = target_ids[..., 2:].ne(1).view(-1)
 
71
  shift_logits = lm_logits[..., :-1, :].contiguous()
 
 
72
  shift_labels = target_ids[..., 2:].contiguous()
 
73
 
74
  exist_labels = exist.contiguous()
 
 
 
 
 
 
75
  pred_out = out[..., 0, :]
 
76
  pred_sigmoid = self.sigmoid(self.pred_dense(pred_out))
 
 
77
  # Flatten the tokens
78
  loss_fct_code = nn.CrossEntropyLoss(ignore_index=-1)
79
  loss_fct_pred = nn.MSELoss(reduction="mean")
 
82
 
83
  loss_pred = loss_fct_pred(pred_sigmoid, exist_labels)
84
  loss = loss_pred * self.mse_loss_weight + loss_code * self.ce_loss_weight
 
 
85
 
86
  outputs = loss, loss*active_loss.sum(), active_loss.sum(), loss_pred, loss_code
87
  return outputs
 
112
  mask = mask & ids[:, None, :].ne(1)
113
  out = self.decoder(input_ids, attention_mask=mask,
114
  past_key_values=context).last_hidden_state
 
 
115
  hidden_states = out[:, -1, :]
 
116
  if out.size(1) == 1:
117
  pred_sigmoid = self.sigmoid(self.pred_dense(
118
  hidden_states.view(-1, 1, hidden_states.size(-1))))
 
129
  pred = [torch.cat([x.view(-1) for x in p] + [zero] *
130
  (self.max_length-len(p))).view(1, -1) for p in pred]
131
  predicates.append(predicate[0][0])# ZM modified
 
132
  preds.append(torch.cat(pred, 0).unsqueeze(0))
 
 
133
  preds = torch.cat(preds, 0)
134
  predicates = torch.tensor(predicates, device="cuda")# ZM modified
 
 
135
  return preds, predicates
136
 
137
 
Scripts/UnixCoder/run_one_model.py CHANGED
@@ -53,7 +53,6 @@ class Example(object):
53
  vec,
54
  exist,
55
  module
56
- # propertyposition,
57
  ):
58
  self.idx = idx
59
  self.source = source
@@ -77,8 +76,6 @@ def read_examples_no_bracket(filename, is_function_test):
77
  break
78
  line = line.strip()
79
  js = json.loads(line)
80
- if idx > 1000:
81
- break
82
  if js["Stmt"].strip()[0] == "}":
83
  continue
84
  if js["Value"].strip().lower() == "nothing" and '#' in js['FIR']:
@@ -119,11 +116,6 @@ def read_examples_no_bracket(filename, is_function_test):
119
  mod = ""
120
  if "Module" in js.keys():
121
  mod = js["Module"]
122
- # propos = ' '.join(js['pp'])
123
- # propos = ' '.join(propos.strip().split(','))
124
- # print(code)
125
- # print(nl)
126
- # print(pro)
127
  examples.append(
128
  Example(
129
  idx=idx,
@@ -152,8 +144,6 @@ def read_examples(filename, is_function_test):
152
  break
153
  line = line.strip()
154
  js = json.loads(line)
155
- if idx > 3000:
156
- break
157
  if 'idx' not in js:
158
  js['idx'] = idx
159
  code = ' '.join(js['FIR_token']).replace('\n', ' ')
@@ -188,11 +178,6 @@ def read_examples(filename, is_function_test):
188
  mod = ""
189
  if "Module" in js.keys():
190
  mod = js["Module"]
191
- # propos = ' '.join(js['pp'])
192
- # propos = ' '.join(propos.strip().split(','))
193
- # print(code)
194
- # print(nl)
195
- # print(pro)
196
  examples.append(
197
  Example(
198
  idx=idx,
@@ -233,7 +218,7 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
233
  # source
234
  func_tokens = tokenizer.tokenize(example.funcname)
235
  source_tokens = tokenizer.tokenize(
236
- example.source) # [:args.max_source_length-5]
237
  pro_tokens = tokenizer.tokenize(example.property)
238
  vec_tokens = example.vec
239
  source_tokens = [tokenizer.cls_token, "<encoder-decoder>", tokenizer.sep_token, "<mask0>"] + func_tokens + [tokenizer.cls_token] + \
@@ -243,8 +228,6 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
243
  padding_length = args.max_source_length - len(source_ids)
244
  source_ids += [tokenizer.pad_token_id] * padding_length
245
 
246
- # target
247
- # if stage=="test":
248
  target_tokens = tokenizer.tokenize(example.target)
249
  exist = [example.exist]
250
  target_tokens = [tokenizer.cls_token, "<mask0>"] + \
@@ -252,13 +235,6 @@ def convert_examples_to_features(examples, tokenizer, args, stage=None):
252
  target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
253
  padding_length = args.max_target_length - len(target_ids)
254
  target_ids += [tokenizer.pad_token_id] * padding_length
255
- # else:
256
- # target_tokens = tokenizer.tokenize(example.target)
257
- # exist_tokens = tokenizer.tokenize(example.exist)
258
- # target_tokens = ["<mask0>"] + exist_tokens + [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
259
- # target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
260
- # padding_length = args.max_target_length - len(target_ids)
261
- # target_ids += [tokenizer.pad_token_id] * padding_length
262
 
263
  features.append(
264
  InputFeatures(
@@ -470,14 +446,7 @@ def vega_train_main():
470
  total_eval_all = len(eval_examples_all)
471
  patience, best_acc, losses, dev_dataset = 0, 0, [], {}
472
  for epoch in tqdm(range(args.num_train_epochs)):
473
- # print(args.num_train_epochs)
474
-
475
  for idx, batch in enumerate(train_dataloader):
476
- # print("##########Debug################")
477
- # print(idx)
478
- # print("###############Debug###########")
479
- # if idx > 100:
480
- # break
481
  batch = tuple(t.to(device) for t in batch)
482
  source_ids, exist, target_ids = batch
483
  loss, _, _, mse_loss, ce_loss = model(
@@ -572,9 +541,7 @@ def vega_train_main():
572
  # convert ids to text
573
  for pred, predicate in zip(preds, predicates):
574
  t = pred[0].cpu().numpy()
575
- #p = predicate[0].cpu().numpy()
576
  p = predicate.float().item()
577
- #print("ZM_Debug -- ppp: " + str(p))
578
  t = list(t)
579
  #p = list(p)
580
  tem_i = 0
@@ -608,7 +575,6 @@ def vega_train_main():
608
  cnt_iteration += 1
609
  pred = ref[0].strip()
610
  predicate = ref[1]
611
- #print("ZM_Debug -- predicate: " + str(predicate))
612
  if gold.property.strip().lower() != "nothing":
613
  predicate = 1.0
614
  else:
@@ -626,7 +592,6 @@ def vega_train_main():
626
 
627
 
628
  if pred == gt_pred and int(round(predicate)) == int(round(gt_predicate)):
629
- #print("Total correct, Inside this place")
630
  EM = EM + 1.0
631
  EM_V = EM_V + 1.0
632
  EM_P = EM_P + 1.0
@@ -646,43 +611,16 @@ def vega_train_main():
646
 
647
  model_predicate.append(predicate)
648
  groundtruth_predicate.append(gt_predicate)
649
- # if len(pred.split(tokenizer.cls_token)) >= 2:
650
- # if pred.split(tokenizer.cls_token)[0].strip() == gt_pred.split(tokenizer.cls_token)[0].strip():
651
- # EM_P += 1
652
- # if pred.split(tokenizer.cls_token)[1].strip() == gt_pred.split(tokenizer.cls_token)[1].strip():
653
- # EM_V += 1
654
- # MAE_P = mean_absolute_error(
655
- # np.array(model_predicate), np.array(groundtruth_predicate))
656
- # MSE_P = mean_squared_error(
657
- # np.array(model_predicate), np.array(groundtruth_predicate))
658
- # RMSE_P = np.sqrt(MSE_P)
659
  dev_acc = round((100*EM/total), 2)
660
  dev_acc_v = round((100*EM_V/total), 2)
661
  dev_acc_p = round((100*EM_P/total), 2)
662
  logger.info(" %s = %s " % ("Current Acc", str(dev_acc)))
663
- #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
664
  logger.info(" "+"*"*20)
665
  logger.info(" %s = %s " % ("Current Acc V", str(dev_acc_v)))
666
- #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
667
  logger.info(" "+"*"*20)
668
  logger.info(" %s = %s " % ("Current Acc P", str(dev_acc_p)))
669
- #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
670
  logger.info(" "+"*"*20)
671
- # logger.info(" %s = %s " %
672
- # ("Current MAE P", str(round(MAE_P, 2))))
673
- # #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
674
- # logger.info(" "+"*"*20)
675
- # logger.info(" %s = %s " %
676
- # ("Current MSE P", str(round(MSE_P, 2))))
677
- # #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
678
- # logger.info(" "+"*"*20)
679
- # logger.info(" %s = %s " %
680
- # ("Current RMSE P", str(round(RMSE_P, 2))))
681
- # #logger.info(" %s = %s "%("Current Edit sim",str(round(edit_sim/total, 2))))
682
- # logger.info(" "+"*"*20)
683
  if dev_acc > best_acc:
684
- #logger.info(" Best acc:%s",dev_acc)
685
- #logger.info(" "+"*"*20)
686
  best_acc = dev_acc
687
  # Save best checkpoint for best bleu
688
  output_dir = os.path.join(
@@ -694,15 +632,6 @@ def vega_train_main():
694
  output_model_file = os.path.join(
695
  output_dir, "pytorch_model.bin")
696
  torch.save(model_to_save.state_dict(), output_model_file)
697
- # with open(args.output_dir+"/p_valid_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
698
- # writer = csv.writer(fcsv2)
699
- # for wl in p_wrong_list:
700
- # writer.writerow(wl)
701
- # with open(args.output_dir+"/v_valid_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
702
- # writer = csv.writer(fcsv2)
703
- # for wl in v_wrong_list:
704
- # writer.writerow(wl)
705
- #print("ZM Debug--cnt_err_v: " + str(cnt_v))
706
  logger.info(" Best acc:%s", best_acc)
707
  logger.info(" " + "*" * 20)
708
 
@@ -753,9 +682,7 @@ def vega_train_main():
753
  # convert ids to text
754
  for pred, predicate in zip(preds, predicates):
755
  t = pred[0].cpu().numpy()
756
- #p = predicate[0].cpu().numpy()
757
  p = predicate.float().item()
758
- #print("ZM_Debug -- ppp: " + str(p))
759
  t = list(t)
760
  tem_i = 0
761
  if 0 in t:
@@ -802,7 +729,6 @@ def vega_train_main():
802
  predicate = 0.0
803
  if 1 in gold.vec[-97:]:
804
  predicate = 1.0
805
- #my_cls = tokenizer.decode([tokenizer.cls_token_id],clean_up_tokenization_spaces=False)
806
  gt_pred = gold.target.strip()
807
  gt_predicate = gold.exist
808
  is_re = False
@@ -840,30 +766,13 @@ def vega_train_main():
840
 
841
  if pred == gt_pred:
842
  EM_V += 1
843
- # else:
844
- # print("TEST Wrong pred:", pred, " gt_pred:", gt_pred)
845
  if round(predicate) == gt_predicate:
846
  EM_P += 1
847
  model_predicate.append(predicate)
848
  groundtruth_predicate.append(gt_predicate)
849
-
850
- # MAE_P = mean_absolute_error(
851
- # np.array(model_predicate), np.array(groundtruth_predicate))
852
- # MSE_P = mean_squared_error(
853
- # np.array(model_predicate), np.array(groundtruth_predicate))
854
- # RMSE_P = np.sqrt(MSE_P)
855
-
856
  dev_acc = round((100 * EM / total), 2)
857
  dev_acc_v = round((100 * EM_V / total), 2)
858
  dev_acc_p = round((100 * EM_P / total), 2)
859
- # logger.info(" %s = %s " % ("Test Acc", str(dev_acc)))
860
- # logger.info(" %s = %s " % ("Test Acc V", str(dev_acc_v)))
861
- # logger.info(" %s = %s " % ("Test Acc P", str(dev_acc_p)))
862
- # logger.info(" %s = %s "%("Test Edit sim",str(round(edit_sim/total, 2))))
863
- # logger.info(" %s = %s " % ("Test MAE P", str(round(MAE_P, 2))))
864
- # logger.info(" %s = %s " % ("Test MSE P", str(round(MSE_P, 2))))
865
- # logger.info(" %s = %s " % ("Test RMSE P", str(round(RMSE_P, 2))))
866
- # logger.info(" " + "*" * 20)
867
  predictions = []
868
 
869
 
@@ -897,15 +806,6 @@ def vega_train_main():
897
  json.dump(dic, f2)
898
  f2.write('\n')
899
 
900
-
901
- # with open(args.output_dir+"/p_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
902
- # writer = csv.writer(fcsv2)
903
- # for wl in p_wrong_list:
904
- # writer.writerow(wl)
905
- # with open(args.output_dir+"/v_wrong.csv", 'w', encoding='utf-8', newline="") as fcsv2:
906
- # writer = csv.writer(fcsv2)
907
- # for wl in v_wrong_list:
908
- # writer.writerow(wl)
909
 
910
 
911
  if __name__ == "__main__":
 
53
  vec,
54
  exist,
55
  module
 
56
  ):
57
  self.idx = idx
58
  self.source = source
 
76
  break
77
  line = line.strip()
78
  js = json.loads(line)
 
 
79
  if js["Stmt"].strip()[0] == "}":
80
  continue
81
  if js["Value"].strip().lower() == "nothing" and '#' in js['FIR']:
 
116
  mod = ""
117
  if "Module" in js.keys():
118
  mod = js["Module"]
 
 
 
 
 
119
  examples.append(
120
  Example(
121
  idx=idx,
 
144
  break
145
  line = line.strip()
146
  js = json.loads(line)
 
 
147
  if 'idx' not in js:
148
  js['idx'] = idx
149
  code = ' '.join(js['FIR_token']).replace('\n', ' ')
 
178
  mod = ""
179
  if "Module" in js.keys():
180
  mod = js["Module"]
 
 
 
 
 
181
  examples.append(
182
  Example(
183
  idx=idx,
 
218
  # source
219
  func_tokens = tokenizer.tokenize(example.funcname)
220
  source_tokens = tokenizer.tokenize(
221
+ example.source)
222
  pro_tokens = tokenizer.tokenize(example.property)
223
  vec_tokens = example.vec
224
  source_tokens = [tokenizer.cls_token, "<encoder-decoder>", tokenizer.sep_token, "<mask0>"] + func_tokens + [tokenizer.cls_token] + \
 
228
  padding_length = args.max_source_length - len(source_ids)
229
  source_ids += [tokenizer.pad_token_id] * padding_length
230
 
 
 
231
  target_tokens = tokenizer.tokenize(example.target)
232
  exist = [example.exist]
233
  target_tokens = [tokenizer.cls_token, "<mask0>"] + \
 
235
  target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
236
  padding_length = args.max_target_length - len(target_ids)
237
  target_ids += [tokenizer.pad_token_id] * padding_length
 
 
 
 
 
 
 
238
 
239
  features.append(
240
  InputFeatures(
 
446
  total_eval_all = len(eval_examples_all)
447
  patience, best_acc, losses, dev_dataset = 0, 0, [], {}
448
  for epoch in tqdm(range(args.num_train_epochs)):
 
 
449
  for idx, batch in enumerate(train_dataloader):
 
 
 
 
 
450
  batch = tuple(t.to(device) for t in batch)
451
  source_ids, exist, target_ids = batch
452
  loss, _, _, mse_loss, ce_loss = model(
 
541
  # convert ids to text
542
  for pred, predicate in zip(preds, predicates):
543
  t = pred[0].cpu().numpy()
 
544
  p = predicate.float().item()
 
545
  t = list(t)
546
  #p = list(p)
547
  tem_i = 0
 
575
  cnt_iteration += 1
576
  pred = ref[0].strip()
577
  predicate = ref[1]
 
578
  if gold.property.strip().lower() != "nothing":
579
  predicate = 1.0
580
  else:
 
592
 
593
 
594
  if pred == gt_pred and int(round(predicate)) == int(round(gt_predicate)):
 
595
  EM = EM + 1.0
596
  EM_V = EM_V + 1.0
597
  EM_P = EM_P + 1.0
 
611
 
612
  model_predicate.append(predicate)
613
  groundtruth_predicate.append(gt_predicate)
 
 
 
 
 
 
 
 
 
 
614
  dev_acc = round((100*EM/total), 2)
615
  dev_acc_v = round((100*EM_V/total), 2)
616
  dev_acc_p = round((100*EM_P/total), 2)
617
  logger.info(" %s = %s " % ("Current Acc", str(dev_acc)))
 
618
  logger.info(" "+"*"*20)
619
  logger.info(" %s = %s " % ("Current Acc V", str(dev_acc_v)))
 
620
  logger.info(" "+"*"*20)
621
  logger.info(" %s = %s " % ("Current Acc P", str(dev_acc_p)))
 
622
  logger.info(" "+"*"*20)
 
 
 
 
 
 
 
 
 
 
 
 
623
  if dev_acc > best_acc:
 
 
624
  best_acc = dev_acc
625
  # Save best checkpoint for best bleu
626
  output_dir = os.path.join(
 
632
  output_model_file = os.path.join(
633
  output_dir, "pytorch_model.bin")
634
  torch.save(model_to_save.state_dict(), output_model_file)
 
 
 
 
 
 
 
 
 
635
  logger.info(" Best acc:%s", best_acc)
636
  logger.info(" " + "*" * 20)
637
 
 
682
  # convert ids to text
683
  for pred, predicate in zip(preds, predicates):
684
  t = pred[0].cpu().numpy()
 
685
  p = predicate.float().item()
 
686
  t = list(t)
687
  tem_i = 0
688
  if 0 in t:
 
729
  predicate = 0.0
730
  if 1 in gold.vec[-97:]:
731
  predicate = 1.0
 
732
  gt_pred = gold.target.strip()
733
  gt_predicate = gold.exist
734
  is_re = False
 
766
 
767
  if pred == gt_pred:
768
  EM_V += 1
 
 
769
  if round(predicate) == gt_predicate:
770
  EM_P += 1
771
  model_predicate.append(predicate)
772
  groundtruth_predicate.append(gt_predicate)
 
 
 
 
 
 
 
773
  dev_acc = round((100 * EM / total), 2)
774
  dev_acc_v = round((100 * EM_V / total), 2)
775
  dev_acc_p = round((100 * EM_P / total), 2)
 
 
 
 
 
 
 
 
776
  predictions = []
777
 
778
 
 
806
  json.dump(dic, f2)
807
  f2.write('\n')
808
 
 
 
 
 
 
 
 
 
 
809
 
810
 
811
  if __name__ == "__main__":
run_fine_tuning.sh CHANGED
@@ -10,6 +10,6 @@ python ./Scripts/UnixCoder/run_one_model.py \
10
  --train_batch_size 64 \
11
  --eval_batch_size 48 \
12
  --learning_rate 6e-5 \
13
- --num_train_epochs 3 \
14
  --mse_loss_weight 0.9 \
15
  --ce_loss_weight 0.1
 
10
  --train_batch_size 64 \
11
  --eval_batch_size 48 \
12
  --learning_rate 6e-5 \
13
+ --num_train_epochs 50 \
14
  --mse_loss_weight 0.9 \
15
  --ce_loss_weight 0.1