zhang-ziang commited on
Commit
738bdfa
·
1 Parent(s): 0366edb

render axis

Browse files
Files changed (2) hide show
  1. app.py +5 -223
  2. utils.py +290 -0
app.py CHANGED
@@ -8,11 +8,9 @@ import os
8
  import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
11
- import random
12
- import rembg
13
- from typing import Any
14
- import torch.nn.functional as F
15
 
 
 
16
 
17
  from huggingface_hub import hf_hub_download
18
  ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
@@ -35,99 +33,6 @@ dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
35
  print('weight loaded')
36
  val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
37
 
38
- def background_preprocess(input_image, do_remove_background):
39
-
40
- rembg_session = rembg.new_session() if do_remove_background else None
41
-
42
- if do_remove_background:
43
- input_image = remove_background(input_image, rembg_session)
44
- input_image = resize_foreground(input_image, 0.85)
45
-
46
- return input_image
47
-
48
- def resize_foreground(
49
- image: Image,
50
- ratio: float,
51
- ) -> Image:
52
- image = np.array(image)
53
- assert image.shape[-1] == 4
54
- alpha = np.where(image[..., 3] > 0)
55
- y1, y2, x1, x2 = (
56
- alpha[0].min(),
57
- alpha[0].max(),
58
- alpha[1].min(),
59
- alpha[1].max(),
60
- )
61
- # crop the foreground
62
- fg = image[y1:y2, x1:x2]
63
- # pad to square
64
- size = max(fg.shape[0], fg.shape[1])
65
- ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
66
- ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
67
- new_image = np.pad(
68
- fg,
69
- ((ph0, ph1), (pw0, pw1), (0, 0)),
70
- mode="constant",
71
- constant_values=((0, 0), (0, 0), (0, 0)),
72
- )
73
-
74
- # compute padding according to the ratio
75
- new_size = int(new_image.shape[0] / ratio)
76
- # pad to size, double side
77
- ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
78
- ph1, pw1 = new_size - size - ph0, new_size - size - pw0
79
- new_image = np.pad(
80
- new_image,
81
- ((ph0, ph1), (pw0, pw1), (0, 0)),
82
- mode="constant",
83
- constant_values=((0, 0), (0, 0), (0, 0)),
84
- )
85
- new_image = Image.fromarray(new_image)
86
- return new_image
87
-
88
- def remove_background(image: Image,
89
- rembg_session: Any = None,
90
- force: bool = False,
91
- **rembg_kwargs,
92
- ) -> Image:
93
- do_remove = True
94
- if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
95
- do_remove = False
96
- do_remove = do_remove or force
97
- if do_remove:
98
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
99
- return image
100
-
101
- def random_crop(image, crop_scale=(0.8, 0.95)):
102
- """
103
- 随机裁切图片
104
- image (numpy.ndarray): (H, W, C)。
105
- crop_scale (tuple): (min_scale, max_scale)。
106
- """
107
- assert isinstance(image, Image.Image), "iput must be PIL.Image.Image"
108
- assert len(crop_scale) == 2 and 0 < crop_scale[0] <= crop_scale[1] <= 1
109
-
110
- width, height = image.size
111
-
112
- # 计算裁切的高度和宽度
113
- crop_width = random.randint(int(width * crop_scale[0]), int(width * crop_scale[1]))
114
- crop_height = random.randint(int(height * crop_scale[0]), int(height * crop_scale[1]))
115
-
116
- # 随机选择裁切的起始点
117
- left = random.randint(0, width - crop_width)
118
- top = random.randint(0, height - crop_height)
119
-
120
- # 裁切图片
121
- cropped_image = image.crop((left, top, left + crop_width, top + crop_height))
122
-
123
- return cropped_image
124
-
125
- def get_crop_images(img, num=3):
126
- cropped_images = []
127
- for i in range(num):
128
- cropped_images.append(random_crop(img))
129
- return cropped_images
130
-
131
 
132
  def get_3angle(image):
133
 
@@ -148,68 +53,6 @@ def get_3angle(image):
148
  angles[3] = confidence
149
  return angles
150
 
151
- def remove_outliers_and_average(tensor, threshold=1.5):
152
- assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
153
-
154
- q1 = torch.quantile(tensor, 0.25)
155
- q3 = torch.quantile(tensor, 0.75)
156
- iqr = q3 - q1
157
-
158
- lower_bound = q1 - threshold * iqr
159
- upper_bound = q3 + threshold * iqr
160
-
161
- non_outliers = tensor[(tensor >= lower_bound) & (tensor <= upper_bound)]
162
-
163
- if len(non_outliers) == 0:
164
- return tensor.mean().item()
165
-
166
- return non_outliers.mean().item()
167
-
168
-
169
- def remove_outliers_and_average_circular(tensor, threshold=1.5):
170
- assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
171
-
172
- # 将角度转换为二维平面上的点
173
- radians = tensor * torch.pi / 180.0
174
- x_coords = torch.cos(radians)
175
- y_coords = torch.sin(radians)
176
-
177
- # 计算平均向量
178
- mean_x = torch.mean(x_coords)
179
- mean_y = torch.mean(y_coords)
180
-
181
- differences = torch.sqrt((x_coords - mean_x) * (x_coords - mean_x) + (y_coords - mean_y) * (y_coords - mean_y))
182
-
183
- # 计算四分位数和 IQR
184
- q1 = torch.quantile(differences, 0.25)
185
- q3 = torch.quantile(differences, 0.75)
186
- iqr = q3 - q1
187
-
188
- # 计算上下限
189
- lower_bound = q1 - threshold * iqr
190
- upper_bound = q3 + threshold * iqr
191
-
192
- # 筛选非离群点
193
- non_outliers = tensor[(differences >= lower_bound) & (differences <= upper_bound)]
194
-
195
- if len(non_outliers) == 0:
196
- mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
197
- mean_angle = (mean_angle + 360) % 360
198
- return mean_angle # 如果没有非离群点,返回 None
199
-
200
- # 对非离群点再次计算平均向量
201
- radians = non_outliers * torch.pi / 180.0
202
- x_coords = torch.cos(radians)
203
- y_coords = torch.sin(radians)
204
-
205
- mean_x = torch.mean(x_coords)
206
- mean_y = torch.mean(y_coords)
207
-
208
- mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
209
- mean_angle = (mean_angle + 360) % 360
210
-
211
- return mean_angle
212
-
213
  def get_3angle_infer_aug(origin_img, rm_bkg_img):
214
 
215
  # image = Image.open(image_path).convert('RGB')
@@ -235,29 +78,6 @@ def get_3angle_infer_aug(origin_img, rm_bkg_img):
235
  angles[3] = confidence
236
  return angles
237
 
238
- def scale(x):
239
- # print(x)
240
- # if abs(x[0])<0.1 and abs(x[1])<0.1:
241
-
242
- # return x*5
243
- # else:
244
- # return x
245
- return x*3
246
-
247
- def get_proj2D_XYZ(phi, theta, gamma):
248
- x = np.array([-1*np.sin(phi)*np.cos(gamma) - np.cos(phi)*np.sin(theta)*np.sin(gamma), np.sin(phi)*np.sin(gamma) - np.cos(phi)*np.sin(theta)*np.cos(gamma)])
249
- y = np.array([-1*np.cos(phi)*np.cos(gamma) + np.sin(phi)*np.sin(theta)*np.sin(gamma), np.cos(phi)*np.sin(gamma) + np.sin(phi)*np.sin(theta)*np.cos(gamma)])
250
- z = np.array([np.cos(theta)*np.sin(gamma), np.cos(theta)*np.cos(gamma)])
251
- x = scale(x)
252
- y = scale(y)
253
- z = scale(z)
254
- return x, y, z
255
-
256
- # 绘制3D坐标轴
257
- def draw_axis(ax, origin, vector, color, label=None):
258
- ax.quiver(origin[0], origin[1], vector[0], vector[1], angles='xy', scale_units='xy', scale=1, color=color)
259
- if label!=None:
260
- ax.text(origin[0] + vector[0] * 1.1, origin[1] + vector[1] * 1.1, label, color=color, fontsize=12)
261
 
262
  def figure_to_img(fig):
263
  with io.BytesIO() as buf:
@@ -275,52 +95,14 @@ def infer_func(img, do_rm_bkg, do_infer_aug):
275
  rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
276
  angles = get_3angle(rm_bkg_img)
277
 
278
- fig, ax = plt.subplots(figsize=(8, 8))
279
-
280
- w, h = rm_bkg_img.size
281
- if h>w:
282
- extent = [-5*w/h, 5*w/h, -5, 5]
283
- else:
284
- extent = [-5, 5, -5*h/w, 5*h/w]
285
- ax.imshow(rm_bkg_img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
286
-
287
- origin = np.array([0, 0])
288
-
289
- # # 设置旋转角度
290
  phi = np.radians(angles[0])
291
  theta = np.radians(angles[1])
292
- gamma = np.radians(-1*angles[2])
293
-
294
- # 旋转后的向量
295
- rot_x, rot_y, rot_z = get_proj2D_XYZ(phi, theta, gamma)
296
-
297
- # draw arrow
298
- arrow_attr = [{'point':rot_x, 'color':'r', 'label':'front'},
299
- {'point':rot_y, 'color':'g', 'label':'right'},
300
- {'point':rot_z, 'color':'b', 'label':'top'}]
301
 
302
- if phi> 45 and phi<=225:
303
- order = [0,1,2]
304
- elif phi > 225 and phi < 315:
305
- order = [2,0,1]
306
- else:
307
- order = [2,1,0]
308
 
309
- for i in range(3):
310
- draw_axis(ax, origin, arrow_attr[order[i]]['point'], arrow_attr[order[i]]['color'], arrow_attr[order[i]]['label'])
311
- # draw_axis(ax, origin, rot_y, 'g', label='right')
312
- # draw_axis(ax, origin, rot_z, 'b', label='top')
313
- # draw_axis(ax, origin, rot_x, 'r', label='front')
314
-
315
- # 关闭坐标轴和网格
316
- ax.set_axis_off()
317
- ax.grid(False)
318
-
319
- # 设置坐标范围
320
- ax.set_xlim(-5, 5)
321
- ax.set_ylim(-5, 5)
322
 
323
- res_img = figure_to_img(fig)
324
  # axis_model = "axis.obj"
325
  return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
326
 
 
8
  import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
 
 
 
 
11
 
12
+ import torch.nn.functional as F
13
+ from utils import *
14
 
15
  from huggingface_hub import hf_hub_download
16
  ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
 
33
  print('weight loaded')
34
  val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def get_3angle(image):
38
 
 
53
  angles[3] = confidence
54
  return angles
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def get_3angle_infer_aug(origin_img, rm_bkg_img):
57
 
58
  # image = Image.open(image_path).convert('RGB')
 
78
  angles[3] = confidence
79
  return angles
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def figure_to_img(fig):
83
  with io.BytesIO() as buf:
 
95
  rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
96
  angles = get_3angle(rm_bkg_img)
97
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  phi = np.radians(angles[0])
99
  theta = np.radians(angles[1])
100
+ gamma = angles[2]
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
102
 
103
+ render_axis = render_3D_axis(phi, theta, gamma)
104
+ res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
 
 
 
 
 
 
 
 
 
 
 
105
 
 
106
  # axis_model = "axis.obj"
107
  return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
108
 
utils.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rembg
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import PIL
7
+ from typing import Any
8
+ import matplotlib.pyplot as plt
9
+
10
+ def resize_foreground(
11
+ image: Image,
12
+ ratio: float,
13
+ ) -> Image:
14
+ image = np.array(image)
15
+ assert image.shape[-1] == 4
16
+ alpha = np.where(image[..., 3] > 0)
17
+ y1, y2, x1, x2 = (
18
+ alpha[0].min(),
19
+ alpha[0].max(),
20
+ alpha[1].min(),
21
+ alpha[1].max(),
22
+ )
23
+ # crop the foreground
24
+ fg = image[y1:y2, x1:x2]
25
+ # pad to square
26
+ size = max(fg.shape[0], fg.shape[1])
27
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
28
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
29
+ new_image = np.pad(
30
+ fg,
31
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
32
+ mode="constant",
33
+ constant_values=((0, 0), (0, 0), (0, 0)),
34
+ )
35
+
36
+ # compute padding according to the ratio
37
+ new_size = int(new_image.shape[0] / ratio)
38
+ # pad to size, double side
39
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
40
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
41
+ new_image = np.pad(
42
+ new_image,
43
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
44
+ mode="constant",
45
+ constant_values=((0, 0), (0, 0), (0, 0)),
46
+ )
47
+ new_image = Image.fromarray(new_image)
48
+ return new_image
49
+
50
+ def remove_background(image: Image,
51
+ rembg_session: Any = None,
52
+ force: bool = False,
53
+ **rembg_kwargs,
54
+ ) -> Image:
55
+ do_remove = True
56
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
57
+ do_remove = False
58
+ do_remove = do_remove or force
59
+ if do_remove:
60
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
61
+ return image
62
+
63
+ def random_crop(image, crop_scale=(0.8, 0.95)):
64
+ """
65
+ 随机裁切图片
66
+ image (numpy.ndarray): (H, W, C)。
67
+ crop_scale (tuple): (min_scale, max_scale)。
68
+ """
69
+ assert isinstance(image, Image.Image), "iput must be PIL.Image.Image"
70
+ assert len(crop_scale) == 2 and 0 < crop_scale[0] <= crop_scale[1] <= 1
71
+
72
+ width, height = image.size
73
+
74
+ # 计算裁切的高度和宽度
75
+ crop_width = random.randint(int(width * crop_scale[0]), int(width * crop_scale[1]))
76
+ crop_height = random.randint(int(height * crop_scale[0]), int(height * crop_scale[1]))
77
+
78
+ # 随机选择裁切的起始点
79
+ left = random.randint(0, width - crop_width)
80
+ top = random.randint(0, height - crop_height)
81
+
82
+ # 裁切图片
83
+ cropped_image = image.crop((left, top, left + crop_width, top + crop_height))
84
+
85
+ return cropped_image
86
+
87
+ def get_crop_images(img, num=3):
88
+ cropped_images = []
89
+ for i in range(num):
90
+ cropped_images.append(random_crop(img))
91
+ return cropped_images
92
+
93
+ def background_preprocess(input_image, do_remove_background):
94
+
95
+ rembg_session = rembg.new_session() if do_remove_background else None
96
+
97
+ if do_remove_background:
98
+ input_image = remove_background(input_image, rembg_session)
99
+ input_image = resize_foreground(input_image, 0.85)
100
+
101
+ return input_image
102
+
103
+ def remove_outliers_and_average(tensor, threshold=1.5):
104
+ assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
105
+
106
+ q1 = torch.quantile(tensor, 0.25)
107
+ q3 = torch.quantile(tensor, 0.75)
108
+ iqr = q3 - q1
109
+
110
+ lower_bound = q1 - threshold * iqr
111
+ upper_bound = q3 + threshold * iqr
112
+
113
+ non_outliers = tensor[(tensor >= lower_bound) & (tensor <= upper_bound)]
114
+
115
+ if len(non_outliers) == 0:
116
+ return tensor.mean().item()
117
+
118
+ return non_outliers.mean().item()
119
+
120
+
121
+ def remove_outliers_and_average_circular(tensor, threshold=1.5):
122
+ assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
123
+
124
+ # 将角度转换为二维平面上的点
125
+ radians = tensor * torch.pi / 180.0
126
+ x_coords = torch.cos(radians)
127
+ y_coords = torch.sin(radians)
128
+
129
+ # 计算平均向量
130
+ mean_x = torch.mean(x_coords)
131
+ mean_y = torch.mean(y_coords)
132
+
133
+ differences = torch.sqrt((x_coords - mean_x) * (x_coords - mean_x) + (y_coords - mean_y) * (y_coords - mean_y))
134
+
135
+ # 计算四分位数和 IQR
136
+ q1 = torch.quantile(differences, 0.25)
137
+ q3 = torch.quantile(differences, 0.75)
138
+ iqr = q3 - q1
139
+
140
+ # 计算上下限
141
+ lower_bound = q1 - threshold * iqr
142
+ upper_bound = q3 + threshold * iqr
143
+
144
+ # 筛选非离群点
145
+ non_outliers = tensor[(differences >= lower_bound) & (differences <= upper_bound)]
146
+
147
+ if len(non_outliers) == 0:
148
+ mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
149
+ mean_angle = (mean_angle + 360) % 360
150
+ return mean_angle # 如果没有非离群点,返回 None
151
+
152
+ # 对非离群点再次计算平均向量
153
+ radians = non_outliers * torch.pi / 180.0
154
+ x_coords = torch.cos(radians)
155
+ y_coords = torch.sin(radians)
156
+
157
+ mean_x = torch.mean(x_coords)
158
+ mean_y = torch.mean(y_coords)
159
+
160
+ mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
161
+ mean_angle = (mean_angle + 360) % 360
162
+
163
+ return mean_angle
164
+
165
+ def scale(x):
166
+ # print(x)
167
+ # if abs(x[0])<0.1 and abs(x[1])<0.1:
168
+
169
+ # return x*5
170
+ # else:
171
+ # return x
172
+ return x*3
173
+
174
+ def get_proj2D_XYZ(phi, theta, gamma):
175
+ x = np.array([-1*np.sin(phi)*np.cos(gamma) - np.cos(phi)*np.sin(theta)*np.sin(gamma), np.sin(phi)*np.sin(gamma) - np.cos(phi)*np.sin(theta)*np.cos(gamma)])
176
+ y = np.array([-1*np.cos(phi)*np.cos(gamma) + np.sin(phi)*np.sin(theta)*np.sin(gamma), np.cos(phi)*np.sin(gamma) + np.sin(phi)*np.sin(theta)*np.cos(gamma)])
177
+ z = np.array([np.cos(theta)*np.sin(gamma), np.cos(theta)*np.cos(gamma)])
178
+ x = scale(x)
179
+ y = scale(y)
180
+ z = scale(z)
181
+ return x, y, z
182
+
183
+ # 绘制3D坐标轴
184
+ def draw_axis(ax, origin, vector, color, label=None):
185
+ ax.quiver(origin[0], origin[1], vector[0], vector[1], angles='xy', scale_units='xy', scale=1, color=color)
186
+ if label!=None:
187
+ ax.text(origin[0] + vector[0] * 1.1, origin[1] + vector[1] * 1.1, label, color=color, fontsize=12)
188
+
189
+ def matplotlib_2D_arrow(angles, rm_bkg_img):
190
+ fig, ax = plt.subplots(figsize=(8, 8))
191
+
192
+ # 设置旋转角度
193
+ phi = np.radians(angles[0])
194
+ theta = np.radians(angles[1])
195
+ gamma = np.radians(-1*angles[2])
196
+
197
+ w, h = rm_bkg_img.size
198
+ if h>w:
199
+ extent = [-5*w/h, 5*w/h, -5, 5]
200
+ else:
201
+ extent = [-5, 5, -5*h/w, 5*h/w]
202
+ ax.imshow(rm_bkg_img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
203
+
204
+ origin = np.array([0, 0])
205
+
206
+ # 旋转后的向量
207
+ rot_x, rot_y, rot_z = get_proj2D_XYZ(phi, theta, gamma)
208
+
209
+ # draw arrow
210
+ arrow_attr = [{'point':rot_x, 'color':'r', 'label':'front'},
211
+ {'point':rot_y, 'color':'g', 'label':'right'},
212
+ {'point':rot_z, 'color':'b', 'label':'top'}]
213
+
214
+ if phi> 45 and phi<=225:
215
+ order = [0,1,2]
216
+ elif phi > 225 and phi < 315:
217
+ order = [2,0,1]
218
+ else:
219
+ order = [2,1,0]
220
+
221
+ for i in range(3):
222
+ draw_axis(ax, origin, arrow_attr[order[i]]['point'], arrow_attr[order[i]]['color'], arrow_attr[order[i]]['label'])
223
+ # draw_axis(ax, origin, rot_y, 'g', label='right')
224
+ # draw_axis(ax, origin, rot_z, 'b', label='top')
225
+ # draw_axis(ax, origin, rot_x, 'r', label='front')
226
+
227
+ # 关闭坐标轴和网格
228
+ ax.set_axis_off()
229
+ ax.grid(False)
230
+
231
+ # 设置坐标范围
232
+ ax.set_xlim(-5, 5)
233
+ ax.set_ylim(-5, 5)
234
+
235
+ from render import render, Model
236
+ import math
237
+ def render_3D_axis(phi, theta, gamma):
238
+ radius = 240
239
+ # camera_location = [radius * math.cos(phi), radius * math.sin(phi), radius * math.tan(theta)]
240
+ # print(camera_location)
241
+ camera_location = [-1*radius * math.cos(phi), -1*radius * math.tan(theta), radius * math.sin(phi)]
242
+ img = render(
243
+ # Model("res/jinx.obj", texture_filename="res/jinx.tga"),
244
+ Model("./axis.obj", texture_filename="./axis.png"),
245
+ height=512,
246
+ width=512,
247
+ filename="tmp_render.png",
248
+ cam_loc = camera_location
249
+ )
250
+ img = img.rotate(gamma)
251
+ return img
252
+
253
+ def overlay_images_with_scaling(center_image: Image.Image, background_image, target_size=(512, 512)):
254
+ """
255
+ 调整前景图像大小为 512x512,将背景图像缩放以适配,并中心对齐叠加
256
+ :param center_image: 前景图像
257
+ :param background_image: 背景图像
258
+ :param target_size: 前景图像的目标大小,默认 (512, 512)
259
+ :return: 叠加后的图像
260
+ """
261
+ # 确保输入图像为 RGBA 模式
262
+ if center_image.mode != "RGBA":
263
+ center_image = center_image.convert("RGBA")
264
+ if background_image.mode != "RGBA":
265
+ background_image = background_image.convert("RGBA")
266
+
267
+ # 调整前景图像大小
268
+ center_image = center_image.resize(target_size)
269
+
270
+ # 缩放背景图像,确保其适合前景图像的尺寸
271
+ bg_width, bg_height = background_image.size
272
+ target_width, target_height = target_size
273
+
274
+ # 按宽度或高度等比例缩放背景
275
+ scale = max(target_width / bg_width, target_height / bg_height)
276
+ new_size = (int(bg_width * scale), int(bg_height * scale))
277
+ resized_background = background_image.resize(new_size)
278
+
279
+ # 裁剪背景图像至目标大小
280
+ left = (new_size[0] - target_width) // 2
281
+ top = (new_size[1] - target_height) // 2
282
+ right = left + target_width
283
+ bottom = top + target_height
284
+ cropped_background = resized_background.crop((left, top, right, bottom))
285
+
286
+ # 将前景图像叠加到背景图像上
287
+ result = cropped_background.copy()
288
+ result.paste(center_image, (0, 0), mask=center_image)
289
+
290
+ return result