Spaces:
Runtime error
Runtime error
"""Python file to serve as the frontend""" | |
import streamlit as st | |
from streamlit_chat import message | |
from langchain.chains import VectorDBQAWithSourcesChain | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
st.set_page_config(page_title="D&D 🗡️ Spell QA Bot", page_icon="🗡️") | |
# Load the LangChain. | |
system_template = """Use the following pieces of context to answer the users question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
ALWAYS return a "SOURCES" part in your answer. | |
The "SOURCES" part should be a reference to the source of the document from which you got your answer. | |
Example of your response should be: | |
``` | |
The answer is foo | |
SOURCES: xyz | |
``` | |
Begin! | |
---------------- | |
{summaries}""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}"), | |
] | |
prompt = ChatPromptTemplate.from_messages(messages) | |
def load_chroma(): | |
persist_directory = "db_spells" | |
embeddings = OpenAIEmbeddings() | |
vectordb = Chroma( | |
persist_directory=persist_directory, embedding_function=embeddings | |
) | |
return vectordb | |
vectordb = load_chroma() | |
chain_type_kwargs = {"prompt": prompt} | |
chain = VectorDBQAWithSourcesChain.from_chain_type( | |
ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0), | |
chain_type="stuff", | |
vectorstore=vectordb, | |
chain_type_kwargs=chain_type_kwargs, | |
) | |
# From here down is all the StreamLit UI. | |
st.header("D&D 🗡️ Spell QA Bot") | |
st.markdown( | |
""" | |
This is a chatbot that can answer questions about **Dungeon and Dragons spells** based on this [database](https://www.aidedd.org/dnd-filters/spells-5e.php) and built with LangChain and OpenAI API. | |
The creator of this bot is **[Corentin Meyer (@corentinm_py)](https://twitter.com/corentinm_py)**. | |
Try by yourself by typing something like: "What's the size of tsunami spell ?" | |
""" | |
) | |
if "generated" not in st.session_state: | |
st.session_state["generated"] = [] | |
if "past" not in st.session_state: | |
st.session_state["past"] = [] | |
def get_text(): | |
input_text = st.text_input( | |
"You: ", "What's the size of tsunami spell ?", key="input" | |
) | |
return input_text | |
user_input = get_text() | |
if user_input: | |
result = chain( | |
{"question": user_input}, | |
return_only_outputs=True, | |
) | |
output = f"Answer: {result['answer']}\nSources: {result['sources']}" | |
st.session_state.past.append(user_input) | |
st.session_state.generated.append(output) | |
if st.session_state["generated"]: | |
for i in range(len(st.session_state["generated"]) - 1, -1, -1): | |
message(st.session_state["generated"][i], key=str(i)) | |
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") | |