import os
import torch
import numpy as np
import argparse
from peft import LoraConfig
from pipeline_dedit_sdxl import DEditSDXLPipeline
from pipeline_dedit_sd import DEditSDPipeline
from utils import load_image, load_mask, load_mask_edit
from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys

    
def run_main(
    name="example_tmp",
    name_2=None,
    mask_np_list=None, 
    mask_label_list=None,
    image_gt=None,
    dpm="sd",
    resolution=512,
    seed=42,
    embedding_learning_rate=1e-4,
    max_emb_train_steps=200,
    diffusion_model_learning_rate=5e-5,
    max_diffusion_train_steps=200,
    train_batch_size=1,
    gradient_accumulation_steps=1,
    num_tokens=1,

    load_trained=False ,
    num_sampling_steps=50,
    guidance_scale= 3 ,
    strength=0.8,

    train_full_lora=False ,
    lora_rank=4,
    lora_alpha=4,

    prompt_auxin_list = None,
    prompt_auxin_idx_list= None,

    load_edited_mask=False,
    load_edited_processed_mask=False,
    edge_thickness=20,
    num_imgs= 1 ,
    active_mask_list = None,
    tgt_index=None,

    recon=False ,
    recon_an_item=False,
    recon_prompt=None,

    text=False,
    tgt_prompt=None,

    image=False ,
    src_index=None,
    tgt_name=None,

    move_resize=False ,
    tgt_indices_list=None,
    delta_x_list=None,
    delta_y_list=None,
    priority_list=None,
    force_mask_remain=None,
    resize_list=None,

    remove=False,
    load_edited_removemask=False
):

    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)  
    base_input_folder = "."
    base_output_folder  = "."

    input_folder = os.path.join(base_input_folder, name)
    mask_list = []
    for mask_np in mask_np_list:
        mask = torch.from_numpy(mask_np.astype(np.uint8))
        mask_list.append(mask)
    
    #mask_list, mask_label_list = load_mask(input_folder)
    assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
    #try:
    #    image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution)
    #except:
    #    image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution)

    if image:
        input_folder_2 = os.path.join(base_input_folder, name_2)
        mask_list_2, mask_label_list_2 = load_mask(input_folder_2)
        assert mask_list_2[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution)
        try:
            image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(resolution) ), size = resolution)
        except:
            image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(resolution) ), size = resolution)
        output_dir = os.path.join(base_output_folder, name + "_" + name_2)
        os.makedirs(output_dir, exist_ok = True)
    else:
        output_dir = os.path.join(base_output_folder, name)
        os.makedirs(output_dir, exist_ok = True)

    if dpm == "sd":
        if image:
            pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens)
        else:
            pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens)
            
    elif dpm == "sdxl":
        if image:
            pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens)
        else:
            pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens)

    else:
        raise NotImplementedError

    set_string_list = pipe.set_string_list
    if prompt_auxin_list is not None:
        for auxin_idx, auxin_prompt in zip(prompt_auxin_idx_list, prompt_auxin_list):
            set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] )
    print(set_string_list)

    if image: 
        set_string_list_2 = pipe.set_string_list_2
        print(set_string_list_2)

    if load_trained:
        unet_save_path = os.path.join(output_dir, "unet.pt")
        unet_state_dict = torch.load(unet_save_path)
        text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
        text_encoder1_state_dict = torch.load(text_encoder1_save_path)
        if dpm == "sdxl":
            text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
            text_encoder2_state_dict = torch.load(text_encoder2_save_path)

        if 'lora' in ''.join(unet_state_dict.keys()):
            unet_lora_config = LoraConfig(
                    r=lora_rank,
                    lora_alpha=lora_alpha,
                    init_lora_weights="gaussian",
                    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
                )
            pipe.unet.add_adapter(unet_lora_config) 
        
        pipe.unet.load_state_dict(unet_state_dict)
        pipe.text_encoder.load_state_dict(text_encoder1_state_dict)
        if dpm == "sdxl":
            pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict)
    else:
        if image:
            pipe.mask_list = [m.cuda() for m in pipe.mask_list]
            pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2] 
            pipe.train_emb_2imgs(
                image_gt,
                image_gt_2, 
                set_string_list,
                set_string_list_2,
                gradient_accumulation_steps = gradient_accumulation_steps,
                embedding_learning_rate = embedding_learning_rate,
                max_emb_train_steps = max_emb_train_steps,
                train_batch_size = train_batch_size,
            )
            
            pipe.train_model_2imgs(
                image_gt,
                image_gt_2, 
                set_string_list,
                set_string_list_2,
                gradient_accumulation_steps = gradient_accumulation_steps,
                max_diffusion_train_steps = max_diffusion_train_steps,
                diffusion_model_learning_rate = diffusion_model_learning_rate ,
                train_batch_size =train_batch_size,
                train_full_lora = train_full_lora,
                lora_rank = lora_rank, 
                lora_alpha = lora_alpha
            )
            
        else:
            pipe.mask_list = [m.cuda() for m in pipe.mask_list] 
            pipe.train_emb(
                image_gt,
                set_string_list,
                gradient_accumulation_steps = gradient_accumulation_steps,
                embedding_learning_rate = embedding_learning_rate,
                max_emb_train_steps = max_emb_train_steps,
                train_batch_size = train_batch_size,
            )

            pipe.train_model(
                image_gt,
                set_string_list,
                gradient_accumulation_steps = gradient_accumulation_steps,
                max_diffusion_train_steps = max_diffusion_train_steps,
                diffusion_model_learning_rate = diffusion_model_learning_rate ,
                train_batch_size = train_batch_size,
                train_full_lora = train_full_lora,
                lora_rank = lora_rank, 
                lora_alpha = lora_alpha
            )

        
        unet_save_path = os.path.join(output_dir, "unet.pt")
        torch.save(pipe.unet.state_dict(),unet_save_path )
        text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt")
        torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path)
        if dpm == "sdxl":
            text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt")
            torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path )
        

    if recon:
        output_dir = os.path.join(output_dir, "recon")
        os.makedirs(output_dir, exist_ok = True)
        if recon_an_item:
            mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))]
            tgt_string = set_string_list[tgt_index]
            tgt_string = recon_prompt.replace("*", tgt_string)
            set_string_list = [tgt_string]
        print(set_string_list)
        save_path = os.path.join(output_dir, "out_recon.png")
        x_np = pipe.inference_with_mask(
            save_path,
            guidance_scale = guidance_scale,
            num_sampling_steps = num_sampling_steps,
            seed = seed,
            num_imgs = num_imgs,
            set_string_list = set_string_list,
            mask_list = mask_list
        )
    
    if text:
        print("*** Text-guided editing ")
        output_dir = os.path.join(output_dir, "text")
        os.makedirs(output_dir, exist_ok = True)
        save_path = os.path.join(output_dir, "out_text.png")
        set_string_list[tgt_index] = tgt_prompt
        mask_active = torch.zeros_like(mask_list[0])
        mask_active = mask_union_torch(mask_active, mask_list[tgt_index])
        
        if active_mask_list is not None:
            for midx in active_mask_list:
                mask_active = mask_union_torch(mask_active, mask_list[midx])

        if load_edited_mask:
            mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder)
            mask_diff = get_mask_difference_torch(mask_list_edited,  mask_list)
            mask_active = mask_union_torch(mask_active, mask_diff)
            mask_list = mask_list_edited
            save_path = os.path.join(output_dir, "out_textEdited.png")
        
        mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active)
        mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness)
        mask_hard = mask_substract_torch(mask_hard, mask_soft)

        pipe.inference_with_mask(
            save_path,
            orig_image = image_gt,
            set_string_list = set_string_list,
            guidance_scale = guidance_scale,
            strength = strength,
            num_imgs = num_imgs,
            mask_hard= mask_hard,
            mask_soft = mask_soft,
            mask_list = mask_list,
            seed = seed,
            num_sampling_steps = num_sampling_steps
        )

    if remove:
        output_dir = os.path.join(output_dir, "remove")
        save_path = os.path.join(output_dir, "out_remove.png")
        os.makedirs(output_dir, exist_ok = True)
        mask_active = torch.zeros_like(mask_list[0])
        
        if load_edited_mask:
            mask_list_edited, _ = load_mask_edit(input_folder)
            mask_diff = get_mask_difference_torch(mask_list_edited,  mask_list)
            mask_active = mask_union_torch(mask_active, mask_diff)
            mask_list = mask_list_edited
            
        if load_edited_processed_mask:
            # manually edit or draw masks after removing one index, then load
            mask_list_processed, _ = load_mask_edit(output_dir)
            mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
        else:
            # generate masks after removing one index, using nearest neighbor algorithm
            mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, tgt_index)
            save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
            visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png"))
        check_cover_all_torch(*mask_list_processed)
        mask_active = mask_union_torch(mask_active, mask_remain)
        
        if active_mask_list is not None:
            for midx in active_mask_list:
                mask_active = mask_union_torch(mask_active, mask_list[midx])

        mask_hard = 1 - mask_active
        mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = edge_thickness)
        mask_hard = mask_substract_torch(mask_hard, mask_soft)    

        pipe.inference_with_mask(
            save_path, 
            orig_image = image_gt,
            guidance_scale = guidance_scale,
            strength = strength,
            num_imgs = num_imgs,
            mask_hard= mask_hard,
            mask_soft = mask_soft,
            mask_list = mask_list_processed, 
            seed = seed,
            num_sampling_steps = num_sampling_steps
        )

    if image:
        output_dir = os.path.join(output_dir, "image")
        save_path = os.path.join(output_dir, "out_image.png")
        os.makedirs(output_dir, exist_ok = True)
        mask_active = torch.zeros_like(mask_list[0])
        
        if None not in (tgt_name, src_index, tgt_index):
            if tgt_name == name:
                set_string_list_tgt = set_string_list
                set_string_list_src = set_string_list_2
                image_tgt = image_gt
                if load_edited_mask:
                    mask_list_edited, _ = load_mask_edit(input_folder)
                    mask_diff = get_mask_difference_torch(mask_list_edited,  mask_list)
                    mask_active = mask_union_torch(mask_active, mask_diff)
                    mask_list = mask_list_edited
                    save_path = os.path.join(output_dir, "out_imageEdited.png")
                mask_list_tgt = mask_list
                
            elif tgt_name == name_2:
                set_string_list_tgt = set_string_list_2
                set_string_list_src = set_string_list
                image_tgt = image_gt_2
                if load_edited_mask:
                    mask_list_2_edited, _ = load_mask_edit(input_folder_2)
                    mask_diff = get_mask_difference_torch(mask_list_2_edited,  mask_list_2)
                    mask_active = mask_union_torch(mask_active, mask_diff)
                    mask_list_2 = mask_list_2_edited
                    save_path = os.path.join(output_dir, "out_imageEdited.png")
                mask_list_tgt = mask_list_2
            else:
                exit("tgt_name should be either name or name_2")
                
            set_string_list_tgt[tgt_index] = set_string_list_src[src_index]
            
            mask_active = mask_list_tgt[tgt_index]
            mask_frozen = (1-mask_active.float()).to(mask_active.device)
            mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = edge_thickness)
            mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu())
            
            mask_list_tgt = [m.cuda() for m in mask_list_tgt]

            pipe.inference_with_mask(
                save_path,
                set_string_list = set_string_list_tgt,
                mask_list = mask_list_tgt, 
                guidance_scale = guidance_scale,
                num_sampling_steps = num_sampling_steps,
                mask_hard = mask_hard.cuda(),
                mask_soft = mask_soft.cuda(), 
                num_imgs = num_imgs,
                orig_image = image_tgt,
                strength = strength,
            )

    if move_resize:
        output_dir = os.path.join(output_dir, "move_resize")
        os.makedirs(output_dir, exist_ok = True)
        save_path = os.path.join(output_dir, "out_moveresize.png")
        mask_active = torch.zeros_like(mask_list[0])
        
        if load_edited_mask:
            mask_list_edited, _ = load_mask_edit(input_folder)
            mask_diff = get_mask_difference_torch(mask_list_edited,  mask_list)
            mask_active = mask_union_torch(mask_active, mask_diff)
            mask_list = mask_list_edited
            # save_path = os.path.join(output_dir, "out_moveresizeEdited.png")
            
        if load_edited_processed_mask:
            mask_list_processed, _ = load_mask_edit(output_dir)
            mask_remain = get_mask_difference_torch(mask_list_processed, mask_list)
        else:
            mask_list_processed, mask_remain = process_mask_move_torch(
                mask_list,
                tgt_indices_list, 
                delta_x_list,
                delta_y_list, priority_list, 
                force_mask_remain = force_mask_remain,
                resize_list = resize_list
            )
            save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask")
            visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png"))
        active_idxs = tgt_indices_list
        
        mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs])
        mask_active = mask_union_torch(mask_remain, mask_active)
        if active_mask_list is not None:
            for midx in active_mask_list:
                mask_active = mask_union_torch(mask_active, mask_list_processed[midx])

        mask_frozen =(1 - mask_active.float())
        mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness)
        mask_hard = mask_substract_torch(mask_frozen, mask_soft)

        check_mask_overlap_torch(mask_hard, mask_soft)

        pipe.inference_with_mask(
            save_path,
            strength = strength,
            orig_image = image_gt, 
            guidance_scale = guidance_scale,
            num_sampling_steps =  num_sampling_steps,
            num_imgs = num_imgs,
            mask_hard= mask_hard,
            mask_soft = mask_soft,
            mask_list = mask_list_processed,
            seed = seed
        )