zejunyang commited on
Commit
0da4ece
1 Parent(s): 0c9dedf

update image saving tool

Browse files
Files changed (1) hide show
  1. src/utils/util.py +13 -0
src/utils/util.py CHANGED
@@ -82,6 +82,19 @@ def save_videos_from_pil(pil_images, path, fps=8):
82
  else:
83
  raise ValueError("Unsupported file type. Use .mp4 or .gif.")
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
87
  videos = rearrange(videos, "b c t h w -> t b c h w")
 
82
  else:
83
  raise ValueError("Unsupported file type. Use .mp4 or .gif.")
84
 
85
+ def save_pil_imgs(videos: torch.Tensor, path: str, rescale=False):
86
+ videos = rearrange(videos, "b c t h w -> t b c h w")
87
+ os.makedirs(path, exist_ok=True)
88
+
89
+ for idx, x in enumerate(videos):
90
+ x = torchvision.utils.make_grid(x, nrow=1) # (c h w)
91
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
92
+ if rescale:
93
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
94
+ x = (x * 255).numpy().astype(np.uint8)
95
+ x = Image.fromarray(x)
96
+ x.save(os.path.join(path, f"{idx:05d}.png"))
97
+
98
 
99
  def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
100
  videos = rearrange(videos, "b c t h w -> t b c h w")