import torch
from diffusers import FluxPipeline
import gradio as gr
import threading
import os

os.environ["OMP_NUM_THREADS"] = str(os.cpu_count())
torch.set_num_threads(os.cpu_count())

# Initialize Flux pipeline
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()

stop_event = threading.Event()

def generate_images(
    prompt,
    height,
    width,
    guidance_scale,
    num_inference_steps,
    max_sequence_length,
    seed,
    randomize_seed
):
    stop_event.clear()
    results = []
    
    for i in range(3):
        if stop_event.is_set():
            return [None] * 3
        
        # Handle seed randomization
        if randomize_seed:
            current_seed = torch.randint(0, 2**32 - 1, (1,)).item()
        else:
            current_seed = seed + i
        
        generator = torch.Generator(device="cpu").manual_seed(current_seed)
        
        # Generate image with current parameters
        image = pipe(
            prompt=prompt,
            height=int(height),
            width=int(width),
            guidance_scale=guidance_scale,
            num_inference_steps=int(num_inference_steps),
            max_sequence_length=int(max_sequence_length),
            generator=generator
        ).images[0]
        
        results.append(image)
    
    return results

def stop_generation():
    stop_event.set()
    return [None] * 3

with gr.Blocks() as interface:
    gr.Markdown("""
    ### FLUX Image Generation
    Adjust parameters below to control the image generation process
    """)
    
    with gr.Row():
        text_input = gr.Textbox(
            label="Prompt",
            placeholder="Describe what you want to generate...",
            scale=3
        )
    
    with gr.Accordion("Generation Parameters", open=False):
        with gr.Row():
            height = gr.Number(
                label="Height",
                value=1024,
                minimum=512,
                maximum=4096,
                step=64,
                precision=0
            )
            width = gr.Number(
                label="Width",
                value=1024,
                minimum=512,
                maximum=4096,
                step=64,
                precision=0
            )
        
        guidance_scale = gr.Slider(
            label="Guidance Scale",
            minimum=0.0,
            maximum=20.0,
            value=7.0,
            step=0.5
        )
        
        num_inference_steps = gr.Slider(
            label="Inference Steps",
            minimum=10,
            maximum=150,
            value=50,
            step=1
        )
        
        max_sequence_length = gr.Dropdown(
            label="Max Sequence Length",
            choices=[512, 768, 1024],
            value=512
        )
        
        with gr.Row():
            seed = gr.Number(
                label="Seed",
                value=42,
                precision=0
            )
            randomize_seed = gr.Checkbox(
                label="Randomize Seed",
                value=True
            )

    with gr.Row():
        generate_btn = gr.Button("Generate", variant="primary")
        stop_btn = gr.Button("Stop Generation")

    with gr.Row():
        output1 = gr.Image(label="Output 1", type="pil")
        output2 = gr.Image(label="Output 2", type="pil")
        output3 = gr.Image(label="Output 3", type="pil")

    generate_btn.click(
        generate_images,
        inputs=[
            text_input,
            height,
            width,
            guidance_scale,
            num_inference_steps,
            max_sequence_length,
            seed,
            randomize_seed
        ],
        outputs=[output1, output2, output3]
    )
    
    stop_btn.click(
        stop_generation,
        inputs=[],
        outputs=[output1, output2, output3]
    )

interface.launch()