Ashhar
feedbacks suggested by Kommuneity
71a2c91
raw
history blame
14 kB
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
import os
import time
import json
import re
from typing import List, Literal, TypedDict, Tuple
from transformers import AutoTokenizer
from gradio_client import Client
import constants as C
import utils as U
from openai import OpenAI
import anthropic
from groq import Groq
from dotenv import load_dotenv
load_dotenv()
ModelType = Literal["GPT4", "CLAUDE", "LLAMA"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | Groq | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE"
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"GPT4": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o-mini",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"CLAUDE": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-5-sonnet-20240620",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"LLAMA": {
"client": Groq(api_key=os.environ.get("GROQ_API_KEY")),
"model": "llama-3.1-70b-versatile",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
}
}
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType == "CLAUDE"
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
st.set_page_config(
page_title="Kommuneity Story Creator",
page_icon=C.AI_ICON,
# menu_items={"About": None}
)
def __isInvalidResponse(response: str):
if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response:
U.pprint("new line followed by small case char")
return True
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
U.pprint("lot of consecutive repeating words")
return True
if len(re.findall(r'\n\n', response)) > 30:
U.pprint("lots of paragraphs")
return True
if C.EXCEPTION_KEYWORD in response:
U.pprint("LLM API threw exception")
if 'roles must alternate between "user" and "assistant"' in str(response):
U.pprint("Removing last msg from context...")
st.session_state.messages.pop(-2)
return True
if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if response.startswith(C.JSON_SEPARATOR):
U.pprint("only options with no text")
return True
def __matchingKeywordsCount(keywords: List[str], text: str):
return sum([
1 if keyword in text else 0
for keyword in keywords
])
def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]:
regex = r'[^a-z0-9 \n\.\-\:\/]|((the) +)'
cleanedResponse = re.sub(regex, '', response.lower())
U.pprint(f"{cleanedResponse=}")
cleanedPrompt = re.sub(regex, '', prompt.lower())
if (st.session_state.selectedStory):
imageText = st.session_state.selectedStory
return (
f"Extract the story from this text and add few more details about this story:\n{imageText}",
"Effect: dramatic, bokeh",
"Painting your story character ...",
)
if (
__matchingKeywordsCount(
[C.BOOKING_LINK],
cleanedResponse
) > 0
and "storytelling coach" not in cleanedPrompt
):
aiResponses = [
chat.get("content")
for chat in st.session_state.chatHistory
if chat.get("role") == "assistant"
]
relevantResponse = f"""
{aiResponses[-1]}
{response}
"""
return (
f"Extract the story from this text:\n{relevantResponse}",
"""
Style: In a storybook, surreal
""",
"Imagining your story scene ...",
)
return (None, None, None)
def __getImagePromptDetails(prompt: str, response: str):
(enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response)
if imagePrompt or enhancePrompt:
# U.pprint(f"[Raw] {enhancePrompt=} | {imagePrompt=}")
promptEnhanceModelType: ModelType = "LLAMA"
U.pprint(f"{promptEnhanceModelType=}")
modelConfig = MODEL_CONFIG[promptEnhanceModelType]
client = modelConfig["client"]
model = modelConfig["model"]
isClaudeModel = promptEnhanceModelType == "CLAUDE"
systemPrompt = "You help in creating prompts for image generation"
promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to"
enhancePrompt = f"""
{promptPrefix} create a prompt for image generation.
{imagePrompt}
Return only the final Image Generation Prompt, and nothing else
"""
U.pprint(f"[Raw] {enhancePrompt=}")
llmArgs = {
"model": model,
"messages": [{
"role": "user",
"content": enhancePrompt
}],
"temperature": 1,
"max_tokens": 2000
}
if isClaudeModel:
llmArgs["system"] = systemPrompt
response = client.messages.create(**llmArgs)
imagePrompt = response.content[0].text
else:
llmArgs["messages"] = [
{"role": "system", "content": systemPrompt},
*llmArgs["messages"]
]
response = client.chat.completions.create(**llmArgs)
responseMessage = response.choices[0].message
imagePrompt = responseMessage.content
U.pprint(f"[Enhanced] {imagePrompt=}")
return (imagePrompt, loaderText)
def __getMessages():
def getContextSize():
currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100
U.pprint(f"{currContextSize=}")
return currContextSize
while getContextSize() > MAX_CONTEXT:
U.pprint("Context size exceeded, removing first message")
st.session_state.messages.pop(0)
return st.session_state.messages
def __logLlmRequest(messagesFormatted: list):
contextSize = __countTokens(messagesFormatted)
U.pprint(f"{contextSize=} | {MODEL}")
# U.pprint(f"{messagesFormatted=}")
def predict():
messagesFormatted = []
try:
if isClaudeModel:
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted)
with client.messages.stream(
model=MODEL,
messages=messagesFormatted,
system=C.SYSTEM_MSG,
temperature=0.9,
max_tokens=4000,
) as stream:
for text in stream.text_stream:
yield text
else:
messagesFormatted.append(
{"role": "system", "content": C.SYSTEM_MSG}
)
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted)
response = client.chat.completions.create(
model=MODEL,
messages=messagesFormatted,
temperature=1,
max_tokens=4000,
stream=True
)
for chunk in response:
choices = chunk.choices
if not choices:
U.pprint("Empty chunk")
continue
chunkContent = chunk.choices[0].delta.content
if chunkContent:
yield chunkContent
except Exception as e:
U.pprint(f"LLM API Error: {e}")
yield f"{C.EXCEPTION_KEYWORD} | {e}"
def __generateImage(prompt: str):
fluxClient = Client("black-forest-labs/FLUX.1-schnell")
result = fluxClient.predict(
prompt=prompt,
seed=0,
randomize_seed=True,
width=1024,
height=768,
num_inference_steps=4,
api_name="/infer"
)
U.pprint(f"imageResult={result}")
return result
def __paintImageIfApplicable(
imageContainer: DeltaGenerator,
prompt: str,
response: str,
):
imagePath = None
try:
(imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
if imagePrompt:
imgContainer = imageContainer.container()
imgContainer.write(
f"""
<div class='blinking code'>
{loaderText}
</div>
""",
unsafe_allow_html=True
)
imgContainer.image(C.IMAGE_LOADER)
(imagePath, seed) = __generateImage(imagePrompt)
imageContainer.image(imagePath)
except Exception as e:
U.pprint(e)
imageContainer.empty()
return imagePath
def __resetButtonState():
st.session_state.buttonValue = ""
def __resetSelectedStory():
st.session_state.selectedStory = {}
def __setStartMsg(msg):
st.session_state.startMsg = msg
if "ipAddress" not in st.session_state:
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")
if "chatHistory" not in st.session_state:
st.session_state.chatHistory = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "buttonValue" not in st.session_state:
__resetButtonState()
if "selectedStory" not in st.session_state:
__resetSelectedStory()
if "selectedStoryTitle" not in st.session_state:
st.session_state.selectedStoryTitle = ""
if "isStoryChosen" not in st.session_state:
st.session_state.isStoryChosen = False
U.pprint("\n")
U.pprint("\n")
U.applyCommonStyles()
st.title("Kommuneity Story Creator 🪄")
if "startMsg" not in st.session_state:
__setStartMsg("")
st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG))
for chat in st.session_state.chatHistory:
role = chat["role"]
content = chat["content"]
imagePath = chat.get("image")
avatar = C.AI_ICON if role == "assistant" else C.USER_ICON
with st.chat_message(role, avatar=avatar):
st.markdown(content)
if imagePath:
st.image(imagePath)
# U.pprint(f"{st.session_state.buttonValue=}")
# U.pprint(f"{st.session_state.selectedStory=}")
# U.pprint(f"{st.session_state.startMsg=}")
if prompt := (
st.chat_input()
or st.session_state["buttonValue"]
or st.session_state["selectedStoryTitle"]
or st.session_state["startMsg"]
):
__resetButtonState()
__setStartMsg("")
if st.session_state["selectedStoryTitle"] != prompt:
__resetSelectedStory()
st.session_state.selectedStoryTitle = ""
with st.chat_message("user", avatar=C.USER_ICON):
st.markdown(prompt)
U.pprint(f"{prompt=}")
st.session_state.chatHistory.append({"role": "user", "content": prompt })
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar=C.AI_ICON):
responseContainer = st.empty()
def __printAndGetResponse():
response = ""
responseContainer.image(C.TEXT_LOADER)
responseGenerator = predict()
for chunk in responseGenerator:
response += chunk
if __isInvalidResponse(response):
U.pprint(f"InvalidResponse={response}")
return
if C.JSON_SEPARATOR not in response:
responseContainer.markdown(response)
return response
response = __printAndGetResponse()
while not response:
U.pprint("Empty response. Retrying..")
time.sleep(0.7)
response = __printAndGetResponse()
U.pprint(f"{response=}")
def selectButton(optionLabel):
st.session_state["buttonValue"] = optionLabel
U.pprint(f"Selected: {optionLabel}")
rawResponse = response
responseParts = response.split(C.JSON_SEPARATOR)
jsonStr = None
if len(responseParts) > 1:
[response, jsonStr] = responseParts
imageContainer = st.empty()
imagePath = __paintImageIfApplicable(imageContainer, prompt, response)
st.session_state.chatHistory.append({
"role": "assistant",
"content": response,
"image": imagePath,
})
st.session_state.messages.append({
"role": "assistant",
"content": rawResponse,
})
if jsonStr:
try:
json.loads(jsonStr)
jsonObj = json.loads(jsonStr)
options = jsonObj.get("options")
action = jsonObj.get("action")
if options:
for option in options:
st.button(
option["label"],
key=option["id"],
on_click=lambda label=option["label"]: selectButton(label)
)
elif action:
U.pprint(f"{action=}")
if action == "SHOW_STORY_DATABASE":
time.sleep(0.5)
st.switch_page("pages/popular-stories.py")
# st.code(jsonStr, language="json")
except Exception as e:
U.pprint(e)