BAAI
/

ldwang commited on
Commit
a70a7f0
·
verified ·
1 Parent(s): e1a4125

Upload scorer_pred_local.py

Browse files
Files changed (1) hide show
  1. scorer_pred_local.py +94 -0
scorer_pred_local.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+
4
+ import os
5
+ import time
6
+ import torch
7
+
8
+ from transformers import AutoModelForSequenceClassification
9
+ from transformers import AutoTokenizer
10
+
11
+ if __name__ == "__main__":
12
+ import argparse
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--scorer-model-path', type=str, default="", help="file path", required=True)
15
+ parser.add_argument('--input-file-path', type=str, default="", help="file path", required=True)
16
+ parser.add_argument('--output-file-path', type=str, default="", help="file path", required=True)
17
+ parser.add_argument('--score-thres', type=float, default=3.0, help="score thres", required=False)
18
+ parser.add_argument('--text-key', type=str, default="text", help="file path", required=False)
19
+ parser.add_argument('--output-key', type=str, default="score", help="file path", required=False)
20
+ parser.add_argument('--do-score-filter', action='store_true', default=False, help='do score filter or not', dest='do_score_filter')
21
+ args = parser.parse_args()
22
+
23
+ # model_dir = '/share/project/ldwang/Aquila3/quality_scorer_base_from_qwen15_0_5b_labeled_by_deepspeek-v2'
24
+ model_dir = args.scorer_model_path
25
+ model = AutoModelForSequenceClassification.from_pretrained(
26
+ model_dir,
27
+ trust_remote_code=False,
28
+ ignore_mismatched_sizes=False,)
29
+ model.cuda()
30
+ model.eval()
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ model_dir,
34
+ use_fast=True,
35
+ token=None,
36
+ trust_remote_code=False,)
37
+
38
+ max_length = 2048
39
+
40
+ import jsonlines
41
+ file_path = args.input_file_path
42
+ output_file_path = args.output_file_path
43
+ writer = jsonlines.open(output_file_path, mode='w')
44
+
45
+ dir_path = None
46
+ if os.path.isdir(file_path):
47
+ dir_path = os.listdir(file_path)
48
+ else:
49
+ dir_path = [file_path]
50
+
51
+ lines = 0
52
+ filtered = 0
53
+ start_time = time.time()
54
+
55
+ for file_path in dir_path:
56
+ input_file = os.path.join(args.input_file_path, file_path)
57
+ with jsonlines.open(input_file) as reader:
58
+ for line in reader:
59
+ lines += 1
60
+ if lines % 1000 == 0:
61
+ end_time = time.time()
62
+ elapsed_time = end_time - start_time
63
+ samples_per_second = lines / elapsed_time
64
+ print(f"Processed {lines} lines in {elapsed_time:.2f} seconds.", flush=True)
65
+ print(f"Samples per second: {samples_per_second:.2f}.", flush=True)
66
+
67
+ if args.text_key not in line:
68
+ filtered += 1
69
+ continue
70
+ sentecnce = line[args.text_key]
71
+ result = tokenizer(
72
+ [sentecnce],
73
+ padding=False,
74
+ max_length=max_length,
75
+ truncation=True,
76
+ return_tensors="pt",).to("cuda")
77
+ for key in result:
78
+ result[key] = torch.tensor(result[key])
79
+
80
+ model_out = model(**result)
81
+ score = float(model_out.logits.tolist()[0][0])
82
+ if args.do_score_filter and score < args.score_thres:
83
+ filtered += 1
84
+ continue
85
+
86
+ line[args.output_key] = score
87
+ writer.write(line)
88
+
89
+ end_time = time.time()
90
+ elapsed_time = end_time - start_time
91
+ samples_per_second = lines / elapsed_time
92
+ print(f"Processed {lines} lines in {elapsed_time:.2f} seconds, Filtered {filtered} samples.", flush=True)
93
+ print(f"Samples per second: {samples_per_second:.2f}.", flush=True)
94
+