AllFlux / app.py
barreloflube's picture
UI code
54d6187
raw
history blame
46.5 kB
import os
import re
import requests
import tempfile
import gradio as gr
from PIL import Image, ImageDraw
from config import config, theme
from public.data.images.loras.flux1 import loras as flux1_loras
# os.makedirs(config.get("HF_HOME"), exist_ok=True)
# UI
with gr.Blocks(
theme=theme,
fill_width=True,
css_paths=[os.path.join("static/css", f) for f in os.listdir("static/css")],
) as demo:
# States
data_state = gr.State()
local_state = gr.BrowserState(
{
"selected_loras": [],
}
)
with gr.Row():
with gr.Column(scale=1):
gr.Label("AllFlux", show_label=False)
with gr.Accordion("Settings", open=True):
with gr.Group():
height_slider = gr.Slider(
minimum=64,
maximum=2048,
value=1024,
step=64,
label="Height",
interactive=True,
)
width_slider = gr.Slider(
minimum=64,
maximum=2048,
value=1024,
step=64,
label="Width",
interactive=True,
)
with gr.Group():
num_images_slider = gr.Slider(
minimum=1,
maximum=4,
value=1,
step=1,
label="Number of Images",
interactive=True,
)
toggles = gr.CheckboxGroup(
choices=["Realtime", "Randomize Seed"],
value=["Randomize Seed"],
show_label=False,
interactive=True,
)
with gr.Accordion("Advanced", open=False):
num_steps_slider = gr.Slider(
minimum=1,
maximum=100,
value=20,
step=1,
label="Steps",
interactive=True,
)
guidance_scale_slider = gr.Slider(
minimum=1,
maximum=10,
value=3.5,
step=0.1,
label="Guidance Scale",
interactive=True,
)
seed_slider = gr.Slider(
minimum=0,
maximum=4294967295,
value=42,
step=1,
label="Seed",
interactive=True,
)
upscale_slider = gr.Slider(
minimum=2,
maximum=4,
value=2,
step=2,
label="Upscale",
interactive=True,
)
scheduler_dropdown = gr.Dropdown(
label="Scheduler",
choices=[
"Euler a",
"Euler",
"LMS",
"Heun",
"DPM++ 2",
"DPM++ 2 a",
"DPM++ SDE",
"DPM++ SDE Karras",
"DDIM",
"PLMS",
],
value="Euler a",
interactive=True,
)
gr.LoginButton()
gr.Markdown(
"""
Yurrrrrrrrrrrr, WIP
"""
)
with gr.Column(scale=3):
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
show_label=False,
placeholder="Enter your prompt here...",
lines=3,
)
with gr.Row():
with gr.Column(scale=3):
submit_btn = gr.Button("Submit")
with gr.Column(scale=1):
ai_improve_btn = gr.Button("💡", link="#improve-prompt")
with gr.Group():
output_gallery = gr.Gallery(
label="Outputs", interactive=False, height=500
)
with gr.Row():
upscale_selected_btn = gr.Button("Upscale Selected", size="sm")
upscale_all_btn = gr.Button("Upscale All", size="sm")
create_similar_btn = gr.Button("Create Similar", size="sm")
with gr.Accordion("Output History", open=False):
with gr.Group():
output_history_gallery = gr.Gallery(
show_label=False, interactive=False, height=500
)
with gr.Row():
clear_history_btn = gr.Button("Clear All", size="sm")
download_history_btn = gr.Button("Download All", size="sm")
with gr.Accordion("Image Playground", open=True):
def show_info(content: str | None = None):
info_checkbox = gr.Checkbox(
value=False, label="Show Info", interactive=True
)
@gr.render(inputs=info_checkbox)
def show_info(info_checkbox):
return (
gr.Markdown(
f"""Sup, need some help here, please check the community tab. {content}"""
)
if info_checkbox
else None
)
with gr.Tabs():
with gr.Tab("Img 2 Img"):
with gr.Group():
img2img_img = gr.Image(show_label=False, interactive=True)
img2img_strength_slider = gr.Slider(
minimum=0,
maximum=1,
value=1.0,
step=0.1,
label="Strength",
interactive=True,
)
show_info()
with gr.Tab("Inpaint"):
with gr.Group():
inpaint_img = gr.ImageMask(
show_label=False, interactive=True, type="pil"
)
generate_mask_btn = gr.Button(
"Remove Background", size="sm"
)
use_fill_pipe_inpaint = gr.Checkbox(
value=True,
label="Use Fill Pipeline 🧪",
interactive=True,
)
show_info()
inpaint_img.upload(
fn=lambda x: (
gr.update(height=x["layers"][0].height + 96)
if x is not None
else None
),
inputs=inpaint_img,
outputs=inpaint_img,
)
with gr.Tab("Outpaint"):
outpaint_img = gr.Image(
show_label=False, interactive=True, type="pil"
)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
ratio_9_16 = gr.Radio(
label="Image Ratio",
choices=["9:16", "16:9", "1:1", "Height & Width"],
value="9:16",
container=True,
interactive=True,
)
with gr.Column(scale=1):
mask_position = gr.Dropdown(
choices=[
"Middle",
"Left",
"Right",
"Top",
"Bottom",
],
value="Middle",
label="Alignment",
interactive=True,
)
with gr.Group():
resize_options = gr.Radio(
choices=["Full", "75%", "50%", "33%", "25%", "Custom"],
value="Full",
label="Resize",
interactive=True,
)
resize_option_custom = gr.State()
@gr.render(inputs=resize_options)
def resize_options_render(resize_option):
if resize_option == "Custom":
resize_option_custom = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Custom Size %",
interactive=True,
)
with gr.Accordion("Advanced settings", open=False):
with gr.Group():
mask_overlap_slider = gr.Slider(
label="Mask Overlap %",
minimum=1,
maximum=50,
value=10,
step=1,
interactive=True,
)
with gr.Row():
overlap_top = gr.Checkbox(
value=True,
label="Overlap Top",
interactive=True,
)
overlap_right = gr.Checkbox(
value=True,
label="Overlap Right",
interactive=True,
)
with gr.Row():
overlap_left = gr.Checkbox(
value=True,
label="Overlap Left",
interactive=True,
)
overlap_bottom = gr.Checkbox(
value=True,
label="Overlap Bottom",
interactive=True,
)
mask_preview_btn = gr.Button(
"Preview", interactive=True
)
mask_preview_img = gr.Image(
show_label=False, visible=False, interactive=True
)
def prepare_image_and_mask(
image,
width,
height,
overlap_percentage,
resize_option,
custom_resize_percentage,
alignment,
overlap_left,
overlap_right,
overlap_top,
overlap_bottom,
):
target_size = (width, height)
scale_factor = min(
target_size[0] / image.width,
target_size[1] / image.height,
)
new_width = int(image.width * scale_factor)
new_height = int(image.height * scale_factor)
source = image.resize(
(new_width, new_height), Image.LANCZOS
)
if resize_option == "Full":
resize_percentage = 100
elif resize_option == "75%":
resize_percentage = 75
elif resize_option == "50%":
resize_percentage = 50
elif resize_option == "33%":
resize_percentage = 33
elif resize_option == "25%":
resize_percentage = 25
else: # Custom
resize_percentage = custom_resize_percentage
# Calculate new dimensions based on percentage
resize_factor = resize_percentage / 100
new_width = int(source.width * resize_factor)
new_height = int(source.height * resize_factor)
# Ensure minimum size of 64 pixels
new_width = max(new_width, 64)
new_height = max(new_height, 64)
# Resize the image
source = source.resize(
(new_width, new_height), Image.LANCZOS
)
# Calculate the overlap in pixels based on the percentage
overlap_x = int(new_width * (overlap_percentage / 100))
overlap_y = int(new_height * (overlap_percentage / 100))
# Ensure minimum overlap of 1 pixel
overlap_x = max(overlap_x, 1)
overlap_y = max(overlap_y, 1)
# Calculate margins based on alignment
if alignment == "Middle":
margin_x = (target_size[0] - new_width) // 2
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Left":
margin_x = 0
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Right":
margin_x = target_size[0] - new_width
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Top":
margin_x = (target_size[0] - new_width) // 2
margin_y = 0
elif alignment == "Bottom":
margin_x = (target_size[0] - new_width) // 2
margin_y = target_size[1] - new_height
# Adjust margins to eliminate gaps
margin_x = max(
0, min(margin_x, target_size[0] - new_width)
)
margin_y = max(
0, min(margin_y, target_size[1] - new_height)
)
# Create a new background image and paste the resized source image
background = Image.new(
"RGB", target_size, (255, 255, 255)
)
background.paste(source, (margin_x, margin_y))
# Create the mask
mask = Image.new("L", target_size, 255)
mask_draw = ImageDraw.Draw(mask)
# Calculate overlap areas
white_gaps_patch = 2
left_overlap = (
margin_x + overlap_x
if overlap_left
else margin_x + white_gaps_patch
)
right_overlap = (
margin_x + new_width - overlap_x
if overlap_right
else margin_x + new_width - white_gaps_patch
)
top_overlap = (
margin_y + overlap_y
if overlap_top
else margin_y + white_gaps_patch
)
bottom_overlap = (
margin_y + new_height - overlap_y
if overlap_bottom
else margin_y + new_height - white_gaps_patch
)
if alignment == "Left":
left_overlap = (
margin_x + overlap_x
if overlap_left
else margin_x
)
elif alignment == "Right":
right_overlap = (
margin_x + new_width - overlap_x
if overlap_right
else margin_x + new_width
)
elif alignment == "Top":
top_overlap = (
margin_y + overlap_y
if overlap_top
else margin_y
)
elif alignment == "Bottom":
bottom_overlap = (
margin_y + new_height - overlap_y
if overlap_bottom
else margin_y + new_height
)
# Draw the mask
mask_draw.rectangle(
[
(left_overlap, top_overlap),
(right_overlap, bottom_overlap),
],
fill=0,
)
return background, mask
mask_preview_btn.click(
fn=prepare_image_and_mask,
inputs=[
outpaint_img,
width_slider,
height_slider,
mask_overlap_slider,
resize_options,
resize_option_custom,
mask_position,
overlap_left,
overlap_right,
overlap_top,
overlap_bottom,
],
outputs=[mask_preview_img, outpaint_img],
)
mask_preview_img.clear(
fn=lambda: gr.update(visible=False),
outputs=mask_preview_img,
)
use_fill_pipe_outpaint = gr.Checkbox(
value=True,
label="Use Fill Pipeline 🧪",
interactive=True,
)
show_info()
with gr.Tab("In-Context"):
with gr.Group():
incontext_img = gr.Image(show_label=False, interactive=True)
# https://huggingface.co/spaces/Yuanshi/OminiControl
show_info(content="1024 res is in beta")
with gr.Tab("IP-Adapter"):
with gr.Group():
ip_adapter_img = gr.Image(
show_label=False, interactive=True
)
ip_adapter_img_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
step=0.1,
label="Scale",
interactive=True,
)
# https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter
show_info(content="1024 res is in beta")
with gr.Tab("Canny"):
with gr.Group():
canny_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
canny_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.65,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
canny_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Tile"):
with gr.Group():
tile_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
tile_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.45,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
tile_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Depth"):
with gr.Group():
depth_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
depth_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.55,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
depth_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Blur"):
with gr.Group():
blur_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
blur_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.45,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
blur_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Pose"):
with gr.Group():
pose_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
pose_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.55,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
pose_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Gray"):
with gr.Group():
gray_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
gray_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.45,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
gray_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Low Quality"):
with gr.Group():
low_quality_img = gr.Image(
show_label=False, interactive=True
)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
low_quality_controlnet_conditioning_scale = (
gr.Slider(
minimum=0,
maximum=1,
value=0.4,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
)
with gr.Column(scale=1):
low_quality_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
# with gr.Tab("Official Canny"):
# with gr.Group():
# gr.HTML(
# """
# <script
# type="module"
# src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js"
# ></script>
# <gradio-app src="https://black-forest-labs-flux-1-canny-dev.hf.space"></gradio-app>
# """
# )
# with gr.Tab("Official Depth"):
# with gr.Group():
# gr.HTML(
# """
# <script
# type="module"
# src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js"
# ></script>
# <gradio-app src="https://black-forest-labs-flux-1-depth-dev.hf.space"></gradio-app>
# """
# )
with gr.Tab("Auto Trainer"):
gr.HTML(
"""
<script
type="module"
src="https://gradio.s3-us-west-2.amazonaws.com/4.42.0/gradio.js"
></script>
<gradio-app src="https://autotrain-projects-train-flux-lora-ease.hf.space"></gradio-app>
"""
)
resize_mode_radio = gr.Radio(
label="Resize Mode",
choices=["Crop & Resize", "Resize Only", "Resize & Fill"],
value="Resize & Fill",
interactive=True,
)
with gr.Accordion("Prompt Generator", open=False):
gr.HTML(
"""
<gradio-app src="https://gokaygokay-flux-prompt-generator.hf.space"></gradio-app>
"""
)
with gr.Column(scale=1):
# Loras
with gr.Accordion("Loras", open=True):
selected_loras = gr.State([])
lora_selector = gr.Gallery(
show_label=False,
value=[(l["image"], l["title"]) for l in flux1_loras],
container=False,
columns=3,
show_download_button=False,
show_fullscreen_button=False,
allow_preview=False,
)
with gr.Group():
lora_selected = gr.Textbox(
show_label=False,
placeholder="Select a Lora to apply...",
container=False,
)
add_lora_btn = gr.Button("Add Lora", size="sm")
gr.Markdown(
"*You can add a Lora by entering a URL or a Hugging Face repo path."
)
# update the selected_loras state with the new lora
@gr.render(
inputs=[lora_selected, selected_loras],
triggers=[add_lora_btn.click],
)
def add_lora(lora_selected):
title = None
weights = None
info = None
if isinstance(lora_selected, int):
# Add from lora selector
title = lora_selector[lora_selected]["title"]
weights = lora_selector[lora_selected]["weights"]
info = lora_selector[lora_selected]["trigger_word"]
elif isinstance(lora_selected, str):
# check if url
if lora_selected.startswith("http"):
# Check if it's a CivitAI URL
if "civitai.com/models/" in lora_selected:
try:
# Extract model ID and version ID from URL
model_id = re.search(
r"/models/(\d+)", lora_selected
).group(1)
version_id = re.search(
r"modelVersionId=(\d+)", lora_selected
)
version_id = (
version_id.group(1) if version_id else None
)
# Get API token from config
api_token = config.get("CIVITAI_TOKEN")
headers = (
{"Authorization": f"Bearer {api_token}"}
if api_token
else {}
)
# Get model version info
if version_id:
url = f"https://civitai.com/api/v1/model-versions/{version_id}"
else:
# Get latest version if no specific version
url = f"https://civitai.com/api/v1/models/{model_id}"
response = requests.get(url, headers=headers)
data = response.json()
# For models endpoint, get first version
if "modelVersions" in data:
version_data = data["modelVersions"][0]
else:
version_data = data
# Verify it's a LoRA for Flux
if (
"flux" not in version_data["baseModel"].lower()
and "1" not in version_data["baseModel"].lower()
):
raise ValueError(
"This LoRA is not compatible with Flux base model"
)
# Find .safetensor file
safetensor_file = next(
(
f
for f in version_data["files"]
if f["name"].endswith(".safetensors")
),
None,
)
if not safetensor_file:
raise ValueError("No .safetensor file found")
# Download file to temp location
temp_dir = tempfile.gettempdir()
file_path = os.path.join(
temp_dir, safetensor_file["name"]
)
download_url = safetensor_file["downloadUrl"]
if api_token:
download_url += f"?token={api_token}"
response = requests.get(
download_url, headers=headers
)
with open(file_path, "wb") as f:
f.write(response.content)
# Set info from model data
title = data["name"]
weights = file_path
# Check usage tips for default weight
if "description" in version_data:
strength_match = re.search(
r"strength[:\s]+(\d*\.?\d+)",
version_data["description"],
re.IGNORECASE,
)
if strength_match:
weight = float(strength_match.group(1))
info = ", ".join(
version_data.get("trainedWords", [])
)
except Exception as e:
gr.Error(f"Error processing CivitAI URL: {str(e)}")
else:
# check if a hugging face repo (user/repo)
if re.match(
r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$", lora_selected
):
try:
# Get API token from config
api_token = config.get("HF_TOKEN")
headers = (
{"Authorization": f"Bearer {api_token}"}
if api_token
else {}
)
# Get model info
url = f"https://huggingface.co/api/models/{lora_selected}"
response = requests.get(url, headers=headers)
data = response.json()
# Verify it's a LoRA for Flux
if (
"tags" in data
and "flux-lora" not in data["tags"]
):
raise ValueError(
"This model is not tagged as a Flux LoRA"
)
# Find .safetensor file
files_url = f"https://huggingface.co/api/models/{lora_selected}/tree"
response = requests.get(files_url, headers=headers)
files = response.json()
safetensor_file = next(
(
f
for f in files
if f.get("path", "").endswith(
".safetensors"
)
),
None,
)
if not safetensor_file:
raise ValueError("No .safetensor file found")
# Download file to temp location
temp_dir = tempfile.gettempdir()
file_name = os.path.basename(
safetensor_file["path"]
)
file_path = os.path.join(temp_dir, file_name)
download_url = (
f"https://huggingface.co/{lora_selected}"
f"/resolve/main/{safetensor_file['path']}"
)
response = requests.get(
download_url, headers=headers
)
with open(file_path, "wb") as f:
f.write(response.content)
# Set info from model data
title = data.get(
"name", lora_selected.split("/")[-1]
)
weights = file_path
# Check model card for weight recommendations
if (
"cardData" in data
and "weight" in data["cardData"]
):
try:
weight = float(data["cardData"]["weight"])
except (ValueError, TypeError):
weight = 1.0
# Get trigger words from tags or model card
trigger_words = []
if (
"cardData" in data
and "trigger_words" in data["cardData"]
):
trigger_words.extend(
data["cardData"]["trigger_words"]
)
if "tags" in data:
trigger_words.extend(
t
for t in data["tags"]
if not t.startswith("flux-")
)
info = (
", ".join(trigger_words)
if trigger_words
else None
)
except Exception as e:
gr.Error(
f"Error processing Hugging Face repo: {str(e)}"
)
# add lora to selected_loras
selected_loras.append(
{
"title": title,
"weights": weights, # i.e safetensors file path
"info": info,
}
)
# render the selected_loras state as sliders
@gr.render(inputs=[selected_loras])
def render_selected_loras(selected_loras):
def update_lora_weight(lora_slider, selected_loras):
for i, lora in enumerate(selected_loras):
if lora["title"] == lora_slider.label:
lora["weight"] = lora_slider.value
for i, lora in enumerate(selected_loras):
lora_slider = gr.Slider(
label=lora["title"],
value=0.8,
interactive=True,
info=lora["info"],
)
lora_slider.change(
fn=update_lora_weight,
inputs=[lora_slider, selected_loras],
outputs=selected_loras,
)
demo.launch()