|
import chromadb |
|
from langchain import LLMChain, PromptTemplate |
|
from langchain_openai import ChatOpenAI |
|
from langchain.chains import RetrievalQA |
|
from langchain.output_parsers import StrOutputParser |
|
from langchain.embeddings import ZhipuAIEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from diffusers import StableDiffusionPipeline |
|
import requests |
|
import gradio as gr |
|
import os |
|
from dotenv import load_dotenv, find_dotenv |
|
|
|
_ = load_dotenv(find_dotenv()) |
|
zhipuai_api_key = os.environ['ZHIPUAI_API_KEY'] |
|
|
|
class HealthcareAgent: |
|
def __init__(self): |
|
self.vectordb = self.get_vectordb() |
|
self.llm = ChatOpenAI( |
|
model="glm-3-turbo", |
|
temperature=0.7, |
|
openai_api_key=zhipuai_api_key, |
|
openai_api_base="https://open.bigmodel.cn/api/paas/v4/" |
|
) |
|
self.diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda") |
|
|
|
def get_vectordb(self): |
|
embedding = ZhipuAIEmbeddings() |
|
persist_directory = '/Users/chenshuyi/Documents/agent/data_base/vector_db' |
|
vectordb = Chroma( |
|
persist_directory=persist_directory, |
|
embedding_function=embedding |
|
) |
|
return vectordb |
|
|
|
def generate_response(self, input_text): |
|
output = self.llm.invoke(input_text) |
|
output_parser = StrOutputParser() |
|
output = output_parser.invoke(output) |
|
return output |
|
|
|
def rag_search(self, symptoms): |
|
template = """使用以下上下文来回答关于症状的问题。如果你不知道答案,就说你不知道,不要试图编造答案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说"谢谢你的提问!"。 |
|
上下文: {context} |
|
问题: 基于这些症状 "{symptoms}",可能是什么疾病?请列出这些疾病的其他常见症状。 |
|
回答格式: |
|
可能的疾病: [疾病1, 疾病2, ...] |
|
其他常见症状: [症状1, 症状2, ...] |
|
回答:""" |
|
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "symptoms"], template=template) |
|
retriever = self.vectordb.as_retriever() |
|
qa_chain = RetrievalQA.from_chain_type( |
|
self.llm, |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} |
|
) |
|
result = qa_chain({"query": symptoms}) |
|
return result["result"] |
|
|
|
def assess_severity(self, condition, symptoms): |
|
template = """使用以下上下文来评估疾病的严重程度。 |
|
上下文: {context} |
|
疾病: {condition} |
|
症状: {symptoms} |
|
请根据给定的疾病和症状,评估病情的严重程度。将严重程度分为轻度、中度和重度三个等级。 |
|
同时,请给出这个评估的理由,并提供一些建议。 |
|
回答格式: |
|
严重程度: [轻度/中度/重度] |
|
理由: [您的解释] |
|
建议: [您的建议] |
|
回答:""" |
|
QA_CHAIN_PROMPT = PromptTemplate( |
|
input_variables=["context", "condition", "symptoms"], |
|
template=template |
|
) |
|
retriever = self.vectordb.as_retriever() |
|
qa_chain = RetrievalQA.from_chain_type( |
|
self.llm, |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} |
|
) |
|
result = qa_chain({"query": f"{condition} {symptoms}", "condition": condition, "symptoms": symptoms}) |
|
return result["result"] |
|
|
|
def generate_skin_condition_image(self, condition): |
|
severities = ["轻度", "中度", "重度"] |
|
images = [] |
|
for severity in severities: |
|
prompt = f"{severity}{condition}的皮肤症状" |
|
image = self.diffusion_model(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] |
|
images.append(image) |
|
return images |
|
|
|
def recommend_medical_facility(self, user_location, condition, severity): |
|
|
|
template = """ |
|
基于以下信息推荐合适的医疗设施类型: |
|
|
|
疾病: {condition} |
|
严重程度: {severity} |
|
|
|
请从以下选项中选择最合适的医疗设施类型: |
|
1. 药房 |
|
2. 社区医院 |
|
3. 二甲医院 |
|
4. 三甲医院 |
|
|
|
只需回复数字1-4,不需要其他解释。 |
|
|
|
推荐: |
|
""" |
|
|
|
prompt = PromptTemplate(template=template, input_variables=["condition", "severity"]) |
|
llm_chain = LLMChain(prompt=prompt, llm=self.llm) |
|
facility_type = llm_chain.run(condition=condition, severity=severity).strip() |
|
|
|
|
|
facility_types = { |
|
"1": "药房", |
|
"2": "社区医院", |
|
"3": "二甲医院", |
|
"4": "三甲医院" |
|
} |
|
recommended_type = facility_types.get(facility_type, "医院") |
|
|
|
|
|
amap_key = "您的高德地图API密钥" |
|
url = f"https://restapi.amap.com/v3/place/text?key={amap_key}&keywords={recommended_type}&city={user_location}&offset=10&page=1&extensions=all" |
|
|
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
data = response.json() |
|
if data["status"] == "1" and data["pois"]: |
|
facilities = data["pois"] |
|
|
|
top_facilities = facilities[:3] |
|
result = f"根据您的情况,我们推荐您去{recommended_type}。以下是附近的几个选择:\n\n" |
|
for facility in top_facilities: |
|
result += f"名称: {facility['name']}\n" |
|
result += f"地址: {facility['address']}\n" |
|
result += f"电话: {facility.get('tel', '未提供')}\n\n" |
|
return result |
|
else: |
|
return f"抱歉,我们无法在您的位置找到合适的{recommended_type}。请考虑寻求紧急医疗帮助或咨询当地卫生部门。" |
|
else: |
|
return "抱歉,我们暂时无法获取医疗设施信息。请稍后再试或直接联系当地医疗机构。" |
|
|
|
def interact(self, symptoms, user_location): |
|
condition = self.rag_search(symptoms) |
|
|
|
if "皮肤" in condition: |
|
images = self.generate_skin_condition_image(condition) |
|
return condition, images, True, None |
|
else: |
|
severity_assessment = self.assess_severity(condition, symptoms) |
|
severity, reason, advice = self.parse_severity_result(severity_assessment) |
|
facility_recommendation = self.recommend_medical_facility(user_location, condition, severity) |
|
return condition, (severity, reason, advice), False, facility_recommendation |
|
|
|
def parse_severity_result(self, result): |
|
|
|
|
|
lines = result.split('\n') |
|
severity = "" |
|
reason = "" |
|
advice = "" |
|
for line in lines: |
|
if line.startswith("严重程度:"): |
|
severity = line.split(':')[1].strip() |
|
elif line.startswith("理由:"): |
|
reason = line.split(':')[1].strip() |
|
elif line.startswith("建议:"): |
|
advice = line.split(':')[1].strip() |
|
return severity, reason, advice |
|
|
|
def gradio_interface(): |
|
agent = HealthcareAgent() |
|
|
|
def process_input(symptoms, user_location): |
|
condition, result, is_skin_condition, facility_recommendation = agent.interact(symptoms, user_location) |
|
if is_skin_condition: |
|
return gr.update(visible=True, value=condition), gr.update(visible=True, value=result), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation) |
|
else: |
|
severity, reason, advice = result |
|
return gr.update(visible=True, value=f"诊断: {condition}\n严重程度: {severity}\n理由: {reason}\n建议: {advice}"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation) |
|
|
|
def on_select(evt: gr.SelectData): |
|
severities = ["轻度", "中度", "重度"] |
|
return f"您选择的严重程度为: {severities[evt.index]}" |
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# 医疗保健助手") |
|
symptoms_input = gr.Textbox(label="请描述您的症状") |
|
location_input = gr.Textbox(label="请输入您的位置") |
|
submit_btn = gr.Button("提交") |
|
|
|
with gr.Group() as output_group: |
|
text_output = gr.Textbox(label="诊断结果", visible=False) |
|
image_output = gr.Gallery(label="请选择最接近您症状的图片", visible=False, columns=3, height=300) |
|
severity_output = gr.Textbox(label="严重程度", visible=False) |
|
facility_output = gr.Textbox(label="推荐医疗设施", visible=False) |
|
|
|
submit_btn.click(process_input, inputs=[symptoms_input, location_input], outputs=[text_output, image_output, severity_output, facility_output]) |
|
image_output.select(on_select, None, severity_output) |
|
|
|
return iface |
|
|
|
if __name__ == "__main__": |
|
iface = gradio_interface() |
|
iface.launch() |
|
|