File size: 7,291 Bytes
fb6f6d9 630d3f4 fdfcf53 131ded0 a06a315 cb8213b 7c119cb 5ce5698 1e06430 5ce5698 51c0f15 c2aa18b de611e2 c2aa18b 1e21aa9 51c0f15 1e06430 c2aa18b de611e2 b9bde1a 51c0f15 c2aa18b 1e21aa9 de611e2 1e21aa9 cb8213b 908d31d 1e21aa9 21d443e 1e21aa9 de611e2 1e21aa9 7c119cb 1e21aa9 742b788 4b8fc6b 742b788 10b5e55 b2369fc 908d31d 1e06430 1e21aa9 edfb894 1e21aa9 de611e2 1e21aa9 5b5abf5 908d31d fb6f6d9 10b5e55 908d31d 1e06430 a6c563b 9aa294c 5b5abf5 8db8541 309abbd 908d31d 4bf350f 630d3f4 de611e2 630d3f4 8db8541 742b788 8db8541 742b788 fb6f6d9 4b8fc6b 908d31d fb6f6d9 86b4310 edfb894 742b788 bca6e22 38e677d edfb894 908d31d 86b4310 de611e2 908d31d de611e2 908d31d fb6f6d9 908d31d edfb894 cb8213b 908d31d 8db8541 86b4310 908d31d edfb894 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import streamlit as st
import random
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from datasets import load_dataset
# Streamlit 界面
st.title("外挂知识库问答系统")
# 使用 假知识 数据集
if "data_list" not in st.session_state:
st.session_state.data_list = []
st.session_state.answer_list = []
if not st.session_state.data_list:
try:
with st.spinner("正在读取数据库..."):
dataset = load_dataset("zeerd/fake_knowledge")
# 输出前五条数据
print(dataset["train"][:5])
data_list = []
answer_list = []
for example in dataset["train"]:
answer_list.append(example["Answer"])
data_list.append({"Question": example["Question"], "Answer": example["Answer"]})
st.session_state.answer_list = answer_list
st.session_state.data_list = data_list
st.success("数据库读取完成!")
print("数据库读取完成!")
except Exception as e:
st.error(f"读取数据集失败:{e}")
st.stop()
# 构建向量数据库 (如果需要,仅构建一次)
if "vector_created" not in st.session_state:
st.session_state.vector_created = False
if not st.session_state.vector_created:
try:
with st.spinner("正在构建向量数据库..."):
# all-mpnet-base-v2 是一个由 Sentence Transformers 库提供的预训练模型,
# 专门用于生成高质量的句子嵌入(sentence embeddings)。
# all-mpnet-base-v2 在多个自然语言处理任务上表现出色,包括语义相似度计算、
# 文本检索、聚类等。它能够有效地捕捉句子的语义信息,并生成具有代表性的向量表示。
st.session_state.embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
st.session_state.db = FAISS.from_texts(st.session_state.answer_list, st.session_state.embeddings)
st.success("向量数据库构建完成!")
print("向量数据库构建完成!")
except Exception as e:
st.error(f"向量数据库构建失败:{e}")
st.stop()
st.session_state.vector_created = True
if "repo_id" not in st.session_state:
st.session_state.repo_id = ''
if "temperature" not in st.session_state:
st.session_state.temperature = ''
if "max_length" not in st.session_state:
st.session_state.max_length = ''
def get_answer(prompt):
answer = st.session_state.llm.invoke(prompt)
# 去掉 prompt 的内容
answer = answer.replace(prompt, "").strip()
print(answer)
return answer
# 问答函数
def answer_question(repo_id, temperature, max_length, question):
# 初始化 Gemma 模型
print('repo_id: ' + repo_id)
print('temperature: ' + str(temperature))
print('max_length: ' + str(max_length))
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
try:
with st.spinner("正在初始化 Gemma 模型..."):
st.session_state.llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
st.success("Gemma 模型初始化完成!")
print("Gemma 模型初始化完成!")
st.session_state.repo_id = repo_id
st.session_state.temperature = temperature
st.session_state.max_length = max_length
except Exception as e:
st.error(f"Gemma 模型加载失败:{e}")
st.stop()
# 获取答案
try:
with st.spinner("正在生成答案(基于模型自身)..."):
pure_answer = get_answer(question)
st.success("答案生成完毕(基于模型自身)!")
print("答案生成完毕(基于模型自身)!")
with st.spinner("正在筛选本地数据集..."):
question_embedding = st.session_state.embeddings.embed_query(question)
# question_embedding_str = " ".join(map(str, question_embedding))
docs_and_scores = st.session_state.db.similarity_search_by_vector(question_embedding, 8)
context_list = []
for doc, score in docs_and_scores:
print(str(score) + ' : ' + doc.page_content)
context_list.append(doc.page_content)
context = "\n".join(context_list)
prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
print('prompt: ' + prompt)
st.success("本地数据集筛选完成!")
print("本地数据集筛选完成!")
with st.spinner("正在生成答案(基于本地数据集)..."):
answer = get_answer(prompt)
st.success("答案生成完毕(基于本地数据集)!")
print("答案生成完毕(基于本地数据集)!")
return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
except Exception as e:
st.error(f"问答过程出错:{e}")
return {"prompt": "", "answer": "An error occurred during the answering process.", "pure_answer": ""}
col1, col2 = st.columns(2)
with col1:
gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
with col2:
temperature = st.number_input("temperature", value=1.0)
max_length = st.number_input("max_length", value=1024)
st.divider()
def generate_answer(repo_id, temperature, max_length, question):
result = answer_question(repo_id, float(temperature), int(max_length), question)
print('prompt: ' + result["prompt"])
print('answer: ' + result["answer"])
print('pure_answer: ' + result["pure_answer"])
st.write("生成答案(无参考):")
st.write(result["pure_answer"])
st.divider()
st.write("参考文字:")
st.markdown(result["prompt"].replace('\n', '<br/>'))
st.write("生成答案:")
st.write(result["answer"])
col3, col4 = st.columns(2)
with col3:
if st.button("使用原数据集中的随机问题"):
dataset_size = len(st.session_state.data_list)
random_index = random.randint(0, dataset_size - 1)
# 读取随机问题
random_question = st.session_state.data_list[random_index]["Question"]
origin_answer = st.session_state.data_list[random_index]["Answer"]
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
print('origin_answer: ' + origin_answer)
st.write("随机问题:")
st.write(random_question)
st.write("原始答案:")
st.write(origin_answer)
generate_answer(gemma, float(temperature), int(max_length), random_question)
with col4:
question = st.text_area("请输入问题", "太阳为什么是绿色的?")
if st.button("提交输入的问题"):
if not question:
st.warning("请输入问题!")
else:
generate_answer(gemma, float(temperature), int(max_length), question)
|