import streamlit as st from streamlit.delta_generator import DeltaGenerator import os import time import json import re from typing import List, Literal, TypedDict, Tuple from transformers import AutoTokenizer import replicate from openai import OpenAI import anthropic from groq import Groq import constants as C import utils as U from helpers.auth import runWithAuth from helpers.sidebar import showSidebar from helpers.activities import saveLatestActivity from helpers.imageCdn import initCloudinary, getCdnUrl from dotenv import load_dotenv load_dotenv() ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] ModelConfig = TypedDict("ModelConfig", { "client": OpenAI | Groq | anthropic.Anthropic, "model": str, "max_context": int, "tokenizer": AutoTokenizer }) modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE" MODEL_CONFIG: dict[ModelType, ModelConfig] = { "GPT4": { "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), "model": "gpt-4o-mini", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, "CLAUDE": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), "model": "claude-3-5-sonnet-20241022", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "LLAMA": { "client": Groq(api_key=os.environ.get("GROQ_API_KEY")), "model": "llama-3.1-70b-versatile", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") } } client = MODEL_CONFIG[modelType]["client"] MODEL = MODEL_CONFIG[modelType]["model"] MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] tokenizer = MODEL_CONFIG[modelType]["tokenizer"] isClaudeModel = modelType == "CLAUDE" def __countTokens(text): text = str(text) tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) st.set_page_config( page_title="Kommuneity Story Creator", page_icon=C.AI_ICON, # menu_items={"About": None} ) def __isInvalidResponse(response: str): if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response: U.pprint("new line followed by small case char") return True if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: U.pprint("lot of consecutive repeating words") return True if len(re.findall(r'\n\n', response)) > 30: U.pprint("lots of paragraphs") return True if C.EXCEPTION_KEYWORD in response: U.pprint("LLM API threw exception") if 'roles must alternate between "user" and "assistant"' in str(response): U.pprint("Removing last msg from context...") st.session_state.messages.pop(-2) return True if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): U.pprint("JSON response without json separator") return True if response.startswith(C.JSON_SEPARATOR): U.pprint("only options with no text") return True def __matchingKeywordsCount(keywords: List[str], text: str): return sum([ 1 if keyword in text else 0 for keyword in keywords ]) def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]: regex = r'[^a-z0-9 \n\.\-\:\/]|((the) +)' cleanedResponse = re.sub(regex, '', response.lower()) U.pprint(f"{cleanedResponse=}") cleanedPrompt = re.sub(regex, '', prompt.lower()) if (st.session_state.selectedStory): imageText = st.session_state.selectedStory return ( f"Extract the story from this text and add few more details about this story:\n{imageText}", "Effect: dramatic, bokeh", "Painting your story character ...", ) if ( __matchingKeywordsCount( [C.BOOKING_LINK], cleanedResponse ) > 0 and "storytelling coach" not in cleanedPrompt ): aiResponses = [ chat.get("content") for chat in st.session_state.chatHistory if chat.get("role") == "assistant" ] relevantResponse = f""" {aiResponses[-1]} {response} """ return ( f"Extract the story from this text:\n{relevantResponse}", """ Style: In a storybook, surreal """, "Imagining your story scene ...", ) return (None, None, None) def __getImagePromptDetails(prompt: str, response: str): (enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response) if imagePrompt or enhancePrompt: # U.pprint(f"[Raw] {enhancePrompt=} | {imagePrompt=}") promptEnhanceModelType: ModelType = "LLAMA" U.pprint(f"{promptEnhanceModelType=}") modelConfig = MODEL_CONFIG[promptEnhanceModelType] client = modelConfig["client"] model = modelConfig["model"] isClaudeModel = promptEnhanceModelType == "CLAUDE" systemPrompt = "You help in creating prompts for image generation" promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to" enhancePrompt = f""" {promptPrefix} create a prompt for image generation. {imagePrompt} Return only the final Image Generation Prompt, and nothing else """ U.pprint(f"[Raw] {enhancePrompt=}") llmArgs = { "model": model, "messages": [{ "role": "user", "content": enhancePrompt }], "temperature": 1, "max_tokens": 2000 } if isClaudeModel: llmArgs["system"] = systemPrompt response = client.messages.create(**llmArgs) imagePrompt = response.content[0].text else: llmArgs["messages"] = [ {"role": "system", "content": systemPrompt}, *llmArgs["messages"] ] response = client.chat.completions.create(**llmArgs) responseMessage = response.choices[0].message imagePrompt = responseMessage.content U.pprint(f"[Enhanced] {imagePrompt=}") return (imagePrompt, loaderText) def __getMessages(): def getContextSize(): currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 U.pprint(f"{currContextSize=}") return currContextSize while getContextSize() > MAX_CONTEXT: U.pprint("Context size exceeded, removing first message") st.session_state.messages.pop(0) return st.session_state.messages def __logLlmRequest(messagesFormatted: list): contextSize = __countTokens(messagesFormatted) U.pprint(f"{contextSize=} | {MODEL}") # U.pprint(f"{messagesFormatted=}") def __predict(): messagesFormatted = [] try: if isClaudeModel: messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) with client.messages.stream( model=MODEL, messages=messagesFormatted, system=C.SYSTEM_MSG, temperature=0.9, max_tokens=4000, ) as stream: for text in stream.text_stream: yield text else: messagesFormatted.append( {"role": "system", "content": C.SYSTEM_MSG} ) messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) response = client.chat.completions.create( model=MODEL, messages=messagesFormatted, temperature=1, max_tokens=4000, stream=True ) for chunk in response: choices = chunk.choices if not choices: U.pprint("Empty chunk") continue chunkContent = chunk.choices[0].delta.content if chunkContent: yield chunkContent except Exception as e: U.pprint(f"LLM API Error: {e}") yield f"{C.EXCEPTION_KEYWORD} | {e}" def __generateImageHF(prompt: str): from gradio_client import Client fluxClient = Client( "black-forest-labs/FLUX.1-schnell", os.environ.get("HF_FLUX_CLIENT_TOKEN") ) result = fluxClient.predict( prompt=prompt, seed=0, randomize_seed=True, width=1024, height=768, num_inference_steps=4, api_name="/infer" ) U.pprint(f"imageResult={result}") return result def __generateImage(prompt: str): result = replicate.run( "black-forest-labs/flux-schnell", input={ "prompt": prompt, "seed": 0, "go_fast": False, "megapixels": "1", "num_outputs": 1, "aspect_ratio": "4:3", "output_format": "webp", # "output_quality": 80, "num_inference_steps": 4 }, use_file_output=False ) if result: result = result[0] return result def __paintImageIfApplicable( imageContainer: DeltaGenerator, prompt: str, response: str, ): imagePath = None try: (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) if imagePrompt: imgContainer = imageContainer.container() imgContainer.write( f"""
{loaderText}
""", unsafe_allow_html=True ) imgContainer.image(C.IMAGE_LOADER) imagePath = __generateImage(imagePrompt) imageContainer.image(imagePath) except Exception as e: U.pprint(e) imageContainer.empty() return imagePath def __selectButton(optionLabel: str): st.session_state["buttonValue"] = optionLabel U.pprint(f"Selected: {optionLabel}") def __showButtons(options: list): for option in options: st.button( option["label"], key=option["id"], on_click=lambda label=option["label"]: __selectButton(label) ) def __showWordsCount(response: str): wordsCount = len(response.split()) countClass = "crossed-limit" if wordsCount > C.WORDS_LIMIT else "" st.markdown( f"""
{wordsCount} words
""", unsafe_allow_html=True ) def __resetButtonState(): st.session_state.buttonValue = "" def __resetButtons(): st.session_state.buttons = [] def __resetSelectedStory(): st.session_state.selectedStory = {} def __setStartMsg(msg): st.session_state.startMsg = msg if "ipAddress" not in st.session_state: st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") if "chatHistory" not in st.session_state: st.session_state.chatHistory = [] if "messages" not in st.session_state: st.session_state.messages = [] if "buttonValue" not in st.session_state: __resetButtonState() if "selectedStory" not in st.session_state: __resetSelectedStory() if "selectedStoryTitle" not in st.session_state: st.session_state.selectedStoryTitle = "" if "isStoryChosen" not in st.session_state: st.session_state.isStoryChosen = False if "buttons" not in st.session_state: st.session_state.buttons = [] if "activityId" not in st.session_state: st.session_state.activityId = None if "userActivitiesLog" not in st.session_state: st.session_state.userActivitiesLog = [] U.pprint("\n") U.pprint("\n") U.applyCommonStyles() initCloudinary() st.title("Kommuneity Story Creator 🪄") def mainApp(): if "startMsg" not in st.session_state: __setStartMsg("") st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG)) for (i, chat) in enumerate(st.session_state.chatHistory): role = chat["role"] content = chat["content"] imagePath = chat.get("image") buttons = chat.get("buttons") avatar = C.AI_ICON if role == "assistant" else C.USER_ICON with st.chat_message(role, avatar=avatar): st.markdown(content) if imagePath and U.isValidImageUrl(imagePath): st.image(imagePath) if buttons: __showButtons(buttons) chat["buttons"] = [] # U.pprint(f"{st.session_state.buttonValue=}") # U.pprint(f"{st.session_state.selectedStoryTitle=}") # U.pprint(f"{st.session_state.startMsg=}") if prompt := ( st.chat_input() or st.session_state["buttonValue"] or st.session_state["selectedStoryTitle"] or st.session_state["startMsg"] ): __resetButtonState() __resetButtons() __setStartMsg("") if st.session_state["selectedStoryTitle"] != prompt: __resetSelectedStory() st.session_state.selectedStoryTitle = "" with st.chat_message("user", avatar=C.USER_ICON): st.markdown(prompt) U.pprint(f"{prompt=}") st.session_state.chatHistory.append({"role": "user", "content": prompt }) st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("assistant", avatar=C.AI_ICON): responseContainer = st.empty() def __printAndGetResponse(): response = "" responseContainer.image(C.TEXT_LOADER) responseGenerator = __predict() for chunk in responseGenerator: response += chunk if __isInvalidResponse(response): U.pprint(f"InvalidResponse={response}") return if C.JSON_SEPARATOR not in response: responseContainer.markdown(response) return response response = __printAndGetResponse() while not response: U.pprint("Empty response. Retrying..") time.sleep(0.7) response = __printAndGetResponse() U.pprint(f"{response=}") rawResponse = response responseParts = response.split(C.JSON_SEPARATOR) jsonStr = None if len(responseParts) > 1: [response, jsonStr] = responseParts imageContainer = st.empty() imagePath = __paintImageIfApplicable(imageContainer, prompt, response) if imagePath: imagePath = getCdnUrl(imagePath) st.session_state.chatHistory.append({ "role": "assistant", "content": response, "image": imagePath, }) st.session_state.messages.append({ "role": "assistant", "content": rawResponse, }) if jsonStr: try: json.loads(jsonStr) jsonObj = json.loads(jsonStr) options = jsonObj.get("options") action = jsonObj.get("action") if options: __showButtons(options) st.session_state.buttons = options elif action: U.pprint(f"{action=}") if action == "SHOW_STORY_DATABASE": time.sleep(1) st.switch_page("pages/popular-stories.py") # st.code(jsonStr, language="json") except Exception as e: U.pprint(e) __showWordsCount(response) saveLatestActivity() runWithAuth(mainApp) showSidebar()