|
import streamlit as st |
|
from streamlit.delta_generator import DeltaGenerator |
|
import os |
|
import time |
|
import json |
|
import re |
|
from typing import List, Literal, TypedDict, Tuple |
|
from transformers import AutoTokenizer |
|
from gradio_client import Client |
|
import constants as C |
|
import utils as U |
|
|
|
from openai import OpenAI |
|
import anthropic |
|
from groq import Groq |
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] |
|
ModelConfig = TypedDict("ModelConfig", { |
|
"client": OpenAI | Groq | anthropic.Anthropic, |
|
"model": str, |
|
"max_context": int, |
|
"tokenizer": AutoTokenizer |
|
}) |
|
|
|
modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE" |
|
|
|
MODEL_CONFIG: dict[ModelType, ModelConfig] = { |
|
"GPT4": { |
|
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), |
|
"model": "gpt-4o-mini", |
|
"max_context": 128000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
|
}, |
|
"CLAUDE": { |
|
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), |
|
"model": "claude-3-5-sonnet-20240620", |
|
"max_context": 128000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") |
|
}, |
|
"LLAMA": { |
|
"client": Groq(api_key=os.environ.get("GROQ_API_KEY")), |
|
"model": "llama-3.1-70b-versatile", |
|
"max_context": 128000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") |
|
} |
|
} |
|
|
|
client = MODEL_CONFIG[modelType]["client"] |
|
MODEL = MODEL_CONFIG[modelType]["model"] |
|
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] |
|
tokenizer = MODEL_CONFIG[modelType]["tokenizer"] |
|
|
|
isClaudeModel = modelType == "CLAUDE" |
|
|
|
|
|
def __countTokens(text): |
|
text = str(text) |
|
tokens = tokenizer.encode(text, add_special_tokens=False) |
|
return len(tokens) |
|
|
|
|
|
st.set_page_config( |
|
page_title="Kommuneity Story Creator", |
|
page_icon=C.AI_ICON, |
|
|
|
) |
|
|
|
|
|
def __isInvalidResponse(response: str): |
|
if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response: |
|
U.pprint("new line followed by small case char") |
|
return True |
|
|
|
if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: |
|
U.pprint("lot of consecutive repeating words") |
|
return True |
|
|
|
if len(re.findall(r'\n\n', response)) > 30: |
|
U.pprint("lots of paragraphs") |
|
return True |
|
|
|
if C.EXCEPTION_KEYWORD in response: |
|
U.pprint("LLM API threw exception") |
|
if 'roles must alternate between "user" and "assistant"' in str(response): |
|
U.pprint("Removing last msg from context...") |
|
st.session_state.messages.pop(-2) |
|
return True |
|
|
|
if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response): |
|
U.pprint("JSON response without json separator") |
|
return True |
|
if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): |
|
U.pprint("JSON response without json separator") |
|
return True |
|
|
|
if response.startswith(C.JSON_SEPARATOR): |
|
U.pprint("only options with no text") |
|
return True |
|
|
|
|
|
def __matchingKeywordsCount(keywords: List[str], text: str): |
|
return sum([ |
|
1 if keyword in text else 0 |
|
for keyword in keywords |
|
]) |
|
|
|
|
|
def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]: |
|
regex = r'[^a-z0-9 \n\.\-\:\/]|((the) +)' |
|
|
|
cleanedResponse = re.sub(regex, '', response.lower()) |
|
U.pprint(f"{cleanedResponse=}") |
|
|
|
cleanedPrompt = re.sub(regex, '', prompt.lower()) |
|
|
|
if (st.session_state.selectedStory): |
|
imageText = st.session_state.selectedStory |
|
|
|
return ( |
|
f"Extract the story from this text and add few more details about this story:\n{imageText}", |
|
"Effect: dramatic, bokeh", |
|
"Painting your story character ...", |
|
) |
|
|
|
if ( |
|
__matchingKeywordsCount( |
|
[C.BOOKING_LINK], |
|
cleanedResponse |
|
) > 0 |
|
and "storytelling coach" not in cleanedPrompt |
|
): |
|
aiResponses = [ |
|
chat.get("content") |
|
for chat in st.session_state.chatHistory |
|
if chat.get("role") == "assistant" |
|
] |
|
|
|
relevantResponse = f""" |
|
{aiResponses[-1]} |
|
{response} |
|
""" |
|
|
|
return ( |
|
f"Extract the story from this text:\n{relevantResponse}", |
|
""" |
|
Style: In a storybook, surreal |
|
""", |
|
"Imagining your story scene ...", |
|
) |
|
|
|
return (None, None, None) |
|
|
|
|
|
def __getImagePromptDetails(prompt: str, response: str): |
|
(enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response) |
|
|
|
if imagePrompt or enhancePrompt: |
|
|
|
|
|
promptEnhanceModelType: ModelType = "LLAMA" |
|
U.pprint(f"{promptEnhanceModelType=}") |
|
|
|
modelConfig = MODEL_CONFIG[promptEnhanceModelType] |
|
client = modelConfig["client"] |
|
model = modelConfig["model"] |
|
isClaudeModel = promptEnhanceModelType == "CLAUDE" |
|
|
|
systemPrompt = "You help in creating prompts for image generation" |
|
promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to" |
|
enhancePrompt = f""" |
|
{promptPrefix} create a prompt for image generation. |
|
|
|
{imagePrompt} |
|
|
|
Return only the final Image Generation Prompt, and nothing else |
|
""" |
|
U.pprint(f"[Raw] {enhancePrompt=}") |
|
|
|
llmArgs = { |
|
"model": model, |
|
"messages": [{ |
|
"role": "user", |
|
"content": enhancePrompt |
|
}], |
|
"temperature": 1, |
|
"max_tokens": 2000 |
|
} |
|
|
|
if isClaudeModel: |
|
llmArgs["system"] = systemPrompt |
|
response = client.messages.create(**llmArgs) |
|
imagePrompt = response.content[0].text |
|
else: |
|
llmArgs["messages"] = [ |
|
{"role": "system", "content": systemPrompt}, |
|
*llmArgs["messages"] |
|
] |
|
response = client.chat.completions.create(**llmArgs) |
|
responseMessage = response.choices[0].message |
|
imagePrompt = responseMessage.content |
|
|
|
U.pprint(f"[Enhanced] {imagePrompt=}") |
|
|
|
return (imagePrompt, loaderText) |
|
|
|
|
|
def __getMessages(): |
|
def getContextSize(): |
|
currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 |
|
U.pprint(f"{currContextSize=}") |
|
return currContextSize |
|
|
|
while getContextSize() > MAX_CONTEXT: |
|
U.pprint("Context size exceeded, removing first message") |
|
st.session_state.messages.pop(0) |
|
|
|
return st.session_state.messages |
|
|
|
|
|
def __logLlmRequest(messagesFormatted: list): |
|
contextSize = __countTokens(messagesFormatted) |
|
U.pprint(f"{contextSize=} | {MODEL}") |
|
|
|
|
|
|
|
def predict(): |
|
messagesFormatted = [] |
|
|
|
try: |
|
if isClaudeModel: |
|
messagesFormatted.extend(__getMessages()) |
|
__logLlmRequest(messagesFormatted) |
|
|
|
with client.messages.stream( |
|
model=MODEL, |
|
messages=messagesFormatted, |
|
system=C.SYSTEM_MSG, |
|
temperature=0.9, |
|
max_tokens=4000, |
|
) as stream: |
|
for text in stream.text_stream: |
|
yield text |
|
else: |
|
messagesFormatted.append( |
|
{"role": "system", "content": C.SYSTEM_MSG} |
|
) |
|
messagesFormatted.extend(__getMessages()) |
|
__logLlmRequest(messagesFormatted) |
|
|
|
response = client.chat.completions.create( |
|
model=MODEL, |
|
messages=messagesFormatted, |
|
temperature=1, |
|
max_tokens=4000, |
|
stream=True |
|
) |
|
for chunk in response: |
|
choices = chunk.choices |
|
if not choices: |
|
U.pprint("Empty chunk") |
|
continue |
|
chunkContent = chunk.choices[0].delta.content |
|
if chunkContent: |
|
yield chunkContent |
|
except Exception as e: |
|
U.pprint(f"LLM API Error: {e}") |
|
yield f"{C.EXCEPTION_KEYWORD} | {e}" |
|
|
|
|
|
def __generateImage(prompt: str): |
|
fluxClient = Client("black-forest-labs/FLUX.1-schnell") |
|
result = fluxClient.predict( |
|
prompt=prompt, |
|
seed=0, |
|
randomize_seed=True, |
|
width=1024, |
|
height=768, |
|
num_inference_steps=4, |
|
api_name="/infer" |
|
) |
|
U.pprint(f"imageResult={result}") |
|
return result |
|
|
|
|
|
def __paintImageIfApplicable( |
|
imageContainer: DeltaGenerator, |
|
prompt: str, |
|
response: str, |
|
): |
|
imagePath = None |
|
try: |
|
(imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) |
|
if imagePrompt: |
|
imgContainer = imageContainer.container() |
|
imgContainer.write( |
|
f""" |
|
<div class='blinking code'> |
|
{loaderText} |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
imgContainer.image(C.IMAGE_LOADER) |
|
(imagePath, seed) = __generateImage(imagePrompt) |
|
imageContainer.image(imagePath) |
|
except Exception as e: |
|
U.pprint(e) |
|
imageContainer.empty() |
|
|
|
return imagePath |
|
|
|
|
|
def __resetButtonState(): |
|
st.session_state.buttonValue = "" |
|
|
|
|
|
def __resetSelectedStory(): |
|
st.session_state.selectedStory = {} |
|
|
|
|
|
def __setStartMsg(msg): |
|
st.session_state.startMsg = msg |
|
|
|
|
|
if "ipAddress" not in st.session_state: |
|
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") |
|
|
|
if "chatHistory" not in st.session_state: |
|
st.session_state.chatHistory = [] |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
if "buttonValue" not in st.session_state: |
|
__resetButtonState() |
|
|
|
if "selectedStory" not in st.session_state: |
|
__resetSelectedStory() |
|
|
|
if "selectedStoryTitle" not in st.session_state: |
|
st.session_state.selectedStoryTitle = "" |
|
|
|
if "isStoryChosen" not in st.session_state: |
|
st.session_state.isStoryChosen = False |
|
|
|
U.pprint("\n") |
|
U.pprint("\n") |
|
|
|
U.applyCommonStyles() |
|
st.title("Kommuneity Story Creator 🪄") |
|
|
|
if "startMsg" not in st.session_state: |
|
__setStartMsg("") |
|
st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG)) |
|
|
|
for chat in st.session_state.chatHistory: |
|
role = chat["role"] |
|
content = chat["content"] |
|
imagePath = chat.get("image") |
|
avatar = C.AI_ICON if role == "assistant" else C.USER_ICON |
|
with st.chat_message(role, avatar=avatar): |
|
st.markdown(content) |
|
if imagePath: |
|
st.image(imagePath) |
|
|
|
|
|
|
|
|
|
|
|
if prompt := ( |
|
st.chat_input() |
|
or st.session_state["buttonValue"] |
|
or st.session_state["selectedStoryTitle"] |
|
or st.session_state["startMsg"] |
|
): |
|
__resetButtonState() |
|
__setStartMsg("") |
|
if st.session_state["selectedStoryTitle"] != prompt: |
|
__resetSelectedStory() |
|
st.session_state.selectedStoryTitle = "" |
|
|
|
with st.chat_message("user", avatar=C.USER_ICON): |
|
st.markdown(prompt) |
|
U.pprint(f"{prompt=}") |
|
st.session_state.chatHistory.append({"role": "user", "content": prompt }) |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
with st.chat_message("assistant", avatar=C.AI_ICON): |
|
responseContainer = st.empty() |
|
|
|
def __printAndGetResponse(): |
|
response = "" |
|
responseContainer.image(C.TEXT_LOADER) |
|
responseGenerator = predict() |
|
|
|
for chunk in responseGenerator: |
|
response += chunk |
|
if __isInvalidResponse(response): |
|
U.pprint(f"InvalidResponse={response}") |
|
return |
|
|
|
if C.JSON_SEPARATOR not in response: |
|
responseContainer.markdown(response) |
|
|
|
return response |
|
|
|
response = __printAndGetResponse() |
|
while not response: |
|
U.pprint("Empty response. Retrying..") |
|
time.sleep(0.7) |
|
response = __printAndGetResponse() |
|
|
|
U.pprint(f"{response=}") |
|
|
|
def selectButton(optionLabel): |
|
st.session_state["buttonValue"] = optionLabel |
|
U.pprint(f"Selected: {optionLabel}") |
|
|
|
rawResponse = response |
|
responseParts = response.split(C.JSON_SEPARATOR) |
|
|
|
jsonStr = None |
|
if len(responseParts) > 1: |
|
[response, jsonStr] = responseParts |
|
|
|
imageContainer = st.empty() |
|
imagePath = __paintImageIfApplicable(imageContainer, prompt, response) |
|
|
|
st.session_state.chatHistory.append({ |
|
"role": "assistant", |
|
"content": response, |
|
"image": imagePath, |
|
}) |
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": rawResponse, |
|
}) |
|
|
|
if jsonStr: |
|
try: |
|
json.loads(jsonStr) |
|
jsonObj = json.loads(jsonStr) |
|
options = jsonObj.get("options") |
|
action = jsonObj.get("action") |
|
|
|
if options: |
|
for option in options: |
|
st.button( |
|
option["label"], |
|
key=option["id"], |
|
on_click=lambda label=option["label"]: selectButton(label) |
|
) |
|
elif action: |
|
U.pprint(f"{action=}") |
|
if action == "SHOW_STORY_DATABASE": |
|
time.sleep(0.5) |
|
st.switch_page("pages/popular-stories.py") |
|
|
|
except Exception as e: |
|
U.pprint(e) |
|
|