Charles Chan commited on
Commit
630d3f4
·
1 Parent(s): 86b4310
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import streamlit as st
 
2
  from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.embeddings import SentenceTransformerEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  from datasets import load_dataset
6
- import random
7
 
8
  # 使用 進擊的巨人 数据集
9
  try:
@@ -31,6 +32,7 @@ def answer_question(repo_id, temperature, max_length, question):
31
  try:
32
  with st.spinner("正在初始化 Gemma 模型..."):
33
  llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
 
34
  except Exception as e:
35
  st.error(f"Gemma 模型加载失败:{e}")
36
  st.stop()
@@ -49,10 +51,13 @@ def answer_question(repo_id, temperature, max_length, question):
49
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
50
  print('prompt: ' + prompt)
51
 
 
 
52
  with st.spinner("正在生成答案..."):
53
  answer = llm.invoke(prompt)
54
  # 去掉 prompt 的内容
55
  answer = answer.replace(prompt, "").strip()
 
56
  return {"prompt": prompt, "answer": answer}
57
  except Exception as e:
58
  st.error(f"问答过程出错:{e}")
 
1
  import streamlit as st
2
+ import random
3
  from langchain_community.llms import HuggingFaceHub
4
  from langchain_community.embeddings import SentenceTransformerEmbeddings
5
  from langchain_community.vectorstores import FAISS
6
  from datasets import load_dataset
7
+ from transformers import pipeline
8
 
9
  # 使用 進擊的巨人 数据集
10
  try:
 
32
  try:
33
  with st.spinner("正在初始化 Gemma 模型..."):
34
  llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
35
+ st.success("Gemma 模型初始化完成!")
36
  except Exception as e:
37
  st.error(f"Gemma 模型加载失败:{e}")
38
  st.stop()
 
51
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
52
  print('prompt: ' + prompt)
53
 
54
+ st.success("本地数据集筛选完成!")
55
+
56
  with st.spinner("正在生成答案..."):
57
  answer = llm.invoke(prompt)
58
  # 去掉 prompt 的内容
59
  answer = answer.replace(prompt, "").strip()
60
+ st.success("答案已经生成!")
61
  return {"prompt": prompt, "answer": answer}
62
  except Exception as e:
63
  st.error(f"问答过程出错:{e}")