DDT / src /callbacks /save_images.py
wangshuai6
init space
9e426da
import lightning.pytorch as pl
from lightning.pytorch import Callback
import os.path
import numpy
from PIL import Image
from typing import Sequence, Any, Dict
from concurrent.futures import ThreadPoolExecutor
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info
def process_fn(image, path):
Image.fromarray(image).save(path)
class SaveImagesHook(Callback):
def __init__(self, save_dir="val", max_save_num=0, compressed=True):
self.save_dir = save_dir
self.max_save_num = max_save_num
self.compressed = compressed
def save_start(self, target_dir):
self.target_dir = target_dir
self.executor_pool = ThreadPoolExecutor(max_workers=8)
if not os.path.exists(self.target_dir):
os.makedirs(self.target_dir, exist_ok=True)
else:
if os.listdir(target_dir) and "debug" not in str(target_dir):
raise FileExistsError(f'{self.target_dir} already exists and not empty!')
self.samples = []
self._have_saved_num = 0
rank_zero_info(f"Save images to {self.target_dir}")
def save_image(self, images, filenames):
images = images.permute(0, 2, 3, 1).cpu().numpy()
for sample, filename in zip(images, filenames):
if isinstance(filename, Sequence):
filename = filename[0]
path = f'{self.target_dir}/{filename}'
if self._have_saved_num >= self.max_save_num:
break
self.executor_pool.submit(process_fn, sample, path)
self._have_saved_num += 1
def process_batch(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: STEP_OUTPUT,
batch: Any,
) -> None:
b, c, h, w = samples.shape
xT, y, metadata = batch
all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
self.save_image(samples, metadata)
if trainer.is_global_zero:
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
self.samples.append(all_samples)
def save_end(self):
if self.compressed and len(self.samples) > 0:
samples = numpy.concatenate(self.samples)
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
self.executor_pool.shutdown(wait=True)
self.samples = []
self.target_dir = None
self._have_saved_num = 0
self.executor_pool = None
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
self.save_start(target_dir)
def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, outputs, batch)
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
self.save_start(target_dir)
def on_predict_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, samples, batch)
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()