|
import os |
|
import re |
|
import requests |
|
import tempfile |
|
|
|
import gradio as gr |
|
from PIL import Image, ImageDraw |
|
|
|
from config import theme |
|
from public.data.images.loras.flux1 import loras as flux1_loras |
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
theme=theme, |
|
fill_width=True, |
|
css_paths=[os.path.join("static/css", f) for f in os.listdir("static/css")], |
|
) as demo: |
|
|
|
|
|
data_state = gr.State() |
|
local_state = gr.BrowserState( |
|
{ |
|
"selected_loras": [], |
|
} |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Label("AllFlux", show_label=False) |
|
|
|
with gr.Accordion("Settings", open=True): |
|
with gr.Group(): |
|
height_slider = gr.Slider( |
|
minimum=64, |
|
maximum=2048, |
|
value=1024, |
|
step=64, |
|
label="Height", |
|
interactive=True, |
|
) |
|
width_slider = gr.Slider( |
|
minimum=64, |
|
maximum=2048, |
|
value=1024, |
|
step=64, |
|
label="Width", |
|
interactive=True, |
|
) |
|
|
|
with gr.Group(): |
|
num_images_slider = gr.Slider( |
|
minimum=1, |
|
maximum=4, |
|
value=1, |
|
step=1, |
|
label="Number of Images", |
|
interactive=True, |
|
) |
|
|
|
toggles = gr.CheckboxGroup( |
|
choices=["Realtime", "Randomize Seed"], |
|
value=["Randomize Seed"], |
|
show_label=False, |
|
interactive=True, |
|
) |
|
|
|
with gr.Accordion("Advanced", open=False): |
|
num_steps_slider = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=20, |
|
step=1, |
|
label="Steps", |
|
interactive=True, |
|
) |
|
guidance_scale_slider = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=3.5, |
|
step=0.1, |
|
label="Guidance Scale", |
|
interactive=True, |
|
) |
|
seed_slider = gr.Slider( |
|
minimum=0, |
|
maximum=4294967295, |
|
value=42, |
|
step=1, |
|
label="Seed", |
|
interactive=True, |
|
) |
|
upscale_slider = gr.Slider( |
|
minimum=2, |
|
maximum=4, |
|
value=2, |
|
step=2, |
|
label="Upscale", |
|
interactive=True, |
|
) |
|
scheduler_dropdown = gr.Dropdown( |
|
label="Scheduler", |
|
choices=[ |
|
"Euler a", |
|
"Euler", |
|
"LMS", |
|
"Heun", |
|
"DPM++ 2", |
|
"DPM++ 2 a", |
|
"DPM++ SDE", |
|
"DPM++ SDE Karras", |
|
"DDIM", |
|
"PLMS", |
|
], |
|
value="Euler a", |
|
interactive=True, |
|
) |
|
|
|
gr.LoginButton() |
|
|
|
gr.Markdown( |
|
""" |
|
Yurrrrrrrrrrrr, WIP |
|
""" |
|
) |
|
|
|
with gr.Column(scale=3): |
|
with gr.Group(): |
|
with gr.Row(): |
|
prompt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter your prompt here...", |
|
lines=3, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
submit_btn = gr.Button("Submit") |
|
with gr.Column(scale=1): |
|
ai_improve_btn = gr.Button("💡", link="#improve-prompt") |
|
|
|
with gr.Group(): |
|
output_gallery = gr.Gallery( |
|
label="Outputs", interactive=False, height=500 |
|
) |
|
|
|
with gr.Row(): |
|
upscale_selected_btn = gr.Button("Upscale Selected", size="sm") |
|
upscale_all_btn = gr.Button("Upscale All", size="sm") |
|
create_similar_btn = gr.Button("Create Similar", size="sm") |
|
|
|
with gr.Accordion("Output History", open=False): |
|
with gr.Group(): |
|
output_history_gallery = gr.Gallery( |
|
show_label=False, interactive=False, height=500 |
|
) |
|
|
|
with gr.Row(): |
|
clear_history_btn = gr.Button("Clear All", size="sm") |
|
download_history_btn = gr.Button("Download All", size="sm") |
|
|
|
with gr.Accordion("Image Playground", open=True): |
|
|
|
def show_info(content: str | None = None): |
|
info_checkbox = gr.Checkbox( |
|
value=False, label="Show Info", interactive=True |
|
) |
|
|
|
@gr.render(inputs=info_checkbox) |
|
def show_info(info_checkbox): |
|
return ( |
|
gr.Markdown( |
|
f"""Sup, need some help here, please check the community tab. {content}""" |
|
) |
|
if info_checkbox |
|
else None |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Img 2 Img"): |
|
with gr.Group(): |
|
img2img_img = gr.Image(show_label=False, interactive=True) |
|
img2img_strength_slider = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=1.0, |
|
step=0.1, |
|
label="Strength", |
|
interactive=True, |
|
) |
|
|
|
show_info() |
|
|
|
with gr.Tab("Inpaint"): |
|
with gr.Group(): |
|
inpaint_img = gr.ImageMask( |
|
show_label=False, interactive=True, type="pil" |
|
) |
|
generate_mask_btn = gr.Button( |
|
"Remove Background", size="sm" |
|
) |
|
|
|
use_fill_pipe_inpaint = gr.Checkbox( |
|
value=True, |
|
label="Use Fill Pipeline 🧪", |
|
interactive=True, |
|
) |
|
|
|
show_info() |
|
|
|
inpaint_img.upload( |
|
fn=lambda x: ( |
|
gr.update(height=x["layers"][0].height + 96) |
|
if x is not None |
|
else None |
|
), |
|
inputs=inpaint_img, |
|
outputs=inpaint_img, |
|
) |
|
with gr.Tab("Outpaint"): |
|
outpaint_img = gr.Image( |
|
show_label=False, interactive=True, type="pil" |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
ratio_9_16 = gr.Radio( |
|
label="Image Ratio", |
|
choices=["9:16", "16:9", "1:1", "Height & Width"], |
|
value="9:16", |
|
container=True, |
|
interactive=True, |
|
) |
|
|
|
with gr.Column(scale=1): |
|
mask_position = gr.Dropdown( |
|
choices=[ |
|
"Middle", |
|
"Left", |
|
"Right", |
|
"Top", |
|
"Bottom", |
|
], |
|
value="Middle", |
|
label="Alignment", |
|
interactive=True, |
|
) |
|
|
|
with gr.Group(): |
|
resize_options = gr.Radio( |
|
choices=["Full", "75%", "50%", "33%", "25%", "Custom"], |
|
value="Full", |
|
label="Resize", |
|
interactive=True, |
|
) |
|
|
|
resize_option_custom = gr.State() |
|
@gr.render(inputs=resize_options) |
|
def resize_options_render(resize_option): |
|
if resize_option == "Custom": |
|
resize_option_custom = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=50, |
|
step=1, |
|
label="Custom Size %", |
|
interactive=True, |
|
) |
|
|
|
with gr.Accordion("Advanced settings", open=False): |
|
with gr.Group(): |
|
mask_overlap_slider = gr.Slider( |
|
label="Mask Overlap %", |
|
minimum=1, |
|
maximum=50, |
|
value=10, |
|
step=1, |
|
interactive=True, |
|
) |
|
with gr.Row(): |
|
overlap_top = gr.Checkbox( |
|
value=True, |
|
label="Overlap Top", |
|
interactive=True, |
|
) |
|
overlap_right = gr.Checkbox( |
|
value=True, |
|
label="Overlap Right", |
|
interactive=True, |
|
) |
|
with gr.Row(): |
|
overlap_left = gr.Checkbox( |
|
value=True, |
|
label="Overlap Left", |
|
interactive=True, |
|
) |
|
overlap_bottom = gr.Checkbox( |
|
value=True, |
|
label="Overlap Bottom", |
|
interactive=True, |
|
) |
|
mask_preview_btn = gr.Button( |
|
"Preview", interactive=True |
|
) |
|
|
|
mask_preview_img = gr.Image( |
|
show_label=False, visible=False, interactive=True |
|
) |
|
|
|
def prepare_image_and_mask( |
|
image, |
|
width, |
|
height, |
|
overlap_percentage, |
|
resize_option, |
|
custom_resize_percentage, |
|
alignment, |
|
overlap_left, |
|
overlap_right, |
|
overlap_top, |
|
overlap_bottom, |
|
): |
|
target_size = (width, height) |
|
|
|
scale_factor = min( |
|
target_size[0] / image.width, |
|
target_size[1] / image.height, |
|
) |
|
new_width = int(image.width * scale_factor) |
|
new_height = int(image.height * scale_factor) |
|
|
|
source = image.resize( |
|
(new_width, new_height), Image.LANCZOS |
|
) |
|
|
|
if resize_option == "Full": |
|
resize_percentage = 100 |
|
elif resize_option == "75%": |
|
resize_percentage = 75 |
|
elif resize_option == "50%": |
|
resize_percentage = 50 |
|
elif resize_option == "33%": |
|
resize_percentage = 33 |
|
elif resize_option == "25%": |
|
resize_percentage = 25 |
|
else: |
|
resize_percentage = custom_resize_percentage |
|
|
|
|
|
resize_factor = resize_percentage / 100 |
|
new_width = int(source.width * resize_factor) |
|
new_height = int(source.height * resize_factor) |
|
|
|
|
|
new_width = max(new_width, 64) |
|
new_height = max(new_height, 64) |
|
|
|
|
|
source = source.resize( |
|
(new_width, new_height), Image.LANCZOS |
|
) |
|
|
|
|
|
overlap_x = int(new_width * (overlap_percentage / 100)) |
|
overlap_y = int(new_height * (overlap_percentage / 100)) |
|
|
|
|
|
overlap_x = max(overlap_x, 1) |
|
overlap_y = max(overlap_y, 1) |
|
|
|
|
|
if alignment == "Middle": |
|
margin_x = (target_size[0] - new_width) // 2 |
|
margin_y = (target_size[1] - new_height) // 2 |
|
elif alignment == "Left": |
|
margin_x = 0 |
|
margin_y = (target_size[1] - new_height) // 2 |
|
elif alignment == "Right": |
|
margin_x = target_size[0] - new_width |
|
margin_y = (target_size[1] - new_height) // 2 |
|
elif alignment == "Top": |
|
margin_x = (target_size[0] - new_width) // 2 |
|
margin_y = 0 |
|
elif alignment == "Bottom": |
|
margin_x = (target_size[0] - new_width) // 2 |
|
margin_y = target_size[1] - new_height |
|
|
|
|
|
margin_x = max( |
|
0, min(margin_x, target_size[0] - new_width) |
|
) |
|
margin_y = max( |
|
0, min(margin_y, target_size[1] - new_height) |
|
) |
|
|
|
|
|
background = Image.new( |
|
"RGB", target_size, (255, 255, 255) |
|
) |
|
background.paste(source, (margin_x, margin_y)) |
|
|
|
|
|
mask = Image.new("L", target_size, 255) |
|
mask_draw = ImageDraw.Draw(mask) |
|
|
|
|
|
white_gaps_patch = 2 |
|
|
|
left_overlap = ( |
|
margin_x + overlap_x |
|
if overlap_left |
|
else margin_x + white_gaps_patch |
|
) |
|
right_overlap = ( |
|
margin_x + new_width - overlap_x |
|
if overlap_right |
|
else margin_x + new_width - white_gaps_patch |
|
) |
|
top_overlap = ( |
|
margin_y + overlap_y |
|
if overlap_top |
|
else margin_y + white_gaps_patch |
|
) |
|
bottom_overlap = ( |
|
margin_y + new_height - overlap_y |
|
if overlap_bottom |
|
else margin_y + new_height - white_gaps_patch |
|
) |
|
|
|
if alignment == "Left": |
|
left_overlap = ( |
|
margin_x + overlap_x |
|
if overlap_left |
|
else margin_x |
|
) |
|
elif alignment == "Right": |
|
right_overlap = ( |
|
margin_x + new_width - overlap_x |
|
if overlap_right |
|
else margin_x + new_width |
|
) |
|
elif alignment == "Top": |
|
top_overlap = ( |
|
margin_y + overlap_y |
|
if overlap_top |
|
else margin_y |
|
) |
|
elif alignment == "Bottom": |
|
bottom_overlap = ( |
|
margin_y + new_height - overlap_y |
|
if overlap_bottom |
|
else margin_y + new_height |
|
) |
|
|
|
|
|
mask_draw.rectangle( |
|
[ |
|
(left_overlap, top_overlap), |
|
(right_overlap, bottom_overlap), |
|
], |
|
fill=0, |
|
) |
|
|
|
return background, mask |
|
|
|
mask_preview_btn.click( |
|
fn=prepare_image_and_mask, |
|
inputs=[ |
|
outpaint_img, |
|
width_slider, |
|
height_slider, |
|
mask_overlap_slider, |
|
resize_options, |
|
resize_option_custom, |
|
mask_position, |
|
overlap_left, |
|
overlap_right, |
|
overlap_top, |
|
overlap_bottom, |
|
], |
|
outputs=[mask_preview_img, outpaint_img], |
|
) |
|
mask_preview_img.clear( |
|
fn=lambda: gr.update(visible=False), |
|
outputs=mask_preview_img, |
|
) |
|
|
|
use_fill_pipe_outpaint = gr.Checkbox( |
|
value=True, |
|
label="Use Fill Pipeline 🧪", |
|
interactive=True, |
|
) |
|
|
|
show_info() |
|
with gr.Tab("In-Context"): |
|
with gr.Group(): |
|
incontext_img = gr.Image(show_label=False, interactive=True) |
|
|
|
show_info(content="1024 res is in beta") |
|
with gr.Tab("IP-Adapter"): |
|
with gr.Group(): |
|
ip_adapter_img = gr.Image( |
|
show_label=False, interactive=True |
|
) |
|
ip_adapter_img_scale = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.7, |
|
step=0.1, |
|
label="Scale", |
|
interactive=True, |
|
) |
|
|
|
show_info(content="1024 res is in beta") |
|
with gr.Tab("Canny"): |
|
with gr.Group(): |
|
canny_img = gr.Image(show_label=False, interactive=True) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
canny_controlnet_conditioning_scale = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.65, |
|
step=0.05, |
|
label="ControlNet Conditioning Scale", |
|
interactive=True, |
|
) |
|
with gr.Column(scale=1): |
|
canny_img_is_preprocessed = gr.Checkbox( |
|
value=True, |
|
label="Preprocessed", |
|
interactive=True, |
|
) |
|
with gr.Tab("Tile"): |
|
with gr.Group(): |
|
tile_img = gr.Image(show_label=False, interactive=True) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
tile_controlnet_conditioning_scale = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.45, |
|
step=0.05, |
|
label="ControlNet Conditioning Scale", |
|
interactive=True, |
|
) |
|
with gr.Column(scale=1): |
|
tile_img_is_preprocessed = gr.Checkbox( |
|
value=True, |
|
label="Preprocessed", |
|
interactive=True, |
|
) |
|
with gr.Tab("Depth"): |
|
with gr.Group(): |
|
depth_img = gr.Image(show_label=False, interactive=True) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
depth_controlnet_conditioning_scale = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.55, |
|
step=0.05, |
|
label="ControlNet Conditioning Scale", |
|
interactive=True, |
|
) |
|
with gr.Column(scale=1): |
|
depth_img_is_preprocessed = gr.Checkbox( |
|
value=True, |
|
label="Preprocessed", |
|
interactive=True, |
|
) |
|
with gr.Tab("Blur"): |
|
with gr.Group(): |
|
blur_img = gr.Image(show_label=False, interactive=True) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
blur_controlnet_conditioning_scale = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.45, |
|
step=0.05, |
|
label="ControlNet Conditioning Scale", |
|
interactive=True, |
|
) |
|
with gr.Column(scale=1): |
|
blur_img_is_preprocessed = gr.Checkbox( |
|
value=True, |
|
label="Preprocessed", |
|
interactive=True, |
|
) |
|
with gr.Tab("Pose"): |
|
with gr.Group(): |
|
pose_img = gr.Image(show_label=False, interactive=True) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
pose_controlnet_conditioning_scale = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.55, |
|
step=0.05, |
|
label="ControlNet Conditioning Scale", |
|
interactive=True, |
|
) |
|
with gr.Column(scale=1): |
|
pose_img_is_preprocessed = gr.Checkbox( |
|
value=True, |
|
label="Preprocessed", |
|
interactive=True, |
|
) |
|
with gr.Tab("Gray"): |
|
with gr.Group(): |
|
gray_img = gr.Image(show_label=False, interactive=True) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
gray_controlnet_conditioning_scale = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.45, |
|
step=0.05, |
|
label="ControlNet Conditioning Scale", |
|
interactive=True, |
|
) |
|
with gr.Column(scale=1): |
|
gray_img_is_preprocessed = gr.Checkbox( |
|
value=True, |
|
label="Preprocessed", |
|
interactive=True, |
|
) |
|
with gr.Tab("Low Quality"): |
|
with gr.Group(): |
|
low_quality_img = gr.Image( |
|
show_label=False, interactive=True |
|
) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=3): |
|
low_quality_controlnet_conditioning_scale = ( |
|
gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.4, |
|
step=0.05, |
|
label="ControlNet Conditioning Scale", |
|
interactive=True, |
|
) |
|
) |
|
with gr.Column(scale=1): |
|
low_quality_img_is_preprocessed = gr.Checkbox( |
|
value=True, |
|
label="Preprocessed", |
|
interactive=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Auto Trainer"): |
|
gr.HTML( |
|
""" |
|
<script |
|
type="module" |
|
src="https://gradio.s3-us-west-2.amazonaws.com/4.42.0/gradio.js" |
|
></script> |
|
|
|
<gradio-app src="https://autotrain-projects-train-flux-lora-ease.hf.space"></gradio-app> |
|
""" |
|
) |
|
resize_mode_radio = gr.Radio( |
|
label="Resize Mode", |
|
choices=["Crop & Resize", "Resize Only", "Resize & Fill"], |
|
value="Resize & Fill", |
|
interactive=True, |
|
) |
|
|
|
with gr.Accordion("Prompt Generator", open=False): |
|
gr.HTML( |
|
""" |
|
<gradio-app src="https://gokaygokay-flux-prompt-generator.hf.space"></gradio-app> |
|
""" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
with gr.Accordion("Loras", open=True): |
|
selected_loras = gr.State([]) |
|
lora_selector = gr.Gallery( |
|
show_label=False, |
|
value=[(l["image"], l["title"]) for l in flux1_loras], |
|
container=False, |
|
columns=3, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
allow_preview=False, |
|
) |
|
with gr.Group(): |
|
lora_selected = gr.Textbox( |
|
show_label=False, |
|
placeholder="Select a Lora to apply...", |
|
container=False, |
|
) |
|
add_lora_btn = gr.Button("Add Lora", size="sm") |
|
gr.Markdown( |
|
"*You can add a Lora by entering a URL or a Hugging Face repo path." |
|
) |
|
|
|
|
|
@gr.render( |
|
inputs=[lora_selected, selected_loras], |
|
triggers=[add_lora_btn.click], |
|
) |
|
def add_lora(lora_selected): |
|
title = None |
|
weights = None |
|
info = None |
|
if isinstance(lora_selected, int): |
|
|
|
title = lora_selector[lora_selected]["title"] |
|
weights = lora_selector[lora_selected]["weights"] |
|
info = lora_selector[lora_selected]["trigger_word"] |
|
elif isinstance(lora_selected, str): |
|
|
|
if lora_selected.startswith("http"): |
|
|
|
if "civitai.com/models/" in lora_selected: |
|
try: |
|
|
|
model_id = re.search( |
|
r"/models/(\d+)", lora_selected |
|
).group(1) |
|
version_id = re.search( |
|
r"modelVersionId=(\d+)", lora_selected |
|
) |
|
version_id = ( |
|
version_id.group(1) if version_id else None |
|
) |
|
|
|
|
|
api_token = os.getenv("CIVITAI_TOKEN") |
|
headers = ( |
|
{"Authorization": f"Bearer {api_token}"} |
|
if api_token |
|
else {} |
|
) |
|
|
|
|
|
if version_id: |
|
url = f"https://civitai.com/api/v1/model-versions/{version_id}" |
|
else: |
|
|
|
url = f"https://civitai.com/api/v1/models/{model_id}" |
|
|
|
response = requests.get(url, headers=headers) |
|
data = response.json() |
|
|
|
|
|
if "modelVersions" in data: |
|
version_data = data["modelVersions"][0] |
|
else: |
|
version_data = data |
|
|
|
|
|
if ( |
|
"flux" not in version_data["baseModel"].lower() |
|
and "1" not in version_data["baseModel"].lower() |
|
): |
|
raise ValueError( |
|
"This LoRA is not compatible with Flux base model" |
|
) |
|
|
|
|
|
safetensor_file = next( |
|
( |
|
f |
|
for f in version_data["files"] |
|
if f["name"].endswith(".safetensors") |
|
), |
|
None, |
|
) |
|
|
|
if not safetensor_file: |
|
raise ValueError("No .safetensor file found") |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
file_path = os.path.join( |
|
temp_dir, safetensor_file["name"] |
|
) |
|
|
|
download_url = safetensor_file["downloadUrl"] |
|
if api_token: |
|
download_url += f"?token={api_token}" |
|
|
|
response = requests.get( |
|
download_url, headers=headers |
|
) |
|
with open(file_path, "wb") as f: |
|
f.write(response.content) |
|
|
|
|
|
title = data["name"] |
|
weights = file_path |
|
|
|
|
|
if "description" in version_data: |
|
strength_match = re.search( |
|
r"strength[:\s]+(\d*\.?\d+)", |
|
version_data["description"], |
|
re.IGNORECASE, |
|
) |
|
if strength_match: |
|
weight = float(strength_match.group(1)) |
|
|
|
info = ", ".join( |
|
version_data.get("trainedWords", []) |
|
) |
|
|
|
except Exception as e: |
|
gr.Error(f"Error processing CivitAI URL: {str(e)}") |
|
else: |
|
|
|
if re.match( |
|
r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$", lora_selected |
|
): |
|
try: |
|
|
|
api_token = os.getenv("HF_TOKEN") |
|
headers = ( |
|
{"Authorization": f"Bearer {api_token}"} |
|
if api_token |
|
else {} |
|
) |
|
|
|
|
|
url = f"https://huggingface.co/api/models/{lora_selected}" |
|
response = requests.get(url, headers=headers) |
|
data = response.json() |
|
|
|
|
|
if ( |
|
"tags" in data |
|
and "flux-lora" not in data["tags"] |
|
): |
|
raise ValueError( |
|
"This model is not tagged as a Flux LoRA" |
|
) |
|
|
|
|
|
files_url = f"https://huggingface.co/api/models/{lora_selected}/tree" |
|
response = requests.get(files_url, headers=headers) |
|
files = response.json() |
|
|
|
safetensor_file = next( |
|
( |
|
f |
|
for f in files |
|
if f.get("path", "").endswith( |
|
".safetensors" |
|
) |
|
), |
|
None, |
|
) |
|
|
|
if not safetensor_file: |
|
raise ValueError("No .safetensor file found") |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
file_name = os.path.basename( |
|
safetensor_file["path"] |
|
) |
|
file_path = os.path.join(temp_dir, file_name) |
|
|
|
download_url = ( |
|
f"https://huggingface.co/{lora_selected}" |
|
f"/resolve/main/{safetensor_file['path']}" |
|
) |
|
|
|
response = requests.get( |
|
download_url, headers=headers |
|
) |
|
with open(file_path, "wb") as f: |
|
f.write(response.content) |
|
|
|
|
|
title = data.get( |
|
"name", lora_selected.split("/")[-1] |
|
) |
|
weights = file_path |
|
|
|
|
|
if ( |
|
"cardData" in data |
|
and "weight" in data["cardData"] |
|
): |
|
try: |
|
weight = float(data["cardData"]["weight"]) |
|
except (ValueError, TypeError): |
|
weight = 1.0 |
|
|
|
|
|
trigger_words = [] |
|
if ( |
|
"cardData" in data |
|
and "trigger_words" in data["cardData"] |
|
): |
|
trigger_words.extend( |
|
data["cardData"]["trigger_words"] |
|
) |
|
if "tags" in data: |
|
trigger_words.extend( |
|
t |
|
for t in data["tags"] |
|
if not t.startswith("flux-") |
|
) |
|
|
|
info = ( |
|
", ".join(trigger_words) |
|
if trigger_words |
|
else None |
|
) |
|
|
|
except Exception as e: |
|
gr.Error( |
|
f"Error processing Hugging Face repo: {str(e)}" |
|
) |
|
|
|
|
|
selected_loras.append( |
|
{ |
|
"title": title, |
|
"weights": weights, |
|
"info": info, |
|
} |
|
) |
|
|
|
|
|
@gr.render(inputs=[selected_loras]) |
|
def render_selected_loras(selected_loras): |
|
def update_lora_weight(lora_slider, selected_loras): |
|
for i, lora in enumerate(selected_loras): |
|
if lora["title"] == lora_slider.label: |
|
lora["weight"] = lora_slider.value |
|
|
|
for i, lora in enumerate(selected_loras): |
|
lora_slider = gr.Slider( |
|
label=lora["title"], |
|
value=0.8, |
|
interactive=True, |
|
info=lora["info"], |
|
) |
|
lora_slider.change( |
|
fn=update_lora_weight, |
|
inputs=[lora_slider, selected_loras], |
|
outputs=selected_loras, |
|
) |
|
|
|
|
|
demo.launch() |
|
|