Charles Chan commited on
Commit
cb8213b
·
1 Parent(s): 9ddf764
Files changed (2) hide show
  1. app.py +26 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,6 +2,8 @@ import streamlit as st
2
  from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.embeddings import SentenceTransformerEmbeddings
4
  from langchain_community.vectorstores import FAISS
 
 
5
 
6
  # 1. 准备知识库数据 (示例)
7
  knowledge_base = [
@@ -12,10 +14,19 @@ knowledge_base = [
12
  "Gemma 支持多种语言。"
13
  ]
14
 
 
 
 
 
 
 
 
 
 
15
  # 2. 构建向量数据库 (如果需要,仅构建一次)
16
  try:
17
  embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
18
- db = FAISS.from_texts(knowledge_base, embeddings)
19
  except Exception as e:
20
  st.error(f"向量数据库构建失败:{e}")
21
  st.stop()
@@ -56,6 +67,20 @@ temperature = st.number_input("temperature", value=1.0)
56
  max_length = st.number_input("max_length", value=1024)
57
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if st.button("提交"):
60
  if not question:
61
  st.warning("请输入问题!")
 
2
  from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.embeddings import SentenceTransformerEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
+ from datasets import load_dataset
6
+ import random
7
 
8
  # 1. 准备知识库数据 (示例)
9
  knowledge_base = [
 
14
  "Gemma 支持多种语言。"
15
  ]
16
 
17
+ try:
18
+ dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
19
+ answer_list = [example["Answer"] for example in dataset["train"]]
20
+
21
+ except Exception as e:
22
+ st.error(f"读取数据集失败:{e}")
23
+ st.stop()
24
+
25
+
26
  # 2. 构建向量数据库 (如果需要,仅构建一次)
27
  try:
28
  embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
29
+ db = FAISS.from_texts(answer_list, embeddings)
30
  except Exception as e:
31
  st.error(f"向量数据库构建失败:{e}")
32
  st.stop()
 
67
  max_length = st.number_input("max_length", value=1024)
68
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")
69
 
70
+ if st.button("随机"):
71
+ dataset_size = len(dataset["train"])
72
+ random_index = random.randint(0, dataset_size - 1)
73
+ # 读取随机问题
74
+ random_question = dataset["train"][random_index]["Question"]
75
+ origin_answer = dataset["train"][random_index]["Answer"]
76
+ st.write("随机问题:")
77
+ st.write(random_question)
78
+ st.write("原始答案:")
79
+ st.write(origin_answer)
80
+ answer = answer_question(gemma, float(temperature), int(max_length), random_question)
81
+ st.write("生成答案:")
82
+ st.write(answer)
83
+
84
  if st.button("提交"):
85
  if not question:
86
  st.warning("请输入问题!")
requirements.txt CHANGED
@@ -5,3 +5,4 @@ langchain-community
5
  langchain-huggingface
6
  sentence_transformers
7
  faiss-cpu
 
 
5
  langchain-huggingface
6
  sentence_transformers
7
  faiss-cpu
8
+ datasets