sosa123454321 commited on
Commit
184cc2a
Β·
verified Β·
1 Parent(s): 1784337

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from typing import Generator
3
+ from groq import Groq
4
+
5
+ st.set_page_config(page_icon="πŸ’¬", layout="wide", page_title="Groq Goes Brrrrrrrr...")
6
+
7
+
8
+ def icon(emoji: str):
9
+ """Shows an emoji as a Notion-style page icon."""
10
+ st.write(
11
+ f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
12
+ unsafe_allow_html=True,
13
+ )
14
+
15
+
16
+ icon("🏎️")
17
+
18
+ st.subheader("Groq Chat Streamlit App", divider="rainbow", anchor=False)
19
+
20
+ client = Groq(
21
+ api_key=st.secrets["GROQ_API_KEY"],
22
+ )
23
+
24
+ # Initialize chat history and selected model
25
+ if "messages" not in st.session_state:
26
+ st.session_state.messages = []
27
+
28
+ if "selected_model" not in st.session_state:
29
+ st.session_state.selected_model = None
30
+
31
+ # Define model details
32
+ models = {
33
+ "llama3-70b-8192": {
34
+ "name": "LLaMA3-70b-Instruct",
35
+ "tokens": 8192,
36
+ "developer": "Meta",
37
+ },
38
+ "llama3-8b-8192": {
39
+ "name": "LLaMA3-8b-Instruct",
40
+ "tokens": 8192,
41
+ "developer": "Meta",
42
+ },
43
+ "mixtral-8x7b-32768": {
44
+ "name": "Mixtral-8x7b-Instruct-v0.1",
45
+ "tokens": 32768,
46
+ "developer": "Mistral",
47
+ },
48
+ "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
49
+ }
50
+
51
+ # Layout for model selection and max_tokens slider
52
+ col1, col2 = st.columns(2)
53
+
54
+ with col1:
55
+ model_option = st.selectbox(
56
+ "Choose a model:",
57
+ options=list(models.keys()),
58
+ format_func=lambda x: models[x]["name"],
59
+ index=0, # Default to the first model in the list
60
+ )
61
+
62
+ # Detect model change and clear chat history if model has changed
63
+ if st.session_state.selected_model != model_option:
64
+ st.session_state.messages = []
65
+ st.session_state.selected_model = model_option
66
+
67
+ max_tokens_range = models[model_option]["tokens"]
68
+
69
+ with col2:
70
+ # Adjust max_tokens slider dynamically based on the selected model
71
+ max_tokens = st.slider(
72
+ "Max Tokens:",
73
+ min_value=512, # Minimum value to allow some flexibility
74
+ max_value=max_tokens_range,
75
+ # Default value or max allowed if less
76
+ value=min(32768, max_tokens_range),
77
+ step=512,
78
+ help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
79
+ )
80
+
81
+ # Display chat messages from history on app rerun
82
+ for message in st.session_state.messages:
83
+ avatar = "πŸ€–" if message["role"] == "assistant" else "πŸ‘¨β€πŸ’»"
84
+ with st.chat_message(message["role"], avatar=avatar):
85
+ st.markdown(message["content"])
86
+
87
+
88
+ def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
89
+ """Yield chat response content from the Groq API response."""
90
+ for chunk in chat_completion:
91
+ if chunk.choices[0].delta.content:
92
+ yield chunk.choices[0].delta.content
93
+
94
+
95
+ if prompt := st.chat_input("Enter your prompt here..."):
96
+ st.session_state.messages.append({"role": "user", "content": prompt})
97
+
98
+ with st.chat_message("user", avatar="πŸ‘¨β€πŸ’»"):
99
+ st.markdown(prompt)
100
+
101
+ # Fetch response from Groq API
102
+ try:
103
+ chat_completion = client.chat.completions.create(
104
+ model=model_option,
105
+ messages=[
106
+ {"role": m["role"], "content": m["content"]}
107
+ for m in st.session_state.messages
108
+ ],
109
+ max_tokens=max_tokens,
110
+ stream=True,
111
+ )
112
+
113
+ # Use the generator function with st.write_stream
114
+ with st.chat_message("assistant", avatar="πŸ€–"):
115
+ chat_responses_generator = generate_chat_responses(chat_completion)
116
+ full_response = st.write_stream(chat_responses_generator)
117
+ except Exception as e:
118
+ st.error(e, icon="🚨")
119
+
120
+ # Append the full response to session_state.messages
121
+ if isinstance(full_response, str):
122
+ st.session_state.messages.append(
123
+ {"role": "assistant", "content": full_response}
124
+ )
125
+ else:
126
+ # Handle the case where full_response is not a string
127
+ combined_response = "\n".join(str(item) for item in full_response)
128
+ st.session_state.messages.append(
129
+ {"role": "assistant", "content": combined_response}
130
+ )