import os
import re
import requests
import tempfile
import gradio as gr
from PIL import Image, ImageDraw
from config import config, theme
from public.data.images.loras.flux1 import loras as flux1_loras
# os.makedirs(config.get("HF_HOME"), exist_ok=True)
# UI
with gr.Blocks(
theme=theme,
fill_width=True,
css_paths=[os.path.join("static/css", f) for f in os.listdir("static/css")],
) as demo:
# States
data_state = gr.State()
local_state = gr.BrowserState(
{
"selected_loras": [],
}
)
with gr.Row():
with gr.Column(scale=1):
gr.Label("AllFlux", show_label=False)
with gr.Accordion("Settings", open=True):
with gr.Group():
height_slider = gr.Slider(
minimum=64,
maximum=2048,
value=1024,
step=64,
label="Height",
interactive=True,
)
width_slider = gr.Slider(
minimum=64,
maximum=2048,
value=1024,
step=64,
label="Width",
interactive=True,
)
with gr.Group():
num_images_slider = gr.Slider(
minimum=1,
maximum=4,
value=1,
step=1,
label="Number of Images",
interactive=True,
)
toggles = gr.CheckboxGroup(
choices=["Realtime", "Randomize Seed"],
value=["Randomize Seed"],
show_label=False,
interactive=True,
)
with gr.Accordion("Advanced", open=False):
num_steps_slider = gr.Slider(
minimum=1,
maximum=100,
value=20,
step=1,
label="Steps",
interactive=True,
)
guidance_scale_slider = gr.Slider(
minimum=1,
maximum=10,
value=3.5,
step=0.1,
label="Guidance Scale",
interactive=True,
)
seed_slider = gr.Slider(
minimum=0,
maximum=4294967295,
value=42,
step=1,
label="Seed",
interactive=True,
)
upscale_slider = gr.Slider(
minimum=2,
maximum=4,
value=2,
step=2,
label="Upscale",
interactive=True,
)
scheduler_dropdown = gr.Dropdown(
label="Scheduler",
choices=[
"Euler a",
"Euler",
"LMS",
"Heun",
"DPM++ 2",
"DPM++ 2 a",
"DPM++ SDE",
"DPM++ SDE Karras",
"DDIM",
"PLMS",
],
value="Euler a",
interactive=True,
)
gr.LoginButton()
gr.Markdown(
"""
Yurrrrrrrrrrrr, WIP
"""
)
with gr.Column(scale=3):
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
show_label=False,
placeholder="Enter your prompt here...",
lines=3,
)
with gr.Row():
with gr.Column(scale=3):
submit_btn = gr.Button("Submit")
with gr.Column(scale=1):
ai_improve_btn = gr.Button("💡", link="#improve-prompt")
with gr.Group():
output_gallery = gr.Gallery(
label="Outputs", interactive=False, height=500
)
with gr.Row():
upscale_selected_btn = gr.Button("Upscale Selected", size="sm")
upscale_all_btn = gr.Button("Upscale All", size="sm")
create_similar_btn = gr.Button("Create Similar", size="sm")
with gr.Accordion("Output History", open=False):
with gr.Group():
output_history_gallery = gr.Gallery(
show_label=False, interactive=False, height=500
)
with gr.Row():
clear_history_btn = gr.Button("Clear All", size="sm")
download_history_btn = gr.Button("Download All", size="sm")
with gr.Accordion("Image Playground", open=True):
def show_info(content: str | None = None):
info_checkbox = gr.Checkbox(
value=False, label="Show Info", interactive=True
)
@gr.render(inputs=info_checkbox)
def show_info(info_checkbox):
return (
gr.Markdown(
f"""Sup, need some help here, please check the community tab. {content}"""
)
if info_checkbox
else None
)
with gr.Tabs():
with gr.Tab("Img 2 Img"):
with gr.Group():
img2img_img = gr.Image(show_label=False, interactive=True)
img2img_strength_slider = gr.Slider(
minimum=0,
maximum=1,
value=1.0,
step=0.1,
label="Strength",
interactive=True,
)
show_info()
with gr.Tab("Inpaint"):
with gr.Group():
inpaint_img = gr.ImageMask(
show_label=False, interactive=True, type="pil"
)
generate_mask_btn = gr.Button(
"Remove Background", size="sm"
)
use_fill_pipe_inpaint = gr.Checkbox(
value=True,
label="Use Fill Pipeline 🧪",
interactive=True,
)
show_info()
inpaint_img.upload(
fn=lambda x: (
gr.update(height=x["layers"][0].height + 96)
if x is not None
else None
),
inputs=inpaint_img,
outputs=inpaint_img,
)
with gr.Tab("Outpaint"):
outpaint_img = gr.Image(
show_label=False, interactive=True, type="pil"
)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
ratio_9_16 = gr.Radio(
label="Image Ratio",
choices=["9:16", "16:9", "1:1", "Height & Width"],
value="9:16",
container=True,
interactive=True,
)
with gr.Column(scale=1):
mask_position = gr.Dropdown(
choices=[
"Middle",
"Left",
"Right",
"Top",
"Bottom",
],
value="Middle",
label="Alignment",
interactive=True,
)
with gr.Group():
resize_options = gr.Radio(
choices=["Full", "75%", "50%", "33%", "25%", "Custom"],
value="Full",
label="Resize",
interactive=True,
)
resize_option_custom = gr.State()
@gr.render(inputs=resize_options)
def resize_options_render(resize_option):
if resize_option == "Custom":
resize_option_custom = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Custom Size %",
interactive=True,
)
with gr.Accordion("Advanced settings", open=False):
with gr.Group():
mask_overlap_slider = gr.Slider(
label="Mask Overlap %",
minimum=1,
maximum=50,
value=10,
step=1,
interactive=True,
)
with gr.Row():
overlap_top = gr.Checkbox(
value=True,
label="Overlap Top",
interactive=True,
)
overlap_right = gr.Checkbox(
value=True,
label="Overlap Right",
interactive=True,
)
with gr.Row():
overlap_left = gr.Checkbox(
value=True,
label="Overlap Left",
interactive=True,
)
overlap_bottom = gr.Checkbox(
value=True,
label="Overlap Bottom",
interactive=True,
)
mask_preview_btn = gr.Button(
"Preview", interactive=True
)
mask_preview_img = gr.Image(
show_label=False, visible=False, interactive=True
)
def prepare_image_and_mask(
image,
width,
height,
overlap_percentage,
resize_option,
custom_resize_percentage,
alignment,
overlap_left,
overlap_right,
overlap_top,
overlap_bottom,
):
target_size = (width, height)
scale_factor = min(
target_size[0] / image.width,
target_size[1] / image.height,
)
new_width = int(image.width * scale_factor)
new_height = int(image.height * scale_factor)
source = image.resize(
(new_width, new_height), Image.LANCZOS
)
if resize_option == "Full":
resize_percentage = 100
elif resize_option == "75%":
resize_percentage = 75
elif resize_option == "50%":
resize_percentage = 50
elif resize_option == "33%":
resize_percentage = 33
elif resize_option == "25%":
resize_percentage = 25
else: # Custom
resize_percentage = custom_resize_percentage
# Calculate new dimensions based on percentage
resize_factor = resize_percentage / 100
new_width = int(source.width * resize_factor)
new_height = int(source.height * resize_factor)
# Ensure minimum size of 64 pixels
new_width = max(new_width, 64)
new_height = max(new_height, 64)
# Resize the image
source = source.resize(
(new_width, new_height), Image.LANCZOS
)
# Calculate the overlap in pixels based on the percentage
overlap_x = int(new_width * (overlap_percentage / 100))
overlap_y = int(new_height * (overlap_percentage / 100))
# Ensure minimum overlap of 1 pixel
overlap_x = max(overlap_x, 1)
overlap_y = max(overlap_y, 1)
# Calculate margins based on alignment
if alignment == "Middle":
margin_x = (target_size[0] - new_width) // 2
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Left":
margin_x = 0
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Right":
margin_x = target_size[0] - new_width
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Top":
margin_x = (target_size[0] - new_width) // 2
margin_y = 0
elif alignment == "Bottom":
margin_x = (target_size[0] - new_width) // 2
margin_y = target_size[1] - new_height
# Adjust margins to eliminate gaps
margin_x = max(
0, min(margin_x, target_size[0] - new_width)
)
margin_y = max(
0, min(margin_y, target_size[1] - new_height)
)
# Create a new background image and paste the resized source image
background = Image.new(
"RGB", target_size, (255, 255, 255)
)
background.paste(source, (margin_x, margin_y))
# Create the mask
mask = Image.new("L", target_size, 255)
mask_draw = ImageDraw.Draw(mask)
# Calculate overlap areas
white_gaps_patch = 2
left_overlap = (
margin_x + overlap_x
if overlap_left
else margin_x + white_gaps_patch
)
right_overlap = (
margin_x + new_width - overlap_x
if overlap_right
else margin_x + new_width - white_gaps_patch
)
top_overlap = (
margin_y + overlap_y
if overlap_top
else margin_y + white_gaps_patch
)
bottom_overlap = (
margin_y + new_height - overlap_y
if overlap_bottom
else margin_y + new_height - white_gaps_patch
)
if alignment == "Left":
left_overlap = (
margin_x + overlap_x
if overlap_left
else margin_x
)
elif alignment == "Right":
right_overlap = (
margin_x + new_width - overlap_x
if overlap_right
else margin_x + new_width
)
elif alignment == "Top":
top_overlap = (
margin_y + overlap_y
if overlap_top
else margin_y
)
elif alignment == "Bottom":
bottom_overlap = (
margin_y + new_height - overlap_y
if overlap_bottom
else margin_y + new_height
)
# Draw the mask
mask_draw.rectangle(
[
(left_overlap, top_overlap),
(right_overlap, bottom_overlap),
],
fill=0,
)
return background, mask
mask_preview_btn.click(
fn=prepare_image_and_mask,
inputs=[
outpaint_img,
width_slider,
height_slider,
mask_overlap_slider,
resize_options,
resize_option_custom,
mask_position,
overlap_left,
overlap_right,
overlap_top,
overlap_bottom,
],
outputs=[mask_preview_img, outpaint_img],
)
mask_preview_img.clear(
fn=lambda: gr.update(visible=False),
outputs=mask_preview_img,
)
use_fill_pipe_outpaint = gr.Checkbox(
value=True,
label="Use Fill Pipeline 🧪",
interactive=True,
)
show_info()
with gr.Tab("In-Context"):
with gr.Group():
incontext_img = gr.Image(show_label=False, interactive=True)
# https://huggingface.co/spaces/Yuanshi/OminiControl
show_info(content="1024 res is in beta")
with gr.Tab("IP-Adapter"):
with gr.Group():
ip_adapter_img = gr.Image(
show_label=False, interactive=True
)
ip_adapter_img_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
step=0.1,
label="Scale",
interactive=True,
)
# https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter
show_info(content="1024 res is in beta")
with gr.Tab("Canny"):
with gr.Group():
canny_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
canny_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.65,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
canny_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Tile"):
with gr.Group():
tile_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
tile_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.45,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
tile_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Depth"):
with gr.Group():
depth_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
depth_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.55,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
depth_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Blur"):
with gr.Group():
blur_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
blur_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.45,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
blur_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Pose"):
with gr.Group():
pose_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
pose_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.55,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
pose_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Gray"):
with gr.Group():
gray_img = gr.Image(show_label=False, interactive=True)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
gray_controlnet_conditioning_scale = gr.Slider(
minimum=0,
maximum=1,
value=0.45,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
with gr.Column(scale=1):
gray_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
with gr.Tab("Low Quality"):
with gr.Group():
low_quality_img = gr.Image(
show_label=False, interactive=True
)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
low_quality_controlnet_conditioning_scale = (
gr.Slider(
minimum=0,
maximum=1,
value=0.4,
step=0.05,
label="ControlNet Conditioning Scale",
interactive=True,
)
)
with gr.Column(scale=1):
low_quality_img_is_preprocessed = gr.Checkbox(
value=True,
label="Preprocessed",
interactive=True,
)
# with gr.Tab("Official Canny"):
# with gr.Group():
# gr.HTML(
# """
#
#
# """
# )
# with gr.Tab("Official Depth"):
# with gr.Group():
# gr.HTML(
# """
#
#
# """
# )
with gr.Tab("Auto Trainer"):
gr.HTML(
"""
"""
)
resize_mode_radio = gr.Radio(
label="Resize Mode",
choices=["Crop & Resize", "Resize Only", "Resize & Fill"],
value="Resize & Fill",
interactive=True,
)
with gr.Accordion("Prompt Generator", open=False):
gr.HTML(
"""
"""
)
with gr.Column(scale=1):
# Loras
with gr.Accordion("Loras", open=True):
selected_loras = gr.State([])
lora_selector = gr.Gallery(
show_label=False,
value=[(l["image"], l["title"]) for l in flux1_loras],
container=False,
columns=3,
show_download_button=False,
show_fullscreen_button=False,
allow_preview=False,
)
with gr.Group():
lora_selected = gr.Textbox(
show_label=False,
placeholder="Select a Lora to apply...",
container=False,
)
add_lora_btn = gr.Button("Add Lora", size="sm")
gr.Markdown(
"*You can add a Lora by entering a URL or a Hugging Face repo path."
)
# update the selected_loras state with the new lora
@gr.render(
inputs=[lora_selected, selected_loras],
triggers=[add_lora_btn.click],
)
def add_lora(lora_selected):
title = None
weights = None
info = None
if isinstance(lora_selected, int):
# Add from lora selector
title = lora_selector[lora_selected]["title"]
weights = lora_selector[lora_selected]["weights"]
info = lora_selector[lora_selected]["trigger_word"]
elif isinstance(lora_selected, str):
# check if url
if lora_selected.startswith("http"):
# Check if it's a CivitAI URL
if "civitai.com/models/" in lora_selected:
try:
# Extract model ID and version ID from URL
model_id = re.search(
r"/models/(\d+)", lora_selected
).group(1)
version_id = re.search(
r"modelVersionId=(\d+)", lora_selected
)
version_id = (
version_id.group(1) if version_id else None
)
# Get API token from config
api_token = config.get("CIVITAI_TOKEN")
headers = (
{"Authorization": f"Bearer {api_token}"}
if api_token
else {}
)
# Get model version info
if version_id:
url = f"https://civitai.com/api/v1/model-versions/{version_id}"
else:
# Get latest version if no specific version
url = f"https://civitai.com/api/v1/models/{model_id}"
response = requests.get(url, headers=headers)
data = response.json()
# For models endpoint, get first version
if "modelVersions" in data:
version_data = data["modelVersions"][0]
else:
version_data = data
# Verify it's a LoRA for Flux
if (
"flux" not in version_data["baseModel"].lower()
and "1" not in version_data["baseModel"].lower()
):
raise ValueError(
"This LoRA is not compatible with Flux base model"
)
# Find .safetensor file
safetensor_file = next(
(
f
for f in version_data["files"]
if f["name"].endswith(".safetensors")
),
None,
)
if not safetensor_file:
raise ValueError("No .safetensor file found")
# Download file to temp location
temp_dir = tempfile.gettempdir()
file_path = os.path.join(
temp_dir, safetensor_file["name"]
)
download_url = safetensor_file["downloadUrl"]
if api_token:
download_url += f"?token={api_token}"
response = requests.get(
download_url, headers=headers
)
with open(file_path, "wb") as f:
f.write(response.content)
# Set info from model data
title = data["name"]
weights = file_path
# Check usage tips for default weight
if "description" in version_data:
strength_match = re.search(
r"strength[:\s]+(\d*\.?\d+)",
version_data["description"],
re.IGNORECASE,
)
if strength_match:
weight = float(strength_match.group(1))
info = ", ".join(
version_data.get("trainedWords", [])
)
except Exception as e:
gr.Error(f"Error processing CivitAI URL: {str(e)}")
else:
# check if a hugging face repo (user/repo)
if re.match(
r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$", lora_selected
):
try:
# Get API token from config
api_token = config.get("HF_TOKEN")
headers = (
{"Authorization": f"Bearer {api_token}"}
if api_token
else {}
)
# Get model info
url = f"https://huggingface.co/api/models/{lora_selected}"
response = requests.get(url, headers=headers)
data = response.json()
# Verify it's a LoRA for Flux
if (
"tags" in data
and "flux-lora" not in data["tags"]
):
raise ValueError(
"This model is not tagged as a Flux LoRA"
)
# Find .safetensor file
files_url = f"https://huggingface.co/api/models/{lora_selected}/tree"
response = requests.get(files_url, headers=headers)
files = response.json()
safetensor_file = next(
(
f
for f in files
if f.get("path", "").endswith(
".safetensors"
)
),
None,
)
if not safetensor_file:
raise ValueError("No .safetensor file found")
# Download file to temp location
temp_dir = tempfile.gettempdir()
file_name = os.path.basename(
safetensor_file["path"]
)
file_path = os.path.join(temp_dir, file_name)
download_url = (
f"https://huggingface.co/{lora_selected}"
f"/resolve/main/{safetensor_file['path']}"
)
response = requests.get(
download_url, headers=headers
)
with open(file_path, "wb") as f:
f.write(response.content)
# Set info from model data
title = data.get(
"name", lora_selected.split("/")[-1]
)
weights = file_path
# Check model card for weight recommendations
if (
"cardData" in data
and "weight" in data["cardData"]
):
try:
weight = float(data["cardData"]["weight"])
except (ValueError, TypeError):
weight = 1.0
# Get trigger words from tags or model card
trigger_words = []
if (
"cardData" in data
and "trigger_words" in data["cardData"]
):
trigger_words.extend(
data["cardData"]["trigger_words"]
)
if "tags" in data:
trigger_words.extend(
t
for t in data["tags"]
if not t.startswith("flux-")
)
info = (
", ".join(trigger_words)
if trigger_words
else None
)
except Exception as e:
gr.Error(
f"Error processing Hugging Face repo: {str(e)}"
)
# add lora to selected_loras
selected_loras.append(
{
"title": title,
"weights": weights, # i.e safetensors file path
"info": info,
}
)
# render the selected_loras state as sliders
@gr.render(inputs=[selected_loras])
def render_selected_loras(selected_loras):
def update_lora_weight(lora_slider, selected_loras):
for i, lora in enumerate(selected_loras):
if lora["title"] == lora_slider.label:
lora["weight"] = lora_slider.value
for i, lora in enumerate(selected_loras):
lora_slider = gr.Slider(
label=lora["title"],
value=0.8,
interactive=True,
info=lora["info"],
)
lora_slider.change(
fn=update_lora_weight,
inputs=[lora_slider, selected_loras],
outputs=selected_loras,
)
demo.launch()