johnnv commited on
Commit
c3d2ad7
·
verified ·
1 Parent(s): fd245c0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -1,3 +1,58 @@
1
  ---
2
  license: apache-2.0
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: image-classification
4
  ---
5
+
6
+ Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo.
7
+
8
+ Original weights from https://github.com/google-research/vision_transformer: This weight is based on the
9
+ [Original ViT_B/32 pretrained on imagenet21k](https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz)
10
+
11
+ Weights converted to PyTorch for Kornia ViT implementation (by [@gau-nernst](https://github.com/gau-nernst) in [kornia/kornia#2786](https://github.com/kornia/kornia/pull/2786#discussion_r1482339811))
12
+ <details>
13
+
14
+ <summary>Convert jax checkpoint function</summary>
15
+
16
+ ```
17
+ def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]):
18
+
19
+ def get_weight(key: str) -> torch.Tensor:
20
+ return torch.from_numpy(np_state_dict[key])
21
+
22
+ state_dict = dict()
23
+ state_dict["patch_embedding.cls_token"] = get_weight("cls")
24
+ state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1) # conv »
25
+ state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias")
26
+ state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0)
27
+
28
+ # for i, block in enumerate(self.encoder.blocks):
29
+ for i in range(100):
30
+ prefix1 = f"encoder.blocks.{i}"
31
+ prefix2 = f"Transformer/encoderblock_{i}"
32
+
33
+ if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict:
34
+ break
35
+
36
+ state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale")
37
+ state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias")
38
+
39
+ mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1"
40
+ qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]]
41
+ qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]]
42
+ state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T
43
+ state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten()
44
+ state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1»
45
+ state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias")
46
+
47
+ state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale")
48
+ state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias")
49
+ state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T
50
+ state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias")
51
+ state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T
52
+ state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias")
53
+
54
+ state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale")
55
+ state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias")
56
+ return state_dict
57
+ ```
58
+ </details>