new attempt
Browse files- config.json +3 -4
- convert_t5_checkpoint_to_flax.py +0 -144
- directly_from_t5x/config.json +27 -0
- directly_from_t5x/flax_model.msgpack +3 -0
- flax_model.msgpack +2 -2
config.json
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": ".",
|
3 |
"architectures": [
|
4 |
-
"
|
5 |
],
|
6 |
"d_ff": 2048,
|
7 |
"d_kv": 64,
|
@@ -20,6 +19,7 @@
|
|
20 |
"num_layers": 6,
|
21 |
"output_past": true,
|
22 |
"pad_token_id": 0,
|
|
|
23 |
"relative_attention_num_buckets": 32,
|
24 |
"task_specific_params": {
|
25 |
"summarization": {
|
@@ -50,8 +50,7 @@
|
|
50 |
"prefix": "translate English to Romanian: "
|
51 |
}
|
52 |
},
|
53 |
-
"
|
54 |
-
"transformers_version": "4.16.2",
|
55 |
"use_cache": true,
|
56 |
"vocab_size": 32128
|
57 |
}
|
|
|
1 |
{
|
|
|
2 |
"architectures": [
|
3 |
+
"T5Model"
|
4 |
],
|
5 |
"d_ff": 2048,
|
6 |
"d_kv": 64,
|
|
|
19 |
"num_layers": 6,
|
20 |
"output_past": true,
|
21 |
"pad_token_id": 0,
|
22 |
+
"relative_attention_max_distance": 128,
|
23 |
"relative_attention_num_buckets": 32,
|
24 |
"task_specific_params": {
|
25 |
"summarization": {
|
|
|
50 |
"prefix": "translate English to Romanian: "
|
51 |
}
|
52 |
},
|
53 |
+
"transformers_version": "4.18.0.dev0",
|
|
|
54 |
"use_cache": true,
|
55 |
"vocab_size": 32128
|
56 |
}
|
convert_t5_checkpoint_to_flax.py
DELETED
@@ -1,144 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
from t5x import checkpoints
|
4 |
-
from transformers import T5Config, FlaxT5Model
|
5 |
-
|
6 |
-
|
7 |
-
def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
|
8 |
-
config = T5Config.from_pretrained(config_name)
|
9 |
-
flax_model = FlaxT5Model(config=config)
|
10 |
-
t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
11 |
-
|
12 |
-
|
13 |
-
#breakpoint()
|
14 |
-
# Encoder
|
15 |
-
for layer_index in range(config.num_layers):
|
16 |
-
layer_name = f"layers_{str(layer_index)}"
|
17 |
-
|
18 |
-
# Self-Attention
|
19 |
-
t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
|
20 |
-
t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
|
21 |
-
t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
|
22 |
-
t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
|
23 |
-
|
24 |
-
## Layer Normalization
|
25 |
-
t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
|
26 |
-
|
27 |
-
# MLP
|
28 |
-
#t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
|
29 |
-
#t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
|
30 |
-
t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
|
31 |
-
t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
|
32 |
-
|
33 |
-
## Layer Normalization
|
34 |
-
t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
|
35 |
-
|
36 |
-
# Assigning
|
37 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
|
38 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
|
39 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
|
40 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
|
41 |
-
|
42 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
|
43 |
-
|
44 |
-
#flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
|
45 |
-
#flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
|
46 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
|
47 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
|
48 |
-
flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
|
49 |
-
|
50 |
-
t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
|
51 |
-
|
52 |
-
# Only for layer 0:
|
53 |
-
t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"]
|
54 |
-
x, y = t5x_encoder_rel_embedding.shape
|
55 |
-
|
56 |
-
# Assigning
|
57 |
-
flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding.reshape(y, x)
|
58 |
-
flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
|
59 |
-
|
60 |
-
# Decoder
|
61 |
-
for layer_index in range(config.num_layers):
|
62 |
-
layer_name = f"layers_{str(layer_index)}"
|
63 |
-
|
64 |
-
# Self-Attention
|
65 |
-
t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
|
66 |
-
t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
|
67 |
-
t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
|
68 |
-
t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
|
69 |
-
|
70 |
-
## Layer Normalization
|
71 |
-
t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
|
72 |
-
|
73 |
-
# Encoder-Decoder-Attention
|
74 |
-
t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
|
75 |
-
t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
|
76 |
-
t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
|
77 |
-
t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
|
78 |
-
|
79 |
-
## Layer Normalization
|
80 |
-
t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
|
81 |
-
|
82 |
-
# MLP
|
83 |
-
#t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
|
84 |
-
#t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
|
85 |
-
t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
|
86 |
-
t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
|
87 |
-
|
88 |
-
## Layer Normalization
|
89 |
-
tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
|
90 |
-
|
91 |
-
#Assigning
|
92 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
|
93 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
|
94 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
|
95 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
|
96 |
-
|
97 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
|
98 |
-
|
99 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
|
100 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
|
101 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
|
102 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
|
103 |
-
|
104 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
|
105 |
-
|
106 |
-
#flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
|
107 |
-
#flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
|
108 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
|
109 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
|
110 |
-
|
111 |
-
flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
|
112 |
-
|
113 |
-
# Decoder Normalization
|
114 |
-
tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
|
115 |
-
flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
|
116 |
-
|
117 |
-
# Only for layer 0:
|
118 |
-
t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"]
|
119 |
-
x, y = t5x_decoder_rel_embedding.shape
|
120 |
-
|
121 |
-
flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"] = t5x_decoder_rel_embedding.reshape(y, x)
|
122 |
-
|
123 |
-
# Token Embeddings
|
124 |
-
tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
|
125 |
-
flax_model.params["shared"]["embedding"] = tx5_token_embeddings
|
126 |
-
|
127 |
-
flax_model.save_pretrained(flax_dump_folder_path)
|
128 |
-
|
129 |
-
|
130 |
-
if __name__ == "__main__":
|
131 |
-
parser = argparse.ArgumentParser()
|
132 |
-
# Required parameters
|
133 |
-
parser.add_argument(
|
134 |
-
"--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
|
135 |
-
)
|
136 |
-
parser.add_argument(
|
137 |
-
"--config_name", default=None, type=str, required=True, help="Config name of T5 model."
|
138 |
-
)
|
139 |
-
parser.add_argument(
|
140 |
-
"--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
|
141 |
-
)
|
142 |
-
args = parser.parse_args()
|
143 |
-
convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
directly_from_t5x/config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"T5Model"
|
4 |
+
],
|
5 |
+
"d_ff": 1024,
|
6 |
+
"d_kv": 64,
|
7 |
+
"d_model": 512,
|
8 |
+
"decoder_start_token_id": 0,
|
9 |
+
"dropout_rate": 0.1,
|
10 |
+
"eos_token_id": 1,
|
11 |
+
"feed_forward_proj": "gated-gelu",
|
12 |
+
"initializer_factor": 1.0,
|
13 |
+
"is_encoder_decoder": true,
|
14 |
+
"layer_norm_epsilon": 1e-06,
|
15 |
+
"model_type": "t5",
|
16 |
+
"num_decoder_layers": 8,
|
17 |
+
"num_heads": 6,
|
18 |
+
"num_layers": 8,
|
19 |
+
"pad_token_id": 0,
|
20 |
+
"relative_attention_max_distance": 128,
|
21 |
+
"relative_attention_num_buckets": 32,
|
22 |
+
"tie_word_embeddings": false,
|
23 |
+
"tokenizer_class": "T5Tokenizer",
|
24 |
+
"transformers_version": "4.18.0.dev0",
|
25 |
+
"use_cache": true,
|
26 |
+
"vocab_size": 250112
|
27 |
+
}
|
directly_from_t5x/flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0598b1d8dba89fd539c9a1776cf0b4f7b3e45c4ac8ec2f498f99c7184c57baa0
|
3 |
+
size 688485886
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e3f34ed523c4ea968bd610811cf91e9f68553eceebb03c7cbe8eae03be023f9
|
3 |
+
size 242032202
|