miaohaiyuan commited on
Commit
9cf8e68
β€’
1 Parent(s): 296d10f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from typing import Generator
3
+ from groq import Groq
4
+
5
+ st.set_page_config(page_icon="πŸ’¬", layout="wide", page_title="Groq Goes Brrrrrrrr...")
6
+
7
+
8
+ def icon(emoji: str):
9
+ """Shows an emoji as a Notion-style page icon."""
10
+ st.write(
11
+ f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
12
+ unsafe_allow_html=True,
13
+ )
14
+
15
+
16
+ icon("🏎️")
17
+
18
+ st.subheader("Groq Chat Streamlit App", divider="rainbow", anchor=False)
19
+
20
+ client = Groq(
21
+ api_key=st.secrets["GROQ_API_KEY"],
22
+ )
23
+
24
+ # Initialize chat history and selected model
25
+ if "messages" not in st.session_state:
26
+ st.session_state.messages = []
27
+
28
+ if "selected_model" not in st.session_state:
29
+ st.session_state.selected_model = None
30
+
31
+ # Define model details
32
+ models = {
33
+ "mixtral-8x7b-32768": {
34
+ "name": "Mixtral-8x7b-Instruct-v0.1",
35
+ "tokens": 32768,
36
+ "developer": "Mistral",
37
+ },
38
+ "llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
39
+ "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
40
+ }
41
+
42
+ # Layout for model selection and max_tokens slider
43
+ col1, col2 = st.columns(2)
44
+
45
+ with col1:
46
+ model_option = st.selectbox(
47
+ "Choose a model:",
48
+ options=list(models.keys()),
49
+ format_func=lambda x: models[x]["name"],
50
+ index=0, # Default to the first model in the list
51
+ )
52
+
53
+ # Detect model change and clear chat history if model has changed
54
+ if st.session_state.selected_model != model_option:
55
+ st.session_state.messages = []
56
+ st.session_state.selected_model = model_option
57
+
58
+ max_tokens_range = models[model_option]["tokens"]
59
+
60
+ with col2:
61
+ # Adjust max_tokens slider dynamically based on the selected model
62
+ max_tokens = st.slider(
63
+ "Max Tokens:",
64
+ min_value=512, # Minimum value to allow some flexibility
65
+ max_value=max_tokens_range,
66
+ # Default value or max allowed if less
67
+ value=min(32768, max_tokens_range),
68
+ step=512,
69
+ help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
70
+ )
71
+
72
+ # Display chat messages from history on app rerun
73
+ for message in st.session_state.messages:
74
+ avatar = "πŸ€–" if message["role"] == "assistant" else "πŸ‘¨β€πŸ’»"
75
+ with st.chat_message(message["role"], avatar=avatar):
76
+ st.markdown(message["content"])
77
+
78
+
79
+ def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
80
+ """Yield chat response content from the Groq API response."""
81
+ for chunk in chat_completion:
82
+ if chunk.choices[0].delta.content:
83
+ yield chunk.choices[0].delta.content
84
+
85
+
86
+ if prompt := st.chat_input("Enter your prompt here..."):
87
+ st.session_state.messages.append({"role": "user", "content": prompt})
88
+
89
+ with st.chat_message("user", avatar="πŸ‘¨β€πŸ’»"):
90
+ st.markdown(prompt)
91
+
92
+ # Fetch response from Groq API
93
+ try:
94
+ chat_completion = client.chat.completions.create(
95
+ model=model_option,
96
+ messages=[
97
+ {"role": m["role"], "content": m["content"]}
98
+ for m in st.session_state.messages
99
+ ],
100
+ max_tokens=max_tokens,
101
+ stream=True,
102
+ )
103
+
104
+ # Use the generator function with st.write_stream
105
+ with st.chat_message("assistant", avatar="πŸ€–"):
106
+ chat_responses_generator = generate_chat_responses(chat_completion)
107
+ full_response = st.write_stream(chat_responses_generator)
108
+ except Exception as e:
109
+ st.error(e, icon="🚨")
110
+
111
+ # Append the full response to session_state.messages
112
+ if isinstance(full_response, str):
113
+ st.session_state.messages.append(
114
+ {"role": "assistant", "content": full_response}
115
+ )
116
+ else:
117
+ # Handle the case where full_response is not a string
118
+ combined_response = "\n".join(str(item) for item in full_response)
119
+ st.session_state.messages.append(
120
+ {"role": "assistant", "content": combined_response}
121
+ )