import os

import gradio as gr

import sys
sys.path.insert(0, 'U-2-Net')

from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

from modnet import ModNet
import huggingface_hub

# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn
def save_output(image_name,pred,d_dir):
    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+'/'+imidx+'.png')
    return d_dir+'/'+imidx+'.png'



modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
                                                      'modnet.onnx',
                                                      force_filename='modnet.onnx')
modnet = ModNet(modnet_path)

 # --------- 1. get image path and name ---------
model_name='u2net_portrait'#u2netp


image_dir = 'portrait_im'
prediction_dir = 'portrait_results'
if(not os.path.exists(prediction_dir)):
    os.mkdir(prediction_dir)

model_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'U-2-Net/saved_models/u2net_portrait/u2net_portrait.pth')


# --------- 3. model define ---------

print("...load U2NET---173.6 MB")
net = U2NET(3,1)

net.load_state_dict(torch.load(model_dir, map_location='cpu'))
# if torch.cuda.is_available():
#     net.cuda()
net.eval()


def process(im):
    image = modnet.segment(im.name)
    im_path = os.path.abspath(os.path.basename(im.name))
    Image.fromarray(np.uint8(image)).save(im_path)

    img_name_list = [im_path]
    print("Number of images: ", len(img_name_list))
    # --------- 2. dataloader ---------
    # 1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(512),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    results = []
    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:", img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        # if torch.cuda.is_available():
        #     inputs_test = Variable(inputs_test.cuda())
        # else:
        inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        # normalization
        pred = 1.0 - d1[:, 0, :, :]
        pred = normPRED(pred)

        # save results to test_results folder
        results.append(save_output(img_name_list[i_test], pred, prediction_dir))

        del d1, d2, d3, d4, d5, d6, d7

    print(results)

    return Image.open(results[0])
        
title = "U-2-Net"
description = "Gradio demo for U-2-Net, https://github.com/xuebinqin/U-2-Net"
article = ""

gr.Interface(
    process, 
    [gr.inputs.Image(type="file", label="Input")
], 
    [gr.outputs.Image(type="pil", label="Output")],
    title=title,
    description=description,
    article=article,
    examples=[],
    allow_flagging=False,
    allow_screenshot=False
    ).launch(enable_queue=True,cache_examples=True)