Bank_Statement_Parser / predict.py
binery's picture
Upload 8 files
16aad69
import os
import tempfile
import random
import string
from ultralyticsplus import YOLO
import streamlit as st
import numpy as np
import pandas as pd
from process import (
filter_columns,
extract_text_of_col,
prepare_cols,
process_cols,
finalize_data,
)
from file_utils import (
get_img,
save_excel_file,
concat_csv,
convert_pdf_to_image,
filter_color,
plot,
delete_file,
)
def process_img(
img,
page_enumeration: int = 0,
filter=False,
foldername: str = "",
filename: str = "",
):
tables = PaddleOCR.table_model(img, conf=0.75)
tables = tables[0].boxes.xyxy.cpu().numpy()
results = []
for table in tables:
try:
# * crop the table as an image from the original image
sub_img = img[
int(table[1].item()): int(table[3].item()),
int(table[0].item()): int(table[2].item()),
]
columns_detect = PaddleOCR.column_model(sub_img, conf=0.75)
cols_data = columns_detect[0].boxes.data.cpu().numpy()
# * Sort columns according to the x coordinate
cols_data = np.array(
sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray
)
# * merge the duplicated columns
cols_data = filter_columns(cols_data)
st.image(plot(sub_img, cols_data), channels="RGB")
except:
st.warning("No Detection")
try:
columns = cols_data[:, 0:4]
sub_imgs = []
for column in columns:
# * Create list of cropped images for each column
sub_imgs.append(sub_img[:, int(column[0]): int(column[2])])
cols = []
thr = 0
for image in sub_imgs:
if filter:
# * keep only black color in the image
image = filter_color(image)
# * extract text of each column and get the length threshold
res, threshold = extract_text_of_col(image)
thr += threshold
# * arrange the rows of each column with respect to row length threshold
cols.append(prepare_cols(res, threshold * 0.6))
thr = thr / len(sub_imgs)
# * append each element in each column to its right place in the dataframe
data = process_cols(cols, thr * 0.6)
# * merge the related rows together
data: pd.DataFrame = finalize_data(data, page_enumeration)
results.append(data)
print("data : ",data)
print("results : ", results)
except:
st.warning("Text Extraction Failed")
continue
list(
map(
lambda x: save_excel_file(
*x,
foldername,
filename,
page_enumeration,
),
enumerate(results),
)
)
class PaddleOCR:
# Load Image Detection model
table_model = YOLO("table.pt")
column_model = YOLO("columns.pt")
def __call__(self, uploaded, filter=False):
foldername = tempfile.TemporaryDirectory(dir=os.getcwd())
filename = uploaded.name.split(".")[0]
if uploaded.name.split(".")[1].lower() == "pdf":
pdf_pages = convert_pdf_to_image(uploaded.read())
for page_enumeration, page in enumerate(pdf_pages, start=1):
process_img(
np.asarray(page),
page_enumeration,
filter=filter,
foldername=foldername.name,
filename=filename,
)
else:
img = get_img(uploaded)
process_img(
img,
filter=filter,
foldername=foldername.name,
filename=filename,
)
# * concatenate all csv files if many
extra = "".join(random.choices(string.ascii_uppercase, k=5))
filename = f"{filename}_{extra}.csv"
try:
concat_csv(foldername, filename)
except:
st.warning("No results found")
foldername.cleanup()
if os.path.exists(filename):
with open(f"{filename}", "rb") as fp:
st.download_button(
label="Download CSV file",
data=fp,
file_name=filename,
mime="text/csv",
)
delete_file(filename)
else:
st.warning("No results found")