demirali commited on
Commit
3a638f5
·
verified ·
1 Parent(s): 0a1a3d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -5
app.py CHANGED
@@ -1,7 +1,108 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain_chroma import Chroma
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain.chains.question_answering import load_qa_chain
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_groq import ChatGroq
9
+ from dotenv import load_dotenv
10
+ from sentence_transformers import SentenceTransformer
11
 
12
+ st.title("Chatbot")
 
13
 
14
+ # Load environment variables
15
+ load_dotenv()
16
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
17
+ assert GROQ_API_KEY, "GROQ_API_KEY environment variable not set."
18
+
19
+ # One-time setup in session state
20
+ if 'initialized' not in st.session_state:
21
+ st.session_state.initialized = False
22
+
23
+ try:
24
+ with st.spinner("Initializing..."):
25
+ # Initialize embeddings model
26
+ model_path = "sentence-transformers/all-MiniLM-L12-v2" # Use a smaller, faster model
27
+
28
+ st.session_state.embedding_function = HuggingFaceEmbeddings(
29
+ model_name=model_path,
30
+ model_kwargs={'device': 'cpu'},
31
+ encode_kwargs={'normalize_embeddings': False}
32
+ )
33
+
34
+ # Set up document search
35
+ persist_directory = "doc_db"
36
+ st.session_state.docsearch = Chroma(
37
+ persist_directory=persist_directory,
38
+ embedding_function=st.session_state.embedding_function
39
+ )
40
+
41
+ # Initialize ChatGroq model
42
+ st.session_state.chat_model = ChatGroq(
43
+ model="llama-3.1-8b-instant",
44
+ temperature=0,
45
+ api_key=GROQ_API_KEY
46
+ )
47
+
48
+ # Define prompt template and memory
49
+ template = """You are a chatbot having a conversation with a human. Your name is Devrim.
50
+ Given the following extracted parts of a long document and a question, create a final answer. If the answer is not in the document or irrelevant, just say that you don't know, don't try to make up an answer.
51
+ {context}
52
+ {chat_history}
53
+ Human: {human_input}
54
+ Chatbot:"""
55
+
56
+ prompt = PromptTemplate(
57
+ input_variables=["chat_history", "human_input", "context"], template=template
58
+ )
59
+ st.session_state.memory = ConversationBufferMemory(memory_key="chat_history", input_key="human_input")
60
+
61
+ # Load QA chain
62
+ st.session_state.qa_chain = load_qa_chain(
63
+ llm=st.session_state.chat_model,
64
+ chain_type="stuff",
65
+ memory=st.session_state.memory,
66
+ prompt=prompt
67
+ )
68
+
69
+ st.session_state.initialized = True
70
+ st.success("Initialization successful.")
71
+
72
+ except Exception as e:
73
+ st.session_state.initialized = False
74
+ st.error(f"Initialization failed: {e}")
75
+
76
+ # Clear chat history buttons
77
+ if st.button("Clear Chat History"):
78
+ if 'memory' in st.session_state:
79
+ st.session_state.memory.clear()
80
+ st.experimental_rerun() # Refresh the app to reflect the cleared history
81
+
82
+ # Display chat history if initialized
83
+ if st.session_state.initialized and 'memory' in st.session_state:
84
+ if st.session_state.memory.buffer_as_messages:
85
+ for message in st.session_state.memory.buffer_as_messages:
86
+ if message.type == "ai":
87
+ st.chat_message(name="ai", avatar="🤖").write(message.content)
88
+ else:
89
+ st.chat_message(name="human", avatar="👤").write(message.content)
90
+
91
+ # Input for new query
92
+ query = st.chat_input("Ask something")
93
+ if query:
94
+ try:
95
+ with st.spinner("Answering..."):
96
+ # Perform similarity search and get response
97
+ docs = st.session_state.docsearch.similarity_search(query, k=1) # Reduced k for speed
98
+ response = st.session_state.qa_chain(
99
+ {"input_documents": docs, "human_input": query},
100
+ return_only_outputs=True
101
+ )["output_text"]
102
+
103
+ # Display new message
104
+ st.chat_message(name="human", avatar="👤").write(query)
105
+ st.chat_message(name="ai", avatar="🤖").write(response)
106
+
107
+ except Exception as e:
108
+ st.error(f"An error occurred: {e}")