Spaces:
Runtime error
Runtime error
import os | |
import io | |
import cv2 | |
import tqdm | |
import torch | |
import imageio | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
class ReportGenerator(): | |
"Generate markdown document, summarizing the training" | |
def __init__(self, run_id, out_dir=None, log_dir=None): | |
self.run_id, self.out_dir, self.log_dir = run_id, out_dir, log_dir | |
if log_dir: | |
self.train_logs = pd.read_csv(os.path.join(log_dir, 'train_logs.csv')) | |
self.metric_logs = pd.read_csv(os.path.join(log_dir, 'metric_logs.csv')) | |
if out_dir: | |
self.dice = pd.read_csv(os.path.join(out_dir, 'MeanDice_raw.csv')) | |
self.hausdorf = pd.read_csv(os.path.join(out_dir, 'HausdorffDistance_raw.csv')) | |
self.surface = pd.read_csv(os.path.join(out_dir, 'SurfaceDistance_raw.csv')) | |
self.mean_metrics = pd.DataFrame( | |
{"mean_dice" : [round(np.mean(self.dice[col]),3) for col in self.dice if col.startswith('class')], | |
"mean_hausdorf" : [round(np.mean(self.hausdorf[col]),3) for col in self.hausdorf if col.startswith('class')], | |
"mean_surface" : [round(np.mean(self.surface[col]),3) for col in self.surface if col.startswith('class')] | |
}).transpose() | |
def generate_report(self, loss_plot=True, metric_plot=True, boxplots=True, animation=True): | |
fn = os.path.join(self.run_id, 'report', 'SegmentationReport.md') | |
os.makedirs(os.path.join(self.run_id, 'report'), exist_ok=True) | |
with open(fn, 'w+') as f: | |
f.write('# Segmentation Report\n\n') | |
if loss_plot: | |
fig = self.plot_loss(self.train_logs, self.metric_logs) | |
plt.savefig(os.path.join(self.run_id, 'report', 'loss_and_lr.png'), dpi = 150) | |
with open(fn, 'a') as f: | |
f.write('## Loss, LR-Schedule and Key Metric\n') | |
f.write('![Loss, LR-Schedule and Key Metric](loss_and_lr.png)\n\n') | |
if metric_plot: | |
fig = plt.figure("metrics", (18, 6)) | |
ax = plt.subplot(1, 3, 1) | |
plt.ylim([0,1]) | |
plt.title("Mean Dice") | |
plt.xlabel("epoch") | |
plt.plot(self.metric_logs.index, self.metric_logs.MeanDice) | |
ax = plt.subplot(1, 3, 2) | |
plt.title("Mean Hausdorff Distance") | |
plt.xlabel("epoch") | |
plt.plot(self.metric_logs.index, self.metric_logs.HausdorffDistance) | |
ax = plt.subplot(1, 3, 3) | |
plt.title("Mean Surface Distance") | |
plt.xlabel("epoch") | |
plt.plot(self.metric_logs.index, self.metric_logs.SurfaceDistance) | |
plt.savefig(os.path.join(self.run_id, 'report', 'metrics.png'), dpi = 150) | |
fig.clear() | |
plt.close() | |
with open(fn, 'a') as f: | |
f.write('## Metrics\n') | |
f.write('![metrics](metrics.png)\n\n') | |
if boxplots: | |
fig = plt.figure("boxplots", (18, 6)) | |
ax = plt.subplot(1, 3, 1) | |
plt.title("Dice") | |
plt.xlabel("class") | |
plt.boxplot(self.dice[[col for col in self.dice if col.startswith('class')]]) | |
ax = plt.subplot(1, 3, 2) | |
plt.title("Hausdorff Distance") | |
plt.xlabel("class") | |
plt.boxplot(self.hausdorf[[col for col in self.hausdorf if col.startswith('class')]]) | |
ax = plt.subplot(1, 3, 3) | |
plt.title("Surface Distance") | |
plt.xlabel("class") | |
plt.boxplot(self.surface[[col for col in self.surface if col.startswith('class')]]) | |
plt.savefig(os.path.join(self.run_id, 'report', 'boxplots.png'),dpi = 150) | |
fig.clear() | |
plt.close() | |
with open(fn, 'a') as f: | |
f.write(f"## Individual metrics\n\n") | |
f.write(f"{self.mean_metrics.to_markdown()}\n\n") | |
f.write(f"![boxplot](boxplots.png)\n\n") | |
if animation: | |
self.generate_gif() | |
with open(fn, 'a') as f: | |
f.write('## Visualization of progress\n') | |
f.write('![progress](progress.gif)\n\n') | |
def plot_loss(self, train_logs, metric_logs): | |
iteration = train_logs.iteration/sum(train_logs.epoch == 1) | |
fig = plt.figure("loss and lr", (12, 6)) | |
y_max = max(metric_logs.eval_loss) + 0.5 | |
if y_max > 3: y_max = 3 | |
ax = plt.subplot(1, 2, 1) | |
plt.ylim([0,y_max]) | |
plt.title("Epoch Average Loss") | |
plt.xlabel("epoch") | |
plt.plot(iteration, train_logs.loss) | |
plt.plot(metric_logs.index, metric_logs.eval_loss) | |
ax = plt.subplot(1, 2, 2) | |
ax.set_yscale('log') | |
plt.title("LR Schedule") | |
plt.xlabel("epoch") | |
plt.plot(iteration, train_logs.lr) | |
return fig | |
def get_arr_from_fig(self, fig, dpi=180): | |
buf = io.BytesIO() | |
fig.savefig(buf, format="png", dpi=dpi) | |
buf.seek(0) | |
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) | |
buf.close() | |
img = cv2.imdecode(img_arr, 1) | |
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
return img | |
def get_slices(self, im, slices): | |
ims = torch.unbind(im[:, :, slices], -1) # extract n slices | |
ims = [i.transpose(0,1).flip(0) for i in ims] # rotate slices 90 degrees | |
if len(slices) > 4 and len(slices) % 2 == 0: | |
n = len(slices) // 2 | |
ims1 = torch.cat(ims[0:n], 1) | |
ims2 = torch.cat(ims[n:], 1) | |
return torch.cat([ims1, ims2], 0) | |
else: | |
return torch.cat(ims, 1) # create tile | |
def plot_images(self, fns, slices, cmap='Greys_r', figsize=15, **kwargs): | |
ims = [torch.load(os.path.join(self.out_dir, 'preds', fn)).cpu().argmax(0) for fn in fns] | |
ims = [self.get_slices(im, slices) for im in ims] | |
ims = torch.cat(ims, 0) | |
plt.figure(figsize=(figsize,figsize)) | |
plt.imshow(ims, cmap=cmap, **kwargs) | |
plt.axis('off') | |
def load_segmentation_image(self, fn): | |
im = torch.load(fn).cpu().unsqueeze(0) | |
im = torch.nn.functional.interpolate(im, (224, 224, 112)) | |
im = im.argmax(1).squeeze() | |
im = self.get_slices(im, slices = (40, 48, 56, 74, 82, 90)) | |
im = im/im.max() * 255 | |
return im | |
def generate_gif(self): | |
with imageio.get_writer( | |
os.path.join(self.run_id,'report','progress.gif'), | |
mode='I', | |
fps = max(self.train_logs.epoch) // 10) as writer: # make gif 10 seconds | |
for epoch in tqdm.tqdm(list(self.train_logs.epoch.unique())): | |
seg_fn = os.path.join(self.out_dir, 'preds', f"pred_epoch_{epoch}.pt") | |
if os.path.exists(seg_fn): im = self.load_segmentation_image(seg_fn) | |
plt_train_logs = self.train_logs[self.train_logs.epoch <= epoch] | |
loss_plt = self.plot_loss(plt_train_logs, self.metric_logs[:epoch]) | |
loss_fig = self.get_arr_from_fig(loss_plt)[:,:,0] | |
new_shape = im.shape[1], int(loss_fig.shape[0] / loss_fig.shape[1] * im.shape[1]) | |
loss_fig = cv2.resize(loss_fig, (im.shape[1], im.shape[0])) | |
images = torch.cat([im, torch.tensor(loss_fig)], 0).numpy().astype(np.uint8) | |
writer.append_data(images) | |
loss_plt.clear() | |
plt.close() |