import spaces
import argparse
import os
import time
from os import path
import shutil
from datetime import datetime
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.stable_diffusion import safety_checker
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import subprocess

# Flash Attention 설치
subprocess.run('pip install flash-attn --no-build-isolation', 
              env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, 
              shell=True)

# Setup and initialization code
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
gallery_path = path.join(PERSISTENT_DIR, "gallery")

os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path

torch.backends.cuda.matmul.allow_tf32 = True

# Create gallery directory
if not path.exists(gallery_path):
    os.makedirs(gallery_path, exist_ok=True)

# Florence 모델 초기화
florence_models = {
    'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
        'gokaygokay/Florence-2-Flux-Large', 
        trust_remote_code=True
    ).eval(),
    'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained(
        'gokaygokay/Florence-2-Flux', 
        trust_remote_code=True
    ).eval(),
}

florence_processors = {
    'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained(
        'gokaygokay/Florence-2-Flux-Large', 
        trust_remote_code=True
    ),
    'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained(
        'gokaygokay/Florence-2-Flux', 
        trust_remote_code=True
    ),
}

def filter_prompt(prompt):
    inappropriate_keywords = [
        "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
        "erotic", "sensual", "seductive", "provocative", "intimate",
        "violence", "gore", "blood", "death", "kill", "murder", "torture",
        "drug", "suicide", "abuse", "hate", "discrimination"
    ]
    
    prompt_lower = prompt.lower()
    
    for keyword in inappropriate_keywords:
        if keyword in prompt_lower:
            return False, "부적절한 내용이 포함된 프롬프트입니다."
            
    return True, prompt

class timer:
    def __init__(self, method_name="timed process"):
        self.method = method_name
    def __enter__(self):
        self.start = time.time()
        print(f"{self.method} starts")
    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        print(f"{self.method} took {str(round(end - self.start, 2))}s")

# Model initialization
if not path.exists(cache_path):
    os.makedirs(cache_path, exist_ok=True)

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
    hf_hub_download(
        "ByteDance/Hyper-SD", 
        "Hyper-FLUX.1-dev-8steps-lora.safetensors"
    )
)
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
    "CompVis/stable-diffusion-safety-checker"
)

# CSS 스타일
css = """
footer {display: none !important}
.gradio-container {
    max-width: 1200px;
    margin: auto;
}
.contain {
    background: rgba(255, 255, 255, 0.05);
    border-radius: 12px;
    padding: 20px;
}
.generate-btn {
    background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
    border: none !important;
    color: white !important;
}
.generate-btn:hover {
    transform: translateY(-2px);
    box-shadow: 0 5px 15px rgba(0,0,0,0.2);
}
.title {
    text-align: center;
    font-size: 2.5em;
    font-weight: bold;
    margin-bottom: 1em;
    background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
    -webkit-background-clip: text;
    -webkit-text-fill-color: transparent;
}
.tabs {
    margin-top: 20px;
    border-radius: 10px;
    overflow: hidden;
}
.tab-nav {
    background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
    padding: 10px;
}
.tab-nav button {
    color: white;
    border: none;
    padding: 10px 20px;
    margin: 0 5px;
    border-radius: 5px;
    transition: all 0.3s ease;
}
.tab-nav button.selected {
    background: rgba(255, 255, 255, 0.2);
}
.image-upload-container {
    border: 2px dashed #4B79A1;
    border-radius: 10px;
    padding: 20px;
    text-align: center;
    transition: all 0.3s ease;
}
.image-upload-container:hover {
    border-color: #283E51;
    background: rgba(75, 121, 161, 0.1);
}
"""

# CSS에 추가할 스타일
additional_css = """
.primary-btn {
    background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
    font-size: 1.2em !important;
    padding: 12px 20px !important;
    margin-top: 20px !important;
}
hr {
    border: none;
    border-top: 1px solid rgba(75, 121, 161, 0.2);
    margin: 20px 0;
}
.input-section {
    background: rgba(255, 255, 255, 0.03);
    border-radius: 12px;
    padding: 20px;
    margin-bottom: 20px;
}
.output-section {
    background: rgba(255, 255, 255, 0.03);
    border-radius: 12px;
    padding: 20px;
}
"""

# 기존 CSS에 새로운 스타일 추가
css = css + additional_css

def save_image(image):
    """Save the generated image and return the path"""
    try:
        if not os.path.exists(gallery_path):
            try:
                os.makedirs(gallery_path, exist_ok=True)
            except Exception as e:
                print(f"Failed to create gallery directory: {str(e)}")
                return None
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        random_suffix = os.urandom(4).hex()
        filename = f"generated_{timestamp}_{random_suffix}.png"
        filepath = os.path.join(gallery_path, filename)
        
        try:
            if isinstance(image, Image.Image):
                image.save(filepath, "PNG", quality=100)
            else:
                image = Image.fromarray(image)
                image.save(filepath, "PNG", quality=100)
            
            if not os.path.exists(filepath):
                print(f"Warning: Failed to verify saved image at {filepath}")
                return None
                
            return filepath
        except Exception as e:
            print(f"Failed to save image: {str(e)}")
            return None
            
    except Exception as e:
        print(f"Error in save_image: {str(e)}")
        return None

def load_gallery():
    try:
        os.makedirs(gallery_path, exist_ok=True)
        
        image_files = []
        for f in os.listdir(gallery_path):
            if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                full_path = os.path.join(gallery_path, f)
                image_files.append((full_path, os.path.getmtime(full_path)))
        
        image_files.sort(key=lambda x: x[1], reverse=True)
        return [f[0] for f in image_files]
    except Exception as e:
        print(f"Error loading gallery: {str(e)}")
        return []

@spaces.GPU
def generate_caption(image, model_name='gokaygokay/Florence-2-Flux-Large'):
    image = Image.fromarray(image)
    task_prompt = "<DESCRIPTION>"
    prompt = task_prompt + "Describe this image in great detail."

    if image.mode != "RGB":
        image = image.convert("RGB")

    model = florence_models[model_name]
    processor = florence_processors[model_name]

    inputs = processor(text=prompt, images=image, return_tensors="pt")
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3,
        repetition_penalty=1.10,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
    return parsed_answer["<DESCRIPTION>"]

@spaces.GPU
def process_and_save_image(height, width, steps, scales, prompt, seed):
    is_safe, filtered_prompt = filter_prompt(prompt)
    if not is_safe:
        gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
        return None, load_gallery()
            
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
        try:
            generated_image = pipe(
                prompt=[filtered_prompt],
                generator=torch.Generator().manual_seed(int(seed)),
                num_inference_steps=int(steps),
                guidance_scale=float(scales),
                height=int(height),
                width=int(width),
                max_sequence_length=256
            ).images[0]
            
            saved_path = save_image(generated_image)
            if saved_path is None:
                print("Warning: Failed to save generated image")
            
            return generated_image, load_gallery()
        except Exception as e:
            print(f"Error in image generation: {str(e)}")
            return None, load_gallery()

def get_random_seed():
    return torch.randint(0, 1000000, (1,)).item()

def update_seed():
    return get_random_seed()

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
    gr.HTML('<div class="title">AI Image Generator & Caption</div>')
    gr.HTML('<div style="text-align: center; margin-bottom: 2em;">Upload an image for caption or create from text description</div>')
    
    with gr.Row():
        # 왼쪽 컬럼: 입력 섹션
        with gr.Column(scale=3):
            # 이미지 업로드 섹션
            input_image = gr.Image(
                label="Upload Image (Optional)",
                type="numpy",
                elem_classes=["image-upload-container"]
            )
            
            florence_model = gr.Dropdown(
                choices=list(florence_models.keys()),
                label="Caption Model",
                value='gokaygokay/Florence-2-Flux-Large',
                visible=True
            )
            
            caption_button = gr.Button(
                "🔍 Generate Caption from Image",
                elem_classes=["generate-btn"]
            )
            
            # 구분선
            gr.HTML('<hr style="margin: 20px 0;">')
            
            # 텍스트 프롬프트 섹션
            prompt = gr.Textbox(
                label="Image Description",
                placeholder="Enter text description or use generated caption above...",
                lines=3
            )
            
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    height = gr.Slider(
                        label="Height",
                        minimum=256,
                        maximum=1152,
                        step=64,
                        value=1024
                    )
                    width = gr.Slider(
                        label="Width",
                        minimum=256,
                        maximum=1152,
                        step=64,
                        value=1024
                    )
                
                with gr.Row():
                    steps = gr.Slider(
                        label="Inference Steps",
                        minimum=6,
                        maximum=25,
                        step=1,
                        value=8
                    )
                    scales = gr.Slider(
                        label="Guidance Scale",
                        minimum=0.0,
                        maximum=5.0,
                        step=0.1,
                        value=3.5
                    )
                
                seed = gr.Number(
                    label="Seed",
                    value=get_random_seed(),
                    precision=0
                )
                
                randomize_seed = gr.Button(
                    "🎲 Randomize Seed", 
                    elem_classes=["generate-btn"]
                )
            
            generate_btn = gr.Button(
                "✨ Generate Image",
                elem_classes=["generate-btn", "primary-btn"]
            )

        # 오른쪽 컬럼: 출력 섹션
        with gr.Column(scale=4):
            output = gr.Image(
                label="Generated Image",
                elem_classes=["output-image"]
            )
            
            gallery = gr.Gallery(
                label="Generated Images Gallery",
                show_label=True,
                columns=[4],
                rows=[2],
                height="auto",
                object_fit="cover",
                elem_classes=["gallery-container"]
            )
            
            gallery.value = load_gallery()

    # Event handlers
    caption_button.click(
        generate_caption,
        inputs=[input_image, florence_model],
        outputs=[prompt]
    )
    
    generate_btn.click(
        process_and_save_image,
        inputs=[height, width, steps, scales, prompt, seed],
        outputs=[output, gallery]
    )
    
    randomize_seed.click(
        update_seed,
        outputs=[seed]
    )
    
    generate_btn.click(
        update_seed,
        outputs=[seed]
    )

if __name__ == "__main__":
    demo.launch(allowed_paths=[PERSISTENT_DIR])