import logging import random import warnings import os import gradio as gr import spaces import numpy as np import torch from diffusers import FluxControlNetModel from diffusers.pipelines import FluxControlNetPipeline from PIL import Image from huggingface_hub import snapshot_download css = """ #col-container { margin: 0 auto; max-width: 512px; } """ # Check for GPU availability if torch.cuda.is_available(): power_device = "GPU" device = "cuda" else: power_device = "CPU" device = "cpu" # Load HuggingFace model huggingface_token = os.getenv("HUGGINFACE_TOKEN") model_path = snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", repo_type="model", ignore_patterns=["*.md", "*..gitattributes"], local_dir="FLUX.1-dev", token=huggingface_token, ) # Load pipeline controlnet = FluxControlNetModel.from_pretrained( "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 ).to(device) pipe = FluxControlNetPipeline.from_pretrained( model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 ) pipe.to(device) MAX_SEED = 1000000 MAX_PIXEL_BUDGET = 1024 * 1024 def process_input(input_image, upscale_factor): w, h = input_image.size w_original, h_original = w, h aspect_ratio = w / h was_resized = False if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: warnings.warn( f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels." ) input_image = input_image.resize( ( int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), ) ) was_resized = True # Resize to multiple of 8 w, h = input_image.size w = w - w % 8 h = h - h % 8 return input_image.resize((w, h)), w_original, h_original, was_resized @spaces.GPU def infer( seed, randomize_seed, input_image_path, num_inference_steps, upscale_factor, controlnet_conditioning_scale ): # Load image input_image = Image.open(input_image_path) # Handle random seed if specified if randomize_seed: seed = random.randint(0, MAX_SEED) true_input_image = input_image input_image, w_original, h_original, was_resized = process_input(input_image, upscale_factor) # Rescale with upscale factor w, h = input_image.size control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) generator = torch.Generator().manual_seed(seed) # Upscale image = pipe( prompt="", control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=num_inference_steps, guidance_scale=3.5, height=control_image.size[1], width=control_image.size[0], generator=generator, ).images[0] # Resize output if initially resized if was_resized: image = image.resize((w_original * upscale_factor, h_original * upscale_factor)) image.save("output.jpg") return true_input_image, image, seed # Gradio setup without ImageSlider with gr.Blocks(css=css) as demo: gr.Markdown( f""" # ⚡ Flux.1-dev Upscaler ControlNet ⚡ This is an interactive demo of [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler). """ ) run_button = gr.Button(value="Run") input_im = gr.Image(label="Input Image", type="filepath") num_inference_steps = gr.Slider(label="Number of Inference Steps", minimum=8, maximum=50, step=1, value=28) upscale_factor = gr.Slider(label="Upscale Factor", minimum=1, maximum=4, step=1, value=4) controlnet_conditioning_scale = gr.Slider(label="Controlnet Conditioning Scale", minimum=0.1, maximum=1.5, step=0.1, value=0.6) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) input_image_display = gr.Image(label="Input Image Display") output_image_display = gr.Image(label="Upscaled Image Display") run_button.click( infer, inputs=[seed, randomize_seed, input_im, num_inference_steps, upscale_factor, controlnet_conditioning_scale], outputs=[input_image_display, output_image_display, gr.Textbox(label="Used Seed")] ) demo.queue().launch(share=False, show_api=True, show_error=True)