zhang-ziang commited on
Commit
00fe360
1 Parent(s): 864becb
Files changed (2) hide show
  1. app.py +18 -63
  2. inference.py +49 -0
app.py CHANGED
@@ -1,17 +1,14 @@
1
  import gradio as gr
2
  from paths import *
3
- import numpy as np
4
  from vision_tower import DINOv2_MLP
5
  from transformers import AutoImageProcessor
6
  import torch
7
- import os
8
- from PIL import Image
9
-
10
- import torch.nn.functional as F
11
  from utils import *
12
 
13
  from huggingface_hub import hf_hub_download
14
- ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
15
  print(ckpt_path)
16
 
17
  save_path = './'
@@ -19,79 +16,37 @@ device = 'cpu'
19
  dino = DINOv2_MLP(
20
  dino_mode = 'large',
21
  in_dim = 1024,
22
- out_dim = 360+180+60+2,
23
  evaluate = True,
24
  mask_dino = False,
25
  frozen_back = False
26
- ).to(device)
27
 
28
  dino.eval()
29
  print('model create')
30
  dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
 
31
  print('weight loaded')
32
  val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
33
 
34
-
35
- def get_3angle(image):
36
-
37
- # image = Image.open(image_path).convert('RGB')
38
- image_inputs = val_preprocess(images = image)
39
- image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
40
- with torch.no_grad():
41
- dino_pred = dino(image_inputs)
42
-
43
- gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
44
- gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
45
- gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
46
- confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
47
- angles = torch.zeros(4)
48
- angles[0] = gaus_ax_pred
49
- angles[1] = gaus_pl_pred - 90
50
- angles[2] = gaus_ro_pred - 30
51
- angles[3] = confidence
52
- return angles
53
-
54
- def get_3angle_infer_aug(origin_img, rm_bkg_img):
55
-
56
- # image = Image.open(image_path).convert('RGB')
57
- image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
58
- image_inputs = val_preprocess(images = image)
59
- image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
60
- with torch.no_grad():
61
- dino_pred = dino(image_inputs)
62
-
63
- gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
64
- gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
65
- gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1).to(torch.float32)
66
-
67
- gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred)
68
- gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred)
69
- gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred)
70
-
71
- confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
72
- angles = torch.zeros(4)
73
- angles[0] = gaus_ax_pred
74
- angles[1] = gaus_pl_pred - 90
75
- angles[2] = gaus_ro_pred - 30
76
- angles[3] = confidence
77
- return angles
78
-
79
  def infer_func(img, do_rm_bkg, do_infer_aug):
80
  origin_img = Image.fromarray(img)
81
  if do_infer_aug:
82
  rm_bkg_img = background_preprocess(origin_img, True)
83
- angles = get_3angle_infer_aug(origin_img, rm_bkg_img)
84
  else:
85
  rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
86
- angles = get_3angle(rm_bkg_img)
87
 
88
  phi = np.radians(angles[0])
89
  theta = np.radians(angles[1])
90
  gamma = angles[2]
91
-
92
-
93
- render_axis = render_3D_axis(phi, theta, gamma)
94
- res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
 
 
95
 
96
  # axis_model = "axis.obj"
97
  return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
@@ -107,10 +62,10 @@ server = gr.Interface(
107
  outputs=[
108
  gr.Image(height=512, width=512, label="result image"),
109
  # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
110
- gr.Textbox(lines=1, label='Azimuth(0~360°)'),
111
- gr.Textbox(lines=1, label='Polar(-90~90°)'),
112
- gr.Textbox(lines=1, label='Rotation(-90~90°)'),
113
- gr.Textbox(lines=1, label='Confidence(0~1)')
114
  ]
115
  )
116
 
 
1
  import gradio as gr
2
  from paths import *
3
+
4
  from vision_tower import DINOv2_MLP
5
  from transformers import AutoImageProcessor
6
  import torch
7
+ from inference import *
 
 
 
8
  from utils import *
9
 
10
  from huggingface_hub import hf_hub_download
11
+ ckpt_path = hf_hub_download(repo_id="Viglong/Orient-Anything", filename="croplargeEX2/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
12
  print(ckpt_path)
13
 
14
  save_path = './'
 
16
  dino = DINOv2_MLP(
17
  dino_mode = 'large',
18
  in_dim = 1024,
19
+ out_dim = 360+180+180+2,
20
  evaluate = True,
21
  mask_dino = False,
22
  frozen_back = False
23
+ )
24
 
25
  dino.eval()
26
  print('model create')
27
  dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
28
+ dino = dino.to(device)
29
  print('weight loaded')
30
  val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def infer_func(img, do_rm_bkg, do_infer_aug):
33
  origin_img = Image.fromarray(img)
34
  if do_infer_aug:
35
  rm_bkg_img = background_preprocess(origin_img, True)
36
+ angles = get_3angle_infer_aug(origin_img, rm_bkg_img, dino, val_preprocess, device)
37
  else:
38
  rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
39
+ angles = get_3angle(rm_bkg_img, dino, val_preprocess, device)
40
 
41
  phi = np.radians(angles[0])
42
  theta = np.radians(angles[1])
43
  gamma = angles[2]
44
+ confidence = float(angles[3])
45
+ if confidence > 0.5:
46
+ render_axis = render_3D_axis(phi, theta, gamma)
47
+ res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
48
+ else:
49
+ res_img = img
50
 
51
  # axis_model = "axis.obj"
52
  return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
 
62
  outputs=[
63
  gr.Image(height=512, width=512, label="result image"),
64
  # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
65
+ gr.Textbox(lines=1, label='Azimuth(0~360°) represents the position of the viewer in the xy plane'),
66
+ gr.Textbox(lines=1, label='Polar(-90~90°) indicating the height at which the viewer is located'),
67
+ gr.Textbox(lines=1, label='Rotation(-90~90°) represents the angle of rotation of the viewer'),
68
+ gr.Textbox(lines=1, label='Confidence(0~1) indicating whether the object has a meaningful orientation')
69
  ]
70
  )
71
 
inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from utils import *
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ def get_3angle(image, dino, val_preprocess, device):
8
+
9
+ # image = Image.open(image_path).convert('RGB')
10
+ image_inputs = val_preprocess(images = image)
11
+ image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
12
+ with torch.no_grad():
13
+ dino_pred = dino(image_inputs)
14
+
15
+ gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
16
+ gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
17
+ gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+180], dim=-1)
18
+ confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
19
+ angles = torch.zeros(4)
20
+ angles[0] = gaus_ax_pred
21
+ angles[1] = gaus_pl_pred - 90
22
+ angles[2] = gaus_ro_pred - 90
23
+ angles[3] = confidence
24
+ return angles
25
+
26
+ def get_3angle_infer_aug(origin_img, rm_bkg_img, dino, val_preprocess, device):
27
+
28
+ # image = Image.open(image_path).convert('RGB')
29
+ image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3)
30
+ image_inputs = val_preprocess(images = image)
31
+ image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
32
+ with torch.no_grad():
33
+ dino_pred = dino(image_inputs)
34
+
35
+ gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
36
+ gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
37
+ gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+180], dim=-1).to(torch.float32)
38
+
39
+ gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred)
40
+ gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred)
41
+ gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred)
42
+
43
+ confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
44
+ angles = torch.zeros(4)
45
+ angles[0] = gaus_ax_pred
46
+ angles[1] = gaus_pl_pred - 90
47
+ angles[2] = gaus_ro_pred - 90
48
+ angles[3] = confidence
49
+ return angles