Ashhar
gsheet save tool
061239f
raw
history blame
12.6 kB
import streamlit as st
import os
import datetime as DT
import pytz
import time
import json
import re
from typing import List
from transformers import AutoTokenizer
from gradio_client import Client
from tools import toolsInfo
from dotenv import load_dotenv
load_dotenv()
useGpt4 = os.environ.get("USE_GPT_4") == "1"
if useGpt4:
from openai import OpenAI
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
MODEL = "gpt-4o-mini"
MAX_CONTEXT = 128000
tokenizer = AutoTokenizer.from_pretrained("Xenova/gpt-4o")
else:
from groq import Groq
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
MODEL = "llama-3.1-70b-versatile"
MODEL = "llama3-groq-70b-8192-tool-use-preview"
MAX_CONTEXT = 8000
tokenizer = AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
def countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
SYSTEM_MSG = f"""
You are a personalized email generator for cold outreach. You take the user through a workflow. Step by step.
- You ask for industry of the recipient
- His/her role
- More details about the recipient
Highlight the exact entity you're requesting for.
Once collected, you store these info in a Google Sheet
"""
USER_ICON = "icons/man.png"
ASSISTANT_ICON = "icons/magic-wand(1).png"
TOOL_ICON = "icons/completed-task.png"
IMAGE_LOADER = "icons/ripple.svg"
TEXT_LOADER = "icons/balls.svg"
START_MSG = "Let's start 😊"
ROLE_TO_AVATAR = {
"user": USER_ICON,
"assistant": ASSISTANT_ICON,
"tool": TOOL_ICON,
}
st.set_page_config(
page_title="EmailGenie",
page_icon=ASSISTANT_ICON,
)
ipAddress = st.context.headers.get("x-forwarded-for")
def __nowInIST() -> DT.datetime:
return DT.datetime.now(pytz.timezone("Asia/Kolkata"))
def pprint(log: str):
now = __nowInIST()
now = now.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{now}] [{ipAddress}] {log}")
pprint("\n")
st.markdown(
"""
<style>
@keyframes blinker {
0% {
opacity: 1;
}
50% {
opacity: 0.2;
}
100% {
opacity: 1;
}
}
.blinking {
animation: blinker 3s ease-out infinite;
}
.code {
color: green;
border-radius: 3px;
padding: 2px 4px; /* Padding around the text */
font-family: 'Courier New', Courier, monospace; /* Monospace font */
}
</style>
""",
unsafe_allow_html=True
)
def __isInvalidResponse(response: str):
# new line followed by small case char
if len(re.findall(r'\n[a-z]', response)) > 3:
return True
# lot of repeating words
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
return True
# lots of paragraphs
if len(re.findall(r'\n\n', response)) > 15:
return True
def __matchingKeywordsCount(keywords: List[str], text: str):
return sum([
1 if keyword in text else 0
for keyword in keywords
])
def __isStringNumber(s: str) -> bool:
try:
float(s)
return True
except ValueError:
return False
def __resetButtonState():
st.session_state["buttonValue"] = ""
def __setStartMsg(msg):
st.session_state.startMsg = msg
if "chatHistory" not in st.session_state:
st.session_state.chatHistory = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "buttonValue" not in st.session_state:
__resetButtonState()
if "startMsg" not in st.session_state:
st.session_state.startMsg = ""
def __getMessages():
def getContextSize():
currContextSize = countTokens(SYSTEM_MSG) + countTokens(st.session_state.messages) + 100
pprint(f"{currContextSize=}")
return currContextSize
while getContextSize() > MAX_CONTEXT:
pprint("Context size exceeded, removing first message")
st.session_state.messages.pop(0)
return st.session_state.messages
tools = [
toolsInfo["saveInGSheet"]["schema"]
]
def __showTaskStatus(msg):
taskContainer = st.container()
taskContainer.image(TOOL_ICON, width=30)
taskContainer.markdown(
f"""
<div class='code'>
{msg}
</div>
""",
unsafe_allow_html=True
)
def __processToolCalls(tool_calls):
for toolCall in tool_calls:
functionName = toolCall.function.name
functionToCall = toolsInfo[functionName]["func"]
functionArgs = json.loads(toolCall.function.arguments)
functionResult = functionToCall(**functionArgs)
functionResponse = functionResult["response"]
shouldShow = functionResult["shouldShow"]
pprint(f"{functionResponse=}")
if shouldShow:
st.session_state.chatHistory.append(
{
"role": "tool",
"content": functionResponse,
}
)
__showTaskStatus(functionResponse)
st.session_state.messages.append(
{
"role": "tool",
"tool_call_id": toolCall.id,
"name": functionName,
"content": functionResponse,
}
)
def __process_stream_chunk(chunk):
delta = chunk.choices[0].delta
if delta.content:
return delta.content
elif delta.tool_calls:
return delta.tool_calls[0]
return None
def __addToolCallsToMsgs(toolCalls):
st.session_state.messages.append(
{
"role": "assistant",
"tool_calls": [
{
"id": toolCall.id,
"function": {
"name": toolCall.function.name,
"arguments": toolCall.function.arguments,
},
"type": toolCall.type,
}
for toolCall in toolCalls
],
}
)
def __add_tool_call(tool_call):
st.session_state.messages.append({
"role": "assistant",
"tool_calls": [{
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": tool_call.type,
}]
})
def predict1():
shouldStream = True
messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
messagesFormatted.extend(__getMessages())
contextSize = countTokens(messagesFormatted)
pprint(f"{contextSize=} | {MODEL}")
response = client.chat.completions.create(
model=MODEL,
messages=messagesFormatted,
temperature=0.8,
max_tokens=4000,
stream=shouldStream,
tools=tools
)
content = ""
tool_call = None
for chunk in response:
chunk_content = __process_stream_chunk(chunk)
if isinstance(chunk_content, str):
content += chunk_content
yield chunk_content
elif chunk_content:
if not tool_call:
tool_call = chunk_content
else:
tool_call.function.arguments += chunk_content.function.arguments
if tool_call:
pprint(f"{tool_call=}")
__addToolCallsToMsgs([tool_call])
try:
__processToolCalls([tool_call])
return predict()
except Exception as e:
pprint(e)
def __dedupeToolCalls(toolCalls: list):
toolCallsDict = {}
for toolCall in toolCalls:
toolCallsDict[toolCall.function.name] = toolCall
dedupedToolCalls = list(toolCallsDict.values())
if len(toolCalls) != len(dedupedToolCalls):
pprint("Deduped tool calls!")
pprint(f"{toolCalls=} -> {dedupedToolCalls=}")
return dedupedToolCalls
def predict():
shouldStream = False
messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}]
messagesFormatted.extend(__getMessages())
contextSize = countTokens(messagesFormatted)
pprint(f"{contextSize=} | {MODEL}")
pprint(f"{messagesFormatted=}")
response = client.chat.completions.create(
model=MODEL,
messages=messagesFormatted,
temperature=0.8,
max_tokens=4000,
stream=shouldStream,
tools=tools
)
# pprint(f"llmResponse: {response}")
if shouldStream:
content = ""
toolCall = None
for chunk in response:
chunkContent = __process_stream_chunk(chunk)
if isinstance(chunkContent, str):
content += chunkContent
yield chunkContent
elif chunkContent:
if not toolCall:
toolCall = chunkContent
else:
toolCall.function.arguments += chunkContent.function.arguments
toolCalls = [toolCall] if toolCall else []
else:
responseMessage = response.choices[0].message
# pprint(f"{responseMessage=}")
responseContent = responseMessage.content
pprint(f"{responseContent=}")
if responseContent:
yield responseContent
toolCalls = responseMessage.tool_calls
# pprint(f"{toolCalls=}")
if toolCalls:
pprint(f"{toolCalls=}")
toolCalls = __dedupeToolCalls(toolCalls)
__addToolCallsToMsgs(toolCalls)
try:
__processToolCalls(toolCalls)
return predict()
except Exception as e:
pprint(e)
st.title("EmailGenie πŸ’Œ")
if not (st.session_state["buttonValue"] or st.session_state["startMsg"]):
st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG))
for chat in st.session_state.chatHistory:
role = chat["role"]
content = chat["content"]
imagePath = chat.get("image")
avatar = ROLE_TO_AVATAR[role]
with st.chat_message(role, avatar=avatar):
if role == "tool":
st.markdown(
f"""
<div class='code'>
{content}
</div>
""",
unsafe_allow_html=True
)
else:
st.markdown(content)
if imagePath:
st.image(imagePath)
if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_state["startMsg"]):
__resetButtonState()
__setStartMsg("")
with st.chat_message("user", avatar=USER_ICON):
st.markdown(prompt)
pprint(f"{prompt=}")
st.session_state.messages.append({"role": "user", "content": prompt })
st.session_state.chatHistory.append({"role": "user", "content": prompt })
with st.chat_message("assistant", avatar=ASSISTANT_ICON):
responseContainer = st.empty()
def __printAndGetResponse():
response = ""
# responseContainer.markdown(".....")
responseContainer.image(TEXT_LOADER)
responseGenerator = predict()
for chunk in responseGenerator:
response += chunk
if __isInvalidResponse(response):
pprint(f"{response=}")
return
responseContainer.markdown(response)
return response
response = __printAndGetResponse()
while not response:
pprint("Empty response. Retrying..")
time.sleep(0.5)
response = __printAndGetResponse()
pprint(f"{response=}")
def selectButton(optionLabel):
st.session_state["buttonValue"] = optionLabel
pprint(f"Selected: {optionLabel}")
# responseParts = response.split(JSON_SEPARATOR)
# jsonStr = None
# if len(responseParts) > 1:
# [response, jsonStr] = responseParts
# if jsonStr:
# try:
# json.loads(jsonStr)
# jsonObj = json.loads(jsonStr)
# options = jsonObj["options"]
# for option in options:
# st.button(
# option["label"],
# key=option["id"],
# on_click=lambda label=option["label"]: selectButton(label)
# )
# # st.code(jsonStr, language="json")
# except Exception as e:
# pprint(e)
st.session_state.messages.append({
"role": "assistant",
"content": response,
})
st.session_state.chatHistory.append({
"role": "assistant",
"content": response,
})