import os
import tempfile
import random
import string
from ultralyticsplus import YOLO
import streamlit as st
import numpy as np
import pandas as pd
from process import (
    filter_columns,
    extract_text_of_col,
    prepare_cols,
    process_cols,
    finalize_data,
)
from file_utils import (
    get_img,
    save_excel_file,
    concat_csv,
    convert_pdf_to_image,
    filter_color,
    plot,
    delete_file,
)


def process_img(
    img,
    page_enumeration: int = 0,
    filter=False,
    foldername: str = "",
    filename: str = "",
):
    tables = PaddleOCR.table_model(img, conf=0.75)
    tables = tables[0].boxes.xyxy.cpu().numpy()
    results = []
    for table in tables:
        try:
            # * crop the table as an image from the original image
            sub_img = img[
                int(table[1].item()): int(table[3].item()),
                int(table[0].item()): int(table[2].item()),
            ]
            columns_detect = PaddleOCR.column_model(sub_img, conf=0.75)
            cols_data = columns_detect[0].boxes.data.cpu().numpy()

            # * Sort columns according to the x coordinate
            cols_data = np.array(
                sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
            )

            # * merge the duplicated columns
            cols_data = filter_columns(cols_data)
            st.image(plot(sub_img, cols_data), channels="RGB")
        except:
            st.warning("No Detection")

        try:
            columns = cols_data[:, 0:4]
            sub_imgs = []
            for column in columns:
                # * Create list of cropped images for each column
                sub_imgs.append(sub_img[:, int(column[0]): int(column[2])])
            cols = []
            thr = 0
            for image in sub_imgs:
                if filter:
                    # * keep only black color in the image
                    image = filter_color(image)

                # * extract text of each column and get the length threshold
                res, threshold = extract_text_of_col(image)
                thr += threshold

                # * arrange the rows of each column with respect to row length threshold
                cols.append(prepare_cols(res, threshold * 0.6))

            thr = thr / len(sub_imgs)

            # * append each element in each column to its right place in the dataframe
            data = process_cols(cols, thr * 0.6)

            # * merge the related rows together
            data: pd.DataFrame = finalize_data(data, page_enumeration)
            results.append(data)
            print("data : ",data)
            print("results : ", results)
        except:
            st.warning("Text Extraction Failed")
            continue
    list(
        map(
            lambda x: save_excel_file(
                *x,
                foldername,
                filename,
                page_enumeration,
            ),
            enumerate(results),
        )
    )


class PaddleOCR:
    # Load Image Detection model
    table_model = YOLO("table.pt")
    column_model = YOLO("columns.pt")

    def __call__(self, uploaded, filter=False):
        foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
        filename = uploaded.name.split(".")[0]
        if uploaded.name.split(".")[1].lower() == "pdf":
            pdf_pages = convert_pdf_to_image(uploaded.read())
            for page_enumeration, page in enumerate(pdf_pages, start=1):
                process_img(
                    np.asarray(page),
                    page_enumeration,
                    filter=filter,
                    foldername=foldername.name,
                    filename=filename,
                )
        else:
            img = get_img(uploaded)
            process_img(
                img,
                filter=filter,
                foldername=foldername.name,
                filename=filename,
            )

        # * concatenate all csv files if many
        extra = "".join(random.choices(string.ascii_uppercase, k=5))
        filename = f"{filename}_{extra}.csv"
        try:
            concat_csv(foldername, filename)
        except:
            st.warning("No results found")

        foldername.cleanup()

        if os.path.exists(filename):
            with open(f"{filename}", "rb") as fp:
                st.download_button(
                    label="Download CSV file",
                    data=fp,
                    file_name=filename,
                    mime="text/csv",
                )
            delete_file(filename)
        else:
            st.warning("No results found")