pablovela5620 commited on
Commit
8c45713
1 Parent(s): c2a846f

Refactor predict_normal function and add DSINE demo

Browse files
Files changed (1) hide show
  1. main.py +14 -6
main.py CHANGED
@@ -20,7 +20,7 @@ model = utils.load_checkpoint("./checkpoints/dsine.pt", model)
20
  model.eval()
21
 
22
 
23
- def predict_normal(img_np: np.ndarray):
24
  # normalize
25
  normalize = transforms.Normalize(
26
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
@@ -28,7 +28,7 @@ def predict_normal(img_np: np.ndarray):
28
 
29
  with torch.no_grad():
30
  img = np.array(img_np).astype(np.float32) / 255.0
31
- img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to("cpu")
32
  _, _, orig_H, orig_W = img.shape
33
 
34
  # zero-pad the input image so that both the width and height are multiples of 32
@@ -39,7 +39,7 @@ def predict_normal(img_np: np.ndarray):
39
  # NOTE: if intrins is not given, we just assume that the principal point is at the center
40
  # and that the field-of-view is 60 degrees (feel free to modify this assumption)
41
  intrins = utils.get_intrins_from_fov(
42
- new_fov=60.0, H=orig_H, W=orig_W, device="cpu"
43
  ).unsqueeze(0)
44
 
45
  intrins[:, 0, 2] += l
@@ -48,7 +48,6 @@ def predict_normal(img_np: np.ndarray):
48
  pred_norm = model(img, intrins=intrins)[-1]
49
  pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
50
 
51
- # save to output folder
52
  # NOTE: by saving the prediction as uint8 png format, you lose a lot of precision
53
  # if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files
54
  pred_norm_np = (
@@ -60,6 +59,12 @@ def predict_normal(img_np: np.ndarray):
60
 
61
 
62
  with gr.Blocks() as demo:
 
 
 
 
 
 
63
  with gr.Group():
64
  with gr.Row():
65
  input_img = gr.Image(label="Input image", image_mode="RGB")
@@ -71,9 +76,12 @@ with gr.Blocks() as demo:
71
 
72
  with Modal(visible=True, allow_user_close=False) as modal:
73
  gr.Markdown(
74
- "To use this space, you must agree to the terms and conditions. found [here](https://github.com/baegwangbin/DSINE/blob/main/LICENSE)."
 
 
 
75
  )
76
- btn = gr.Button("I agree")
77
  btn.click(lambda: Modal(visible=False), None, modal)
78
 
79
  if __name__ == "__main__":
 
20
  model.eval()
21
 
22
 
23
+ def predict_normal(img_np: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
24
  # normalize
25
  normalize = transforms.Normalize(
26
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 
28
 
29
  with torch.no_grad():
30
  img = np.array(img_np).astype(np.float32) / 255.0
31
+ img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)
32
  _, _, orig_H, orig_W = img.shape
33
 
34
  # zero-pad the input image so that both the width and height are multiples of 32
 
39
  # NOTE: if intrins is not given, we just assume that the principal point is at the center
40
  # and that the field-of-view is 60 degrees (feel free to modify this assumption)
41
  intrins = utils.get_intrins_from_fov(
42
+ new_fov=60.0, H=orig_H, W=orig_W, device=device
43
  ).unsqueeze(0)
44
 
45
  intrins[:, 0, 2] += l
 
48
  pred_norm = model(img, intrins=intrins)[-1]
49
  pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
50
 
 
51
  # NOTE: by saving the prediction as uint8 png format, you lose a lot of precision
52
  # if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files
53
  pred_norm_np = (
 
59
 
60
 
61
  with gr.Blocks() as demo:
62
+ gr.Markdown(
63
+ """
64
+ # DSINE
65
+ Unofficial Gradio demo of [DSINE: Rethinking Inductive Biases for Surface Normal Estimation](https://github.com/baegwangbin/DSINE)
66
+ """
67
+ )
68
  with gr.Group():
69
  with gr.Row():
70
  input_img = gr.Image(label="Input image", image_mode="RGB")
 
76
 
77
  with Modal(visible=True, allow_user_close=False) as modal:
78
  gr.Markdown(
79
+ """
80
+ To use this space, you must agree to the terms and conditions.
81
+ Found [HERE](https://github.com/baegwangbin/DSINE/blob/main/LICENSE).
82
+ """,
83
  )
84
+ btn = gr.Button("I Agree to the Terms and Conditions")
85
  btn.click(lambda: Modal(visible=False), None, modal)
86
 
87
  if __name__ == "__main__":