import gradio as gr import json import os import random from PIL import Image import base64 import io # Define categories at the top so they are accessible throughout the code categories = [ ('Setting', 'scene_tags'), ('Position', 'position_tags'), ('Outfit', 'outfit_tags'), ('Camera View/Angle', 'camera_tags'), ('Concept', 'concept_tags'), ('Facial Expression', 'facial_expression_tags'), ('Pose', 'pose_tags'), ('Additional', 'additional_tags'), ('LORA', 'lora_tags') ] class DataManager: def __init__(self, base_dir='/data'): self.base_dir = base_dir # Ensure the base directory exists if not os.path.exists(self.base_dir): os.makedirs(self.base_dir) self.characters_file = os.path.join(self.base_dir, 'characters.json') self.persistent_tags_file = os.path.join(self.base_dir, 'persistent_tags.json') self.category_tags_file = os.path.join(self.base_dir, 'category_tags.json') self.images_folder = os.path.join(self.base_dir, 'character_images') # Ensure the images folder exists if not os.path.exists(self.images_folder): os.makedirs(self.images_folder) self.load_data() def load_data(self): # Load characters if os.path.exists(self.characters_file): with open(self.characters_file, 'r') as f: self.characters = json.load(f) else: self.characters = [] # Load persistent tags if os.path.exists(self.persistent_tags_file): with open(self.persistent_tags_file, 'r') as f: self.persistent_tags = json.load(f) else: self.persistent_tags = [] # Load category tags self.load_category_tags() def save_characters(self): with open(self.characters_file, 'w') as f: json.dump(self.characters, f) def save_persistent_tags(self): with open(self.persistent_tags_file, 'w') as f: json.dump(self.persistent_tags, f) def load_category_tags(self): if os.path.exists(self.category_tags_file): with open(self.category_tags_file, 'r') as f: self.category_tags = json.load(f) else: self.category_tags = {} def save_category_tags(self): with open(self.category_tags_file, 'w') as f: json.dump(self.category_tags, f) def get_category_tags(self, category_var_name): # Return the tags list for the given category variable name return self.category_tags.get(category_var_name, []) def set_category_tags(self, category_var_name, tags_list): self.category_tags[category_var_name] = tags_list self.save_category_tags() def get_characters(self): # Load character images for char in self.characters: image_path = char.get('image_path') if image_path and os.path.exists(image_path): char['image'] = image_path else: char['image'] = None return self.characters def set_persistent_tags(self, tags_list): self.persistent_tags = tags_list self.save_persistent_tags() def get_persistent_tags(self): return self.persistent_tags def add_character(self, character): # Save image to disk and store the filename image_data = character['image'] # This is base64 encoded string safe_name = "".join(c for c in character['name'] if c.isalnum() or c in (' ', '_', '-')).rstrip() image_filename = f"{safe_name}.png" image_path = os.path.join(self.images_folder, image_filename) # Decode the base64 image data and save it if image_data: image = Image.open(io.BytesIO(base64.b64decode(image_data.split(",")[1]))) image.save(image_path) character['image_path'] = image_path else: character['image_path'] = None character.pop('image', None) # Assuming traits is a string, split into list if necessary if isinstance(character['traits'], str): character['traits'] = character['traits'].split(',') character['traits'] = [t.strip() for t in character['traits']] self.characters.append(character) self.save_characters() def character_creation_app(data_manager): with gr.Tab("Character Creation"): with gr.Row(): name_input = gr.Textbox(label="Character Name") traits_input = gr.Textbox(label="Traits/Appearance Tags (comma separated)") image_input = gr.Image(label="Upload Character Image", type="filepath") # New gender selection input gender_input = gr.Radio(choices=["Boy", "Girl"], label="Gender") save_button = gr.Button("Save Character") output = gr.Textbox(label="Status", interactive=False) def save_character(name, traits, image_path, gender): if not name or not traits or not gender: return "Please enter all fields." character = {'name': name, 'traits': traits, 'gender': gender, 'image': None} # Read and encode image if provided if image_path: with open(image_path, "rb") as img_file: image_data = base64.b64encode(img_file.read()).decode('utf-8') character['image'] = f"data:image/png;base64,{image_data}" else: character['image'] = None data_manager.add_character(character) return f"Character '{name}' saved successfully." save_button.click(save_character, inputs=[name_input, traits_input, image_input, gender_input], outputs=output) def prompt_generator_app(data_manager): with gr.Tab("Prompt Generator"): gr.Markdown("## Prompt Generator") # Add a refresh tags button refresh_tags_button = gr.Button("Refresh Tags") inputs = {} tag_displays = {} for category_name, var_name in categories: tags_list = data_manager.get_category_tags(var_name) tags_string = ', '.join(tags_list) max_tags = len(tags_list) if max_tags == 0: default_value = 0 else: default_value = min(1, max_tags) with gr.Group(): gr.Markdown(f"### {category_name}") tag_display = gr.Markdown(f"**Tags:** {tags_string}") tag_num = gr.Slider(minimum=0, maximum=max_tags, step=1, value=default_value, label=f"Number of {category_name} Tags to Select") inputs[f"{var_name}_num"] = tag_num tag_displays[var_name] = (tag_display, tag_num) # For Character Selection with gr.Group(): gr.Markdown("### Character Selection") # Get the list of characters def get_character_options(): characters = data_manager.get_characters() character_options = [] for char in characters: option_label = f"{char['name']} ({char['gender']})" character_options.append(option_label) return character_options character_options = get_character_options() character_select = gr.CheckboxGroup(choices=character_options, label="Select Characters", interactive=True) refresh_characters_button = gr.Button("Refresh Character List") def refresh_characters(): new_options = get_character_options() return gr.update(choices=new_options) refresh_characters_button.click(refresh_characters, outputs=character_select) random_characters = gr.Checkbox(label="Select Random Characters") num_characters = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Characters (if random)") generate_button = gr.Button("Generate Prompt") prompt_output = gr.Textbox(label="Generated Prompt", lines=5) def generate_prompt(*args): arg_idx = 0 prompt_tags = [] for category_name, var_name in categories: tags_list = data_manager.get_category_tags(var_name) tags_num = args[arg_idx] arg_idx += 1 if tags_list and tags_num > 0: selected_tags = random.sample(tags_list, min(len(tags_list), int(tags_num))) prompt_tags.extend(selected_tags) # Handle Characters selected_character_options = args[arg_idx] random_chars = args[arg_idx + 1] num_random_chars = args[arg_idx + 2] arg_idx += 3 characters = data_manager.get_characters() if random_chars: num = min(len(characters), int(num_random_chars)) selected_chars = random.sample(characters, num) else: # Extract selected character names from options selected_chars = [] for option in selected_character_options: name = option.split(' (')[0] for char in characters: if char['name'] == name: selected_chars.append(char) break # Determine the number of boys and girls num_girls = sum(1 for char in selected_chars if char.get('gender') == 'Girl') num_boys = sum(1 for char in selected_chars if char.get('gender') == 'Boy') # Build the initial character count tags character_count_tags = [] if num_girls > 0: character_count_tags.append(f"{num_girls}girl" if num_girls == 1 else f"{num_girls}girls") if num_boys > 0: character_count_tags.append(f"{num_boys}boy" if num_boys == 1 else f"{num_boys}boys") prompt_parts = [] if character_count_tags: prompt_parts.append(', '.join(character_count_tags)) # Build character descriptions character_descriptions = [] for idx, char in enumerate(selected_chars): # Get traits for the character traits = ', '.join(char['traits']) # Create a description for each character # For SDXL models, use the format "[char1 description] AND [char2 description]" # Each character's description is enclosed in parentheses character_description = f"({traits})" character_descriptions.append(character_description) # Join character descriptions appropriately for SDXL models if character_descriptions: character_descriptions_str = ' AND '.join(character_descriptions) prompt_parts.append(character_descriptions_str) # Append selected prompt tags from categories if prompt_tags: prompt_tags_str = ', '.join(prompt_tags) prompt_parts.append(prompt_tags_str) # Load persistent tags persistent_tags = data_manager.get_persistent_tags() if persistent_tags: persistent_tags_str = ', '.join(persistent_tags) prompt_parts.append(persistent_tags_str) # Add ending tags ending_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed" prompt_parts.append(ending_tags) prompt_string = ', '.join(prompt_parts) return prompt_string # Prepare the list of inputs for the generate_prompt function inputs_list = [] for category_name, var_name in categories: inputs_list.append(inputs[f"{var_name}_num"]) # Add character_select directly to inputs inputs_list.extend([character_select, random_characters, num_characters]) generate_button.click(generate_prompt, inputs=inputs_list, outputs=prompt_output) # Function to refresh tags display and sliders def refresh_tags(): updates = [] for category_name, var_name in categories: # Reload tags from data_manager tags_list = data_manager.get_category_tags(var_name) tags_string = ', '.join(tags_list) max_tags = len(tags_list) if max_tags == 0: slider_value = 0 else: slider_value = min(1, max_tags) # Update the tag display and slider tag_display, tag_num = tag_displays[var_name] updates.append(gr.update(value=f"**Tags:** {tags_string}")) updates.append(gr.update(maximum=max_tags, value=slider_value)) return updates # Prepare the outputs list outputs = [component for pair in tag_displays.values() for component in pair] # Connect the refresh_tags function to the refresh_tags_button refresh_tags_button.click(refresh_tags, outputs=outputs) def tags_app(data_manager): with gr.Tab("Tags"): gr.Markdown("## Edit Tags for Each Category") for category_name, var_name in categories: gr.Markdown(f"### {category_name} Tags") tags_list = data_manager.get_category_tags(var_name) tags_string = ', '.join(tags_list) tag_input = gr.Textbox(label=f"{category_name} Tags (comma separated)", value=tags_string) save_button = gr.Button(f"Save {category_name} Tags") status_output = gr.Textbox(label="", interactive=False) # Function to save tags def make_save_category_tags_fn(var_name, category_name): def fn(tags_string): tags_list = [t.strip() for t in tags_string.split(',') if t.strip()] data_manager.set_category_tags(var_name, tags_list) return f"{category_name} tags saved successfully." return fn save_fn = make_save_category_tags_fn(var_name, category_name) save_button.click(save_fn, inputs=tag_input, outputs=status_output) # Persistent Tags gr.Markdown(f"### Persistent Tags") persistent_tags_string = ', '.join(data_manager.get_persistent_tags()) persistent_tags_input = gr.Textbox(label="Persistent Tags (comma separated)", value=persistent_tags_string) save_persistent_tags_button = gr.Button("Save Persistent Tags") persistent_status_output = gr.Textbox(label="", interactive=False) def save_persistent_tags(tags_string): tags_list = [t.strip() for t in tags_string.split(',') if t.strip()] data_manager.set_persistent_tags(tags_list) return "Persistent tags saved successfully." save_persistent_tags_button.click(save_persistent_tags, inputs=persistent_tags_input, outputs=persistent_status_output) def main(): data_manager = DataManager(base_dir='/data') with gr.Blocks() as demo: prompt_generator_app(data_manager) character_creation_app(data_manager) tags_app(data_manager) demo.launch() if __name__ == "__main__": main()