import argparse
import os
import time
from os import path
from safetensors.torch import load_file
import huggingface_hub
from huggingface_hub import hf_hub_download
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
import gradio as gr
import torch
from diffusers import FluxPipeline
torch.backends.cuda.matmul.allow_tf32 = True
huggingface_hub.login(os.getenv('HF_TOKEN'))
class timer:
def __init__(self, method_name="timed process"):
self.method = method_name
def __enter__(self):
self.start = time.time()
print(f"{self.method} starts")
def __exit__(self, exc_type, exc_val, exc_tb):
end = time.time()
print(f"{self.method} took {str(round(end - self.start, 2))}s")
if not path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
def load_and_fuse_lora_weights(pipe, lora_models):
for repo, file_path, lora_scale in lora_models:
lora_weights_path = hf_hub_download(repo_id=repo, filename=file_path)
pipe.load_lora_weights(lora_weights_path)
pipe.fuse_lora(lora_scale=lora_scale)
# List of LoRA models and their corresponding scales
lora_models = [
("mrcuddle/Character_Design_Helper", "CharacterDesign-FluxV2.safetensors", 0.125),
("mrcuddle/live2d-model-maker", "LIVE2D-FLUX.safetensors", 0.125)
]
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
# Load and fuse LoRA weights
load_and_fuse_lora_weights(pipe, lora_models)
pipe.to(device="cuda", dtype=torch.bfloat16)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
Hyper-FLUX-8steps-LoRA
AutoML team from ByteDance
"""
)
with gr.Row():
with gr.Column(scale=3):
with gr.Group():
prompt = gr.Textbox(
label="Your Image Description",
placeholder="E.g., A serene landscape with mountains and a lake at sunset",
lines=3
)
# Hidden textbox for the preset prompt
preset_prompt = gr.Textbox(
label="Preset Prompt",
value="live2d,guijiaoxiansheng,separate hand,separate feet,separate head,multiple views,white background,CharacterDisgnFlux,magic particles, multiple references,color pallete reference,simple background,upper body,front,from side",
visible=False
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Group():
with gr.Row():
height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
with gr.Row():
steps = gr.Slider(label="Inference Steps", minimum=5, maximum=25, step=1, value=8)
scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=1, value=3.5)
seed = gr.Number(label="Seed (for reproducibility)", value=-1, precision=0)
generate_btn = gr.Button("Generate Image", variant="primary", scale=1)
with gr.Column(scale=4):
output = gr.Image(label="Your Generated Image")
def process_image(height, width, steps, scales, prompt, seed, preset_prompt):
global pipe
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
# Concatenate the preset prompt with the user's input prompt
full_prompt = f"{preset_prompt} {prompt}"
return pipe(
prompt=[full_prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
generate_btn.click(
process_image,
inputs=[height, width, steps, scales, prompt, seed, preset_prompt],
outputs=output
)
if __name__ == "__main__":
demo.launch()