import torch
from tqdm import tqdm
from typing import List, Optional, Tuple
from models import PipelineWrapper
import gradio as gr


def inversion_forward_process(model: PipelineWrapper,
                              x0: torch.Tensor,
                              etas: Optional[float] = None,
                              prompts: List[str] = [""],
                              cfg_scales: List[float] = [3.5],
                              num_inference_steps: int = 50,
                              numerical_fix: bool = False,
                              duration: Optional[float] = None,
                              first_order: bool = False,
                              save_compute: bool = True,
                              progress=gr.Progress()) -> Tuple:
    if len(prompts) > 1 or prompts[0] != "":
        text_embeddings_hidden_states, text_embeddings_class_labels, \
            text_embeddings_boolean_prompt_mask = model.encode_text(prompts)

        # In the forward negative prompts are not supported currently (TODO)
        uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
            [""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1]
            if text_embeddings_class_labels is not None else None)
    else:
        uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
            [""], negative=True, save_compute=False)

    timesteps = model.model.scheduler.timesteps.to(model.device)
    variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)

    if type(etas) in [int, float]:
        etas = [etas]*model.model.scheduler.num_inference_steps
    xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
    zs = torch.zeros(size=variance_noise_shape, device=model.device)
    extra_info = [None] * len(zs)

    if timesteps[0].dtype == torch.int64:
        t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
    elif timesteps[0].dtype == torch.float32:
        t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
    xt = x0
    op = tqdm(timesteps, desc="Inverting")
    model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration,
                             save_compute=save_compute and prompts[0] != "")
    app_op = progress.tqdm(timesteps, desc="Inverting")
    for t, _ in zip(op, app_op):
        idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1

        # 1. predict noise residual
        xt = xts[idx+1][None]
        xt_inp = model.model.scheduler.scale_model_input(xt, t)

        with torch.no_grad():
            if save_compute and prompts[0] != "":
                comb_out, _, _ = model.unet_forward(
                    xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
                    timestep=t,
                    encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
                                                     ], dim=0)
                    if uncond_embeddings_hidden_states is not None else None,
                    class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
                    if uncond_embeddings_class_lables is not None else None,
                    encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
                                                      ], dim=0)
                    if uncond_boolean_prompt_mask is not None else None,
                )
                out, cond_out = comb_out.sample.chunk(2, dim=0)
            else:
                out = model.unet_forward(xt_inp, timestep=t,
                                         encoder_hidden_states=uncond_embeddings_hidden_states,
                                         class_labels=uncond_embeddings_class_lables,
                                         encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
                if len(prompts) > 1 or prompts[0] != "":
                    cond_out = model.unet_forward(
                        xt_inp,
                        timestep=t,
                        encoder_hidden_states=text_embeddings_hidden_states,
                        class_labels=text_embeddings_class_labels,
                        encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample

        if len(prompts) > 1 or prompts[0] != "":
            # # classifier free guidance
            noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
        else:
            noise_pred = out

        # xtm1 =  xts[idx+1][None]
        xtm1 = xts[idx][None]
        z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t,
                                               eta=etas[idx], numerical_fix=numerical_fix,
                                               first_order=first_order)
        zs[idx] = z
        # print(f"Fix Xt-1 distance -  NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}")
        xts[idx] = xtm1
        extra_info[idx] = extra

    if zs is not None:
        # zs[-1] = torch.zeros_like(zs[-1])
        zs[0] = torch.zeros_like(zs[0])
        # zs_cycle[0] = torch.zeros_like(zs[0])

    del app_op.iterables[0]
    return xt, zs, xts, extra_info


def inversion_reverse_process(model: PipelineWrapper,
                              xT: torch.Tensor,
                              tstart: torch.Tensor,
                              etas: float = 0,
                              prompts: List[str] = [""],
                              neg_prompts: List[str] = [""],
                              cfg_scales: Optional[List[float]] = None,
                              zs: Optional[List[torch.Tensor]] = None,
                              duration: Optional[float] = None,
                              first_order: bool = False,
                              extra_info: Optional[List] = None,
                              save_compute: bool = True,
                              progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]:

    text_embeddings_hidden_states, text_embeddings_class_labels, \
        text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
    uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \
        uncond_boolean_prompt_mask = model.encode_text(neg_prompts,
                                                       negative=True,
                                                       save_compute=save_compute,
                                                       cond_length=text_embeddings_class_labels.shape[1]
                                                       if text_embeddings_class_labels is not None else None)

    xt = xT[tstart.max()].unsqueeze(0)

    if etas is None:
        etas = 0
    if type(etas) in [int, float]:
        etas = [etas]*model.model.scheduler.num_inference_steps
    assert len(etas) == model.model.scheduler.num_inference_steps
    timesteps = model.model.scheduler.timesteps.to(model.device)

    op = tqdm(timesteps[-zs.shape[0]:], desc="Editing")
    if timesteps[0].dtype == torch.int64:
        t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
    elif timesteps[0].dtype == torch.float32:
        t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
    model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]],
                             audio_end_in_s=duration, save_compute=save_compute)
    app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing")
    for it, (t, _) in enumerate(zip(op, app_op)):
        idx = model.model.scheduler.num_inference_steps - t_to_idx[
            int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \
                (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)

        xt_inp = model.model.scheduler.scale_model_input(xt, t)

        # # Unconditional embedding
        with torch.no_grad():
            # print(f'xt_inp.shape: {xt_inp.shape}')
            # print(f't.shape: {t.shape}')
            # print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}')
            # print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}')
            # print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}')
            # print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}')
            # print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}')
            # print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}')

            if save_compute:
                comb_out, _, _ = model.unet_forward(
                    xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
                    timestep=t,
                    encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
                                                     ], dim=0)
                    if uncond_embeddings_hidden_states is not None else None,
                    class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
                    if uncond_embeddings_class_lables is not None else None,
                    encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
                                                      ], dim=0)
                    if uncond_boolean_prompt_mask is not None else None,
                )
                uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
            else:
                uncond_out = model.unet_forward(
                    xt_inp, timestep=t,
                    encoder_hidden_states=uncond_embeddings_hidden_states,
                    class_labels=uncond_embeddings_class_lables,
                    encoder_attention_mask=uncond_boolean_prompt_mask,
                    )[0].sample

                # Conditional embedding
                cond_out = model.unet_forward(
                    xt_inp,
                    timestep=t,
                    encoder_hidden_states=text_embeddings_hidden_states,
                    class_labels=text_embeddings_class_labels,
                    encoder_attention_mask=text_embeddings_boolean_prompt_mask,
                    )[0].sample

        z = zs[idx] if zs is not None else None
        z = z.unsqueeze(0)
        # classifier free guidance
        noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)

        # 2. compute less noisy image and set x_t -> x_t-1
        xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z,
                                                  eta=etas[idx], first_order=first_order)

    del app_op.iterables[0]
    return xt, zs