File size: 4,279 Bytes
96cd96f
 
 
 
 
 
 
 
 
 
 
 
 
a10fbed
65d97b1
96cd96f
 
 
 
 
 
 
 
 
 
65d97b1
 
 
 
 
 
 
 
 
 
96cd96f
 
 
 
 
3c24b5a
 
96cd96f
3c24b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96cd96f
 
 
 
 
 
 
 
 
 
3c24b5a
96cd96f
 
 
 
 
 
 
 
 
65d97b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author:quincy qiang
@license: Apache Licence
@file: model.py
@time: 2023/04/17
@contact: [email protected]
@software: PyCharm
@description: coding..
"""
from langchain.chains import RetrievalQA
from langchain.prompts.prompt import PromptTemplate

from clc.config import LangChainCFG
from clc.gpt_service import ChatGLMService
from clc.source_service import SourceService


class LangChainApplication(object):
    def __init__(self, config):
        self.config = config
        self.llm_service = ChatGLMService()
        self.llm_service.load_model(model_name_or_path=self.config.llm_model_name)
        self.source_service = SourceService(config)

        # if self.config.kg_vector_stores is None:
        #     print("init a source vector store")
        #     self.source_service.init_source_vector()
        # else:
        #     print("load zh_wikipedia source vector store ")
        #     try:
        #         self.source_service.load_vector_store(self.config.kg_vector_stores['初始化知识库'])
        #     except Exception as e:
        #         self.source_service.init_source_vector()

    def get_knowledge_based_answer(self, query,
                                   history_len=5,
                                   temperature=0.1,
                                   top_p=0.9,
                                   top_k=4,
                                   web_content='',
                                   chat_history=[]):
        if web_content:
            prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
                                如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
                                已知网络检索内容:{web_content}""" + """
                                已知内容:
                                {context}
                                问题:
                                {question}"""
        else:
            prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
                                            如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
                                            已知内容:
                                            {context}
                                            问题:
                                            {question}"""
        prompt = PromptTemplate(template=prompt_template,
                                input_variables=["context", "question"])
        self.llm_service.history = chat_history[-history_len:] if history_len > 0 else []

        self.llm_service.temperature = temperature
        self.llm_service.top_p = top_p

        knowledge_chain = RetrievalQA.from_llm(
            llm=self.llm_service,
            retriever=self.source_service.vector_store.as_retriever(
                search_kwargs={"k": top_k}),
            prompt=prompt)
        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
            input_variables=["page_content"], template="{page_content}")

        knowledge_chain.return_source_documents = True

        result = knowledge_chain({"query": query})
        return result

    def get_llm_answer(self, query='', web_content=''):
        if web_content:
            prompt = f'基于网络检索内容:{web_content},回答以下问题{query}'
        else:
            prompt = query
        result = self.llm_service._call(prompt)
        return result


if __name__ == '__main__':
    config = LangChainCFG()
    application = LangChainApplication(config)
    # result = application.get_knowledge_based_answer('马保国是谁')
    # print(result)
    # application.source_service.add_document('/home/searchgpt/yq/Knowledge-ChatGLM/docs/added/马保国.txt')
    # result = application.get_knowledge_based_answer('马保国是谁')
    # print(result)
    result = application.get_llm_answer('马保国是谁')
    print(result)