flx-upscale / app.py
fantaxy's picture
Update app.py
c43a736 verified
raw
history blame
5.23 kB
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import gc
# Force CPU usage
device = "cpu"
dtype = torch.float32
# Clear memory
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
css = """
#col-container {
margin: 0 auto;
max-width: 512px;
}
"""
huggingface_token = os.getenv("HF_TOKEN")
# Minimal model configuration
model_config = {
"low_cpu_mem_usage": True,
"torch_dtype": dtype,
"use_safetensors": True,
"device_map": "cpu"
}
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes", "*.bin"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load models on CPU
controlnet = FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
**model_config
)
pipe = FluxControlNetPipeline.from_pretrained(
model_path,
controlnet=controlnet,
**model_config
)
# Enable optimizations
pipe.enable_attention_slicing(1)
pipe.enable_vae_slicing()
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 64 * 64 # Extremely reduced
def process_input(input_image, upscale_factor):
input_image = input_image.convert('RGB')
# Aggressive size reduction
w, h = input_image.size
max_size = int(np.sqrt(MAX_PIXEL_BUDGET))
# Resize to very small size
new_w = min(w, max_size)
new_h = min(h, max_size)
input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
# Ensure dimensions are multiples of 8
w = new_w - new_w % 8
h = new_h - new_h % 8
return input_image.resize((w, h)), w, h
def infer(
seed,
randomize_seed,
input_image,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
progress=gr.Progress(track_tqdm=True),
):
try:
gc.collect()
if randomize_seed:
seed = random.randint(0, MAX_SEED)
input_image, w, h = process_input(input_image, upscale_factor)
with torch.inference_mode():
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt="",
control_image=input_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
guidance_scale=1.5,
height=h,
width=w,
generator=generator,
).images[0]
gc.collect()
return [input_image, image, seed]
except Exception as e:
gr.Error(f"Error: {str(e)}")
return None
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
run_button = gr.Button(value="Run")
with gr.Row():
with gr.Column(scale=4):
input_im = gr.Image(label="Input Image", type="pil")
with gr.Column(scale=1):
num_inference_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=10,
step=1,
value=5,
)
upscale_factor = gr.Slider(
label="Scale",
minimum=1,
maximum=1,
step=1,
value=1,
)
controlnet_conditioning_scale = gr.Slider(
label="Control Scale",
minimum=0.1,
maximum=0.3,
step=0.1,
value=0.2,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Random Seed", value=True)
with gr.Row():
result = ImageSlider(label="Result", type="pil", interactive=True)
current_dir = os.path.dirname(os.path.abspath(__file__))
examples = gr.Examples(
examples=[
[42, False, os.path.join(current_dir, "z1.webp"), 5, 1, 0.2],
[42, False, os.path.join(current_dir, "z2.webp"), 5, 1, 0.2],
],
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
fn=infer,
outputs=result,
cache_examples=False,
)
gr.on(
[run_button.click],
fn=infer,
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
outputs=result,
show_api=False,
)
# Minimal launch configuration
demo.queue(max_size=1).launch(
share=False,
debug=True,
show_error=True,
max_threads=1,
enable_queue=True,
cache_examples=False,
quiet=True,
)