import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
from huggingface_hub import from_pretrained_keras

def resize_image(img_in,input_height,input_width):
    return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)

def otsu_copy_binary(img):
    img_r=np.zeros((img.shape[0],img.shape[1],3))
    img1=img[:,:,0]

    retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)

    img_r[:,:,0]=threshold1
    img_r[:,:,1]=threshold1
    img_r[:,:,2]=threshold1

    return img_r

def visualize_model_output(prediction, img, model_name):
    if model_name == "SBB/eynollah-binarization":
        prediction = prediction * -1
        prediction = prediction + 1
        added_image = prediction * 255
    else:
        unique_classes = np.unique(prediction[:,:,0])
        rgb_colors = {'0' : [255, 255, 255],
                     '1' : [255, 0, 0],
                     '2' : [255, 125, 0],
                     '3' : [255, 0, 125],
                     '4' : [125, 125, 125],
                     '5' : [125, 125, 0],
                     '6' : [0, 125, 255],
                     '7' : [0, 125, 0],
                     '8' : [125, 125, 125],
                     '9' : [0, 125, 255],
                     '10' : [125, 0, 125],
                     '11' : [0, 255, 0],
                     '12' : [0, 0, 255],
                     '13' : [0, 255, 255],
                     '14' : [255, 125, 125],
                     '15' : [255, 0, 255]}
    
        output = np.zeros(prediction.shape)
    
        for unq_class in unique_classes:
            rgb_class_unique = rgb_colors[str(int(unq_class))]
            output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0]
            output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
            output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
    
    
    
        img = resize_image(img, output.shape[0], output.shape[1])
    
        output = output.astype(np.int32)
        img = img.astype(np.int32)
    
        
        
        added_image = cv2.addWeighted(img,0.5,output,0.1,0)
        
    return added_image

def return_num_columns(img):
    model_classifier = from_pretrained_keras("SBB/eynollah-column-classifier")
    img_1ch = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    img_1ch = img_1ch / 255.0
    img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
    img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
    img_in[0, :, :, 0] = img_1ch[:, :]
    img_in[0, :, :, 1] = img_1ch[:, :]
    img_in[0, :, :, 2] = img_1ch[:, :]
              
    label_p_pred = model_classifier.predict(img_in, verbose=0)
    num_col = np.argmax(label_p_pred[0]) + 1
    return num_col

def return_scaled_image(img, num_col, width_early, model_name):
    if model_name == "SBB/eynollah-enhancement" or "SBB/eynollah-main-regions-aug-rotation" or "SBB/eynollah-main-regions-aug-scaling" or "SBB/eynollah-main-regions-ensembled" or "SBB/eynollah-textline" or "SBB/eynollah-binarization":
        if num_col == 1 and width_early < 1100:
            img_w_new = 2000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000)
        elif num_col == 1 and width_early >= 2500:
            img_w_new = 2000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000)
        elif num_col == 1 and width_early >= 1100 and width_early < 2500:
            img_w_new = width_early
            img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
        elif num_col == 2 and width_early < 2000:
            img_w_new = 2400
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 2400)
        elif num_col == 2 and width_early >= 3500:
            img_w_new = 2400
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 2400)
        elif num_col == 2 and width_early >= 2000 and width_early < 3500:
            img_w_new = width_early
            img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
        elif num_col == 3 and width_early < 2000:
            img_w_new = 3000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 3000)
        elif num_col == 3 and width_early >= 4000:
            img_w_new = 3000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 3000)
        elif num_col == 3 and width_early >= 2000 and width_early < 4000:
            img_w_new = width_early
            img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
        elif num_col == 4 and width_early < 2500:
            img_w_new = 4000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 4000)
        elif num_col == 4 and width_early >= 5000:
            img_w_new = 4000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 4000)
        elif num_col == 4 and width_early >= 2500 and width_early < 5000:
            img_w_new = width_early
            img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
        elif num_col == 5 and width_early < 3700:
            img_w_new = 5000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 5000)
        elif num_col == 5 and width_early >= 7000:
            img_w_new = 5000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 5000)
        elif num_col == 5 and width_early >= 3700 and width_early < 7000:
            img_w_new = width_early
            img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
        elif num_col == 6 and width_early < 4500:
            img_w_new = 6500  # 5400
            img_h_new = int(img.shape[0] / float(img.shape[1]) * 6500)
        else:
            img_w_new = width_early
            img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early)
        img_new = resize_image(img, img_h_new, img_w_new)
    elif model_name=="SBB/eynollah-main-regions" or "SBB/eynollah-textline_light":
        if num_col == 1:
            img_w_new = 1000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
            
        elif num_col == 2:
            img_w_new = 1500
            img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
            
        elif num_col == 3:
            img_w_new = 2000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
            
        elif num_col == 4:
            img_w_new = 2500
            img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
        elif num_col == 5:
            img_w_new = 3000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
        else:
            img_w_new = 4000
            img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
        img_resized = resize_image(img,img_h_new, img_w_new )

        img_new = otsu_copy_binary(img_resized)
    return img_new
        
def do_prediction(model_name, img):
    img_org = np.copy(img)
    model = from_pretrained_keras(model_name)

    match model_name:
        # numerical output
        case "SBB/eynollah-column-classifier": 

            img_1ch = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

            img_1ch = img_1ch / 255.0
            img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST)
            img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3))
            img_in[0, :, :, 0] = img_1ch[:, :]
            img_in[0, :, :, 1] = img_1ch[:, :]
            img_in[0, :, :, 2] = img_1ch[:, :]
                      
            label_p_pred = model.predict(img_in, verbose=0)
            num_col = np.argmax(label_p_pred[0]) + 1
            return "Found {} columns".format(num_col), None

        case "SBB/eynollah-page-extraction":
            img_height_model = model.layers[len(model.layers) - 1].output_shape[1]
            img_width_model = model.layers[len(model.layers) - 1].output_shape[2]
    
            img_h_page = img.shape[0]
            img_w_page = img.shape[1]

            img = img / float(255.0)
            img = resize_image(img, img_height_model, img_width_model)

            label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]),
                                         verbose=0)

            seg = np.argmax(label_p_pred, axis=3)[0]
            seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
            prediction_true = resize_image(seg_color, img_h_page, img_w_page)
            prediction_true = prediction_true.astype(np.uint8)
            
            imgray = cv2.cvtColor(prediction_true, cv2.COLOR_BGR2GRAY)
            _, thresh = cv2.threshold(imgray, 0, 255, 0)
            #thresh = cv2.dilate(thresh, KERNEL, iterations=3)
            contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

            if len(contours)>0:
                cnt_size = np.array([cv2.contourArea(contours[j]) for j in range(len(contours))])
                cnt = contours[np.argmax(cnt_size)]
                x, y, w, h = cv2.boundingRect(cnt)
                if x <= 30:
                    w += x
                    x = 0
                if (img_org.shape[1] - (x + w)) <= 30:
                    w = w + (img_org.shape[1] - (x + w))
                if y <= 30:
                    h = h + y
                    y = 0
                if (img_org.shape[0] - (y + h)) <= 30:
                    h = h + (img_org.shape[0] - (y + h))

                box = [x, y, w, h]


                img_border = np.zeros((prediction_true.shape[0],prediction_true.shape[1]))
                img_border[y:y+h, x:x+w] = 1
                img_border = np.repeat(img_border[:, :, np.newaxis], 3, axis=2)
            else:
                img_border = np.zeros((prediction_true.shape[0],prediction_true.shape[1]))
                img_border[:, :] = 1
                img_border = np.repeat(img_border[:, :, np.newaxis], 3, axis=2)
                
            return "No numerical output", visualize_model_output(img_border,img_org, model_name)
                
                            
        # bitmap output
        case "SBB/eynollah-binarization" | "SBB/eynollah-textline" | "SBB/eynollah-textline_light" | "SBB/eynollah-enhancement" | "SBB/eynollah-tables" | "SBB/eynollah-main-regions" | "SBB/eynollah-main-regions-aug-rotation" | "SBB/eynollah-main-regions-aug-scaling" | "SBB/eynollah-main-regions-ensembled" | "SBB/eynollah-full-regions-1column" | "SBB/eynollah-full-regions-3pluscolumn": 
            
            img_height_model=model.layers[len(model.layers)-1].output_shape[1]
            img_width_model=model.layers[len(model.layers)-1].output_shape[2]
            n_classes=model.layers[len(model.layers)-1].output_shape[3]



            img_org = np.copy(img)
            img_height_h = img_org.shape[0]
            img_width_h = img_org.shape[1]
    
            num_col_classifier = return_num_columns(img)
            width_early = img.shape[1]
    
            
            img = return_scaled_image(img, num_col_classifier, width_early, model_name)


            

            if img.shape[0] < img_height_model:
                img = resize_image(img, img_height_model, img.shape[1])

            if img.shape[1] < img_width_model:
                img = resize_image(img, img.shape[0], img_width_model)


            

            marginal_of_patch_percent = 0.1
            margin = int(marginal_of_patch_percent * img_height_model)
            width_mid = img_width_model - 2 * margin
            height_mid = img_height_model - 2 * margin
            img = img / float(255.0)
            img = img.astype(np.float16)
            img_h = img.shape[0]
            img_w = img.shape[1]
            prediction_true = np.zeros((img_h, img_w, 3))
            mask_true = np.zeros((img_h, img_w))
            nxf = img_w / float(width_mid)
            nyf = img_h / float(height_mid)
            nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
            nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)

            for i in range(nxf):
                for j in range(nyf):
                    if i == 0:
                        index_x_d = i * width_mid
                        index_x_u = index_x_d + img_width_model
                    else:
                        index_x_d = i * width_mid
                        index_x_u = index_x_d + img_width_model
                    if j == 0:
                        index_y_d = j * height_mid
                        index_y_u = index_y_d + img_height_model
                    else:
                        index_y_d = j * height_mid
                        index_y_u = index_y_d + img_height_model
                    if index_x_u > img_w:
                        index_x_u = img_w
                        index_x_d = img_w - img_width_model
                    if index_y_u > img_h:
                        index_y_u = img_h
                        index_y_d = img_h - img_height_model

                    img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
                    label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
                                                 verbose=0)

                    if model_name == "SBB/eynollah-enhancement":
                        seg_color = label_p_pred[0, :, :, :]
                        seg_color = seg_color * 255
                    else:     
                        seg = np.argmax(label_p_pred, axis=3)[0]
                        seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)

                    if i == 0 and j == 0:
                        seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
                        #seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
                        #mask_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin] = seg
                        prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
                    elif i == nxf - 1 and j == nyf - 1:
                        seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
                        #seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0]
                        #mask_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0] = seg
                        prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg_color
                    elif i == 0 and j == nyf - 1:
                        seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
                        #seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin]
                        #mask_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin] = seg
                        prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg_color
                    elif i == nxf - 1 and j == 0:
                        seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
                        #seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0]
                        #mask_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0] = seg
                        prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
                    elif i == 0 and j != 0 and j != nyf - 1:
                        seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
                        #seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
                        #mask_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin] = seg
                        prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
                    elif i == nxf - 1 and j != 0 and j != nyf - 1:
                        seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
                        #seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0]
                        #mask_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0] = seg
                        prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
                    elif i != 0 and i != nxf - 1 and j == 0:
                        seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
                        #seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin]
                        #mask_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin] = seg
                        prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color
                    elif i != 0 and i != nxf - 1 and j == nyf - 1:
                        seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
                        #seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin]
                        #mask_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin] = seg
                        prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg_color
                    else:
                        seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
                        #seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin]
                        #mask_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin] = seg
                        prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color

            if model_name == "SBB/eynollah-enhancement":
                prediction_true = prediction_true.astype(int)
                return "No numerical output", prediction_true
            else:
                prediction_true = prediction_true.astype(np.uint8)
                return "No numerical output", visualize_model_output(prediction_true,img_org, model_name)

            
        
        # catch-all (we should not reach this)
        case _:
            return None, None

title = "Welcome to the Eynollah Demo page! 👁️"
description = """
 <div class="row" style="display: flex">
  <div class="column" style="flex: 50%; font-size: 17px">
        This Space demonstrates the functionality of various Eynollah models developed at <a rel="nofollow" href="https://huggingface.co/SBB">SBB</a>.
        <br><br>
        The Eynollah suite introduces an <u>end-to-end pipeline</u> to extract layout, text lines and reading order for historic documents, where the output can be used as an input for OCR engines.
        Please keep in mind that with this demo you can just use <u>one of the 13 sub-modules</u> of the whole Eynollah system <u>at a time</u>.
  </div>
  <div class="column" style="flex: 5%; font-size: 17px"></div>
  <div class="column" style="flex: 45%; font-size: 17px">
    <strong style="font-size: 19px">Resources for more information:</strong>
        <ul>
            <li>The GitHub Repo can be found <a rel="nofollow" href="https://github.com/qurator-spk/eynollah">here</a></li>
            <li>Associated Paper: <a rel="nofollow" href="https://doi.org/10.1145/3604951.3605513">Document Layout Analysis with Deep Learning and Heuristics</a></li>
            <li>The full Eynollah pipeline can be viewed <a rel="nofollow" href="https://huggingface.co/spaces/SBB/eynollah-demo/blob/main/eynollah-flow.png">here</a></li>
        </ul>
    </li>
  </div>
</div> 
"""
iface = gr.Interface(
            title=title,
            description=description,
            fn=do_prediction, 
            inputs=[
                gr.Dropdown([
                    "SBB/eynollah-binarization", 
                    "SBB/eynollah-enhancement",
                    "SBB/eynollah-page-extraction", 
                    "SBB/eynollah-column-classifier",
                    "SBB/eynollah-tables",
                    "SBB/eynollah-textline",
                    "SBB/eynollah-textline_light",
                    "SBB/eynollah-main-regions",
                    "SBB/eynollah-main-regions-aug-rotation",
                    "SBB/eynollah-main-regions-aug-scaling",
                    "SBB/eynollah-main-regions-ensembled",
                    "SBB/eynollah-full-regions-1column",
                    "SBB/eynollah-full-regions-3pluscolumn"
                ], label="Select one model of the Eynollah suite 👇", info=""),
                gr.Image()
            ], 
            outputs=[
              gr.Textbox(label="Output of model (numerical or bitmap) ⬇️"),
              gr.Image()
            ],
            #examples=[['example-1.jpg']]
        )
iface.launch()