import streamlit as st import os import time import json import re from typing import List, Literal, TypedDict, Tuple from transformers import AutoTokenizer from gradio_client import Client import constants as C import utils as U from openai import OpenAI import anthropic from groq import Groq from dotenv import load_dotenv load_dotenv() ModelType = Literal["GPT4", "CLAUDE", "LLAMA"] ModelConfig = TypedDict("ModelConfig", { "client": OpenAI | Groq | anthropic.Anthropic, "model": str, "max_context": int, "tokenizer": AutoTokenizer }) modelType: ModelType = os.environ.get("MODEL_TYPE") or "CLAUDE" MODEL_CONFIG: dict[ModelType, ModelConfig] = { "GPT4": { "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), "model": "gpt-4o-mini", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, "CLAUDE": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), "model": "claude-3-sonnet-20240229", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "LLAMA": { "client": Groq(api_key=os.environ.get("GROQ_API_KEY")), "model": "llama-3.1-70b-versatile", "max_context": 128000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer") } } client = MODEL_CONFIG[modelType]["client"] MODEL = MODEL_CONFIG[modelType]["model"] MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] tokenizer = MODEL_CONFIG[modelType]["tokenizer"] isClaudeModel = modelType == "CLAUDE" def __countTokens(text): text = str(text) tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) st.set_page_config( page_title="Kommuneity Story Creator", page_icon=C.AI_ICON, # menu_items={"About": None} ) U.pprint("\n") U.pprint("\n") def __isInvalidResponse(response: str): # new line followed by small case char if len(re.findall(r'\n[a-z]', response)) > 3: return True # lot of repeating words if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1: return True # lots of paragraphs if len(re.findall(r'\n\n', response)) > 20: return True # LLM API threw exception if C.EXCEPTION_KEYWORD in response: return True # json response without json separator if ('{\n "options"' in response) and (C.JSON_SEPARATOR not in response): return True if ('{\n "action"' in response) and (C.JSON_SEPARATOR not in response): return True # only options with no text if response.startswith(C.JSON_SEPARATOR): return True def __matchingKeywordsCount(keywords: List[str], text: str): return sum([ 1 if keyword in text else 0 for keyword in keywords ]) def __getRawImagePromptDetails(prompt: str, response: str) -> Tuple[str, str, str]: regex = r'[^a-z0-9 \n\.\-]|((the) +)' cleanedResponse = re.sub(regex, '', response.lower()) U.pprint(f"{cleanedResponse=}") cleanedPrompt = re.sub(regex, '', prompt.lower()) U.pprint(f"{cleanedPrompt=}") if ( __matchingKeywordsCount( ["adapt", "personal branding", "purpose", "use case"], cleanedResponse ) > 2 and "story so far" not in cleanedResponse ): return ( f"Extract the name of selected story from this text and add few more details about this story:\n{response}", "Effect: dramatic, bokeh", "Painting your character ...", ) if __matchingKeywordsCount( [C.BOOKING_LINK], cleanedResponse ) > 0: relevantResponse = f""" {st.session_state.chatHistory[-1].get("content")} {response} """ return ( f"Extract the story from this text:\n{relevantResponse}", """ Style: In a storybook, surreal """, "Imagining your scene (beta) ...", ) return (None, None, None) def __getImagePromptDetails(prompt: str, response: str): (enhancePrompt, imagePrompt, loaderText) = __getRawImagePromptDetails(prompt, response) if imagePrompt or enhancePrompt: U.pprint(f"[Raw] {enhancePrompt=} | {imagePrompt=}") promptEnhanceModelType: ModelType = "LLAMA" U.pprint(f"{promptEnhanceModelType=}") modelConfig = MODEL_CONFIG[promptEnhanceModelType] client = modelConfig["client"] model = modelConfig["model"] isClaudeModel = promptEnhanceModelType == "CLAUDE" systemPrompt = "You help in creating prompts for image generation" promptPrefix = f"{enhancePrompt}\nAnd then use the above to" if enhancePrompt else "Use the text below to" llmArgs = { "model": model, "messages": [{ "role": "user", "content": f""" {promptPrefix} create a prompt for image generation (limit to less than 500 words) {imagePrompt} Return only the final Image Generation Prompt, and nothing else """ }], "temperature": 1, "max_tokens": 2000 } if isClaudeModel: llmArgs["system"] = systemPrompt response = client.messages.create(**llmArgs) imagePrompt = response.content[0].text else: llmArgs["messages"] = [ {"role": "system", "content": systemPrompt}, *llmArgs["messages"] ] response = client.chat.completions.create(**llmArgs) responseMessage = response.choices[0].message imagePrompt = responseMessage.content U.pprint(f"[Enhanced] {imagePrompt=}") return (imagePrompt, loaderText) def __getMessages(): def getContextSize(): currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100 U.pprint(f"{currContextSize=}") return currContextSize while getContextSize() > MAX_CONTEXT: U.pprint("Context size exceeded, removing first message") st.session_state.messages.pop(0) return st.session_state.messages def __logLlmRequest(messagesFormatted: list): contextSize = __countTokens(messagesFormatted) U.pprint(f"{contextSize=} | {MODEL}") # U.pprint(f"{messagesFormatted=}") def predict(): messagesFormatted = [] try: if isClaudeModel: messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) with client.messages.stream( model=MODEL, messages=messagesFormatted, temperature=0.8, system=C.SYSTEM_MSG, max_tokens=4000, ) as stream: for text in stream.text_stream: yield text else: messagesFormatted.append( {"role": "system", "content": C.SYSTEM_MSG} ) messagesFormatted.extend(__getMessages()) __logLlmRequest(messagesFormatted) response = client.chat.completions.create( model=MODEL, messages=messagesFormatted, temperature=1, max_tokens=4000, stream=True ) for chunk in response: choices = chunk.choices if not choices: U.pprint("Empty chunk") continue chunkContent = chunk.choices[0].delta.content if chunkContent: yield chunkContent except Exception as e: U.pprint(f"LLM API Error: {e}") yield C.EXCEPTION_KEYWORD def __generateImage(prompt: str): fluxClient = Client("black-forest-labs/FLUX.1-schnell") result = fluxClient.predict( prompt=prompt, seed=0, randomize_seed=True, width=1024, height=768, num_inference_steps=4, api_name="/infer" ) U.pprint(f"imageResult={result}") return result U.applyCommonStyles() st.title("Kommuneity Story Creator 🪄") def __resetButtonState(): st.session_state.buttonValue = "" def __resetSelectedStory(): st.session_state.selectedStory = {} def __setStartMsg(msg): st.session_state.startMsg = msg if "chatHistory" not in st.session_state: st.session_state.chatHistory = [] if "messages" not in st.session_state: st.session_state.messages = [] if "buttonValue" not in st.session_state: __resetButtonState() if "selectedStory" not in st.session_state: __resetSelectedStory() if "storyChosen" not in st.session_state: st.session_state.storyChosen = False if "startMsg" not in st.session_state: __setStartMsg("") st.button(C.START_MSG, on_click=lambda: __setStartMsg(C.START_MSG)) for chat in st.session_state.chatHistory: role = chat["role"] content = chat["content"] imagePath = chat.get("image") avatar = C.AI_ICON if role == "assistant" else C.USER_ICON with st.chat_message(role, avatar=avatar): st.markdown(content) if imagePath: st.image(imagePath) # U.pprint(f"{st.session_state.buttonValue=}") # U.pprint(f"{st.session_state.selectedStory=}") # U.pprint(f"{st.session_state.startMsg=}") if prompt := ( st.chat_input() or st.session_state["buttonValue"] or st.session_state["selectedStory"].get("title") or st.session_state["startMsg"] ): __resetButtonState() __resetSelectedStory() __setStartMsg("") with st.chat_message("user", avatar=C.USER_ICON): st.markdown(prompt) U.pprint(f"{prompt=}") st.session_state.chatHistory.append({"role": "user", "content": prompt }) st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("assistant", avatar=C.AI_ICON): responseContainer = st.empty() def __printAndGetResponse(): response = "" responseContainer.image(C.TEXT_LOADER) responseGenerator = predict() for chunk in responseGenerator: response += chunk if __isInvalidResponse(response): U.pprint(f"InvalidResponse={response}") return if C.JSON_SEPARATOR not in response: responseContainer.markdown(response) return response response = __printAndGetResponse() while not response: U.pprint("Empty response. Retrying..") time.sleep(0.7) response = __printAndGetResponse() U.pprint(f"{response=}") def selectButton(optionLabel): st.session_state["buttonValue"] = optionLabel U.pprint(f"Selected: {optionLabel}") rawResponse = response responseParts = response.split(C.JSON_SEPARATOR) jsonStr = None if len(responseParts) > 1: [response, jsonStr] = responseParts imagePath = None imageContainer = st.empty() try: (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response) if imagePrompt: imgContainer = imageContainer.container() imgContainer.write( f"""