TonAI-Creative / app.py
tungedng2710's picture
Initial commit
8eca2ce
import os
import random
import torch
import gc
import gradio as gr
import style as sty
from PIL import Image
from scheduler_mapping import schedulers, apply_scheduler
from utils import *
from diffusers.utils import logging
from query_comfyui import *
logging.set_verbosity_info()
logging.get_logger("diffusers").setLevel(logging.ERROR)
SCHEDULERS = list(schedulers.keys())
SCHEDULERS.insert(0, "Default")
def gen_image(prompt, negative_prompt, width, height,
num_steps, mode, seed, guidance_scale,
lora_weight_file, lora_scale, fast_infer,
scheduler, num_images, progress=gr.Progress(track_tqdm=True)):
"""
Run diffusion model to generate image
"""
progress(0, "Starting image generation...")
for i in range(1, num_steps + 1):
progress(i / num_steps * 100, f"Processing step {i} of {num_steps}...")
images = [Image.open("stuffs/logo.png")]
if len(prompt) == 0:
gr.Info("Please input prompt!", duration=5)
return images
# Query COmfyUI backend
if "Stable Diffusion 3.5" in mode:
if "Medium" in mode:
ckpt_name = "sd3.5_medium.safetensors"
else:
ckpt_name = "sd3.5_large.safetensors"
images = query_sd35(ckpt_name, prompt, negative_prompt,
int(width), int(height),
int(num_images), int(seed),
float(guidance_scale), int(num_steps))
return images
model = TEXT_TO_IMAGE_DICTIONARY[mode]
use_lora = False
_, current_max_memory = get_gpu_info(width, height, num_images)
Text2Image_class = model["pipeline"]
diffusion_configs = {
"use_safetensors": True,
"max_memory": current_max_memory
}
if "device_map" in model:
diffusion_configs["device_map"] = model["device_map"]
if fast_infer:
diffusion_configs["torch_dtype"] = torch.float16
if "FLUX" in mode:
diffusion_configs["torch_dtype"] = torch.bfloat16
if model["path"].endswith('.safetensors'):
pipeline = Text2Image_class.from_single_file(
model["path"], **diffusion_configs)
else:
pipeline = Text2Image_class.from_pretrained(
model["path"], **diffusion_configs)
pipeline.safety_checker = None
try:
pipeline = apply_scheduler(scheduler, pipeline)
except BaseException:
gr.Warning(f"Cannot apply {scheduler} for {mode}. Use default sampler instead")
pipeline = apply_scheduler("Default", pipeline)
# Load LoRA adapter
if lora_weight_file is not None:
directory, file_name = os.path.split(lora_weight_file.name)
try:
pipeline.load_lora_weights(
directory,
weight_name=file_name,
adapter_name=file_name.replace(".safetensors", ''))
gr.Info("LoRA weight loaded succesfully", duration=5)
use_lora = True
except Exception as e:
print(e)
gr.Warning("Cannot load LoRA weight, your model won't use adapter", duration=5)
# Assign GPU for pipeline
# if "FLUX" not in mode and "Stable Diffusion 3" not in mode:
device = assign_gpu(required_vram=10000,
width=width,
height=height,
num_images=num_images)
if device == "cpu":
gr.Warning("No available GPUs for inference")
return images
generator = torch.Generator("cuda").manual_seed(int(seed))
try:
pipeline_configs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"width": nearest_divisible_by_8(int(width)),
"height": nearest_divisible_by_8(int(height)),
"num_inference_steps": int(num_steps),
"generator": generator,
"guidance_scale": float(guidance_scale),
"num_images_per_prompt": num_images
}
if "FLUX" not in mode:
pipeline = pipeline.to(device)
else:
# Adjust for FLUX Pipeline
del pipeline_configs["negative_prompt"]
# Max 256 tokens for prompt
pipeline_configs["max_sequence_length"] = 256
if use_lora:
if "FLUX" in mode or "Stable Diffusion 3" in mode:
pipeline_configs["joint_attention_kwargs"] = {
"scale": lora_scale}
else:
pipeline_configs["cross_attention_kwargs"] = {
"scale": lora_scale}
# Generate images
images = pipeline(**pipeline_configs).images
except Exception as e:
raise gr.Error(f"Exception: {e}", duration=5)
progress(100, "Completed!")
del pipeline
pipeline = None
gc.collect()
torch.cuda.empty_cache()
return images
# -------------------------------------------- Gradio App -------------------------------------------- #
with gr.Blocks(title="TonAI Creative",
theme=sty.app_theme,
css=sty.custom_css) as interface:
gr.HTML(sty.tonai_creative_html)
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("Basic Usage", open=True):
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate")
with gr.Row():
width = gr.components.Slider(
minimum=512, maximum=1920, value=1024, step=8,
label="Width",
scale=1
)
height = gr.components.Slider(
minimum=512, maximum=1920, value=1024, step=8,
label="Height",
scale=1
)
mode = gr.Dropdown(
choices=TEXT_TO_IMAGE_DICTIONARY.keys(),
label="Mode",
filterable=False,
value=list(TEXT_TO_IMAGE_DICTIONARY.keys())[
0], # FLUX.1 Merged is default
interactive=True,
scale=1)
with gr.Row():
generate_btn = gr.Button("Generate", scale=2)
stop_btn = gr.Button("Stop", elem_id="stop-button", scale=1)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="ugly, disfigured, deformed",
placeholder="Instruct the AI model that it should not include")
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
num_steps = gr.components.Slider(
minimum=3, maximum=50, value=20, step=1,
label="Inference Steps",
scale=2
)
with gr.Row():
guidance_scale = gr.components.Slider(
minimum=0, maximum=20, value=3, step=0.1,
label="CFG Scale",
scale=1
)
with gr.Row():
num_images = gr.components.Slider(
minimum=1, maximum=6, value=1, step=1,
label="Number of generated images",
scale=1)
scheduler = gr.Dropdown(
choices=SCHEDULERS,
label="Sampler",
filterable=False,
value=SCHEDULERS[0],
interactive=True,
scale=1)
with gr.Column(scale=1):
seed = gr.Textbox(label="RNG Seed", value=0)
rng_btn = gr.Button("Roll the 🎲", scale=1)
rng_btn.click(
fn=generate_number, inputs=None, outputs=seed)
fast_infer = gr.Checkbox(
label="Fast Inference",
info="Faster run with FP16",
value=True,
scale=1)
with gr.Row():
lora_weight_file = gr.File(
label="LoRA safetensors file",
elem_classes="file-uploader",
file_types=["safetensors"],
min_width=50, height=30, scale=2)
lora_scale = gr.components.Slider(
minimum=0, maximum=1, value=0.8, step=0.01,
label="LoRA Scale",
scale=1
)
with gr.Accordion("Helps", open=False):
gr.Markdown(sty.tips_content)
with gr.Column(scale=1):
gallery = gr.Gallery(
label="Generated Images",
format="png",
elem_id="gallery",
columns=2, rows=2,
preview=True,
object_fit="contain")
click_button_behavior = {
"fn": gen_image,
"outputs": gallery,
"concurrency_limit": 10
}
click_event = generate_btn.click(inputs=[prompt,
negative_prompt,
width,
height,
num_steps,
mode,
seed,
guidance_scale,
lora_weight_file,
lora_scale,
fast_infer,
scheduler,
num_images],
**click_button_behavior)
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
interface.load(
lambda: gr.update(
value=random.randint(
0, 999999)), None, seed)
if __name__ == '__main__':
allowed_paths = ["stuffs/splash.png", "stuffs/favicon.png"]
interface.queue(default_concurrency_limit=10)
interface.launch(share=False,
root_path="/tonai",
server_name="0.0.0.0",
show_error=True,
favicon_path="stuffs/favicon.png",
allowed_paths=allowed_paths,
max_threads=10)