Spaces:
Running
on
Zero
Running
on
Zero
import lightning.pytorch as pl | |
from lightning.pytorch import Callback | |
import os.path | |
import numpy | |
from PIL import Image | |
from typing import Sequence, Any, Dict | |
from concurrent.futures import ThreadPoolExecutor | |
from lightning.pytorch.utilities.types import STEP_OUTPUT | |
from lightning_utilities.core.rank_zero import rank_zero_info | |
def process_fn(image, path): | |
Image.fromarray(image).save(path) | |
class SaveImagesHook(Callback): | |
def __init__(self, save_dir="val", max_save_num=0, compressed=True): | |
self.save_dir = save_dir | |
self.max_save_num = max_save_num | |
self.compressed = compressed | |
def save_start(self, target_dir): | |
self.target_dir = target_dir | |
self.executor_pool = ThreadPoolExecutor(max_workers=8) | |
if not os.path.exists(self.target_dir): | |
os.makedirs(self.target_dir, exist_ok=True) | |
else: | |
if os.listdir(target_dir) and "debug" not in str(target_dir): | |
raise FileExistsError(f'{self.target_dir} already exists and not empty!') | |
self.samples = [] | |
self._have_saved_num = 0 | |
rank_zero_info(f"Save images to {self.target_dir}") | |
def save_image(self, images, filenames): | |
images = images.permute(0, 2, 3, 1).cpu().numpy() | |
for sample, filename in zip(images, filenames): | |
if isinstance(filename, Sequence): | |
filename = filename[0] | |
path = f'{self.target_dir}/{filename}' | |
if self._have_saved_num >= self.max_save_num: | |
break | |
self.executor_pool.submit(process_fn, sample, path) | |
self._have_saved_num += 1 | |
def process_batch( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
samples: STEP_OUTPUT, | |
batch: Any, | |
) -> None: | |
b, c, h, w = samples.shape | |
xT, y, metadata = batch | |
all_samples = pl_module.all_gather(samples).view(-1, c, h, w) | |
self.save_image(samples, metadata) | |
if trainer.is_global_zero: | |
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy() | |
self.samples.append(all_samples) | |
def save_end(self): | |
if self.compressed and len(self.samples) > 0: | |
samples = numpy.concatenate(self.samples) | |
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples) | |
self.executor_pool.shutdown(wait=True) | |
self.samples = [] | |
self.target_dir = None | |
self._have_saved_num = 0 | |
self.executor_pool = None | |
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") | |
self.save_start(target_dir) | |
def on_validation_batch_end( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
outputs: STEP_OUTPUT, | |
batch: Any, | |
batch_idx: int, | |
dataloader_idx: int = 0, | |
) -> None: | |
return self.process_batch(trainer, pl_module, outputs, batch) | |
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
self.save_end() | |
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict") | |
self.save_start(target_dir) | |
def on_predict_batch_end( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
samples: Any, | |
batch: Any, | |
batch_idx: int, | |
dataloader_idx: int = 0, | |
) -> None: | |
return self.process_batch(trainer, pl_module, samples, batch) | |
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
self.save_end() |