Ashhar
fixed cdn logic
e99e506
raw
history blame
16.3 kB
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
import os
import time
import json
import re
from typing import List, Literal, TypedDict, Tuple
from transformers import AutoTokenizer
import replicate
from openai import OpenAI
import anthropic
from groq import Groq
import constants as C
import utils as U
from helpers.auth import runWithAuth
from helpers.sidebar import showSidebar
from helpers.activities import saveLatestActivity
from helpers.imageCdn import initCloudinary, getCdnUrl
from dotenv import load_dotenv
load_dotenv()
ModelType = Literal["GPT4", "CLAUDE", "LLAMA"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | Groq | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE"
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"GPT4": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o-mini",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"CLAUDE": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-5-sonnet-20241022",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"LLAMA": {
"client": Groq(api_key=os.environ.get("GROQ_API_KEY")),
"model": "llama-3.1-70b-versatile",
"max_context": 128000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
}
}
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType == "CLAUDE"
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
st.set_page_config(
page_title="Kommuneity Story Creator",
page_icon=C.AI_ICON,
# menu_items={"About": None}
)
def __isInvalidResponse(response: str):
if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response:
U.pprint("new line followed by small case char")
return True
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
U.pprint("lot of consecutive repeating words")
return True
if len(re.findall(r'\n\n', response)) > 30:
U.pprint("lots of paragraphs")
return True
if C.EXCEPTION_KEYWORD in response:
U.pprint("LLM API threw exception")
if 'roles must alternate between "user" and "assistant"' in str(response):
U.pprint("Removing last msg from context...")
st.session_state.messages.pop(-2)
return True
if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response):
U.pprint("JSON response without json separator")
return True
if response.startswith(C.JSON_SEPARATOR):
U.pprint("only options with no text")
return True
def __matchingKeywordsCount(keywords: List[str], text: str):
return sum([
1 if keyword in text else 0
for keyword in keywords
])
def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]:
regex = r'[^a-z0-9 \n\.\-\:\/]|((the) +)'
cleanedResponse = re.sub(regex, '', response.lower())
U.pprint(f"{cleanedResponse=}")
cleanedPrompt = re.sub(regex, '', prompt.lower())
if (st.session_state.selectedStory):
imageText = st.session_state.selectedStory
return (
f"Extract the story from this text and add few more details about this story:\n{imageText}",
"Effect: dramatic, bokeh",
"Painting your story character ...",
)
if (
__matchingKeywordsCount(
[C.BOOKING_LINK],
cleanedResponse
) > 0
and "storytelling coach" not in cleanedPrompt
):
aiResponses = [
chat.get("content")
for chat in st.session_state.chatHistory
if chat.get("role") == "assistant"
]
relevantResponse = f"""
{aiResponses[-1]}
{response}
"""
return (
f"Extract the story from this text:\n{relevantResponse}",
"""
Style: In a storybook, surreal
""",
"Imagining your story scene ...",
)
return (None, None, None)
def __getImagePromptDetails(prompt: str, response: str):
(enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response)
if imagePrompt or enhancePrompt:
# U.pprint(f"[Raw] {enhancePrompt=} | {imagePrompt=}")
promptEnhanceModelType: ModelType = "LLAMA"
U.pprint(f"{promptEnhanceModelType=}")
modelConfig = MODEL_CONFIG[promptEnhanceModelType]
client = modelConfig["client"]
model = modelConfig["model"]
isClaudeModel = promptEnhanceModelType == "CLAUDE"
systemPrompt = "You help in creating prompts for image generation"
promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to"
enhancePrompt = f"""
{promptPrefix} create a prompt for image generation.
{imagePrompt}
Return only the final Image Generation Prompt, and nothing else
"""
U.pprint(f"[Raw] {enhancePrompt=}")
llmArgs = {
"model": model,
"messages": [{
"role": "user",
"content": enhancePrompt
}],
"temperature": 1,
"max_tokens": 2000
}
if isClaudeModel:
llmArgs["system"] = systemPrompt
response = client.messages.create(**llmArgs)
imagePrompt = response.content[0].text
else:
llmArgs["messages"] = [
{"role": "system", "content": systemPrompt},
*llmArgs["messages"]
]
response = client.chat.completions.create(**llmArgs)
responseMessage = response.choices[0].message
imagePrompt = responseMessage.content
U.pprint(f"[Enhanced] {imagePrompt=}")
return (imagePrompt, loaderText)
def __getMessages():
def getContextSize():
currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100
U.pprint(f"{currContextSize=}")
return currContextSize
while getContextSize() > MAX_CONTEXT:
U.pprint("Context size exceeded, removing first message")
st.session_state.messages.pop(0)
return st.session_state.messages
def __logLlmRequest(messagesFormatted: list):
contextSize = __countTokens(messagesFormatted)
U.pprint(f"{contextSize=} | {MODEL}")
# U.pprint(f"{messagesFormatted=}")
def __predict():
messagesFormatted = []
try:
if isClaudeModel:
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted)
with client.messages.stream(
model=MODEL,
messages=messagesFormatted,
system=C.SYSTEM_MSG,
temperature=0.9,
max_tokens=4000,
) as stream:
for text in stream.text_stream:
yield text
else:
messagesFormatted.append(
{"role": "system", "content": C.SYSTEM_MSG}
)
messagesFormatted.extend(__getMessages())
__logLlmRequest(messagesFormatted)
response = client.chat.completions.create(
model=MODEL,
messages=messagesFormatted,
temperature=1,
max_tokens=4000,
stream=True
)
for chunk in response:
choices = chunk.choices
if not choices:
U.pprint("Empty chunk")
continue
chunkContent = chunk.choices[0].delta.content
if chunkContent:
yield chunkContent
except Exception as e:
U.pprint(f"LLM API Error: {e}")
yield f"{C.EXCEPTION_KEYWORD} | {e}"
def __generateImageHF(prompt: str):
from gradio_client import Client
fluxClient = Client(
"black-forest-labs/FLUX.1-schnell",
os.environ.get("HF_FLUX_CLIENT_TOKEN")
)
result = fluxClient.predict(
prompt=prompt,
seed=0,
randomize_seed=True,
width=1024,
height=768,
num_inference_steps=4,
api_name="/infer"
)
U.pprint(f"imageResult={result}")
return result
def __generateImage(prompt: str):
result = replicate.run(
"black-forest-labs/flux-schnell",
input={
"prompt": prompt,
"seed": 0,
"go_fast": False,
"megapixels": "1",
"num_outputs": 1,
"aspect_ratio": "4:3",
"output_format": "webp",
# "output_quality": 80,
"num_inference_steps": 4
},
use_file_output=False
)
if result:
result = result[0]
return result
def __paintImageIfApplicable(
imageContainer: DeltaGenerator,
prompt: str,
response: str,
):
imagePath = None
try:
(imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
if imagePrompt:
imgContainer = imageContainer.container()
imgContainer.write(
f"""
<div class='blinking code'>
{loaderText}
</div>
""",
unsafe_allow_html=True
)
imgContainer.image(C.IMAGE_LOADER)
imagePath = __generateImage(imagePrompt)
imageContainer.image(imagePath)
except Exception as e:
U.pprint(e)
imageContainer.empty()
return imagePath
def __selectButton(optionLabel: str):
st.session_state["buttonValue"] = optionLabel
U.pprint(f"Selected: {optionLabel}")
def __showButtons(options: list):
for option in options:
st.button(
option["label"],
key=option["id"],
on_click=lambda label=option["label"]: __selectButton(label)
)
def __showWordsCount(response: str):
wordsCount = len(response.split())
countClass = "crossed-limit" if wordsCount > C.WORDS_LIMIT else ""
st.markdown(
f"""
<div class="words-count {countClass}">
{wordsCount} words
</div>
""",
unsafe_allow_html=True
)
def __resetButtonState():
st.session_state.buttonValue = ""
def __resetButtons():
st.session_state.buttons = []
def __resetSelectedStory():
st.session_state.selectedStory = {}
def __setStartMsg(msg):
st.session_state.startMsg = msg
if "ipAddress" not in st.session_state:
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")
if "chatHistory" not in st.session_state:
st.session_state.chatHistory = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "buttonValue" not in st.session_state:
__resetButtonState()
if "selectedStory" not in st.session_state:
__resetSelectedStory()
if "selectedStoryTitle" not in st.session_state:
st.session_state.selectedStoryTitle = ""
if "isStoryChosen" not in st.session_state:
st.session_state.isStoryChosen = False
if "buttons" not in st.session_state:
st.session_state.buttons = []
if "activityId" not in st.session_state:
st.session_state.activityId = None
if "userActivitiesLog" not in st.session_state:
st.session_state.userActivitiesLog = []
U.pprint("\n")
U.pprint("\n")
U.applyCommonStyles()
initCloudinary()
st.title("Kommuneity Story Creator 🪄")
def mainApp():
if "startMsg" not in st.session_state:
__setStartMsg("")
st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG))
for (i, chat) in enumerate(st.session_state.chatHistory):
role = chat["role"]
content = chat["content"]
imagePath = chat.get("image")
buttons = chat.get("buttons")
avatar = C.AI_ICON if role == "assistant" else C.USER_ICON
with st.chat_message(role, avatar=avatar):
st.markdown(content)
if imagePath and U.isValidImageUrl(imagePath):
st.image(imagePath)
if buttons:
__showButtons(buttons)
chat["buttons"] = []
# U.pprint(f"{st.session_state.buttonValue=}")
# U.pprint(f"{st.session_state.selectedStoryTitle=}")
# U.pprint(f"{st.session_state.startMsg=}")
if prompt := (
st.chat_input()
or st.session_state["buttonValue"]
or st.session_state["selectedStoryTitle"]
or st.session_state["startMsg"]
):
__resetButtonState()
__resetButtons()
__setStartMsg("")
if st.session_state["selectedStoryTitle"] != prompt:
__resetSelectedStory()
st.session_state.selectedStoryTitle = ""
with st.chat_message("user", avatar=C.USER_ICON):
st.markdown(prompt)
U.pprint(f"{prompt=}")
st.session_state.chatHistory.append({"role": "user", "content": prompt })
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar=C.AI_ICON):
responseContainer = st.empty()
def __printAndGetResponse():
response = ""
responseContainer.image(C.TEXT_LOADER)
responseGenerator = __predict()
for chunk in responseGenerator:
response += chunk
if __isInvalidResponse(response):
U.pprint(f"InvalidResponse={response}")
return
if C.JSON_SEPARATOR not in response:
responseContainer.markdown(response)
return response
response = __printAndGetResponse()
while not response:
U.pprint("Empty response. Retrying..")
time.sleep(0.7)
response = __printAndGetResponse()
U.pprint(f"{response=}")
rawResponse = response
responseParts = response.split(C.JSON_SEPARATOR)
jsonStr = None
if len(responseParts) > 1:
[response, jsonStr] = responseParts
imageContainer = st.empty()
imagePath = __paintImageIfApplicable(imageContainer, prompt, response)
if imagePath:
imagePath = getCdnUrl(imagePath)
st.session_state.chatHistory.append({
"role": "assistant",
"content": response,
"image": imagePath,
})
st.session_state.messages.append({
"role": "assistant",
"content": rawResponse,
})
if jsonStr:
try:
json.loads(jsonStr)
jsonObj = json.loads(jsonStr)
options = jsonObj.get("options")
action = jsonObj.get("action")
if options:
__showButtons(options)
st.session_state.buttons = options
elif action:
U.pprint(f"{action=}")
if action == "SHOW_STORY_DATABASE":
time.sleep(1)
st.switch_page("pages/popular-stories.py")
# st.code(jsonStr, language="json")
except Exception as e:
U.pprint(e)
__showWordsCount(response)
saveLatestActivity()
runWithAuth(mainApp)
showSidebar()