root commited on
Commit
73386d5
·
1 Parent(s): 3692eb0

add_function_pmids_retrieval_and_ref_ench_sent

Browse files
README.md CHANGED
@@ -14,6 +14,7 @@ pinned: false
14
  - 整一个帮我写综述的Agent,希望他能完成文献内容的收集,文本分类和总结,科学事实对比,撰写综述等功能
15
  - 计划用到RAG, function calling等技术
16
  - 还在不断摸索中,欢迎大佬指导!
 
17
 
18
  ## 流程图
19
  基本上就是在上海AIlab的茴香豆上面改的 这里主要讲解使用流程 架构和茴香豆一样 [茴香豆架构](https://github.com/InternLM/HuixiangDou/blob/main/docs/architecture_zh.md)
@@ -81,7 +82,7 @@ git clone https://github.com/jabberwockyang/MedicalReviewAgent.git
81
  cd MedicalReviewAgent
82
  pip install -r requirements.txt
83
  ```
84
- huggingface-cli下载模型
85
 
86
  ```bash
87
  cd /root && mkdir models
@@ -96,7 +97,8 @@ huggingface-cli download maidalun1020/bce-reranker-base_v1 --local-dir /root/mod
96
  ```bash
97
  conda activate ReviewAgent
98
  cd MedicalReviewAgent
99
- python3 app.py
 
100
  ```
101
  gradio在本地7860端口运行
102
 
@@ -170,7 +172,7 @@ python3 app.py
170
 
171
  ## 感谢
172
  1. [茴香豆](https://github.com/InternLM/HuixiangDou)
173
- 2. [E-utilities](https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=pmc&id=PMCID)
174
  3. [Ragflow](https://github.com/infiniflow/ragflow/blob/main/README_zh.md)
175
  4. [Advanced RAG pipeline](https://medium.aiplanet.com/evaluating-naive-rag-and-advanced-rag-pipeline-using-langchain-v-0-1-0-and-ragas-17d24e74e5cf)
176
 
 
14
  - 整一个帮我写综述的Agent,希望他能完成文献内容的收集,文本分类和总结,科学事实对比,撰写综述等功能
15
  - 计划用到RAG, function calling等技术
16
  - 还在不断摸索中,欢迎大佬指导!
17
+ - [huggingface 体验链接](https://huggingface.co/spaces/Yijun-Yang/ReadReview/), zeroGPUs 比较吝啬 我把本地推理给阉割了 不要用本地模型哈 用API 用本地模型会报错
18
 
19
  ## 流程图
20
  基本上就是在上海AIlab的茴香豆上面改的 这里主要讲解使用流程 架构和茴香豆一样 [茴香豆架构](https://github.com/InternLM/HuixiangDou/blob/main/docs/architecture_zh.md)
 
82
  cd MedicalReviewAgent
83
  pip install -r requirements.txt
84
  ```
85
+ huggingface-cli下载模型(optional, 第一次调用的时候hf会下载,但是可能有墙)
86
 
87
  ```bash
88
  cd /root && mkdir models
 
97
  ```bash
98
  conda activate ReviewAgent
99
  cd MedicalReviewAgent
100
+ python3 app.py --model_downloaded True # 如果已经在/root/models下载了模型 这个参数会换一个配置文件,里面的modelpath是本地路径不是hf的仓库路径 自己显卡跑跑用这个
101
+ python3 app.py # 如果不打算用本地/root/models储存的模型 这是hf的spaces的构建配置
102
  ```
103
  gradio在本地7860端口运行
104
 
 
172
 
173
  ## 感谢
174
  1. [茴香豆](https://github.com/InternLM/HuixiangDou)
175
+ 2. [E-utilities](https://www.ncbi.nlm.nih.gov/books/NBK25499/)
176
  3. [Ragflow](https://github.com/infiniflow/ragflow/blob/main/README_zh.md)
177
  4. [Advanced RAG pipeline](https://medium.aiplanet.com/evaluating-naive-rag-and-advanced-rag-pipeline-using-langchain-v-0-1-0-and-ragas-17d24e74e5cf)
178
 
app.py CHANGED
@@ -42,6 +42,10 @@ def parse_args():
42
  action='store_true',
43
  default=True,
44
  help='Auto deploy required Hybrid LLM Service.')
 
 
 
 
45
  args = parser.parse_args()
46
  return args
47
 
@@ -55,6 +59,8 @@ def update_remote_buttons(remote):
55
  interactive=True,visible=True),
56
  gr.Textbox(label="您的API",lines = 1,
57
  interactive=True,visible=True),
 
 
58
  gr.Dropdown([],label="选择模型",
59
  interactive=True,visible=True)
60
  ]
@@ -80,7 +86,7 @@ def udate_model_dropdown(remote_company):
80
  }
81
  return gr.Dropdown(choices= model_choices[remote_company])
82
 
83
- def update_remote_config(remote_ornot,remote_company = None,api = None,model = None):
84
  with open(CONFIG_PATH, encoding='utf8') as f:
85
  config = pytoml.load(f)
86
 
@@ -91,6 +97,7 @@ def update_remote_config(remote_ornot,remote_company = None,api = None,model = N
91
  config['llm']['enable_remote'] = 1
92
  config['llm']['server']['remote_type'] = remote_company
93
  config['llm']['server']['remote_api_key'] = api
 
94
  config['llm']['server']['remote_llm_model'] = model
95
  else:
96
  config['llm']['enable_local'] = 1
@@ -188,11 +195,17 @@ def upload_file(files):
188
 
189
  return files
190
 
191
- def generate_articles_repo(keywords:str,retmax:int):
192
- keys= [k.strip() for k in keywords.split('\n')]
 
 
 
 
 
193
  repodir, _, _ = get_ready('repo_work')
194
 
195
  articelfinder = ArticleRetrieval(keywords = keys,
 
196
  repo_dir = repodir,
197
  retmax = retmax)
198
  articelfinder.initiallize()
@@ -212,7 +225,7 @@ def delete_articles_repo():
212
 
213
  def update_repo():
214
  keys,len,retmax,pdflen = update_repo_info()
215
- if keys:
216
  newinfo = f"搜索得到文献:\n 关键词:{keys}\n 文献数量:{len}\n 获取上限:{retmax}\n\n上传文献:\n 数量:{pdflen}"
217
  else:
218
  if pdflen:
@@ -400,8 +413,8 @@ def summarize_text(query,chunksize:int,remote_ornot:bool):
400
 
401
  logger.info(f'{code}, {query}, {reply}, {references}')
402
  urls = getpmcurls(references)
403
- mds = '\n'.join([f'[{ref}]({url})' for ref,url in zip(references,urls)])
404
- return reply, gr.Markdown(label="参考文献",value = mds)
405
 
406
  def main_interface():
407
  with gr.Blocks() as demo:
@@ -436,13 +449,14 @@ def main_interface():
436
  remote_company = gr.Dropdown(["kimi", "deepseek", "zhipuai",'gpt'],
437
  label="选择大模型提供商",interactive=False,visible=False)
438
  api = gr.Textbox(label="您的API",lines = 1,interactive=False,visible=False)
 
439
  model = gr.Dropdown([],label="选择模型",interactive=False,visible=False)
440
 
441
  confirm_button = gr.Button("保存配置")
442
 
443
- remote_ornot.change(update_remote_buttons, inputs=[remote_ornot],outputs=[apimd,remote_company,api,model])
444
  remote_company.change(udate_model_dropdown, inputs=[remote_company],outputs=[model])
445
- confirm_button.click(update_remote_config, inputs=[remote_ornot,remote_company,api,model],outputs=[confirm_button])
446
 
447
 
448
  with gr.Tab("文献查找+数据库生成"):
@@ -478,6 +492,7 @@ def main_interface():
478
  with gr.Row(equal_height=True):
479
  with gr.Column(scale=1):
480
  input_keys = gr.Textbox(label="感兴趣的关键词",
 
481
  lines = 5)
482
  retmax = gr.Slider(
483
  minimum=0,
@@ -593,7 +608,7 @@ def main_interface():
593
  query = gr.Textbox(label="想写什么")
594
 
595
  write_button = gr.Button("写综述")
596
- output_text = gr.Textbox(label="看看",lines=10)
597
  output_references = gr.Markdown(label="参考文献")
598
 
599
  update_options.click(update_chunksize_dropdown,
@@ -620,8 +635,12 @@ def main_interface():
620
  # start service
621
  if __name__ == '__main__':
622
  args = parse_args()
623
- # copy config from config-bak
624
- shutil.copy('config-bak.ini', args.config_path) # yyj
 
 
 
 
625
  CONFIG_PATH = args.config_path
626
 
627
  if args.standalone is True:
 
42
  action='store_true',
43
  default=True,
44
  help='Auto deploy required Hybrid LLM Service.')
45
+ parser.add_argument("--model_downloaded",
46
+ type=bool,
47
+ default=False,
48
+ help="If the model has been downloaded in the root/models folder. Default is False.")
49
  args = parser.parse_args()
50
  return args
51
 
 
59
  interactive=True,visible=True),
60
  gr.Textbox(label="您的API",lines = 1,
61
  interactive=True,visible=True),
62
+ gr.Textbox(label="base url",lines = 1,
63
+ interactive=True,visible=True),
64
  gr.Dropdown([],label="选择模型",
65
  interactive=True,visible=True)
66
  ]
 
86
  }
87
  return gr.Dropdown(choices= model_choices[remote_company])
88
 
89
+ def update_remote_config(remote_ornot,remote_company = None,api = None,baseurl = None, model = None):
90
  with open(CONFIG_PATH, encoding='utf8') as f:
91
  config = pytoml.load(f)
92
 
 
97
  config['llm']['enable_remote'] = 1
98
  config['llm']['server']['remote_type'] = remote_company
99
  config['llm']['server']['remote_api_key'] = api
100
+ config['llm']['server']['remote_base_url'] = baseurl
101
  config['llm']['server']['remote_llm_model'] = model
102
  else:
103
  config['llm']['enable_local'] = 1
 
195
 
196
  return files
197
 
198
+ def generate_articles_repo(strings:str,retmax:int):
199
+
200
+ string = [k.strip() for k in strings.split('\n')]
201
+
202
+ pmids = [k for k in string if k.isdigit()]
203
+ keys = [k for k in string if not k.isdigit()]
204
+
205
  repodir, _, _ = get_ready('repo_work')
206
 
207
  articelfinder = ArticleRetrieval(keywords = keys,
208
+ pmids = pmids,
209
  repo_dir = repodir,
210
  retmax = retmax)
211
  articelfinder.initiallize()
 
225
 
226
  def update_repo():
227
  keys,len,retmax,pdflen = update_repo_info()
228
+ if keys or len:
229
  newinfo = f"搜索得到文献:\n 关键词:{keys}\n 文献数量:{len}\n 获取上限:{retmax}\n\n上传文献:\n 数量:{pdflen}"
230
  else:
231
  if pdflen:
 
413
 
414
  logger.info(f'{code}, {query}, {reply}, {references}')
415
  urls = getpmcurls(references)
416
+ mds = '\n\n'.join([f'[{ref}]({url})' for ref,url in zip(references,urls)])
417
+ return gr.Markdown(label="看看",value = reply,line_breaks=True) , gr.Markdown(label="参考文献",value = mds,line_breaks=True)
418
 
419
  def main_interface():
420
  with gr.Blocks() as demo:
 
449
  remote_company = gr.Dropdown(["kimi", "deepseek", "zhipuai",'gpt'],
450
  label="选择大模型提供商",interactive=False,visible=False)
451
  api = gr.Textbox(label="您的API",lines = 1,interactive=False,visible=False)
452
+ baseurl = gr.Textbox(label="base url",lines = 1,interactive=False,visible=False)
453
  model = gr.Dropdown([],label="选择模型",interactive=False,visible=False)
454
 
455
  confirm_button = gr.Button("保存配置")
456
 
457
+ remote_ornot.change(update_remote_buttons, inputs=[remote_ornot],outputs=[apimd,remote_company,api,baseurl,model])
458
  remote_company.change(udate_model_dropdown, inputs=[remote_company],outputs=[model])
459
+ confirm_button.click(update_remote_config, inputs=[remote_ornot,remote_company,api,baseurl,model],outputs=[confirm_button])
460
 
461
 
462
  with gr.Tab("文献查找+数据库生成"):
 
492
  with gr.Row(equal_height=True):
493
  with gr.Column(scale=1):
494
  input_keys = gr.Textbox(label="感兴趣的关键词",
495
+ value = "输入关键词或者PMID, 换行分隔",
496
  lines = 5)
497
  retmax = gr.Slider(
498
  minimum=0,
 
608
  query = gr.Textbox(label="想写什么")
609
 
610
  write_button = gr.Button("写综述")
611
+ output_text = gr.Markdown(label="看看")
612
  output_references = gr.Markdown(label="参考文献")
613
 
614
  update_options.click(update_chunksize_dropdown,
 
635
  # start service
636
  if __name__ == '__main__':
637
  args = parse_args()
638
+ # copy config from config-bak
639
+ if args.model_downloaded:
640
+ shutil.copy('config-mod_down-bak.ini', args.config_path) # yyj
641
+ else:
642
+ shutil.copy('config-bak.ini', args.config_path) # yyj
643
+
644
  CONFIG_PATH = args.config_path
645
 
646
  if args.standalone is True:
applocal.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ import os
5
+ import glob
6
+ import random
7
+ import shutil
8
+ from enum import Enum
9
+ from threading import Thread
10
+ from multiprocessing import Process, Value
11
+
12
+ import gradio as gr
13
+ import pytoml
14
+ from loguru import logger
15
+ # import spaces
16
+
17
+ from huixiangdou.service import Worker, llm_serve, ArticleRetrieval, CacheRetriever, FeatureStore, FileOperation
18
+
19
+ class PARAM_CODE(Enum):
20
+ """Parameter code."""
21
+ SUCCESS = 0
22
+ FAILED = 1
23
+ ERROR = 2
24
+
25
+ def parse_args():
26
+ """Parse args."""
27
+ parser = argparse.ArgumentParser(description='Worker.')
28
+ parser.add_argument('--work_dir',
29
+ type=str,
30
+ default='workdir',
31
+ help='Working directory.')
32
+ parser.add_argument('--repo_dir',
33
+ type=str,
34
+ default='repodir',
35
+ help='Repository directory.')
36
+ parser.add_argument(
37
+ '--config_path',
38
+ default='config.ini',
39
+ type=str,
40
+ help='Worker configuration path. Default value is config.ini')
41
+ parser.add_argument('--standalone',
42
+ action='store_true',
43
+ default=True,
44
+ help='Auto deploy required Hybrid LLM Service.')
45
+ parser.add_argument("--model_downloaded",
46
+ type=bool,
47
+ default=False,
48
+ help="If the model has been downloaded in the root/models folder. Default is False.")
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+ def update_remote_buttons(remote):
53
+ if remote:
54
+ return [
55
+ gr.Markdown("[如何配置API]('https://github.com/jabberwockyang/MedicalReviewAgent/blob/main/README.md')",
56
+ visible=True),
57
+ gr.Dropdown(["kimi", "deepseek", "zhipuai",'gpt'],
58
+ label="选择大模型提供商",
59
+ interactive=True,visible=True),
60
+ gr.Textbox(label="您的API",lines = 1,
61
+ interactive=True,visible=True),
62
+ gr.Textbox(label="base url",lines = 1,
63
+ interactive=True,visible=True),
64
+ gr.Dropdown([],label="选择模型",
65
+ interactive=True,visible=True)
66
+ ]
67
+ else:
68
+ return [
69
+ gr.Markdown("[如何配置API]('https://github.com/jabberwockyang/MedicalReviewAgent/blob/main/README.md')",
70
+ visible=False),
71
+ gr.Dropdown(["kimi", "deepseek", "zhipuai",'gpt'],
72
+ label="选择大模型提供商",
73
+ interactive=False,visible=False),
74
+ gr.Textbox(label="您的API",lines = 1,
75
+ interactive=False,visible=False),
76
+ gr.Dropdown([],label="选择模型",
77
+ interactive=False,visible=False)
78
+ ]
79
+
80
+ def udate_model_dropdown(remote_company):
81
+ model_choices = {
82
+ 'kimi': ['moonshot-v1-128k'],
83
+ 'deepseek': ['deepseek-chat'],
84
+ 'zhipuai': ['glm-4'],
85
+ 'gpt': ['gpt-4-32k-0613','gpt-3.5-turbo']
86
+ }
87
+ return gr.Dropdown(choices= model_choices[remote_company])
88
+
89
+ def update_remote_config(remote_ornot,remote_company = None,api = None,baseurl = None, model = None):
90
+ with open(CONFIG_PATH, encoding='utf8') as f:
91
+ config = pytoml.load(f)
92
+
93
+ if remote_ornot:
94
+ if remote_company == None or api == None or model == None:
95
+ raise ValueError('remote_company, api, model not provided')
96
+ config['llm']['enable_local'] = 0
97
+ config['llm']['enable_remote'] = 1
98
+ config['llm']['server']['remote_type'] = remote_company
99
+ config['llm']['server']['remote_api_key'] = api
100
+ config['llm']['server']['remote_base_url'] = baseurl
101
+ config['llm']['server']['remote_llm_model'] = model
102
+ else:
103
+ config['llm']['enable_local'] = 1
104
+ config['llm']['enable_remote'] = 0
105
+ with open(CONFIG_PATH, 'w') as f:
106
+ pytoml.dump(config, f)
107
+ return gr.Button("配置已保存")
108
+
109
+ # @spaces.GPU(duration=360)
110
+ def get_ready(query:str,chunksize=None,k=None):
111
+
112
+ with open(CONFIG_PATH, encoding='utf8') as f:
113
+ config = pytoml.load(f)
114
+ workdir = config['feature_store']['work_dir']
115
+ repodir = config['feature_store']['repo_dir']
116
+
117
+ if query == 'repo_work': # no need to return assistant
118
+ return repodir, workdir, config
119
+ theme = ''
120
+ try:
121
+ with open(os.path.join(config['feature_store']['repo_dir'],'config.json'), 'r') as f:
122
+ repo_config = json.load(f)
123
+ theme = ' '.join(repo_config['keywords'])
124
+ except:
125
+ pass
126
+
127
+ if query == 'annotation':
128
+ if not chunksize or not k:
129
+ raise ValueError('chunksize or k not provided')
130
+ chunkdir = os.path.join(workdir, f'chunksize_{chunksize}')
131
+ clusterdir = os.path.join(chunkdir, 'cluster_features', f'cluster_features_{k}')
132
+ assistant = Worker(work_dir=chunkdir, config_path=CONFIG_PATH,language='en')
133
+ samples_json = os.path.join(clusterdir,'samples.json')
134
+ with open(samples_json, 'r') as f:
135
+ samples = json.load(f)
136
+ f.close()
137
+ return clusterdir, samples, assistant, theme
138
+
139
+ elif query == 'inspiration':
140
+ if not chunksize or not k:
141
+ raise ValueError('chunksize or k not provided')
142
+
143
+ chunkdir = os.path.join(workdir, f'chunksize_{chunksize}')
144
+ clusterdir = os.path.join(chunkdir, 'cluster_features', f'cluster_features_{k}')
145
+ assistant = Worker(work_dir=chunkdir, config_path=CONFIG_PATH,language='en')
146
+ annofile = os.path.join(clusterdir,'annotation.jsonl')
147
+ with open(annofile, 'r') as f:
148
+ annoresult = f.readlines()
149
+
150
+ f.close()
151
+ annoresult = [json.loads(obj) for obj in annoresult]
152
+ return clusterdir, annoresult, assistant, theme
153
+ elif query == 'summarize': # no need for params k
154
+ if not chunksize:
155
+ raise ValueError('chunksize not provided')
156
+ chunkdir = os.path.join(workdir, f'chunksize_{chunksize}')
157
+ assistant = Worker(work_dir=chunkdir, config_path=CONFIG_PATH,language='en')
158
+ return assistant,theme
159
+
160
+ else:
161
+ raise ValueError('query not recognized')
162
+
163
+ def update_repo_info():
164
+ with open(CONFIG_PATH, encoding='utf8') as f:
165
+ config = pytoml.load(f)
166
+ repodir = config['feature_store']['repo_dir']
167
+ if os.path.exists(repodir):
168
+ pdffiles = glob.glob(os.path.join(repodir, '*.pdf'))
169
+ number_of_pdf = len(pdffiles)
170
+ if os.path.exists(os.path.join(repodir,'config.json')):
171
+
172
+ with open(os.path.join(repodir,'config.json'), 'r') as f:
173
+ repo_config = json.load(f)
174
+
175
+ keywords = repo_config['keywords']
176
+ length = repo_config['len']
177
+ retmax = repo_config['retmax']
178
+
179
+ return keywords,length,retmax,number_of_pdf
180
+ else:
181
+ return None,None,None,number_of_pdf
182
+ else:
183
+ return None,None,None,None
184
+
185
+ def upload_file(files):
186
+ repodir, workdir, _ = get_ready('repo_work')
187
+ if not os.path.exists(repodir):
188
+ os.makedirs(repodir)
189
+
190
+ for file in files:
191
+ destination_path = os.path.join(repodir, os.path.basename(file.name))
192
+
193
+ shutil.copy(file.name, destination_path)
194
+
195
+
196
+ return files
197
+
198
+ def generate_articles_repo(strings:str,retmax:int):
199
+
200
+ string = [k.strip() for k in strings.split('\n')]
201
+
202
+ pmids = [k for k in string if k.isdigit()]
203
+ keys = [k for k in string if not k.isdigit()]
204
+
205
+ repodir, _, _ = get_ready('repo_work')
206
+
207
+ articelfinder = ArticleRetrieval(keywords = keys,
208
+ pmids = pmids,
209
+ repo_dir = repodir,
210
+ retmax = retmax)
211
+ articelfinder.initiallize()
212
+ return update_repo()
213
+
214
+ def delete_articles_repo():
215
+ # 在这里运行生成数据库的函数
216
+ repodir, workdir, _ = get_ready('repo_work')
217
+ if os.path.exists(repodir):
218
+ shutil.rmtree(repodir)
219
+ if os.path.exists(workdir):
220
+ shutil.rmtree(workdir)
221
+
222
+ return gr.Textbox(label="文献库概况",lines =3,
223
+ value = '文献库和相关数据库已删除',
224
+ visible = True)
225
+
226
+ def update_repo():
227
+ keys,len,retmax,pdflen = update_repo_info()
228
+ if keys or len:
229
+ newinfo = f"搜索得到文献:\n 关键词:{keys}\n 文献数量:{len}\n 获取上限:{retmax}\n\n上传文献:\n 数量:{pdflen}"
230
+ else:
231
+ if pdflen:
232
+ newinfo = f'搜索得到文献:无\n上传文献:\n 数量:{pdflen}'
233
+ else:
234
+ newinfo = '目前还没有文献库'
235
+
236
+ return gr.Textbox(label="文献库概况",lines =1,
237
+ value = newinfo,
238
+ visible = True)
239
+
240
+ def update_database_info():
241
+ with open(CONFIG_PATH, encoding='utf8') as f:
242
+ config = pytoml.load(f)
243
+ workdir = config['feature_store']['work_dir']
244
+ chunkdirs = glob.glob(os.path.join(workdir, 'chunksize_*'))
245
+ chunkdirs.sort()
246
+ list_of_chunksize = [int(chunkdir.split('_')[-1]) for chunkdir in chunkdirs]
247
+ # print(list_of_chunksize)
248
+ jsonobj = {}
249
+ for chunkdir in chunkdirs:
250
+ k_dir = glob.glob(os.path.join(chunkdir, 'cluster_features','cluster_features_*'))
251
+ k_dir.sort()
252
+ list_of_k = [int(k.split('_')[-1]) for k in k_dir]
253
+ jsonobj[int(chunkdir.split('_')[-1])] = list_of_k
254
+
255
+
256
+ new_options = [f"chunksize:{chunksize}, k:{k}" for chunksize in list_of_chunksize for k in jsonobj[chunksize]]
257
+
258
+ return new_options, jsonobj
259
+
260
+ # @spaces.GPU(duration=360)
261
+ def generate_database(chunksize:int,nclusters:str|list[str]):
262
+ # 在这里运行生成数据库的函数
263
+ repodir, workdir, _ = get_ready('repo_work')
264
+ if not os.path.exists(repodir):
265
+ return gr.Textbox(label="数据库已生成",value = '请先生成文献库',visible = True)
266
+ nclusters = [int(i) for i in nclusters]
267
+ # 文献库和数据库的覆盖删除逻辑待定
268
+ # 理论上 文献库只能生成一次 所以每次生成文献库都要删除之前的文献库和数据库
269
+ # 数据库可以根据文献库多次生成 暂不做删除 目前没有节省算力的逻辑 重复计算后覆盖 以后优化
270
+ # 不同的chunksize和nclusters会放在不同的文件夹下 不会互相覆盖
271
+ # if os.path.exists(workdir):
272
+ # shutil.rmtree(workdir)
273
+
274
+ cache = CacheRetriever(config_path=CONFIG_PATH)
275
+ fs_init = FeatureStore(embeddings=cache.embeddings,
276
+ reranker=cache.reranker,
277
+ chunk_size=chunksize,
278
+ n_clusters=nclusters,
279
+ config_path=CONFIG_PATH)
280
+
281
+ # walk all files in repo dir
282
+ file_opr = FileOperation()
283
+ files = file_opr.scan_dir(repo_dir=repodir)
284
+ fs_init.initialize(files=files, work_dir=workdir,file_opr=file_opr)
285
+ file_opr.summarize(files)
286
+ del fs_init
287
+ cache.pop('default')
288
+ texts, _ = update_database_info()
289
+ return gr.Textbox(label="数据库概况",value = '\n'.join(texts) ,visible = True)
290
+
291
+ def delete_database():
292
+ _, workdir, _ = get_ready('repo_work')
293
+ if os.path.exists(workdir):
294
+ shutil.rmtree(workdir)
295
+ return gr.Textbox(label="数据库概况",lines =3,value = '数据库已删除',visible = True)
296
+
297
+ def update_database_textbox():
298
+ texts, _ = update_database_info()
299
+ if texts == []:
300
+ return gr.Textbox(label="数据库概况",value = '目前还没有数据库',visible = True)
301
+ else:
302
+ return gr.Textbox(label="数据库概况",value = '\n'.join(texts),visible = True)
303
+
304
+ def update_chunksize_dropdown():
305
+ _, jsonobj = update_database_info()
306
+ return gr.Dropdown(choices= jsonobj.keys())
307
+
308
+ def update_ncluster_dropdown(chunksize:int):
309
+ _, jsonobj = update_database_info()
310
+ nclusters = jsonobj[chunksize]
311
+ return gr.Dropdown(choices= nclusters)
312
+
313
+ # @spaces.GPU(duration=360)
314
+ def annotation(n,chunksize:int,nclusters:int,remote_ornot:bool):
315
+ '''
316
+ use llm to annotate cluster
317
+ n: percentage of clusters to annotate
318
+ '''
319
+ query = 'annotation'
320
+ if remote_ornot:
321
+ backend = 'remote'
322
+ else:
323
+ backend = 'local'
324
+
325
+ clusterdir, samples, assistant, theme = get_ready('annotation',chunksize,nclusters)
326
+ new_obj_list = []
327
+ n = round(n * len(samples.keys()))
328
+ for cluster_no in random.sample(samples.keys(), n):
329
+ chunk = '\n'.join(samples[cluster_no]['samples'][:10])
330
+
331
+ code, reply, cluster_no = assistant.annotate_cluster(
332
+ theme = theme,
333
+ cluster_no=cluster_no,
334
+ chunk=chunk,
335
+ history=[],
336
+ groupname='',
337
+ backend=backend)
338
+ references = f"cluster_no: {cluster_no}"
339
+ new_obj = {
340
+ 'cluster_no': cluster_no,
341
+ 'chunk': chunk,
342
+ 'annotation': reply
343
+ }
344
+ new_obj_list.append(new_obj)
345
+ logger.info(f'{code}, {query}, {reply}, {references}')
346
+
347
+ with open(os.path.join(clusterdir, 'annotation.jsonl'), 'a') as f:
348
+ json.dump(new_obj, f, ensure_ascii=False)
349
+ f.write('\n')
350
+
351
+ return '\n\n'.join([obj['annotation'] for obj in new_obj_list])
352
+
353
+ # @spaces.GPU(duration=360)
354
+ def inspiration(annotation:str,chunksize:int,nclusters:int,remote_ornot:bool):
355
+ query = 'inspiration'
356
+ if remote_ornot:
357
+ backend = 'remote'
358
+ else:
359
+ backend = 'local'
360
+
361
+ clusterdir, annoresult, assistant, theme = get_ready('inspiration',chunksize,nclusters)
362
+ new_obj_list = []
363
+
364
+ if annotation is not None: # if the user wants to get inspiration from specific clusters only
365
+ annoresult = [obj for obj in annoresult if obj['annotation'] in [txt.strip() for txt in annotation.split('\n')]]
366
+
367
+ for index in random.sample(range(len(annoresult)), min(5, len(annoresult))):
368
+ cluster_no = annoresult[index]['cluster_no']
369
+ chunks = annoresult[index]['annotation']
370
+
371
+ code, reply = assistant.getinspiration(
372
+ theme = theme,
373
+ annotations = chunks,
374
+ history=[],
375
+ groupname='',backend=backend)
376
+ new_obj = {
377
+ 'inspiration': reply,
378
+ 'cluster_no': cluster_no
379
+ }
380
+ new_obj_list.append(new_obj)
381
+ logger.info(f'{code}, {query}, {cluster_no},{reply}')
382
+
383
+ with open(os.path.join(clusterdir, 'inspiration.jsonl'), 'a') as f:
384
+ json.dump(new_obj, f, ensure_ascii=False)
385
+ with open(os.path.join(clusterdir, 'inspiration.txt'), 'a') as f:
386
+ f.write(f'{reply}\n')
387
+
388
+ return '\n\n'.join(list(set([obj['inspiration'] for obj in new_obj_list])))
389
+
390
+
391
+ def getpmcurls(references):
392
+ urls = []
393
+ for ref in references:
394
+ if ref.startswith('PMC'):
395
+
396
+ refid = ref.replace('.txt','')
397
+ urls.append(f'https://www.ncbi.nlm.nih.gov/pmc/articles/{refid}/')
398
+ else:
399
+ urls.append(ref)
400
+ return urls
401
+
402
+ # @spaces.GPU(duration=360)
403
+ def summarize_text(query,chunksize:int,remote_ornot:bool):
404
+ if remote_ornot:
405
+ backend = 'remote'
406
+ else:
407
+ backend = 'local'
408
+
409
+ assistant,_ = get_ready('summarize',chunksize=chunksize,k=None)
410
+ code, reply, references = assistant.generate(query=query,
411
+ history=[],
412
+ groupname='',backend = backend)
413
+
414
+ logger.info(f'{code}, {query}, {reply}, {references}')
415
+ urls = getpmcurls(references)
416
+ mds = '\n\n'.join([f'[{ref}]({url})' for ref,url in zip(references,urls)])
417
+ return gr.Markdown(label="看看",value = reply,line_breaks=True) , gr.Markdown(label="参考文献",value = mds,line_breaks=True)
418
+
419
+ def main_interface():
420
+ with gr.Blocks() as demo:
421
+ with gr.Row():
422
+ gr.Markdown(
423
+ """
424
+ # 医学文献综述助手 (又名 不想看文献)
425
+ """
426
+ )
427
+
428
+ with gr.Tab("模型服务配置"):
429
+ gr.Markdown("""
430
+ #### 配置模型服务 🛠️
431
+
432
+ 1. **是否使用远程大模型**
433
+ - 勾选此项,如果你想使用远程的大模型服务。
434
+ - 如果不勾选,将默认使用本地模型服务。
435
+
436
+ 2. **API配置**
437
+ - 配置大模型提供商和API,确保模型服务能够正常运行。
438
+ - 提供商选择:kimi、deepseek、zhipuai、gpt。
439
+ - 输入您的API密钥和选择对应模型。
440
+ - 点击“保存配置”按钮以保存您的设置。
441
+
442
+ 📝 **备注**:请参考[如何使用]('https://github.com/jabberwockyang/MedicalReviewAgent/blob/main/README.md')获取更多信息。
443
+
444
+ """)
445
+
446
+ remote_ornot = gr.Checkbox(label="是否使用远程大模型")
447
+ with gr.Accordion("API配置", open=True):
448
+ apimd = gr.Markdown("[如何配置API]('https://github.com/jabberwockyang/MedicalReviewAgent/blob/main/README.md')",visible=False)
449
+ remote_company = gr.Dropdown(["kimi", "deepseek", "zhipuai",'gpt'],
450
+ label="选择大模型提供商",interactive=False,visible=False)
451
+ api = gr.Textbox(label="您的API",lines = 1,interactive=False,visible=False)
452
+ baseurl = gr.Textbox(label="base url",lines = 1,interactive=False,visible=False)
453
+ model = gr.Dropdown([],label="选择模型",interactive=False,visible=False)
454
+
455
+ confirm_button = gr.Button("保存配置")
456
+
457
+ remote_ornot.change(update_remote_buttons, inputs=[remote_ornot],outputs=[apimd,remote_company,api,baseurl,model])
458
+ remote_company.change(udate_model_dropdown, inputs=[remote_company],outputs=[model])
459
+ confirm_button.click(update_remote_config, inputs=[remote_ornot,remote_company,api,baseurl,model],outputs=[confirm_button])
460
+
461
+
462
+ with gr.Tab("文献查找+数据库生成"):
463
+ gr.Markdown("""
464
+ #### 查找文献 📚
465
+
466
+ 1. **输入关键词批量PubMed PMC文献**
467
+ - 在“感兴趣的关键词”框中输入您感兴趣的关键词,每行一个。
468
+ - 设置查找数量(0-1000)。
469
+ - 点击“搜索PubMed PMC”按钮进行文献查找。
470
+
471
+ 2. **上传PDF**
472
+ - 通过“上传PDF”按钮上传您已有的PDF文献文件。
473
+
474
+ 3. **更新文献库情况 删除文献库**
475
+ - 点击“更新文献库情况”按钮,查看当前文献库的概况。
476
+ - 如果需要重置或删除现有文献库,点击“删除文献库”按钮。
477
+
478
+
479
+ #### 生成数据库 🗂️
480
+
481
+ 1. **设置数据库构建参数 生成数据库**
482
+ - 选择块大小(Chunk Size)和聚类数(Number of Clusters)。
483
+ - 提供选项用于选择合适的块大小和聚类数。
484
+ - 点击“生成数据库”按钮开始数据库生成过程。
485
+
486
+ 2. **更新数据库情况 删除数据库**
487
+ - 点击“更新数据库情况”按钮,查看当前数据库的概况。
488
+ - 点击“删除数据库”按钮移除现有数据库。
489
+
490
+ 📝 **备注**:请参考[如何选择数据库构建参数]('https://github.com/jabberwockyang/MedicalReviewAgent/tree/main')获取更多信息。
491
+ """)
492
+ with gr.Row(equal_height=True):
493
+ with gr.Column(scale=1):
494
+ input_keys = gr.Textbox(label="感兴趣的关键词",
495
+ value = "输入关键词或者PMID, 换行分隔",
496
+ lines = 5)
497
+ retmax = gr.Slider(
498
+ minimum=0,
499
+ maximum=1000,
500
+ value=500,
501
+ interactive=True,
502
+ label="查多少",
503
+ )
504
+ generate_repo_button = gr.Button("搜索PubMed PMC")
505
+ with gr.Column(scale=2):
506
+ file_output = gr.File(scale=2)
507
+ upload_button = gr.UploadButton("上传PDF",
508
+ file_types=[".pdf",".csv",".doc"],
509
+ file_count="multiple",scale=0)
510
+
511
+ with gr.Row(equal_height=True):
512
+ with gr.Column(scale=0):
513
+ delete_repo_button = gr.Button("删除文献库")
514
+ update_repo_button = gr.Button("更新文献库情况")
515
+ with gr.Column(scale=2):
516
+
517
+ repo_summary =gr.Textbox(label= '文献库概况', value="目前还没有文献库")
518
+
519
+ generate_repo_button.click(generate_articles_repo,
520
+ inputs=[input_keys,retmax],
521
+ outputs = [repo_summary])
522
+
523
+
524
+ delete_repo_button.click(delete_articles_repo, inputs=None,
525
+ outputs = repo_summary)
526
+ update_repo_button.click(update_repo, inputs=None,
527
+ outputs = repo_summary)
528
+ upload_button.upload(upload_file, upload_button, file_output)
529
+
530
+ with gr.Accordion("数据库构建参数", open=True):
531
+ gr.Markdown("[如何选择数据库构建参数]('https://github.com/jabberwockyang/MedicalReviewAgent/tree/main')")
532
+ chunksize = gr.Slider(label="Chunk Size",
533
+ info= 'How long you want the chunk to be?',
534
+ minimum=128, maximum=4096,value=1024,step=1,
535
+ interactive=True)
536
+ ncluster = gr.CheckboxGroup(["10", "20", "50", '100','200','500','1000'],
537
+ # default=["20", "50", '100'],
538
+ label="Number of Clusters",
539
+ info="How many Clusters you want to generate")
540
+
541
+ with gr.Row():
542
+ gene_database_button = gr.Button("生成数据库")
543
+ delete_database_button = gr.Button("删除数据库")
544
+ update_database_button = gr.Button("更新数据库情况")
545
+
546
+ database_summary = gr.Textbox(label="数据库概况",lines = 1,value="目前还没有数据库")
547
+
548
+
549
+ gene_database_button.click(generate_database, inputs=[chunksize,ncluster],
550
+ outputs = database_summary)
551
+
552
+ update_database_button.click(update_database_textbox,inputs=None,
553
+ outputs = [database_summary])
554
+
555
+ delete_database_button.click(delete_database, inputs=None,
556
+ outputs = database_summary)
557
+ with gr.Tab("写综述"):
558
+ gr.Markdown("""
559
+ #### 写综述 ✍️
560
+
561
+ 1. **更新数据库情况**
562
+ - 点击“更新数据库情况”按钮,确保使用最新的数据库信息。
563
+
564
+ 2. **选择块大小和聚类数**
565
+ - 从下拉菜单中选择合适的块大小和聚类数。
566
+
567
+ 3. **抽样标注文章聚类**
568
+ - 设置抽样标注比例(0-1)。
569
+ - 点击“抽样标注文章聚类”按钮开始标注过程。
570
+
571
+ 4. **获取灵感**
572
+ - 如果不知道写什么,点击“获取灵感”按钮。
573
+ - 系统将基于标注的文章聚类提供相应的综述子问题。
574
+
575
+ 5. **写综述**
576
+ - 输入您想写的内容或主题。
577
+ - 点击“写综述”按钮,生成综述文本。
578
+
579
+ 6. **查看生成结果**
580
+ - 生成的综述文本将显示在“看看”文本框中。
581
+ - 参考文献将显示在“参考文献”框中。
582
+
583
+ 📝 **备注**:可以尝试不同的参数进行标注和灵感获取,有助于提高综述的质量和相关性。
584
+ """)
585
+
586
+ with gr.Accordion("聚类标注相关参数", open=True):
587
+ with gr.Row():
588
+ update_options = gr.Button("更新数据库情况", scale=0)
589
+ chunksize = gr.Dropdown([], label="选择块大小", scale=0)
590
+ nclusters = gr.Dropdown([], label="选择聚类数", scale=0)
591
+ ntoread = gr.Slider(
592
+ minimum=0,maximum=1,value=0.5,
593
+ interactive=True,
594
+ label="抽样标注比例",
595
+ )
596
+
597
+ annotation_button = gr.Button("抽样标注文章聚类")
598
+ annotation_output = gr.Textbox(label="文章聚类标注/片段摘要",
599
+ lines = 5,
600
+ interactive= True,
601
+ show_copy_button=True)
602
+ inspiration_button = gr.Button("获取灵感")
603
+ inspiration_output = gr.Textbox(label="灵光一现",
604
+ lines = 5,
605
+ show_copy_button=True)
606
+
607
+
608
+ query = gr.Textbox(label="想写什么")
609
+
610
+ write_button = gr.Button("写综述")
611
+ output_text = gr.Markdown(label="看看")
612
+ output_references = gr.Markdown(label="参考文献")
613
+
614
+ update_options.click(update_chunksize_dropdown,
615
+ outputs=[chunksize])
616
+
617
+ chunksize.change(update_ncluster_dropdown,
618
+ inputs=[chunksize],
619
+ outputs= [nclusters])
620
+
621
+ annotation_button.click(annotation,
622
+ inputs = [ntoread, chunksize, nclusters,remote_ornot],
623
+ outputs=[annotation_output])
624
+
625
+ inspiration_button.click(inspiration,
626
+ inputs= [annotation_output, chunksize, nclusters,remote_ornot],
627
+ outputs=[inspiration_output])
628
+
629
+ write_button.click(summarize_text,
630
+ inputs=[query, chunksize,remote_ornot],
631
+ outputs =[output_text,output_references])
632
+
633
+ demo.launch(share=False, server_name='0.0.0.0', debug=True,show_error=True,allowed_paths=['img_0.jpg'])
634
+
635
+ # start service
636
+ if __name__ == '__main__':
637
+ args = parse_args()
638
+ # copy config from config-bak
639
+ if args.model_downloaded:
640
+ shutil.copy('config-mod_down-bak.ini', args.config_path) # yyj
641
+ else:
642
+ shutil.copy('config-bak.ini', args.config_path) # yyj
643
+
644
+ CONFIG_PATH = args.config_path
645
+
646
+ if args.standalone is True:
647
+ # hybrid llm serve
648
+ server_ready = Value('i', 0)
649
+ server_process = Process(target=llm_serve,
650
+ args=(args.config_path, server_ready))
651
+ server_process.start()
652
+ while True:
653
+ if server_ready.value == 0:
654
+ logger.info('waiting for server to be ready..')
655
+ time.sleep(3)
656
+ elif server_ready.value == 1:
657
+ break
658
+ else:
659
+ logger.error('start local LLM server failed, quit.')
660
+ raise Exception('local LLM path')
661
+ logger.info('Hybrid LLM Server start.')
662
+
663
+ main_interface()
config-bak.ini CHANGED
@@ -23,6 +23,7 @@ local_llm_max_text_length = 32000
23
  local_llm_bind_port = 8888
24
  remote_type = ""
25
  remote_api_key = ""
 
26
  remote_llm_max_text_length = 32000
27
  remote_llm_model = ""
28
  rpm = 500
 
23
  local_llm_bind_port = 8888
24
  remote_type = ""
25
  remote_api_key = ""
26
+ remote_base_url = ""
27
  remote_llm_max_text_length = 32000
28
  remote_llm_model = ""
29
  rpm = 500
config-mod_down-bak.ini ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [feature_store]
2
+ reject_throttle = 0
3
+ embedding_model_path = "/root/models/bce-embedding-base_v1"
4
+ reranker_model_path = "/root/models/bce-reranker-base_v1"
5
+ repo_dir = "repodir"
6
+ work_dir = "workdir"
7
+ n_clusters = [20, 50]
8
+ chunk_size = 1024
9
+
10
+ [web_search]
11
+ x_api_key = "${YOUR-API-KEY}"
12
+ domain_partial_order = ["openai.com", "pytorch.org", "readthedocs.io", "nvidia.com", "stackoverflow.com", "juejin.cn", "zhuanlan.zhihu.com", "www.cnblogs.com"]
13
+ save_dir = "logs/web_search_result"
14
+
15
+ [llm]
16
+ enable_local = 1
17
+ enable_remote = 1
18
+ client_url = "http://127.0.0.1:8888/inference"
19
+
20
+ [llm.server]
21
+ local_llm_path = "/root/models/Qwen1.5-7B-Chat"
22
+ local_llm_max_text_length = 32000
23
+ local_llm_bind_port = 8888
24
+ remote_type = ""
25
+ remote_api_key = ""
26
+ remote_base_url = ""
27
+ remote_llm_max_text_length = 32000
28
+ remote_llm_model = ""
29
+ rpm = 500
30
+
31
+ [worker]
32
+ enable_sg_search = 0
33
+ save_path = "logs/work.txt"
34
+
35
+ [worker.time]
36
+ start = "00:00:00"
37
+ end = "23:59:59"
38
+ has_weekday = 1
39
+
40
+ [sg_search]
41
+ binary_src_path = "/usr/local/bin/src"
42
+ src_access_token = "${YOUR-SRC-ACCESS-TOKEN}"
43
+
44
+ [sg_search.opencompass]
45
+ github_repo_id = "open-compass/opencompass"
46
+ introduction = "用于评测大型语言模型(LLM). 它提供了完整的开源可复现的评测框架,支持大语言模型、多模态模型的一站式评测,基于分布式技术,对大参数量模型亦能实现高效评测。评测方向汇总为知识、语言、理解、推理、考试五大能力维度,整合集纳了超过70个评测数据集,合计提供了超过40万个模型评测问题,并提供长文本、安全、代码3类大模型特色技术能力评测。"
47
+
48
+ [sg_search.lmdeploy]
49
+ github_repo_id = "internlm/lmdeploy"
50
+ introduction = "lmdeploy 是一个用于压缩、部署和服务 LLM(Large Language Model)的工具包。是一个服务端场景下,transformer 结构 LLM 部署工具,支持 GPU 服务端部署,速度有保障,支持 Tensor Parallel,多并发优化,功能全面,包括模型转换、缓存历史会话的 cache feature 等. 它还提供了 WebUI、命令行和 gRPC 客户端接入。"
51
+
52
+ [frontend]
53
+ type = "none"
54
+ webhook_url = "https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxxxxxxx"
55
+ message_process_policy = "immediate"
56
+
57
+ [frontend.lark_group]
58
+ app_id = "cli_a53a34dcb778500e"
59
+ app_secret = "2ajhg1ixSvlNm1bJkH4tJhPfTCsGGHT1"
60
+ encrypt_key = "abc"
61
+ verification_token = "def"
62
+
63
+ [frontend.wechat_personal]
64
+ bind_port = 9527
huixiangdou/service/findarticles.py CHANGED
@@ -10,28 +10,61 @@ from lxml import etree
10
 
11
  class ArticleRetrieval:
12
  def __init__(self,
13
- keywords: list,
 
14
  repo_dir = 'repodir',
15
  retmax = 500):
 
 
 
16
  self.keywords = keywords
 
17
  self.repo_dir = repo_dir
18
  self.retmax = retmax
19
-
20
- ## 通过PMC数据库检索文章
21
- def search_pmc(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
24
  params = {
25
- "db": "pmc",
26
  "term": '+'.join(self.keywords),
27
  "retmax": self.retmax
28
  }
29
  response = requests.get(base_url, params=params)
30
  root = ET.fromstring(response.content)
31
- pmc_ids = [id_element.text for id_element in root.findall('.//Id')]
32
- print(f"Found {len(pmc_ids)} articles.")
33
- self.pmc_ids = pmc_ids
34
- return pmc_ids
 
35
 
36
  # 解析XML文件
37
  def _get_all_text(self, element):
@@ -74,8 +107,8 @@ class ArticleRetrieval:
74
  if full_text.strip() == '':
75
  continue
76
  else:
77
- logger.info(full_text[:1000])
78
- with open(os.path.join(self.repo_dir,f'PMC{id}.txt'), 'w') as f:
79
  f.write(full_text)
80
  self.success += 1
81
 
@@ -83,7 +116,12 @@ class ArticleRetrieval:
83
  config = {
84
  'keywords': self.keywords,
85
  'repo_dir': self.repo_dir,
86
- 'pmc_ids': self.pmc_ids,
 
 
 
 
 
87
  'len': self.success,
88
  'retmax': self.retmax
89
  }
@@ -91,12 +129,33 @@ class ArticleRetrieval:
91
  json.dump(config, f, indent=4, ensure_ascii=False)
92
 
93
  def initiallize(self):
94
- self.search_pmc()
95
- self.fetch_full_text()
96
- self.save_config()
 
 
 
 
97
 
98
  if __name__ == '__main__':
99
  if os.path.exists('repodir'):
100
  shutil.rmtree('repodir')
101
- articelfinder = ArticleRetrieval(keywords = ['covid-19'],repo_dir = 'repodir',retmax = 5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  articelfinder.initiallize()
 
10
 
11
  class ArticleRetrieval:
12
  def __init__(self,
13
+ keywords: list = [],
14
+ pmids: list = [],
15
  repo_dir = 'repodir',
16
  retmax = 500):
17
+ if keywords is [] and pmids is []:
18
+ raise ValueError("Either keywords or pmids must be provided.")
19
+
20
  self.keywords = keywords
21
+ self.pmids = pmids
22
  self.repo_dir = repo_dir
23
  self.retmax = retmax
24
+ self.pmc_ids = []
25
+
26
+
27
+ def esummary_pmc(self):
28
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?"
29
+ params = {
30
+ "db": "pubmed",
31
+ "id": ','.join(self.pmids),
32
+ # "retmax": self.retmax
33
+ }
34
+ response = requests.get(base_url, params=params)
35
+ root = ET.fromstring(response.content)
36
+ results = []
37
+ for docsum in root.findall('DocSum'):
38
+ pmcid = None
39
+ id_value = docsum.find('Id').text
40
+ for item in docsum.findall('Item'):
41
+ if item.attrib.get('Name') == 'ArticleIds':
42
+ for id_item in item.findall('Item'):
43
+ if id_item.attrib.get('Name') == 'pmc':
44
+ pmcid = id_item.text
45
+ break
46
+
47
+ if pmcid:
48
+ results.append((id_value, pmcid))
49
+ self.esummary = results
50
+ self.pmc_ids = [r[1] for r in results]
51
+
52
+ ## 通过Pubmed数据库检索文章
53
+ def esearch_pmc(self):
54
 
55
  base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
56
  params = {
57
+ "db": "pubmed",
58
  "term": '+'.join(self.keywords),
59
  "retmax": self.retmax
60
  }
61
  response = requests.get(base_url, params=params)
62
  root = ET.fromstring(response.content)
63
+ idlist = root.find('.//IdList')
64
+ pmids = [id_element.text for id_element in idlist.findall('.//Id')]
65
+ print(f"Found {len(pmids)} articles for keywords {self.keywords}.")
66
+ self.pmids.extend(pmids)
67
+
68
 
69
  # 解析XML文件
70
  def _get_all_text(self, element):
 
107
  if full_text.strip() == '':
108
  continue
109
  else:
110
+ logger.info(full_text[:500])
111
+ with open(os.path.join(self.repo_dir,f'{id}.txt'), 'w') as f:
112
  f.write(full_text)
113
  self.success += 1
114
 
 
116
  config = {
117
  'keywords': self.keywords,
118
  'repo_dir': self.repo_dir,
119
+ 'result': [
120
+ {
121
+ 'pmid': r[0],
122
+ 'pmcid': r[1]
123
+ } for r in self.esummary
124
+ ],
125
  'len': self.success,
126
  'retmax': self.retmax
127
  }
 
129
  json.dump(config, f, indent=4, ensure_ascii=False)
130
 
131
  def initiallize(self):
132
+ if self.keywords !=[]:
133
+ print(self.keywords)
134
+ self.esearch_pmc() # get pmids from pubmed database using keywords
135
+
136
+ self.esummary_pmc() # get pmc ids from pubmed database using pmids
137
+ self.fetch_full_text() # get full text from pmc database using pmc ids
138
+ self.save_config() # save config file
139
 
140
  if __name__ == '__main__':
141
  if os.path.exists('repodir'):
142
  shutil.rmtree('repodir')
143
+
144
+ strings = """
145
+ 36944324
146
+ 38453907
147
+ 38300432
148
+ 38651453
149
+ 38398096
150
+ 38255885
151
+ 38035547
152
+ 38734498"""
153
+ string = [k.strip() for k in strings.split('\n')]
154
+
155
+ pmids = [k for k in string if k.isdigit()]
156
+ print(pmids)
157
+ keys = [k for k in string if not k.isdigit() and k != '']
158
+ print(keys)
159
+ articelfinder = ArticleRetrieval(keywords = keys,pmids = pmids,
160
+ repo_dir = 'repodir',retmax = 5)
161
  articelfinder.initiallize()
huixiangdou/service/llm_server_hybrid.py CHANGED
@@ -202,6 +202,7 @@ class HybridLLMServer:
202
  self.remote_model = self.server_config['remote_llm_model']
203
  self.remote_max_length = self.server_config['remote_llm_max_text_length']
204
  self.api_key = self.server_config['remote_api_key']
 
205
 
206
  def call_puyu(self, prompt, history):
207
  self.reload_config()
@@ -345,6 +346,9 @@ class HybridLLMServer:
345
  if base_url is not None:
346
  client = OpenAI(api_key=self.server_config['remote_api_key'],
347
  base_url=base_url)
 
 
 
348
  else:
349
  client = OpenAI(api_key=self.server_config['remote_api_key'])
350
 
 
202
  self.remote_model = self.server_config['remote_llm_model']
203
  self.remote_max_length = self.server_config['remote_llm_max_text_length']
204
  self.api_key = self.server_config['remote_api_key']
205
+ self.base_url = self.server_config['remote_base_url']
206
 
207
  def call_puyu(self, prompt, history):
208
  self.reload_config()
 
346
  if base_url is not None:
347
  client = OpenAI(api_key=self.server_config['remote_api_key'],
348
  base_url=base_url)
349
+ elif self.base_url != '':
350
+ client = OpenAI(api_key=self.server_config['remote_api_key'],
351
+ base_url=self.base_url)
352
  else:
353
  client = OpenAI(api_key=self.server_config['remote_api_key'])
354
 
huixiangdou/service/worker.py CHANGED
@@ -17,6 +17,14 @@ from .sg_search import SourceGraphProxy
17
  from .web_search import WebSearch
18
 
19
 
 
 
 
 
 
 
 
 
20
  class Worker:
21
  """The Worker class orchestrates the logic of handling user queries,
22
  generating responses and managing several aspects of a chat assistant. It
@@ -85,7 +93,7 @@ class Worker:
85
  self.SECURITY_TEMAPLTE = 'Evaluate whether the following sentence involves prohibited content such as politics, insult, pornography, terror, religion, cyber violence, racial discrimination, etc., rate it on a scale of 0-10, do not explain, just give the score. The scoring standard is as follows: any violation directly gets 10 points; completely unrelated gets 0 points. Give the score, no explanation: "{}"' # noqa E501
86
  self.PERPLESITY_TEMPLATE = 'Question: {} Answer: {}\nRead the dialogue above, does the answer express that they don\'t know? The more comprehensive the answer, the lower the score. Rate it on a scale of 0-10, no explanation, just give the score.\nThe scoring standard is as follows: an accurate answer to the question gets 0 points; a detailed answer gets 1 point; knowing some answers but having uncertain information gets 8 points; knowing a small part of the answer but recommends seeking help from others gets 9 points; not knowing any of the answers and directly recommending asking others for help gets 10 points. Just give the score, no explanation.' # noqa E501
87
  self.SUMMARIZE_TEMPLATE = '"{}" \n Read the content above carefully, summarize it in a short and powerful way.' # noqa E501
88
- self.GENERATE_TEMPLATE = 'Background Information: "{}"\n Question: "{}"\n Please read the reference material carefully and answer the question.' # noqa E501
89
  self.ANNOTATE_CLUSTER = 'these are chunklized sentences from different papers about{}, they are clustered by similarity, the following is 10 samples from one of the cluster: "{}"\n Please tag the cluster in one breif sentence.'
90
  self.INSPIRATION_TEMPLATE = 'Given the following summary of the articles content about {0} {1}, give some idea or sub-questions of the review about {0}, one question is sufficient.' # noqa E501
91
 
@@ -205,7 +213,7 @@ class Worker:
205
  tracker=tracker,
206
  throttle=5,
207
  default=10,backend=backend):
208
- context += chunk
209
  context += '\n\n'
210
  refs.append(ref)
211
  refs = list(set(refs))
@@ -219,8 +227,13 @@ class Worker:
219
  backend=backend,
220
  history=history)
221
  tracker.log('feature store doc', [chunk, response])
 
 
222
  return ErrorCode.SUCCESS, response, refs
 
 
223
 
 
224
  # try:
225
  # references = []
226
  # web_context = ''
@@ -266,13 +279,13 @@ class Worker:
266
  # except Exception as e:
267
  # logger.error(e)
268
 
269
- if response is not None and len(response) > 0:
270
- prompt = self.PERPLESITY_TEMPLATE.format(query, response)
271
- if self.single_judge(prompt=prompt,
272
- tracker=tracker,
273
- throttle=10,
274
- default=0,backend = backend):
275
- reborn_code = ErrorCode.BAD_ANSWER
276
 
277
  # if self.config['worker']['enable_sg_search']:
278
  # if reborn_code == ErrorCode.BAD_ANSWER or reborn_code == ErrorCode.NO_SEARCH_RESULT: # noqa E501
@@ -301,11 +314,11 @@ class Worker:
301
  # default=0):
302
  # return ErrorCode.BAD_ANSWER, response, references
303
 
304
- if response is not None and len(response) >= 800:
305
- # reply too long, summarize it
306
- response = self.llm.generate_response(
307
- prompt=self.SUMMARIZE_TEMPLATE.format(response),
308
- backend=backend)
309
 
310
  # if len(response) > 0 and self.single_judge(
311
  # self.SECURITY_TEMAPLTE.format(response),
@@ -314,10 +327,9 @@ class Worker:
314
  # default=0):
315
  # return ErrorCode.SECURITY, response, references
316
 
317
- if reborn_code != ErrorCode.SUCCESS:
318
- return reborn_code, response, references
319
-
320
- return ErrorCode.SUCCESS, response, references
321
 
322
  def annotate_cluster(self,theme, cluster_no, chunk, history, groupname,backend):
323
  """Annotates a cluster of questions based on the user query and
 
17
  from .web_search import WebSearch
18
 
19
 
20
+ def convertid2url(text):
21
+ # Regular expression to find all PMC references
22
+ pattern = r"\[PMC(\d+)\]"
23
+ # Function to replace each match with a URL link
24
+ replacement = lambda match: f"[PMC{match.group(1)}](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC{match.group(1)}/)"
25
+ # Replace all occurrences in the text
26
+ formatted_text = re.sub(pattern, replacement, text)
27
+ return formatted_text
28
  class Worker:
29
  """The Worker class orchestrates the logic of handling user queries,
30
  generating responses and managing several aspects of a chat assistant. It
 
93
  self.SECURITY_TEMAPLTE = 'Evaluate whether the following sentence involves prohibited content such as politics, insult, pornography, terror, religion, cyber violence, racial discrimination, etc., rate it on a scale of 0-10, do not explain, just give the score. The scoring standard is as follows: any violation directly gets 10 points; completely unrelated gets 0 points. Give the score, no explanation: "{}"' # noqa E501
94
  self.PERPLESITY_TEMPLATE = 'Question: {} Answer: {}\nRead the dialogue above, does the answer express that they don\'t know? The more comprehensive the answer, the lower the score. Rate it on a scale of 0-10, no explanation, just give the score.\nThe scoring standard is as follows: an accurate answer to the question gets 0 points; a detailed answer gets 1 point; knowing some answers but having uncertain information gets 8 points; knowing a small part of the answer but recommends seeking help from others gets 9 points; not knowing any of the answers and directly recommending asking others for help gets 10 points. Just give the score, no explanation.' # noqa E501
95
  self.SUMMARIZE_TEMPLATE = '"{}" \n Read the content above carefully, summarize it in a short and powerful way.' # noqa E501
96
+ self.GENERATE_TEMPLATE = 'Background Information: "{}"\n Question: "{}"\n Please read the reference material carefully and answer the question. with reference id at the end of the corresponding content for example: Primary determinants of the therapeutic approach are age, comorbidities, and diagnostic molecular profile [PMC9958584]' # noqa E501
97
  self.ANNOTATE_CLUSTER = 'these are chunklized sentences from different papers about{}, they are clustered by similarity, the following is 10 samples from one of the cluster: "{}"\n Please tag the cluster in one breif sentence.'
98
  self.INSPIRATION_TEMPLATE = 'Given the following summary of the articles content about {0} {1}, give some idea or sub-questions of the review about {0}, one question is sufficient.' # noqa E501
99
 
 
213
  tracker=tracker,
214
  throttle=5,
215
  default=10,backend=backend):
216
+ context += f"reference: {ref} content: {chunk}"
217
  context += '\n\n'
218
  refs.append(ref)
219
  refs = list(set(refs))
 
227
  backend=backend,
228
  history=history)
229
  tracker.log('feature store doc', [chunk, response])
230
+ response = convertid2url(response)
231
+
232
  return ErrorCode.SUCCESS, response, refs
233
+ else:
234
+ return ErrorCode.NO_SEARCH_RESULT, response, references
235
 
236
+
237
  # try:
238
  # references = []
239
  # web_context = ''
 
279
  # except Exception as e:
280
  # logger.error(e)
281
 
282
+ # if response is not None and len(response) > 0:
283
+ # prompt = self.PERPLESITY_TEMPLATE.format(query, response)
284
+ # if self.single_judge(prompt=prompt,
285
+ # tracker=tracker,
286
+ # throttle=10,
287
+ # default=0,backend = backend):
288
+ # reborn_code = ErrorCode.BAD_ANSWER
289
 
290
  # if self.config['worker']['enable_sg_search']:
291
  # if reborn_code == ErrorCode.BAD_ANSWER or reborn_code == ErrorCode.NO_SEARCH_RESULT: # noqa E501
 
314
  # default=0):
315
  # return ErrorCode.BAD_ANSWER, response, references
316
 
317
+ # if response is not None and len(response) >= 800:
318
+ # # reply too long, summarize it
319
+ # response = self.llm.generate_response(
320
+ # prompt=self.SUMMARIZE_TEMPLATE.format(response),
321
+ # backend=backend)
322
 
323
  # if len(response) > 0 and self.single_judge(
324
  # self.SECURITY_TEMAPLTE.format(response),
 
327
  # default=0):
328
  # return ErrorCode.SECURITY, response, references
329
 
330
+ # if reborn_code != ErrorCode.SUCCESS:
331
+ # return reborn_code, response, references
332
+ # return ErrorCode.SUCCESS, response, references
 
333
 
334
  def annotate_cluster(self,theme, cluster_no, chunk, history, groupname,backend):
335
  """Annotates a cluster of questions based on the user query and
run.sh CHANGED
@@ -1,7 +1,7 @@
1
  # 云端运行
2
  conda activate ReviewAgent
3
  cd MedicalReviewAgent
4
- python3 app.py
5
 
6
  # 本地映射端口
7
- ssh -CNg -L 7860:127.0.0.1:7860 root@43.133.72.216
 
1
  # 云端运行
2
  conda activate ReviewAgent
3
  cd MedicalReviewAgent
4
+ python3 applocal.py --model_downloaded True
5
 
6
  # 本地映射端口
7
+ ssh -CNg -L 7860:127.0.0.1:7860 root@youripaddress