|
""" |
|
Approximate the bits/dimension for an image model. |
|
""" |
|
|
|
import argparse |
|
import os |
|
|
|
import numpy as np |
|
import torch.distributed as dist |
|
|
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.image_datasets import load_data |
|
from guided_diffusion.script_util import ( |
|
model_and_diffusion_defaults, |
|
create_model_and_diffusion, |
|
add_dict_to_argparser, |
|
args_to_dict, |
|
) |
|
|
|
|
|
def main(): |
|
args = create_argparser().parse_args() |
|
|
|
dist_util.setup_dist() |
|
logger.configure() |
|
|
|
logger.log("creating model and diffusion...") |
|
model, diffusion = create_model_and_diffusion( |
|
**args_to_dict(args, model_and_diffusion_defaults().keys()) |
|
) |
|
model.load_state_dict( |
|
dist_util.load_state_dict(args.model_path, map_location="cpu") |
|
) |
|
model.to(dist_util.dev()) |
|
model.eval() |
|
|
|
logger.log("creating data loader...") |
|
data = load_data( |
|
data_dir=args.data_dir, |
|
batch_size=args.batch_size, |
|
image_size=args.image_size, |
|
class_cond=args.class_cond, |
|
deterministic=True, |
|
) |
|
|
|
logger.log("evaluating...") |
|
run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) |
|
|
|
|
|
def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): |
|
all_bpd = [] |
|
all_metrics = {"vb": [], "mse": [], "xstart_mse": []} |
|
num_complete = 0 |
|
while num_complete < num_samples: |
|
batch, model_kwargs = next(data) |
|
batch = batch.to(dist_util.dev()) |
|
model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} |
|
minibatch_metrics = diffusion.calc_bpd_loop( |
|
model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs |
|
) |
|
|
|
for key, term_list in all_metrics.items(): |
|
terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() |
|
dist.all_reduce(terms) |
|
term_list.append(terms.detach().cpu().numpy()) |
|
|
|
total_bpd = minibatch_metrics["total_bpd"] |
|
total_bpd = total_bpd.mean() / dist.get_world_size() |
|
dist.all_reduce(total_bpd) |
|
all_bpd.append(total_bpd.item()) |
|
num_complete += dist.get_world_size() * batch.shape[0] |
|
|
|
logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") |
|
|
|
if dist.get_rank() == 0: |
|
for name, terms in all_metrics.items(): |
|
out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") |
|
logger.log(f"saving {name} terms to {out_path}") |
|
np.savez(out_path, np.mean(np.stack(terms), axis=0)) |
|
|
|
dist.barrier() |
|
logger.log("evaluation complete") |
|
|
|
|
|
def create_argparser(): |
|
defaults = dict( |
|
data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" |
|
) |
|
defaults.update(model_and_diffusion_defaults()) |
|
parser = argparse.ArgumentParser() |
|
add_dict_to_argparser(parser, defaults) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|