File size: 7,291 Bytes
fb6f6d9
630d3f4
fdfcf53
131ded0
a06a315
cb8213b
7c119cb
5ce5698
1e06430
5ce5698
51c0f15
c2aa18b
de611e2
 
c2aa18b
 
1e21aa9
 
51c0f15
1e06430
 
 
c2aa18b
 
de611e2
b9bde1a
51c0f15
c2aa18b
 
1e21aa9
de611e2
1e21aa9
 
 
cb8213b
908d31d
1e21aa9
 
 
 
 
21d443e
 
 
 
 
 
1e21aa9
de611e2
1e21aa9
 
 
 
7c119cb
1e21aa9
 
 
 
 
 
742b788
 
 
 
 
4b8fc6b
742b788
 
10b5e55
b2369fc
908d31d
1e06430
 
 
 
1e21aa9
 
 
edfb894
1e21aa9
de611e2
1e21aa9
 
 
 
 
 
5b5abf5
908d31d
fb6f6d9
10b5e55
 
 
 
908d31d
1e06430
a6c563b
9aa294c
5b5abf5
8db8541
 
 
 
 
309abbd
908d31d
 
4bf350f
630d3f4
de611e2
630d3f4
8db8541
742b788
8db8541
 
742b788
fb6f6d9
 
4b8fc6b
908d31d
 
 
 
 
 
 
fb6f6d9
86b4310
 
edfb894
 
 
 
742b788
 
 
 
bca6e22
38e677d
edfb894
 
 
908d31d
 
86b4310
de611e2
908d31d
 
de611e2
 
908d31d
 
fb6f6d9
908d31d
 
 
 
edfb894
cb8213b
908d31d
8db8541
86b4310
908d31d
 
 
edfb894
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import streamlit as st
import random
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from datasets import load_dataset

# Streamlit 界面
st.title("外挂知识库问答系统")

# 使用 假知识 数据集
if "data_list" not in st.session_state:
    st.session_state.data_list = []
    st.session_state.answer_list = []

if not st.session_state.data_list:
    try:
        with st.spinner("正在读取数据库..."):
            dataset = load_dataset("zeerd/fake_knowledge")
            # 输出前五条数据
            print(dataset["train"][:5])

            data_list = []
            answer_list = []
            for example in dataset["train"]:
                answer_list.append(example["Answer"])
                data_list.append({"Question": example["Question"], "Answer": example["Answer"]})
            st.session_state.answer_list = answer_list
            st.session_state.data_list = data_list
            st.success("数据库读取完成!")
            print("数据库读取完成!")
    except Exception as e:
        st.error(f"读取数据集失败:{e}")
        st.stop()

# 构建向量数据库 (如果需要,仅构建一次)
if "vector_created" not in st.session_state:
    st.session_state.vector_created = False
if not st.session_state.vector_created:
    try:
        with st.spinner("正在构建向量数据库..."):
            # all-mpnet-base-v2 是一个由 Sentence Transformers 库提供的预训练模型,
            # 专门用于生成高质量的句子嵌入(sentence embeddings)。
            # all-mpnet-base-v2 在多个自然语言处理任务上表现出色,包括语义相似度计算、
            # 文本检索、聚类等。它能够有效地捕捉句子的语义信息,并生成具有代表性的向量表示。
            st.session_state.embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
            st.session_state.db = FAISS.from_texts(st.session_state.answer_list, st.session_state.embeddings)
            st.success("向量数据库构建完成!")
            print("向量数据库构建完成!")
    except Exception as e:
        st.error(f"向量数据库构建失败:{e}")
        st.stop()
    st.session_state.vector_created = True

if "repo_id" not in st.session_state:
    st.session_state.repo_id = ''
if "temperature" not in st.session_state:
    st.session_state.temperature = ''
if "max_length" not in st.session_state:
    st.session_state.max_length = ''

def get_answer(prompt):
    answer = st.session_state.llm.invoke(prompt)
    # 去掉 prompt 的内容
    answer = answer.replace(prompt, "").strip()
    print(answer)
    return answer

# 问答函数
def answer_question(repo_id, temperature, max_length, question):
    # 初始化 Gemma 模型
    print('repo_id: ' + repo_id)
    print('temperature: ' + str(temperature))
    print('max_length: ' + str(max_length))

    if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
        try:
            with st.spinner("正在初始化 Gemma 模型..."):
                st.session_state.llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
                st.success("Gemma 模型初始化完成!")
                print("Gemma 模型初始化完成!")
                st.session_state.repo_id = repo_id
                st.session_state.temperature = temperature
                st.session_state.max_length = max_length
        except Exception as e:
            st.error(f"Gemma 模型加载失败:{e}")
            st.stop()

    # 获取答案
    try:
        with st.spinner("正在生成答案(基于模型自身)..."):
            pure_answer = get_answer(question)
            st.success("答案生成完毕(基于模型自身)!")
            print("答案生成完毕(基于模型自身)!")
        with st.spinner("正在筛选本地数据集..."):
            question_embedding = st.session_state.embeddings.embed_query(question)
            # question_embedding_str = " ".join(map(str, question_embedding))
            docs_and_scores = st.session_state.db.similarity_search_by_vector(question_embedding, 8)

            context_list = []
            for doc, score in docs_and_scores:
                print(str(score) + ' : ' + doc.page_content)
                context_list.append(doc.page_content)
            context = "\n".join(context_list)

            prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
            print('prompt: ' + prompt)

            st.success("本地数据集筛选完成!")
            print("本地数据集筛选完成!")

        with st.spinner("正在生成答案(基于本地数据集)..."):
            answer = get_answer(prompt)
            st.success("答案生成完毕(基于本地数据集)!")
            print("答案生成完毕(基于本地数据集)!")
        return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
    except Exception as e:
        st.error(f"问答过程出错:{e}")
        return {"prompt": "", "answer": "An error occurred during the answering process.", "pure_answer": ""}

col1, col2 = st.columns(2)
with col1:
    gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
with col2:
    temperature = st.number_input("temperature", value=1.0)
    max_length = st.number_input("max_length", value=1024)

st.divider()

def generate_answer(repo_id, temperature, max_length, question):
    result = answer_question(repo_id, float(temperature), int(max_length), question)
    print('prompt: ' + result["prompt"])
    print('answer: ' + result["answer"])
    print('pure_answer: ' + result["pure_answer"])
    st.write("生成答案(无参考):")
    st.write(result["pure_answer"])
    st.divider()
    st.write("参考文字:")
    st.markdown(result["prompt"].replace('\n', '<br/>'))
    st.write("生成答案:")
    st.write(result["answer"])

col3, col4 = st.columns(2)
with col3:
    if st.button("使用原数据集中的随机问题"):
        dataset_size = len(st.session_state.data_list)
        random_index = random.randint(0, dataset_size - 1)
        # 读取随机问题
        random_question = st.session_state.data_list[random_index]["Question"]
        origin_answer = st.session_state.data_list[random_index]["Answer"]
        print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
        print('origin_answer: ' + origin_answer)

        st.write("随机问题:")
        st.write(random_question)
        st.write("原始答案:")
        st.write(origin_answer)
        generate_answer(gemma, float(temperature), int(max_length), random_question)

with col4:
    question = st.text_area("请输入问题", "太阳为什么是绿色的?")
    if st.button("提交输入的问题"):
        if not question:
            st.warning("请输入问题!")
        else:
            generate_answer(gemma, float(temperature), int(max_length), question)