Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import datetime as DT | |
import pytz | |
import time | |
import json | |
import re | |
from typing import List | |
from transformers import AutoTokenizer | |
from gradio_client import Client | |
from tools import toolsInfo | |
from dotenv import load_dotenv | |
load_dotenv() | |
useGpt4 = os.environ.get("USE_GPT_4") == "1" | |
if useGpt4: | |
from openai import OpenAI | |
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
MODEL = "gpt-4o-mini" | |
MAX_CONTEXT = 128000 | |
tokenizer = AutoTokenizer.from_pretrained("Xenova/gpt-4o") | |
else: | |
from groq import Groq | |
client = Groq( | |
api_key=os.environ.get("GROQ_API_KEY"), | |
) | |
MODEL = "llama-3.1-70b-versatile" | |
MODEL = "llama3-groq-70b-8192-tool-use-preview" | |
MAX_CONTEXT = 8000 | |
tokenizer = AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") | |
def countTokens(text): | |
text = str(text) | |
tokens = tokenizer.encode(text, add_special_tokens=False) | |
return len(tokens) | |
SYSTEM_MSG = f""" | |
You are a personalized email generator for cold outreach. You take the user through a workflow. Step by step. | |
- You ask for industry of the recipient | |
- His/her role | |
- More details about the recipient | |
Highlight the exact entity you're requesting for. | |
Once collected, you store these info in a Google Sheet | |
""" | |
USER_ICON = "icons/man.png" | |
ASSISTANT_ICON = "icons/magic-wand(1).png" | |
TOOL_ICON = "icons/completed-task.png" | |
IMAGE_LOADER = "icons/ripple.svg" | |
TEXT_LOADER = "icons/balls.svg" | |
START_MSG = "Let's start π" | |
ROLE_TO_AVATAR = { | |
"user": USER_ICON, | |
"assistant": ASSISTANT_ICON, | |
"tool": TOOL_ICON, | |
} | |
st.set_page_config( | |
page_title="EmailGenie", | |
page_icon=ASSISTANT_ICON, | |
) | |
ipAddress = st.context.headers.get("x-forwarded-for") | |
def __nowInIST() -> DT.datetime: | |
return DT.datetime.now(pytz.timezone("Asia/Kolkata")) | |
def pprint(log: str): | |
now = __nowInIST() | |
now = now.strftime("%Y-%m-%d %H:%M:%S") | |
print(f"[{now}] [{ipAddress}] {log}") | |
pprint("\n") | |
st.markdown( | |
""" | |
<style> | |
@keyframes blinker { | |
0% { | |
opacity: 1; | |
} | |
50% { | |
opacity: 0.2; | |
} | |
100% { | |
opacity: 1; | |
} | |
} | |
.blinking { | |
animation: blinker 3s ease-out infinite; | |
} | |
.code { | |
color: green; | |
border-radius: 3px; | |
padding: 2px 4px; /* Padding around the text */ | |
font-family: 'Courier New', Courier, monospace; /* Monospace font */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
def __isInvalidResponse(response: str): | |
# new line followed by small case char | |
if len(re.findall(r'\n[a-z]', response)) > 3: | |
return True | |
# lot of repeating words | |
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: | |
return True | |
# lots of paragraphs | |
if len(re.findall(r'\n\n', response)) > 15: | |
return True | |
def __matchingKeywordsCount(keywords: List[str], text: str): | |
return sum([ | |
1 if keyword in text else 0 | |
for keyword in keywords | |
]) | |
def __isStringNumber(s: str) -> bool: | |
try: | |
float(s) | |
return True | |
except ValueError: | |
return False | |
def __resetButtonState(): | |
st.session_state["buttonValue"] = "" | |
def __setStartMsg(msg): | |
st.session_state.startMsg = msg | |
if "chatHistory" not in st.session_state: | |
st.session_state.chatHistory = [] | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "buttonValue" not in st.session_state: | |
__resetButtonState() | |
if "startMsg" not in st.session_state: | |
st.session_state.startMsg = "" | |
def __getMessages(): | |
def getContextSize(): | |
currContextSize = countTokens(SYSTEM_MSG) + countTokens(st.session_state.messages) + 100 | |
pprint(f"{currContextSize=}") | |
return currContextSize | |
while getContextSize() > MAX_CONTEXT: | |
pprint("Context size exceeded, removing first message") | |
st.session_state.messages.pop(0) | |
return st.session_state.messages | |
tools = [ | |
toolsInfo["saveInGSheet"]["schema"] | |
] | |
def __showTaskStatus(msg): | |
taskContainer = st.container() | |
taskContainer.image(TOOL_ICON, width=30) | |
taskContainer.markdown( | |
f""" | |
<div class='code'> | |
{msg} | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
def __processToolCalls(tool_calls): | |
for toolCall in tool_calls: | |
functionName = toolCall.function.name | |
functionToCall = toolsInfo[functionName]["func"] | |
functionArgs = json.loads(toolCall.function.arguments) | |
functionResult = functionToCall(**functionArgs) | |
functionResponse = functionResult["response"] | |
shouldShow = functionResult["shouldShow"] | |
pprint(f"{functionResponse=}") | |
if shouldShow: | |
st.session_state.chatHistory.append( | |
{ | |
"role": "tool", | |
"content": functionResponse, | |
} | |
) | |
__showTaskStatus(functionResponse) | |
st.session_state.messages.append( | |
{ | |
"role": "tool", | |
"tool_call_id": toolCall.id, | |
"name": functionName, | |
"content": functionResponse, | |
} | |
) | |
def __process_stream_chunk(chunk): | |
delta = chunk.choices[0].delta | |
if delta.content: | |
return delta.content | |
elif delta.tool_calls: | |
return delta.tool_calls[0] | |
return None | |
def __addToolCallsToMsgs(toolCalls): | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"tool_calls": [ | |
{ | |
"id": toolCall.id, | |
"function": { | |
"name": toolCall.function.name, | |
"arguments": toolCall.function.arguments, | |
}, | |
"type": toolCall.type, | |
} | |
for toolCall in toolCalls | |
], | |
} | |
) | |
def __add_tool_call(tool_call): | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"tool_calls": [{ | |
"id": tool_call.id, | |
"function": { | |
"name": tool_call.function.name, | |
"arguments": tool_call.function.arguments, | |
}, | |
"type": tool_call.type, | |
}] | |
}) | |
def predict1(): | |
shouldStream = True | |
messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}] | |
messagesFormatted.extend(__getMessages()) | |
contextSize = countTokens(messagesFormatted) | |
pprint(f"{contextSize=} | {MODEL}") | |
response = client.chat.completions.create( | |
model=MODEL, | |
messages=messagesFormatted, | |
temperature=0.8, | |
max_tokens=4000, | |
stream=shouldStream, | |
tools=tools | |
) | |
content = "" | |
tool_call = None | |
for chunk in response: | |
chunk_content = __process_stream_chunk(chunk) | |
if isinstance(chunk_content, str): | |
content += chunk_content | |
yield chunk_content | |
elif chunk_content: | |
if not tool_call: | |
tool_call = chunk_content | |
else: | |
tool_call.function.arguments += chunk_content.function.arguments | |
if tool_call: | |
pprint(f"{tool_call=}") | |
__addToolCallsToMsgs([tool_call]) | |
try: | |
__processToolCalls([tool_call]) | |
return predict() | |
except Exception as e: | |
pprint(e) | |
def __dedupeToolCalls(toolCalls: list): | |
toolCallsDict = {} | |
for toolCall in toolCalls: | |
toolCallsDict[toolCall.function.name] = toolCall | |
dedupedToolCalls = list(toolCallsDict.values()) | |
if len(toolCalls) != len(dedupedToolCalls): | |
pprint("Deduped tool calls!") | |
pprint(f"{toolCalls=} -> {dedupedToolCalls=}") | |
return dedupedToolCalls | |
def predict(): | |
shouldStream = False | |
messagesFormatted = [{"role": "system", "content": SYSTEM_MSG}] | |
messagesFormatted.extend(__getMessages()) | |
contextSize = countTokens(messagesFormatted) | |
pprint(f"{contextSize=} | {MODEL}") | |
pprint(f"{messagesFormatted=}") | |
response = client.chat.completions.create( | |
model=MODEL, | |
messages=messagesFormatted, | |
temperature=0.8, | |
max_tokens=4000, | |
stream=shouldStream, | |
tools=tools | |
) | |
# pprint(f"llmResponse: {response}") | |
if shouldStream: | |
content = "" | |
toolCall = None | |
for chunk in response: | |
chunkContent = __process_stream_chunk(chunk) | |
if isinstance(chunkContent, str): | |
content += chunkContent | |
yield chunkContent | |
elif chunkContent: | |
if not toolCall: | |
toolCall = chunkContent | |
else: | |
toolCall.function.arguments += chunkContent.function.arguments | |
toolCalls = [toolCall] if toolCall else [] | |
else: | |
responseMessage = response.choices[0].message | |
# pprint(f"{responseMessage=}") | |
responseContent = responseMessage.content | |
pprint(f"{responseContent=}") | |
if responseContent: | |
yield responseContent | |
toolCalls = responseMessage.tool_calls | |
# pprint(f"{toolCalls=}") | |
if toolCalls: | |
pprint(f"{toolCalls=}") | |
toolCalls = __dedupeToolCalls(toolCalls) | |
__addToolCallsToMsgs(toolCalls) | |
try: | |
__processToolCalls(toolCalls) | |
return predict() | |
except Exception as e: | |
pprint(e) | |
st.title("EmailGenie π") | |
if not (st.session_state["buttonValue"] or st.session_state["startMsg"]): | |
st.button(START_MSG, on_click=lambda: __setStartMsg(START_MSG)) | |
for chat in st.session_state.chatHistory: | |
role = chat["role"] | |
content = chat["content"] | |
imagePath = chat.get("image") | |
avatar = ROLE_TO_AVATAR[role] | |
with st.chat_message(role, avatar=avatar): | |
if role == "tool": | |
st.markdown( | |
f""" | |
<div class='code'> | |
{content} | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
else: | |
st.markdown(content) | |
if imagePath: | |
st.image(imagePath) | |
if prompt := (st.chat_input() or st.session_state["buttonValue"] or st.session_state["startMsg"]): | |
__resetButtonState() | |
__setStartMsg("") | |
with st.chat_message("user", avatar=USER_ICON): | |
st.markdown(prompt) | |
pprint(f"{prompt=}") | |
st.session_state.messages.append({"role": "user", "content": prompt }) | |
st.session_state.chatHistory.append({"role": "user", "content": prompt }) | |
with st.chat_message("assistant", avatar=ASSISTANT_ICON): | |
responseContainer = st.empty() | |
def __printAndGetResponse(): | |
response = "" | |
# responseContainer.markdown(".....") | |
responseContainer.image(TEXT_LOADER) | |
responseGenerator = predict() | |
for chunk in responseGenerator: | |
response += chunk | |
if __isInvalidResponse(response): | |
pprint(f"{response=}") | |
return | |
responseContainer.markdown(response) | |
return response | |
response = __printAndGetResponse() | |
while not response: | |
pprint("Empty response. Retrying..") | |
time.sleep(0.5) | |
response = __printAndGetResponse() | |
pprint(f"{response=}") | |
def selectButton(optionLabel): | |
st.session_state["buttonValue"] = optionLabel | |
pprint(f"Selected: {optionLabel}") | |
# responseParts = response.split(JSON_SEPARATOR) | |
# jsonStr = None | |
# if len(responseParts) > 1: | |
# [response, jsonStr] = responseParts | |
# if jsonStr: | |
# try: | |
# json.loads(jsonStr) | |
# jsonObj = json.loads(jsonStr) | |
# options = jsonObj["options"] | |
# for option in options: | |
# st.button( | |
# option["label"], | |
# key=option["id"], | |
# on_click=lambda label=option["label"]: selectButton(label) | |
# ) | |
# # st.code(jsonStr, language="json") | |
# except Exception as e: | |
# pprint(e) | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": response, | |
}) | |
st.session_state.chatHistory.append({ | |
"role": "assistant", | |
"content": response, | |
}) | |