minchul commited on
Commit
602cee8
1 Parent(s): 4e19122

Upload directory

Browse files
Files changed (1) hide show
  1. models/vit_kprpe/__init__.py +65 -0
models/vit_kprpe/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..base import BaseModel
2
+ from .vit import VisionTransformerWithKPRPE
3
+ from torchvision import transforms
4
+
5
+
6
+ class ViTKPRPEModel(BaseModel):
7
+
8
+
9
+ """
10
+ Vision Transformer for face recognition model with KeyPoint Relative Position Encoding (KP-RPE).
11
+
12
+ ```
13
+ @article{kim2024keypoint,
14
+ title={KeyPoint Relative Position Encoding for Face Recognition},
15
+ author={Kim, Minchul and Su, Yiyang and Liu, Feng and Jain, Anil and Liu, Xiaoming},
16
+ journal={CVPR},
17
+ year={2024}
18
+ }
19
+ ```
20
+ """
21
+ def __init__(self, net, config):
22
+ super(ViTKPRPEModel, self).__init__(config)
23
+ self.net = net
24
+
25
+
26
+ @classmethod
27
+ def from_config(cls, config):
28
+
29
+ if config.name == 'small':
30
+ net = VisionTransformerWithKPRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12,
31
+ mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln",
32
+ mask_ratio=config.mask_ratio, rpe_config=config.rpe_config)
33
+ elif config.name == 'base':
34
+ net = VisionTransformerWithKPRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24,
35
+ mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln",
36
+ mask_ratio=config.mask_ratio, rpe_config=config.rpe_config)
37
+ else:
38
+ raise NotImplementedError
39
+
40
+ model = cls(net, config)
41
+ model.eval()
42
+ return model
43
+
44
+ def forward(self, x, *args, **kwargs):
45
+ if self.input_color_flip:
46
+ x = x.flip(1)
47
+ return self.net(x, *args, **kwargs)
48
+
49
+ def make_train_transform(self):
50
+ transform = transforms.Compose([
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
53
+ ])
54
+ return transform
55
+
56
+ def make_test_transform(self):
57
+ transform = transforms.Compose([
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
60
+ ])
61
+ return transform
62
+
63
+ def load_model(model_config):
64
+ model = ViTKPRPEModel.from_config(model_config)
65
+ return model