import os
import cv2
import time
import torch
import imageio
import numpy as np
from PIL import Image

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import DDIMScheduler, AutoencoderKL, DDIMInverseScheduler

from models.pipeline_flatten import FlattenPipeline
from models.util import sample_trajectories
from models.unet import UNet3DConditionModel


def init_pipeline(device):
    dtype = torch.float16
    sd_path = "stabilityai/stable-diffusion-2-1-base"
    UNET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints", "unet")
    unet = UNet3DConditionModel.from_pretrained_2d(UNET_PATH, dtype=torch.float16)
    # unet = UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet").to(dtype=torch.float16)

    vae          = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(dtype=torch.float16)
    tokenizer    = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer", dtype=dtype)
    text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
    scheduler    = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
    inverse      = DDIMInverseScheduler.from_pretrained(sd_path, subfolder="scheduler")

    pipe = FlattenPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, inverse_scheduler=inverse)
    pipe.enable_vae_slicing()
    pipe.to(device)
    return pipe


height       = 512
width        = 512
sample_steps = 50
inject_step  = 40

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
pipe = init_pipeline(device)


def inference(
    seed          : int = 66,
    prompt        : str = None,
    neg_prompt    : str = "",
    guidance_scale: float = 10.0,
    video_length  : int = 16,
    video_path    : str = None,
    output_dir    : str = None,
    frame_rate    : int = 1,
    fps           : int = 15,
    old_qk        : int = 0,
):
    generator = torch.Generator(device=device)
    generator.manual_seed(seed)
    # xformers should be used here to support ZeroGPU?
    pipe.enable_xformers_memory_efficient_attention()

    # read the source video
    video_reader = imageio.get_reader(video_path, "ffmpeg")
    video = []
    for frame in video_reader:
        if len(video) >= video_length:
            break
        video.append(cv2.resize(frame, (width, height)))  # .transpose(2, 0, 1))
    real_frames = [Image.fromarray(frame) for frame in video]

    # compute optical flows and sample trajectories
    trajectories = sample_trajectories(torch.tensor(np.array(video)).permute(0, 3, 1, 2), device)
    torch.cuda.empty_cache()

    for k in trajectories.keys():
        trajectories[k] = trajectories[k].to(device)
    sample = (
        pipe(
            prompt,
            video_length        = video_length,
            frames              = real_frames,
            num_inference_steps = sample_steps,
            generator           = generator,
            guidance_scale      = guidance_scale,
            negative_prompt     = neg_prompt,
            width               = width,
            height              = height,
            trajs               = trajectories,
            output_dir          = "tmp/",
            inject_step         = inject_step,
            old_qk              = old_qk,
        )
        .videos[0]
        .permute(1, 2, 3, 0)
        .cpu()
        .numpy()
        * 255
    ).astype(np.uint8)
    temp_video_name = f"/tmp/{prompt}_{neg_prompt}_{str(guidance_scale)}_{time.time()}.mp4".replace(" ", "-")
    video_writer = imageio.get_writer(temp_video_name, fps=fps)
    for frame in sample:
        video_writer.append_data(frame)
    print(f"Saving video to {temp_video_name}, sample shape: {sample.shape}")
    return temp_video_name


if __name__ == "__main__":
    video_path = "./data/puff.mp4"
    inference(
        video_path     = video_path,
        prompt         = "A Tiger, high quality",
        neg_prompt     = "a cat with big eyes, deformed",
        guidance_scale = 20,
        old_qk         = 0,
    )