FaYo commited on
Commit
0d29d74
·
1 Parent(s): 64a76db
dataset/gen_dataset/gen_dataset.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from copy import deepcopy
3
+ import json
4
+ import random
5
+ import re
6
+ from http import HTTPStatus
7
+ from pathlib import Path
8
+
9
+ import dashscope
10
+ import requests
11
+ import yaml
12
+ from tqdm import tqdm
13
+
14
+
15
+ def set_api_key(api_type, api_yaml_path):
16
+ """设置 api key
17
+
18
+ Args:
19
+ api_type (str): api 类型
20
+ api_yaml_path (str): api yaml 文件路径
21
+ """
22
+ # 读取 yaml 文件
23
+ with open(api_yaml_path, "r", encoding="utf-8") as f:
24
+ api_yaml = yaml.safe_load(f)
25
+
26
+ # 设置 api key
27
+ if api_type == "qwen":
28
+ api_key = api_yaml["ali_qwen_api_key"]
29
+ dashscope.api_key = api_key
30
+ elif api_type == "ernie":
31
+ api_key = api_yaml["baidu_ernie_api_key"]
32
+ else:
33
+ raise ValueError("api_type must be qwen or ernie")
34
+
35
+ return api_key
36
+
37
+
38
+ def call_qwen_message(content_str, model_type=dashscope.Generation.Models.qwen_turbo):
39
+
40
+ try:
41
+ response = dashscope.Generation.call(model_type, prompt=content_str)
42
+ except Exception as e:
43
+ print(f"Maybe connect error , try again : {e}")
44
+ response = dashscope.Generation.call(model_type, prompt=content_str)
45
+
46
+ if response.status_code == HTTPStatus.OK:
47
+ print("Used token: ", response.usage)
48
+ response_str = response.output.text
49
+ else:
50
+ print(
51
+ "Request id: %s, Status code: %s, error code: %s, error message: %s"
52
+ % (
53
+ response.request_id,
54
+ response.status_code,
55
+ response.code,
56
+ response.message,
57
+ )
58
+ )
59
+ response_str = "Error"
60
+
61
+ return response_str
62
+
63
+
64
+ def call_ernie_message(content_str, access_token):
65
+ url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token={access_token}"
66
+
67
+ payload = json.dumps(
68
+ {
69
+ "messages": [
70
+ {"role": "user", "content": content_str},
71
+ ],
72
+ "disable_search": False,
73
+ "enable_citation": False,
74
+ }
75
+ )
76
+ headers = {"Content-Type": "application/json"}
77
+
78
+ response = requests.request("POST", url, headers=headers, data=payload)
79
+
80
+ if response.status_code == HTTPStatus.OK:
81
+
82
+ # 获取 body 中的数据
83
+ response_json = response.json()
84
+
85
+ print("Used token: ", response_json["usage"])
86
+ response_str = response_json["result"]
87
+ else:
88
+ response_str = "Error"
89
+
90
+ return response_str
91
+
92
+
93
+ def format_json_from_response(func, content_str, func_args, model_name):
94
+ response = func(content_str, func_args)
95
+
96
+ if "```json" in response:
97
+ response = re.findall(r"```json(.*)```", response, flags=re.DOTALL)[0]
98
+
99
+ # 去掉导致 json 格式化失败的字符
100
+ response = response.replace("\\", "\\\\").replace("\n\n", "\n").replace("”", '"').replace("“", '"')
101
+
102
+ if model_name == "qwen":
103
+ # qwen 需要检查文案中是否有 " ,并替换为单引号 '
104
+
105
+ # 查找第一个 output 的字符串
106
+ output_start = response.find('"output": "')
107
+ if output_start != -1:
108
+ # 查找第二个 output 的字符位置
109
+ output_end = response.find("}", output_start + 1)
110
+ if output_end != -1:
111
+ response = list(response)
112
+ # 截取第二个 output 的字符串
113
+ check_len = len(response[output_start + len('"output": "') : output_end - 10])
114
+ for idx in range(check_len):
115
+ str_idx = output_start + len('"output": "') + idx
116
+ if response[str_idx] == '"':
117
+ response[str_idx] = "'"
118
+
119
+ response = "".join(response)
120
+
121
+ # 加上 strict=False 解决 decode Invalid control character
122
+ format_json = json.loads(response, strict=False)
123
+
124
+ return format_json, response
125
+
126
+
127
+ def process_request(func, content_str, func_args, model_name):
128
+ """_summary_
129
+
130
+ Args:
131
+ func (_type_): _description_
132
+ content_str (_type_): _description_
133
+ func_args (str):
134
+ qwen: model_type
135
+ ernie: api_key
136
+ Returns:
137
+ _type_: _description_
138
+ """
139
+
140
+ try:
141
+ format_json, response = format_json_from_response(func, content_str, func_args, model_name)
142
+ except Exception as e:
143
+ try:
144
+ # 再试一次
145
+ print(f"\n Got error, try again <== {e} \n")
146
+ if isinstance(e, json.decoder.JSONDecodeError):
147
+ print(f"JSONDecodeError doc 1: {str(e.doc)} \n")
148
+ format_json, response = format_json_from_response(func, content_str, func_args, model_name)
149
+ except Exception as e:
150
+ print(f"\n Got error <== {e} \n")
151
+ if isinstance(e, json.decoder.JSONDecodeError):
152
+ print(f"JSONDecodeError doc 2: {str(e.doc)} \n")
153
+ with open(f"error-{model_name}.log", "a+", encoding="utf-8") as f_error:
154
+ if isinstance(e, json.decoder.JSONDecodeError):
155
+ f_error.write(f"JSONDecodeError doc: {str(e.doc)} \n")
156
+ f_error.write(str(e))
157
+ f_error.flush()
158
+
159
+ format_json = {"Error": "Error"}
160
+
161
+ return format_json
162
+
163
+
164
+ def gen_product_highlights(dastset_yaml_path, api_yaml_path):
165
+ """根据产品的 yaml 文件生成每个产品的特点描述
166
+
167
+ Args:
168
+ dastset_yaml_path (str): 数据集的 yaml 文件路径
169
+ api_yaml_path (_type_): api 的 yaml 文件路径
170
+ """
171
+
172
+ # 读取 yaml 文件
173
+ with open(dastset_yaml_path, "r", encoding="utf-8") as f:
174
+ dataset_yaml = yaml.safe_load(f)
175
+
176
+ set_api_key("qwen", api_yaml_path)
177
+
178
+ for _, products in dataset_yaml["product_list"].items():
179
+ for product_class, product in products.items():
180
+ product_str = str(product).replace("'", "")
181
+ print(f"Process: {product_str}")
182
+
183
+ product_highlights = call_qwen_message(
184
+ content_str=product_str,
185
+ system_str="现在你精通医院里的各种事物,你帮我举例每个科室中五个细分专业治疗方法中每个细分治疗方法的六个优势或者特点,然后用python-dic的形式输出:{类名:[特点1,特点2,...]},去掉1,2的字样,除python字典外的其他都不要输出,不要有任何的警告信息",
186
+ model_type=dashscope.Generation.Models.qwen_turbo,
187
+ )
188
+
189
+ code_block = re.findall(r"```python(.*)```", product_highlights, flags=re.DOTALL)[0]
190
+ if " = " in code_block[:20]:
191
+ code_block = code_block.split(" = ")[1]
192
+
193
+ products[product_class] = eval(re.findall(r"```python(.*)```", product_highlights, flags=re.DOTALL)[0])
194
+
195
+ # 保存 yaml 文件
196
+ with open(f"{dastset_yaml_path}", "w", encoding="utf-8") as f:
197
+ yaml.dump(dataset_yaml, f, allow_unicode=True)
198
+
199
+
200
+ def gen_dataset(dastset_yaml_path: str, api_yaml_path: str, save_json_root: Path, model_name: str, specific_name=""):
201
+
202
+ # 确保文件夹存在
203
+ save_json_root.mkdir(parents=True, exist_ok=True)
204
+
205
+ # 读取 yaml 文件
206
+ with open(dastset_yaml_path, "r", encoding="utf-8") as f:
207
+ dataset_yaml = yaml.safe_load(f)
208
+
209
+ if specific_name != "":
210
+ assert (
211
+ specific_name in dataset_yaml["role_type"]
212
+ ), f"{specific_name} not in dataset_yaml['role_type'] ({dataset_yaml['role_type']}), pls check dataset yaml!"
213
+
214
+ # 设置 api key
215
+ api_key = set_api_key(model_name, api_yaml_path)
216
+
217
+ data_gen_setting = dataset_yaml["data_generation_setting"]
218
+ gen_num = data_gen_setting["each_product_gen"]
219
+ each_pick_hightlight = data_gen_setting["each_pick_hightlight"]
220
+ each_pick_question = data_gen_setting["each_pick_question"]
221
+
222
+ # qwen 配置调取的模型种类,确保有个一是最强模型
223
+ # gen_model_type = [dashscope.Generation.Models.qwen_plus] * (gen_num - 2)
224
+ # gen_model_type += [dashscope.Generation.Models.qwen_max] * 2
225
+ qwen_model_type = [dashscope.Generation.Models.qwen_max] * gen_num
226
+
227
+ for role_type, role_character in dataset_yaml["role_type"].items():
228
+
229
+ if specific_name != "" and role_type != specific_name:
230
+ # 只生成特定人物的
231
+ print(f"specific_name = {specific_name}, skipping for {role_type}")
232
+ continue
233
+
234
+ gen_json = dict()
235
+
236
+ save_json_path = save_json_root.joinpath(f"{model_name}_{role_type}_train.json")
237
+ bk_json_path = save_json_root.joinpath(f"{model_name}_{role_type}_train.json.bk")
238
+
239
+ # 加载之前已经有的 json
240
+ if save_json_path.exists():
241
+ with open(save_json_path, "r", encoding="utf-8") as f:
242
+ gen_json = json.load(f)
243
+
244
+ # 加载成功的话,再删除备份的 json
245
+ if bk_json_path.exists():
246
+ bk_json_path.unlink()
247
+
248
+ # 遍历所有产品,方便进度条显示
249
+ list_product = [
250
+ product_name
251
+ for _, products in dataset_yaml["product_list"].items()
252
+ for _, product_name_list in products.items()
253
+ for product_name in product_name_list.keys()
254
+ ]
255
+
256
+ # 生成人物性格
257
+ character = "、".join(role_character)
258
+
259
+ pbar = tqdm(total=len(list_product))
260
+
261
+ # 遍历产品
262
+ for _, products in dataset_yaml["product_list"].items():
263
+ for _, product_name_list in products.items():
264
+ for product, hightlights in product_name_list.items():
265
+ pbar.set_description(product)
266
+
267
+ if product in gen_json:
268
+ # 跳过已经有的
269
+ pbar.update(1)
270
+ continue
271
+
272
+ gen_json.update({product: []})
273
+
274
+ # 生成数据
275
+ for idx in range(gen_num):
276
+
277
+ # 随机抽取 ${each_pick_hightlight} 个产品特性
278
+ if each_pick_hightlight >= len(hightlights):
279
+ # 超过打乱,增加随机性
280
+ hightlights_list = random.shuffle(hightlights)
281
+ else:
282
+ hightlights_list = random.sample(hightlights, each_pick_hightlight)
283
+ hightlight_str = "、".join(hightlights_list)
284
+
285
+ # 随机抽取 ${each_pick_question} 个提问角度
286
+ if each_pick_question >= len(dataset_yaml["customer_question_type"]):
287
+ # 超过打乱,增加随机性
288
+ customer_question_type = random.shuffle(dataset_yaml["customer_question_type"])
289
+ else:
290
+ customer_question_type = random.sample(dataset_yaml["customer_question_type"], each_pick_question)
291
+ customer_question_str = "、".join(customer_question_type)
292
+
293
+ # 商品信息
294
+ product_info_str = dataset_yaml["product_info_struct"][0].replace("{name}", product)
295
+ product_info_str += dataset_yaml["product_info_struct"][1].replace("{highlights}", hightlight_str)
296
+
297
+ content_str = (
298
+ data_gen_setting["dataset_gen_prompt"]
299
+ .replace("{role_type}", role_type)
300
+ .replace("{character}", character)
301
+ .replace("{product_info}", product_info_str)
302
+ .replace("{customer_question}", customer_question_str)
303
+ .replace("{each_conversation_qa}", str(data_gen_setting["each_conversation_qa"]))
304
+ .replace(
305
+ "{dataset_json_format}",
306
+ data_gen_setting["dataset_json_format"].replace("{product_info}", product_info_str),
307
+ )
308
+ )
309
+
310
+ print(f"\n Resquest [ {model_name} ] {idx + 1}/{gen_num} ==> {content_str} \n")
311
+ if model_name == "qwen":
312
+ format_json = process_request(call_qwen_message, content_str, qwen_model_type[idx], model_name)
313
+ elif model_name == "ernie":
314
+ format_json = process_request(call_ernie_message, content_str, api_key, model_name)
315
+ else:
316
+ raise ValueError(f"model_name {model_name} not support")
317
+
318
+ if "conversation" in format_json and len(format_json["conversation"]) > 0:
319
+
320
+ # 第一个结果因为节省 token,需要将 system 和 input 放回去
321
+ conversation_setting = deepcopy(dataset_yaml["conversation_setting"])
322
+ system_str = (
323
+ conversation_setting["system"].replace("{role_type}", role_type).replace("{character}", character)
324
+ )
325
+ input_str = conversation_setting["first_input"].replace("{product_info}", product_info_str)
326
+
327
+ # 将第一个对话加入必要信息
328
+ format_json["conversation"][0] = {
329
+ "system": system_str,
330
+ "input": input_str,
331
+ "output": format_json["conversation"][0]["output"],
332
+ }
333
+ else:
334
+ format_json = {"Error": "Error"}
335
+
336
+ print(f"\n Response [ {model_name} ] {idx + 1}/{gen_num} <== {format_json} \n")
337
+ gen_json[product].append(format_json)
338
+
339
+ pbar.update(1)
340
+
341
+ # 备份旧的
342
+ if save_json_path.exists():
343
+ save_json_path.rename(bk_json_path)
344
+
345
+ # 保存 json
346
+ with open(save_json_path, "w", encoding="utf-8") as f:
347
+ json.dump(gen_json, f, indent=4, ensure_ascii=False)
348
+
349
+ # 如果保存成功,删掉旧的
350
+ if bk_json_path.exists():
351
+ bk_json_path.unlink()
352
+
353
+
354
+ if __name__ == "__main__":
355
+
356
+ # 例子:全部人物使用 Qwen api 生成数据
357
+ # cd /path/to/Streamer-Sales/dataset/gen_dataset
358
+ # python gen_dataset.py qwen
359
+
360
+ # 命令行输入参数
361
+ parser = argparse.ArgumentParser(description="Gen Dataset")
362
+ parser.add_argument("model_name", type=str, choices=["qwen", "ernie"], help="Model name for data generation")
363
+ parser.add_argument("--data_yaml", type=str, default="../../configs/conversation_cfg.yaml", help="data setting file path")
364
+ parser.add_argument("--api_yaml", type=str, default="../../configs/api_cfg.yaml", help="api setting file path")
365
+ parser.add_argument("--output_dir", type=str, default="./train_dataset/response", help="generation json output dir")
366
+ parser.add_argument("--specific_name", type=str, default="", help="Character name for data generation")
367
+ args = parser.parse_args()
368
+
369
+ # 生成产品特性(可选)
370
+ # gen_product_highlights(args.data_yaml, args.api_yaml)
371
+
372
+ # 生成对话数据集
373
+ gen_dataset(
374
+ args.data_yaml, args.api_yaml, Path(args.output_dir), model_name=args.model_name, specific_name=args.specific_name
375
+ )
dataset/gen_dataset/merge_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import random
5
+
6
+
7
+ def gen_self_self_aware_dataset():
8
+
9
+ # 自我认知
10
+ self_aware_question = [
11
+ "你好",
12
+ "你是谁",
13
+ "你叫什么名字",
14
+ "请做一下自我介绍",
15
+ "介绍下你自己",
16
+ ]
17
+
18
+ self_aware_answer_lelemiao = [
19
+ "您好,我是智能医导,随时准备解答您的医疗疑问。",
20
+ "您好,我是智能医导,助您轻松就医。",
21
+ "您好,我是智能医导,提供专业医疗指导。",
22
+ "您好,我是智能医导,解答您的健康疑惑。",
23
+ "您好,我是智能医导,帮助您了解医疗服务。",
24
+ "您好,我是智能医导,您的医疗问题助手。",
25
+ "您好,我是智能医导,助您快速获取医疗信息。",
26
+ "您好,我是智能医导,为您提供医疗解答。",
27
+ "您好,我是智能医导,帮助您理解医疗流程。",
28
+ "您好,我是智能医导,解答您的医疗咨询。",
29
+ "您好,我是智能医导,助您掌握健康知识。",
30
+ "您好,我是智能医导,提供医疗信息查询。",
31
+ "您好,我是智能医导,助您解决就医难题。",
32
+ "您好,我是智能医导,您的私人医疗顾问。",
33
+ "您好,我是智能医导,随时为您提供帮助。",
34
+
35
+ ]
36
+
37
+ self_aware_json = []
38
+ for anser in self_aware_answer_lelemiao:
39
+
40
+ self_aware_json.append({"conversation": [{"input": random.choice(self_aware_question), "output": anser}]})
41
+
42
+ return self_aware_json
43
+
44
+
45
+ def merge_dataset(save_json_root: Path, final_save_json_path: Path):
46
+ # 将两个 json 进行合并
47
+ json_list = []
48
+ for json_path in save_json_root.glob("*.json"):
49
+ with open(json_path, "r", encoding="utf-8") as f:
50
+ json_list.append(json.load(f))
51
+
52
+ filter_json_list = []
53
+
54
+ dirty_conversion = []
55
+ for model_name in json_list:
56
+ for product_name, gen_data_list in model_name.items():
57
+
58
+ for gen_data in gen_data_list:
59
+ if isinstance(gen_data, dict) and "Error" in gen_data.keys():
60
+ print(f"Got error data in {product_name}")
61
+ dirty_conversion.append(gen_data)
62
+ continue
63
+
64
+ # 洗掉一些没有 input 的数据
65
+ sub_filter_list = {"conversation": []}
66
+ for sub_list in gen_data["conversation"]:
67
+
68
+ # 剔除不合适的 key
69
+ accept_keys = ["input", "output", "system"]
70
+ sub_list = {key: value for key, value in sub_list.items() if key in accept_keys}
71
+
72
+ if len(sub_list.keys()) < 2:
73
+ # 如果只有单个 input output 出现,跳过
74
+ dirty_conversion.append(sub_list)
75
+ continue
76
+
77
+ if "input" not in sub_list or "output" not in sub_list:
78
+ # 如果没有 input 或者 output,跳过
79
+ dirty_conversion.append(sub_list)
80
+ continue
81
+
82
+ sub_filter_list["conversation"].append(sub_list)
83
+
84
+ if len(sub_filter_list["conversation"]) > 0:
85
+ filter_json_list.append(sub_filter_list)
86
+
87
+ # 修复数据集
88
+ for idx in range(len(filter_json_list)):
89
+ filter_json_list[idx]["conversation"][0][
90
+ "system"
91
+ ] = "现在你是一位医院大厅里的智能医导小助手,你的名字叫智能医导小助手,你的说话方式是严肃端庄。你能够根据病人的需求提供专业的医疗咨询,并且结合医疗知识解答用户提出的各种健康相关疑问。"
92
+
93
+ # 生成自我认知的数据
94
+ filter_json_list += gen_self_self_aware_dataset()
95
+
96
+ # 保存
97
+ with open(
98
+ final_save_json_path.parent.joinpath(f"{len(filter_json_list)}_{final_save_json_path.name}"), "w", encoding="utf-8"
99
+ ) as f:
100
+ json.dump(filter_json_list, f, ensure_ascii=False, indent=4)
101
+
102
+ if len(dirty_conversion) > 0:
103
+ # 保存错误的过滤数据,方便用户自行解决
104
+ with open(final_save_json_path.parent.joinpath(f"error_{final_save_json_path.name}"), "w", encoding="utf-8") as f:
105
+ json.dump(dirty_conversion, f, ensure_ascii=False, indent=4)
106
+
107
+ sum_input_output_count = 0
108
+ for conversion in filter_json_list:
109
+ sum_input_output_count += len(conversion["conversation"])
110
+ print(
111
+ f"总生成有效 conversion 数据 {len(filter_json_list)} 组,内含 {sum_input_output_count} 条对话,剔除脏对话 {len(dirty_conversion)} 条,保存到 error_{final_save_json_path.name} 中。"
112
+ )
113
+
114
+
115
+ if __name__ == "__main__":
116
+ # 命令行输入参数
117
+ # TODO 目前仅仅支持 乐乐喵
118
+ parser = argparse.ArgumentParser(description="Merge Dataset")
119
+ parser.add_argument("data_root", type=str, help="path to response dir")
120
+ parser.add_argument("output_path", type=str, help="path to response dir")
121
+ args = parser.parse_args()
122
+
123
+ save_json_root = Path(args.data_root)
124
+ final_save_json_path = Path(args.output_path)
125
+ merge_dataset(save_json_root, final_save_json_path)
dataset/gen_dataset/train_dataset/90_train.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
dataset/gen_dataset/train_dataset/response/qwen_智能医导小助手_train.json ADDED
The diff for this file is too large to render. See raw diff