# Extract Leaf Patches From Plates

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import datetime as dt
import warnings
import random

from tqdm import tqdm

import cv2

import pandas as pd

from siuba import _ as s
from siuba import filter as sfilter
from siuba import mutate, select, if_else

import panel as pn

import torch

from pytorch_lightning.callbacks import (
 RichProgressBar,
 ModelCheckpoint,
 LearningRateMonitor,
)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger


import com_const as cc
import com_image as ci
import com_func as cf
import leaf_patch_extractor_model as lpem

## Setup

In [None]:
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
pd.set_option("display.max_colwidth", 500)
pd.set_option("display.max_columns", 500)
pd.set_option("display.width", 1000)
pd.set_option("display.max_rows", 16)

In [None]:
pn.extension(notifications=True, console_output="disable")

## Train Disc Detector

### Load Datasets

In [None]:
train, val, test = [
 cf.read_dataframe(cc.path_to_data.joinpath(f"ldd_{d}.csv"))
 for d in ["train", "val", "test"]
]

print(len(train), len(test), len(val))

### Test Augmentations

In [None]:
# aug_ = lpem.get_augmentations(image_size=10, kinds=["resize", "train"])

# test_aug_dataset = lpem.LeafDiskDetectorDataset(csv=train, transform=aug_)

# file_name = train.sample(n=1).plate_name.to_list()[0]

# print(aug_[0].width, aug_[0].height)

# lpem.make_patches_grid(
# images=[
# test_aug_dataset.draw_image_with_boxes(plate_name=file_name) for _ in range(12)
# ],
# row_count=3,
# col_count=4,
# figsize=(12, 6),
# )

### Train

In [None]:
# model = lpem.LeafDiskDetector(
# batch_size=15,
# learning_rate=7.0e-05,
# image_factor=10,
# max_epochs=1000,
# train_data=train,
# val_data=val,
# test_data=test,
# augmentations_kinds=["resize", "train", "to_tensor"],
# augmentations_params={"gamma": (60, 180)},
# num_workers=2,
# accumulate_grad_batches=5,
# scheduler="steplr",
# scheduler_params={"step_size": 10, "gamma": 0.80},
# )

# model.eval()
# len(model(torch.rand(2, 3, 128, 128)))

# model.hr_desc()

In [None]:
# trainer = Trainer(
# default_root_dir=str(cc.path_to_chk_detector),
# logger=TensorBoardLogger(
# save_dir=str(cc.path_to_chk_detector),
# version=model.model_name + "_" + dt.now().strftime("%Y%m%d_%H%M%S"),
# name="lightning_logs",
# ),
# accelerator="gpu",
# max_epochs=model.max_epochs,
# log_every_n_steps=5,
# callbacks=[
# RichProgressBar(),
# EarlyStopping(monitor="val_loss", mode="min", patience=10, min_delta=0.0005),
# ModelCheckpoint(
# save_top_k=1,
# monitor="val_loss",
# auto_insert_metric_name=True,
# filename=model.model_name
# + "-{val_loss:.3f}-{epoch}-{train_loss:.3f}-{step}",
# ),
# LearningRateMonitor(logging_interval="epoch"),
# ],
# accumulate_grad_batches=model.accumulate_grad_batches,
# )

# trainer.fit(model)

## Extract Patches

### Load Model

In [None]:
ld_model: lpem.LeafDiskDetector = lpem.LeafDiskDetector.load_from_checkpoint(
 cc.path_to_chk_detector.joinpath("leaf_disc_detector.ckpt")
)
ld_model.hr_desc()

### Predict All Bounding Boxes

In [None]:
bb_predictions_path = cc.path_to_data.joinpath("train_ld_bounding_boxes.csv")

bb_predictions = (
 cf.read_dataframe(bb_predictions_path)
 if bb_predictions_path.is_file() is True
 else pd.DataFrame()
)

bb_predictions

In [None]:
plates = list(cc.path_to_plates.rglob("*.JPG"))
len(plates)

In [None]:
errors = []
handled_plates = bb_predictions.file_name.unique()

for plate in tqdm(plates):
 if "file_name" in bb_predictions and plate.name in handled_plates:
 continue
 try:
 current_data = ld_model.index_plate(plate) >> mutate(
 disc_name=s.file_name.str.replace(" ", "").replace(".JPG", "")
 + "_"
 + s.row.astype(str)
 + "_"
 + s.col.astype(str)
 + ".png"
 )
 bb_predictions = pd.concat([bb_predictions, current_data])
 except:
 errors.append(plate)

print(errors)
cf.write_dataframe(
 bb_predictions.sort_values(["file_name", "col", "row"]).reset_index(drop=True)
 >> mutate(disc_name=s.disc_name.str.replace(".JPG", "")),
 bb_predictions_path,
)

bb_predictions = cf.read_dataframe(bb_predictions_path)
bb_predictions

In [None]:
selected_image = random.choice(plates)
bboxes = bb_predictions >> sfilter(s.file_name == selected_image.name)
pn.Column(
 pn.pane.Markdown(f"### {selected_image.name}"),
 pn.pane.DataFrame(bboxes),
 pn.pane.Image(
 ci.to_pil(
 lpem.print_boxes(
 image_name=selected_image,
 boxes=bboxes,
 draw_first_line=True,
 return_plot=False,
 ) #
 ),
 sizing_mode="scale_width",
 ),
)

### Extract Needed Patches

#### Model Training

In [None]:
df_model_training = pd.concat(
 [
 cf.read_dataframe(cc.path_to_data.joinpath(f"oiv_{d}.csv"))
 for d in ["train", "val", "test"]
 ]
).sort_values(["file_name"]).reset_index(drop=True)
df_model_training

In [None]:
err = {}

for file_name in tqdm(df_model_training.file_name):
 row = (bb_predictions >> sfilter(s.disc_name == file_name)).reset_index(drop=True)
 lpem.handle_bbox(
 row.iloc[0],
 add_process_image=True,
 paths=dict(
 segmented_leaf_disc=cc.path_to_leaf_discs,
 leaf_disc_patch=cc.path_to_leaf_patches,
 plates=cc.path_to_plates,
 ),
 errors=err,
 )
err

#### Genotype differenciation

In [None]:
df_gd = cf.read_dataframe(
 cc.path_to_data.joinpath("genotype_differenciation_dataset.csv")
)
df_gd

In [None]:
err = {}

for file_name in tqdm(df_gd.file_name):
 row = (bb_predictions >> sfilter(s.disc_name == file_name)).reset_index(drop=True)
 lpem.handle_bbox(
 row.iloc[0],
 add_process_image=True,
 paths=dict(
 segmented_leaf_disc=cc.path_to_leaf_discs,
 leaf_disc_patch=cc.path_to_leaf_patches,
 plates=cc.path_to_plates,
 ),
 errors=err,
 )
err