import os
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
import rembg
import gradio as gr

# download checkpoints
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16.safetensors")

try:
    import diff_gaussian_rasterization
except ImportError:
    os.system("pip install ./diff-gaussian-rasterization")

import kiui
from kiui.op import recenter

from core.options import Options
from core.models import LGM
from mvdream.pipeline_mvdream import MVDreamPipeline

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

TMP_DIR = '/tmp'
os.makedirs(TMP_DIR, exist_ok=True)

# opt = tyro.cli(AllConfigs)
opt = Options(
    input_size=256,
    up_channels=(1024, 1024, 512, 256, 128), # one more decoder
    up_attention=(True, True, True, False, False),
    splat_size=128,
    output_size=512, # render & supervise Gaussians at a higher resolution.
    batch_size=8,
    num_views=8,
    gradient_accumulation_steps=1,
    mixed_precision='bf16',
    resume=ckpt_path,
)

# model
model = LGM(opt)

# resume pretrained checkpoint
if opt.resume is not None:
    if opt.resume.endswith('safetensors'):
        ckpt = load_file(opt.resume, device='cpu')
    else:
        ckpt = torch.load(opt.resume, map_location='cpu')
    model.load_state_dict(ckpt, strict=False)
    print(f'[INFO] Loaded checkpoint from {opt.resume}')
else:
    print(f'[WARN] model randomly initialized, are you sure?')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.half().to(device)
model.eval()

tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
proj_matrix[0, 0] = -1 / tan_half_fov
proj_matrix[1, 1] = -1 / tan_half_fov
proj_matrix[2, 2] = - (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
proj_matrix[2, 3] = 1

# load dreams
pipe_text = MVDreamPipeline.from_pretrained(
    'ashawkey/mvdream-sd2.1-diffusers', # remote weights
    torch_dtype=torch.float16,
    trust_remote_code=True,
    # local_files_only=True,
)
pipe_text = pipe_text.to(device)

pipe_image = MVDreamPipeline.from_pretrained(
    "ashawkey/imagedream-ipmv-diffusers", # remote weights
    torch_dtype=torch.float16,
    trust_remote_code=True,
    # local_files_only=True,
)
pipe_image = pipe_image.to(device)

# load rembg
bg_remover = rembg.new_session()

# process function
def run(input_image):
    prompt_neg = "ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate"

    # seed
    kiui.seed_everything(42)

    output_ply_path = os.path.join(TMP_DIR, 'output.ply')

    input_image = np.array(input_image) # uint8
    # bg removal
    carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
    mask = carved_image[..., -1] > 0
    image = recenter(carved_image, mask, border_ratio=0.2)
    image = image.astype(np.float32) / 255.0
    image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
    mv_image = pipe_image("", image, negative_prompt=prompt_neg, num_inference_steps=30, guidance_scale=5.0,  elevation=0)

    # generate gaussians
    input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
    input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
    input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
    input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

    rays_embeddings = model.prepare_default_rays(device, elevation=0)
    input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]

    with torch.no_grad():
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            # generate gaussians
            gaussians = model.forward_gaussians(input_image)
        
        # save gaussians
        model.gs.save_ply(gaussians, output_ply_path)
    
    return output_ply_path

# gradio UI

_TITLE = '''LGM Mini'''

_DESCRIPTION = '''
<div>
A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>.
</div>
'''

css = '''
#duplicate-button {
    margin: auto;
    color: white;
    background: #1565c0;
    border-radius: 100vh;
}
'''

block = gr.Blocks(title=_TITLE, css=css)
with block:
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown('# ' + _TITLE)
    gr.Markdown(_DESCRIPTION)
    
    with gr.Row(variant='panel'):
        with gr.Column(scale=1):
            # input image
            input_image = gr.Image(label="image", type='pil', height=320)
            # gen button
            button_gen = gr.Button("Generate")

        
        with gr.Column(scale=1):
            output_splat = gr.Model3D(label="3D Gaussians")

        button_gen.click(fn=run, inputs=[input_image], outputs=[output_splat])
    
    gr.Examples(
        examples=[
            "data_test/frog_sweater.jpg",
            "data_test/bird.jpg",
            "data_test/boy.jpg",
            "data_test/cat_statue.jpg",
            "data_test/dragontoy.jpg",
            "data_test/gso_rabbit.jpg",
        ],
        inputs=[input_image],
        outputs=[output_splat],
        fn=lambda x: run(input_image=x),
        cache_examples=True,
        label='Image-to-3D Examples'
    )
    
block.queue().launch(debug=True, share=True)