Spaces:
Runtime error
Runtime error
zejunyang
commited on
Commit
•
0da4ece
1
Parent(s):
0c9dedf
update image saving tool
Browse files- src/utils/util.py +13 -0
src/utils/util.py
CHANGED
@@ -82,6 +82,19 @@ def save_videos_from_pil(pil_images, path, fps=8):
|
|
82 |
else:
|
83 |
raise ValueError("Unsupported file type. Use .mp4 or .gif.")
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
87 |
videos = rearrange(videos, "b c t h w -> t b c h w")
|
|
|
82 |
else:
|
83 |
raise ValueError("Unsupported file type. Use .mp4 or .gif.")
|
84 |
|
85 |
+
def save_pil_imgs(videos: torch.Tensor, path: str, rescale=False):
|
86 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
87 |
+
os.makedirs(path, exist_ok=True)
|
88 |
+
|
89 |
+
for idx, x in enumerate(videos):
|
90 |
+
x = torchvision.utils.make_grid(x, nrow=1) # (c h w)
|
91 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
|
92 |
+
if rescale:
|
93 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
94 |
+
x = (x * 255).numpy().astype(np.uint8)
|
95 |
+
x = Image.fromarray(x)
|
96 |
+
x.save(os.path.join(path, f"{idx:05d}.png"))
|
97 |
+
|
98 |
|
99 |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
100 |
videos = rearrange(videos, "b c t h w -> t b c h w")
|