import torch, os, io
import numpy as np
from PIL import Image
import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline
from diffsynth.data.video import crop_and_resize


config = {
    "Stable Diffusion": {
        "model_folder": "models/stable_diffusion",
        "pipeline_class": SDImagePipeline,
        "fixed_parameters": {}
    },
    "Stable Diffusion XL": {
        "model_folder": "models/stable_diffusion_xl",
        "pipeline_class": SDXLImagePipeline,
        "fixed_parameters": {}
    },
    "Stable Diffusion 3": {
        "model_folder": "models/stable_diffusion_3",
        "pipeline_class": SD3ImagePipeline,
        "fixed_parameters": {}
    },
    "Stable Diffusion XL Turbo": {
        "model_folder": "models/stable_diffusion_xl_turbo",
        "pipeline_class": SDXLImagePipeline,
        "fixed_parameters": {
            "negative_prompt": "",
            "cfg_scale": 1.0,
            "num_inference_steps": 1,
            "height": 512,
            "width": 512,
        }
    },
    "HunyuanDiT": {
        "model_folder": "models/HunyuanDiT",
        "pipeline_class": HunyuanDiTImagePipeline,
        "fixed_parameters": {
            "height": 1024,
            "width": 1024,
        }
    },
}


def load_model_list(model_type):
    folder = config[model_type]["model_folder"]
    file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
    if model_type == "HunyuanDiT":
        file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
    file_list = sorted(file_list)
    return file_list


def release_model():
    if "model_manager" in st.session_state:
        st.session_state["model_manager"].to("cpu")
        del st.session_state["loaded_model_path"]
        del st.session_state["model_manager"]
        del st.session_state["pipeline"]
        torch.cuda.empty_cache()


def load_model(model_type, model_path):
    model_manager = ModelManager()
    if model_type == "HunyuanDiT":
        model_manager.load_models([
            os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
            os.path.join(model_path, "mt5/pytorch_model.bin"),
            os.path.join(model_path, "model/pytorch_model_ema.pt"),
            os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
        ])
    else:
        model_manager.load_model(model_path)
    pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
    st.session_state.loaded_model_path = model_path
    st.session_state.model_manager = model_manager
    st.session_state.pipeline = pipeline
    return model_manager, pipeline


def use_output_image_as_input(update=True):
    # Search for input image
    output_image_id = 0
    selected_output_image = None
    while True:
        if f"use_output_as_input_{output_image_id}" not in st.session_state:
            break
        if st.session_state[f"use_output_as_input_{output_image_id}"]:
            selected_output_image = st.session_state["output_images"][output_image_id]
            break
        output_image_id += 1
    if update and selected_output_image is not None:
        st.session_state["input_image"] = selected_output_image
    return selected_output_image is not None


def apply_stroke_to_image(stroke_image, image):
    image = np.array(image.convert("RGB")).astype(np.float32)
    height, width, _ = image.shape

    stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
    weight = stroke_image[:, :, -1:] / 255
    stroke_image = stroke_image[:, :, :-1]

    image = stroke_image * weight + image * (1 - weight)
    image = np.clip(image, 0, 255).astype(np.uint8)
    image = Image.fromarray(image)
    return image


@st.cache_data
def image2bits(image):
    image_byte = io.BytesIO()
    image.save(image_byte, format="PNG")
    image_byte = image_byte.getvalue()
    return image_byte


def show_output_image(image):
    st.image(image, use_column_width="always")
    st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
    st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")


column_input, column_output = st.columns(2)
with st.sidebar:
    # Select a model
    with st.expander("Model", expanded=True):
        model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
        fixed_parameters = config[model_type]["fixed_parameters"]
        model_path_list = ["None"] + load_model_list(model_type)
        model_path = st.selectbox("Model path", model_path_list)

        # Load the model
        if model_path == "None":
            # No models are selected. Release VRAM.
            st.markdown("No models are selected.")
            release_model()
        else:
            # A model is selected.
            model_path = os.path.join(config[model_type]["model_folder"], model_path)
            if st.session_state.get("loaded_model_path", "") != model_path:
                # The loaded model is not the selected model. Reload it.
                st.markdown(f"Loading model at {model_path}.")
                st.markdown("Please wait a moment...")
                release_model()
                model_manager, pipeline = load_model(model_type, model_path)
                st.markdown("Done.")
            else:
                # The loaded model is not the selected model. Fetch it from `st.session_state`.
                st.markdown(f"Loading model at {model_path}.")
                st.markdown("Please wait a moment...")
                model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
                st.markdown("Done.")

    # Show parameters
    with st.expander("Prompt", expanded=True):
        prompt = st.text_area("Positive prompt")
        if "negative_prompt" in fixed_parameters:
            negative_prompt = fixed_parameters["negative_prompt"]
        else:
            negative_prompt = st.text_area("Negative prompt")
        if "cfg_scale" in fixed_parameters:
            cfg_scale = fixed_parameters["cfg_scale"]
        else:
            cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
    with st.expander("Image", expanded=True):
        if "num_inference_steps" in fixed_parameters:
            num_inference_steps = fixed_parameters["num_inference_steps"]
        else:
            num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
        if "height" in fixed_parameters:
            height = fixed_parameters["height"]
        else:
            height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
        if "width" in fixed_parameters:
            width = fixed_parameters["width"]
        else:
            width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
        num_images = st.number_input("Number of images", value=2)
        use_fixed_seed = st.checkbox("Use fixed seed", value=False)
        if use_fixed_seed:
            seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)

    # Other fixed parameters
    denoising_strength = 1.0
    repetition = 1


# Show input image
with column_input:
    with st.expander("Input image (Optional)", expanded=True):
        with st.container(border=True):
            column_white_board, column_upload_image = st.columns([1, 2])
            with column_white_board:
                create_white_board = st.button("Create white board")
                delete_input_image = st.button("Delete input image")
            with column_upload_image:
                upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")

        if upload_image is not None:
            st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
        elif create_white_board:
            st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
        else:
            use_output_image_as_input()

        if delete_input_image and "input_image" in st.session_state:
            del st.session_state.input_image
        if delete_input_image and "upload_image" in st.session_state:
            del st.session_state.upload_image

        input_image = st.session_state.get("input_image", None)
        if input_image is not None:
            with st.container(border=True):
                column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
                with column_drawing_mode:
                    drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
                with column_color_1:
                    stroke_color = st.color_picker("Stroke color")
                with column_color_2:
                    fill_color = st.color_picker("Fill color")
                stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
            with st.container(border=True):
                denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
                repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
            with st.container(border=True):
                input_width, input_height = input_image.size
                canvas_result = st_canvas(
                    fill_color=fill_color,
                    stroke_width=stroke_width,
                    stroke_color=stroke_color,
                    background_color="rgba(255, 255, 255, 0)",
                    background_image=input_image,
                    update_streamlit=True,
                    height=int(512 / input_width * input_height),
                    width=512,
                    drawing_mode=drawing_mode,
                    key="canvas"
                )


with column_output:
    run_button = st.button("Generate image", type="primary")
    auto_update = st.checkbox("Auto update", value=False)
    num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
    image_columns = st.columns(num_image_columns)

    # Run
    if (run_button or auto_update) and model_path != "None":

        if input_image is not None:
            input_image = input_image.resize((width, height))
            if canvas_result.image_data is not None:
                input_image = apply_stroke_to_image(canvas_result.image_data, input_image)

        output_images = []
        for image_id in range(num_images * repetition):
            if use_fixed_seed:
                torch.manual_seed(seed + image_id)
            else:
                torch.manual_seed(np.random.randint(0, 10**9))
            if image_id >= num_images:
                input_image = output_images[image_id - num_images]
            with image_columns[image_id % num_image_columns]:
                progress_bar_st = st.progress(0.0)
                image = pipeline(
                    prompt, negative_prompt=negative_prompt,
                    cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
                    height=height, width=width,
                    input_image=input_image, denoising_strength=denoising_strength,
                    progress_bar_st=progress_bar_st
                )
                output_images.append(image)
                progress_bar_st.progress(1.0)
                show_output_image(image)
                st.session_state["output_images"] = output_images

    elif "output_images" in st.session_state:
        for image_id in range(len(st.session_state.output_images)):
            with image_columns[image_id % num_image_columns]:
                image = st.session_state.output_images[image_id]
                progress_bar = st.progress(1.0)
                show_output_image(image)
    if "upload_image" in st.session_state and use_output_image_as_input(update=False):
        st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")