pablovela5620 commited on
Commit
169b74e
1 Parent(s): 8c45713

Delete test_samples function and its dependencies

Browse files
Files changed (1) hide show
  1. test.py +0 -86
test.py DELETED
@@ -1,86 +0,0 @@
1
- import os
2
- import sys
3
- import glob
4
- import argparse
5
- import numpy as np
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from torchvision import transforms
10
- from PIL import Image
11
- import utils.utils as utils
12
-
13
-
14
- def test_samples(args, model, intrins=None, device="cpu"):
15
- img_paths = glob.glob("./samples/img/*.png") + glob.glob("./samples/img/*.jpg")
16
- img_paths.sort()
17
-
18
- # normalize
19
- normalize = transforms.Normalize(
20
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
21
- )
22
-
23
- with torch.no_grad():
24
- for img_path in img_paths:
25
- print(img_path)
26
- ext = os.path.splitext(img_path)[1]
27
- img = Image.open(img_path).convert("RGB")
28
- img = np.array(img).astype(np.float32) / 255.0
29
- img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)
30
- _, _, orig_H, orig_W = img.shape
31
-
32
- # zero-pad the input image so that both the width and height are multiples of 32
33
- l, r, t, b = utils.pad_input(orig_H, orig_W)
34
- img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
35
- img = normalize(img)
36
-
37
- intrins_path = img_path.replace(ext, ".txt")
38
- if os.path.exists(intrins_path):
39
- # NOTE: camera intrinsics should be given as a txt file
40
- # it should contain the values of fx, fy, cx, cy
41
- intrins = utils.get_intrins_from_txt(
42
- intrins_path, device=device
43
- ).unsqueeze(0)
44
- else:
45
- # NOTE: if intrins is not given, we just assume that the principal point is at the center
46
- # and that the field-of-view is 60 degrees (feel free to modify this assumption)
47
- intrins = utils.get_intrins_from_fov(
48
- new_fov=60.0, H=orig_H, W=orig_W, device=device
49
- ).unsqueeze(0)
50
-
51
- intrins[:, 0, 2] += l
52
- intrins[:, 1, 2] += t
53
-
54
- pred_norm = model(img, intrins=intrins)[-1]
55
- pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
56
-
57
- # save to output folder
58
- # NOTE: by saving the prediction as uint8 png format, you lose a lot of precision
59
- # if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files
60
- pred_norm_np = (
61
- pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
62
- ) # (H, W, 3)
63
- pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)
64
- target_path = img_path.replace("/img/", "/output/").replace(ext, ".png")
65
- im = Image.fromarray(pred_norm_np)
66
- im.save(target_path)
67
-
68
-
69
- if __name__ == "__main__":
70
- parser = argparse.ArgumentParser()
71
- parser.add_argument("--ckpt", default="dsine", type=str, help="model checkpoint")
72
- parser.add_argument("--mode", default="samples", type=str, help="{samples}")
73
- args = parser.parse_args()
74
-
75
- # define model
76
- device = torch.device("cpu")
77
-
78
- from models.dsine import DSINE
79
-
80
- model = DSINE().to(device)
81
- model.pixel_coords = model.pixel_coords.to(device)
82
- model = utils.load_checkpoint("./checkpoints/%s.pt" % args.ckpt, model)
83
- model.eval()
84
-
85
- if args.mode == "samples":
86
- test_samples(args, model, intrins=None, device=device)