osbm's picture
Upload 9 files
3953219
raw
history blame
7.48 kB
import os
import io
import cv2
import tqdm
import torch
import imageio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
class ReportGenerator():
"Generate markdown document, summarizing the training"
def __init__(self, run_id, out_dir=None, log_dir=None):
self.run_id, self.out_dir, self.log_dir = run_id, out_dir, log_dir
if log_dir:
self.train_logs = pd.read_csv(os.path.join(log_dir, 'train_logs.csv'))
self.metric_logs = pd.read_csv(os.path.join(log_dir, 'metric_logs.csv'))
if out_dir:
self.dice = pd.read_csv(os.path.join(out_dir, 'MeanDice_raw.csv'))
self.hausdorf = pd.read_csv(os.path.join(out_dir, 'HausdorffDistance_raw.csv'))
self.surface = pd.read_csv(os.path.join(out_dir, 'SurfaceDistance_raw.csv'))
self.mean_metrics = pd.DataFrame(
{"mean_dice" : [round(np.mean(self.dice[col]),3) for col in self.dice if col.startswith('class')],
"mean_hausdorf" : [round(np.mean(self.hausdorf[col]),3) for col in self.hausdorf if col.startswith('class')],
"mean_surface" : [round(np.mean(self.surface[col]),3) for col in self.surface if col.startswith('class')]
}).transpose()
def generate_report(self, loss_plot=True, metric_plot=True, boxplots=True, animation=True):
fn = os.path.join(self.run_id, 'report', 'SegmentationReport.md')
os.makedirs(os.path.join(self.run_id, 'report'), exist_ok=True)
with open(fn, 'w+') as f:
f.write('# Segmentation Report\n\n')
if loss_plot:
fig = self.plot_loss(self.train_logs, self.metric_logs)
plt.savefig(os.path.join(self.run_id, 'report', 'loss_and_lr.png'), dpi = 150)
with open(fn, 'a') as f:
f.write('## Loss, LR-Schedule and Key Metric\n')
f.write('![Loss, LR-Schedule and Key Metric](loss_and_lr.png)\n\n')
if metric_plot:
fig = plt.figure("metrics", (18, 6))
ax = plt.subplot(1, 3, 1)
plt.ylim([0,1])
plt.title("Mean Dice")
plt.xlabel("epoch")
plt.plot(self.metric_logs.index, self.metric_logs.MeanDice)
ax = plt.subplot(1, 3, 2)
plt.title("Mean Hausdorff Distance")
plt.xlabel("epoch")
plt.plot(self.metric_logs.index, self.metric_logs.HausdorffDistance)
ax = plt.subplot(1, 3, 3)
plt.title("Mean Surface Distance")
plt.xlabel("epoch")
plt.plot(self.metric_logs.index, self.metric_logs.SurfaceDistance)
plt.savefig(os.path.join(self.run_id, 'report', 'metrics.png'), dpi = 150)
fig.clear()
plt.close()
with open(fn, 'a') as f:
f.write('## Metrics\n')
f.write('![metrics](metrics.png)\n\n')
if boxplots:
fig = plt.figure("boxplots", (18, 6))
ax = plt.subplot(1, 3, 1)
plt.title("Dice")
plt.xlabel("class")
plt.boxplot(self.dice[[col for col in self.dice if col.startswith('class')]])
ax = plt.subplot(1, 3, 2)
plt.title("Hausdorff Distance")
plt.xlabel("class")
plt.boxplot(self.hausdorf[[col for col in self.hausdorf if col.startswith('class')]])
ax = plt.subplot(1, 3, 3)
plt.title("Surface Distance")
plt.xlabel("class")
plt.boxplot(self.surface[[col for col in self.surface if col.startswith('class')]])
plt.savefig(os.path.join(self.run_id, 'report', 'boxplots.png'),dpi = 150)
fig.clear()
plt.close()
with open(fn, 'a') as f:
f.write(f"## Individual metrics\n\n")
f.write(f"{self.mean_metrics.to_markdown()}\n\n")
f.write(f"![boxplot](boxplots.png)\n\n")
if animation:
self.generate_gif()
with open(fn, 'a') as f:
f.write('## Visualization of progress\n')
f.write('![progress](progress.gif)\n\n')
def plot_loss(self, train_logs, metric_logs):
iteration = train_logs.iteration/sum(train_logs.epoch == 1)
fig = plt.figure("loss and lr", (12, 6))
y_max = max(metric_logs.eval_loss) + 0.5
if y_max > 3: y_max = 3
ax = plt.subplot(1, 2, 1)
plt.ylim([0,y_max])
plt.title("Epoch Average Loss")
plt.xlabel("epoch")
plt.plot(iteration, train_logs.loss)
plt.plot(metric_logs.index, metric_logs.eval_loss)
ax = plt.subplot(1, 2, 2)
ax.set_yscale('log')
plt.title("LR Schedule")
plt.xlabel("epoch")
plt.plot(iteration, train_logs.lr)
return fig
def get_arr_from_fig(self, fig, dpi=180):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def get_slices(self, im, slices):
ims = torch.unbind(im[:, :, slices], -1) # extract n slices
ims = [i.transpose(0,1).flip(0) for i in ims] # rotate slices 90 degrees
if len(slices) > 4 and len(slices) % 2 == 0:
n = len(slices) // 2
ims1 = torch.cat(ims[0:n], 1)
ims2 = torch.cat(ims[n:], 1)
return torch.cat([ims1, ims2], 0)
else:
return torch.cat(ims, 1) # create tile
def plot_images(self, fns, slices, cmap='Greys_r', figsize=15, **kwargs):
ims = [torch.load(os.path.join(self.out_dir, 'preds', fn)).cpu().argmax(0) for fn in fns]
ims = [self.get_slices(im, slices) for im in ims]
ims = torch.cat(ims, 0)
plt.figure(figsize=(figsize,figsize))
plt.imshow(ims, cmap=cmap, **kwargs)
plt.axis('off')
def load_segmentation_image(self, fn):
im = torch.load(fn).cpu().unsqueeze(0)
im = torch.nn.functional.interpolate(im, (224, 224, 112))
im = im.argmax(1).squeeze()
im = self.get_slices(im, slices = (40, 48, 56, 74, 82, 90))
im = im/im.max() * 255
return im
def generate_gif(self):
with imageio.get_writer(
os.path.join(self.run_id,'report','progress.gif'),
mode='I',
fps = max(self.train_logs.epoch) // 10) as writer: # make gif 10 seconds
for epoch in tqdm.tqdm(list(self.train_logs.epoch.unique())):
seg_fn = os.path.join(self.out_dir, 'preds', f"pred_epoch_{epoch}.pt")
if os.path.exists(seg_fn): im = self.load_segmentation_image(seg_fn)
plt_train_logs = self.train_logs[self.train_logs.epoch <= epoch]
loss_plt = self.plot_loss(plt_train_logs, self.metric_logs[:epoch])
loss_fig = self.get_arr_from_fig(loss_plt)[:,:,0]
new_shape = im.shape[1], int(loss_fig.shape[0] / loss_fig.shape[1] * im.shape[1])
loss_fig = cv2.resize(loss_fig, (im.shape[1], im.shape[0]))
images = torch.cat([im, torch.tensor(loss_fig)], 0).numpy().astype(np.uint8)
writer.append_data(images)
loss_plt.clear()
plt.close()