Spaces:
its-magick
/
Runtime error

pixe / app.py
marks
Added install step
47ada1c
import os
# Install pruna dependency
os.system("pip install pruna[gpu]==0.1.2 --extra-index-url https://prunaai.pythonanywhere.com/")
import gradio as gr
import numpy as np
import random
import spaces
import torch
import time
from diffusers import DiffusionPipeline, AutoencoderTiny
from diffusers.models.attention_processor import AttnProcessor2_0
from custom_pipeline import FluxWithCFGPipeline
from pruna import SmashConfig
torch.backends.cuda.matmul.allow_tf32 = True
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_INFERENCE_STEPS = 1
# Device and model setup
dtype = torch.bfloat16
print('Initializing pipeline...')
pipe = FluxWithCFGPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
)
print('Loading VAE...')
#pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
pipe.to("cuda")
smash_config = SmashConfig()
smash_config['compilers'] = ['flux_caching']
smash_config['comp_flux_caching_cache_interval'] = 2 # Higher is faster, but reduces quality
smash_config['comp_flux_caching_start_step'] = 2 # Best to keep it as the same as cache_interval
smash_config['comp_flux_caching_compile'] = True # Whether to additionally compile the model for extra speed up
smash_config['comp_flux_caching_save_model'] = True # Whether to save the model after compilation or just use it for inference
print('Pipeline and VAE loaded to CUDA.')
print('Loading weights from repo: Shakker-Labs/FLUX.1-dev-LoRA-add-details')
pipe.load_lora_weights('Shakker-Labs/FLUX.1-dev-LoRA-add-details', weight_name='FLUX-dev-lora-add_details.safetensors', adapter_name='detail')
print('Loading weights from repo: its-magick/merlin-test-loras')
pipe.load_lora_weights('its-magick/merlin-test-loras', weight_name='Canopus-LoRA-Flux-UltraRealism.safetensors', adapter_name='ultrarealism')
print('Loading weights from repo: its-magick/merlin-test-loras')
pipe.load_lora_weights('its-magick/merlin-test-loras', weight_name='Canopus-LoRA-Flux-FaceRealism.safetensors', adapter_name='faces')
print('Loading weights from repo: miike-ai/merlin-ironman')
pipe.load_lora_weights('miike-ai/merlin-ironman', weight_name='lora.safetensors', adapter_name='ironman')
print('Loading weights from repo: its-magick/merlin-food')
pipe.load_lora_weights('its-magick/merlin-food', weight_name='lora.safetensors', adapter_name='food')
print('Loading weights from repo: its-magick/merlin-logos')
pipe.load_lora_weights('its-magick/merlin-logos', weight_name='merlin-logos.safetensors', adapter_name='logos')
print('Loading weights from repo: its-magick/merlin-mobile-app')
pipe.load_lora_weights('its-magick/merlin-mobile-app', weight_name='lora.safetensors', adapter_name='mobile')
pipe.load_lora_weights('its-magick/merlin-infographic', weight_name='lora.safetensors', adapter_name='infographic')
print('Loading weights from repo: its-magick/merlin-anti-blur')
pipe.load_lora_weights('its-magick/merlin-anti-blur', weight_name='merlin-anti-blur.safetensors', adapter_name='deblur')
print('Loading weights from repo: its-magick/merlin-office')
pipe.load_lora_weights('its-magick/merlin-office', weight_name='lora.safetensors', adapter_name='office')
print('Loading weights from repo: its-magick/merlin-channel-letters')
pipe.load_lora_weights('its-magick/merlin-channel-letters', weight_name='lora.safetensors', adapter_name='channel-letters')
print('Loading weights from repo: its-magick/merlin-headshots')
pipe.load_lora_weights('its-magick/merlin-headshots', weight_name='lora.safetensors', adapter_name='headshots')
print('Loading weights from repo: its-magick/merlin-panoramic')
pipe.load_lora_weights('its-magick/merlin-panoramic', weight_name='lora.safetensors', adapter_name='panoramic')
pipe.load_lora_weights('its-magick/perfection style v1.safetensors', weight_name='perfection style v1.safetensors', adapter_name='perfection')
print('All safetensor files have loaded successfully.')
print('Setting adapters...')
pipe.set_adapters(["detail"], adapter_weights=[0.6])
pipe.set_adapters(["faces"], adapter_weights=[0.6])
pipe.set_adapters(["ultrarealism"], adapter_weights=[0.6])
pipe.set_adapters(["ironman"], adapter_weights=[0.6])
pipe.set_adapters(["food"], adapter_weights=[0.6])
pipe.set_adapters(["logos"], adapter_weights=[0.6])
pipe.set_adapters(["mobile"], adapter_weights=[0.6])
pipe.set_adapters(["infographic"], adapter_weights=[0.6])
pipe.set_adapters(["deblur"], adapter_weights=[0.6])
pipe.set_adapters(["office"], adapter_weights=[0.6])
pipe.set_adapters(["channel-letters"], adapter_weights=[0.6])
pipe.set_adapters(["headshots"], adapter_weights=[0.6])
pipe.set_adapters(["panoramic"], adapter_weights=[0.6])
pipe.set_adapters(["perfection"], adapter_weights=[0.8])
print('Adapters have been set.')
print('Fusing LoRAs...')
pipe.fuse_lora(adapter_name=["faces"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["detail"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["ultrarealism"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["ironman"], lora_scale=0.4)
pipe.fuse_lora(adapter_name=["food"], lora_scale=0.4)
pipe.fuse_lora(adapter_name=["logos"], lora_scale=0.4)
pipe.fuse_lora(adapter_name=["infographic"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["deblur"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["office"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["channel-letters"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["headshots"], lora_scale=0.6)
pipe.fuse_lora(adapter_name=["panoramic"], lora_scale=0.6)
print('LoRAs have been fused.')
# Inference function
@spaces.GPU(duration=25)
def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=False, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(int(float(seed)))
start_time = time.time()
print(f'Starting image generation with prompt: "{prompt}", seed: {seed}, width: {width}, height: {height}, steps: {num_inference_steps}')
# Only generate the last image in the sequence
img = pipe.generate_images(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator
)
latency = f"Latency: {(time.time()-start_time):.2f} seconds"
print(f'Image generation completed. {latency}')
return img, seed, latency
# Example prompts
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cute white cat holding a sign that says hello world",
"an anime illustration of Steve Jobs",
"Create image of Modern house in minecraft style",
"photo of a woman on the beach, shot from above. She is facing the sea, while wearing a white dress. She has long blonde hair",
"Selfie photo of a wizard with long beard and purple robes, he is apparently in the middle of Tokyo. Probably taken from a phone.",
"Photo of a young woman with long, wavy brown hair tied in a bun and glasses. She has a fair complexion and is wearing subtle makeup, emphasizing her eyes and lips. She is dressed in a black top. The background appears to be an urban setting with a building facade, and the sunlight casts a warm glow on her face.",
]
# --- Gradio UI ---
with gr.Blocks() as demo:
with gr.Column(elem_id="app-container"):
gr.Markdown("#pixe")
gr.Markdown("Generate stunning images in real-time.")
# gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
with gr.Row():
with gr.Column(scale=2.5):
result = gr.Image(label="Generated Image", show_label=False, interactive=False)
with gr.Column(scale=1):
prompt = gr.Text(
label="Prompt",
placeholder="Describe the image you want to generate...",
lines=3,
show_label=False,
container=False,
)
generateBtn = gr.Button("πŸ–ΌοΈ Generate Image")
# enhanceBtn = gr.Button("πŸš€ Enhance Image")
with gr.Column("Advanced Options"):
with gr.Row():
realtime = gr.Checkbox(label="Realtime Toggler", info="If TRUE then uses more GPU but create image in realtime.", value=False)
latency = gr.Text(label="Latency")
with gr.Row():
seed = gr.Number(label="Seed", value=42)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=4, step=1, value=DEFAULT_INFERENCE_STEPS)
# with gr.Row():
# gr.Markdown("### 🌟 Inspiration Gallery")
# with gr.Row():
# gr.Examples(
# examples=examples,
# fn=generate_image,
# inputs=[prompt],
# outputs=[result, seed, latency],
# cache_examples="lazy"
# )
# enhanceBtn.click(
# fn=generate_image,
# inputs=[prompt, seed, width, height],
# outputs=[result, seed, latency],
# show_progress="full",
# queue=False,
# concurrency_limit=None
# )
generateBtn.click(
fn=generate_image,
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="full",
api_name="RealtimeFlux",
queue=False
)
def update_ui(realtime_enabled):
return {
prompt: gr.update(interactive=True),
generateBtn: gr.update(visible=not realtime_enabled)
}
realtime.change(
fn=update_ui,
inputs=[realtime],
outputs=[prompt, generateBtn],
queue=False,
concurrency_limit=None
)
def realtime_generation(*args):
if args[0]: # If realtime is enabled
return next(generate_image(*args[1:]))
prompt.submit(
fn=generate_image,
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="full",
queue=False,
concurrency_limit=None
)
for component in [prompt, width, height, num_inference_steps]:
component.input(
fn=realtime_generation,
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="hidden",
trigger_mode="always_last",
queue=False,
concurrency_limit=None
)
# Launch the app
print('Launching the app...')
demo.launch(share=True)
print('App launched successfully.')