Ii commited on
Commit
6b68340
·
verified ·
1 Parent(s): f7a3e0f

Delete arcface_onnx.py.txt

Browse files
Files changed (1) hide show
  1. arcface_onnx.py.txt +0 -91
arcface_onnx.py.txt DELETED
@@ -1,91 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Organization : insightface.ai
3
- # @Author : Jia Guo
4
- # @Time : 2021-05-04
5
- # @Function :
6
-
7
- import numpy as np
8
- import cv2
9
- import onnx
10
- import onnxruntime
11
- import face_align
12
-
13
- __all__ = [
14
- 'ArcFaceONNX',
15
- ]
16
-
17
-
18
- class ArcFaceONNX:
19
- def __init__(self, model_file=None, session=None):
20
- assert model_file is not None
21
- self.model_file = model_file
22
- self.session = session
23
- self.taskname = 'recognition'
24
- find_sub = False
25
- find_mul = False
26
- model = onnx.load(self.model_file)
27
- graph = model.graph
28
- for nid, node in enumerate(graph.node[:8]):
29
- #print(nid, node.name)
30
- if node.name.startswith('Sub') or node.name.startswith('_minus'):
31
- find_sub = True
32
- if node.name.startswith('Mul') or node.name.startswith('_mul'):
33
- find_mul = True
34
- if find_sub and find_mul:
35
- #mxnet arcface model
36
- input_mean = 0.0
37
- input_std = 1.0
38
- else:
39
- input_mean = 127.5
40
- input_std = 127.5
41
- self.input_mean = input_mean
42
- self.input_std = input_std
43
- #print('input mean and std:', self.input_mean, self.input_std)
44
- if self.session is None:
45
- self.session = onnxruntime.InferenceSession(self.model_file, providers=['CoreMLExecutionProvider','CUDAExecutionProvider'])
46
- input_cfg = self.session.get_inputs()[0]
47
- input_shape = input_cfg.shape
48
- input_name = input_cfg.name
49
- self.input_size = tuple(input_shape[2:4][::-1])
50
- self.input_shape = input_shape
51
- outputs = self.session.get_outputs()
52
- output_names = []
53
- for out in outputs:
54
- output_names.append(out.name)
55
- self.input_name = input_name
56
- self.output_names = output_names
57
- assert len(self.output_names)==1
58
- self.output_shape = outputs[0].shape
59
-
60
- def prepare(self, ctx_id, **kwargs):
61
- if ctx_id<0:
62
- self.session.set_providers(['CPUExecutionProvider'])
63
-
64
- def get(self, img, kps):
65
- aimg = face_align.norm_crop(img, landmark=kps, image_size=self.input_size[0])
66
- embedding = self.get_feat(aimg).flatten()
67
- return embedding
68
-
69
- def compute_sim(self, feat1, feat2):
70
- from numpy.linalg import norm
71
- feat1 = feat1.ravel()
72
- feat2 = feat2.ravel()
73
- sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
74
- return sim
75
-
76
- def get_feat(self, imgs):
77
- if not isinstance(imgs, list):
78
- imgs = [imgs]
79
- input_size = self.input_size
80
-
81
- blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
82
- (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
83
- net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
84
- return net_out
85
-
86
- def forward(self, batch_data):
87
- blob = (batch_data - self.input_mean) / self.input_std
88
- net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
89
- return net_out
90
-
91
-