File size: 9,383 Bytes
3f9c425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import math
import re
from collections import OrderedDict

import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig


def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
    """
    Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
    """

    # Word embedding and position embedding
    def key_mapping_pos_emb(key):
        return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)

    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
    word_embeddings = state_dict.pop("transformer.wte.weight")
    # It's possible that vocab_size is padded to be a multiple of 8, for example.
    pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
    state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
    )
    state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]

    # LayerNorm
    def key_mapping_ln(key):
        key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(
            r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
            r"transformer.layers.\1.norm\2.\3",
            key,
        )
        return key

    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())

    def key_mapping_mlp(key):
        key = re.sub(
            r"^transformer.h.(\d+).mlp.c_fc.weight",
            r"transformer.layers.\1.mlp.fc1.weight",
            key,
        )
        key = re.sub(
            r"^transformer.h.(\d+).mlp.c_proj.weight",
            r"transformer.layers.\1.mlp.fc2.weight",
            key,
        )
        key = re.sub(
            r"^transformer.h.(\d+).mlp.c_fc.bias",
            r"transformer.layers.\1.mlp.fc1.bias",
            key,
        )
        key = re.sub(
            r"^transformer.h.(\d+).mlp.c_proj.bias",
            r"transformer.layers.\1.mlp.fc2.bias",
            key,
        )
        return key

    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())

    # TODO: add support for multi-head attention
    assert config.multi_query, "Only multi-query attention is supported"

    # Attention
    for d in range(config.num_hidden_layers):
        embed_dim = config.n_embd
        head_dim = embed_dim // config.n_head

        c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
        # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
        # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
        # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
        # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
        q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
        # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
        k = torch.tile(k, (config.n_head, 1))
        v = torch.tile(v, (config.n_head, 1))
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)

        # same deal with the bias
        c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
        # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
        q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
        # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
        k = torch.tile(k, (config.n_head,))
        v = torch.tile(v, (config.n_head,))
        state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)

    def key_mapping_attn(key):
        key = re.sub(
            r"^transformer.h.(\d+).attn.c_proj.weight",
            r"transformer.layers.\1.mixer.out_proj.weight",
            key,
        )
        key = re.sub(
            r"^transformer.h.(\d+).attn.c_proj.bias",
            r"transformer.layers.\1.mixer.out_proj.bias",
            key,
        )
        return key

    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict


def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
    """
    Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.

    This function is meant to be the inverse of remap_state_dict_hf_bigcode.
    """

    # Word embedding and position embeddings
    def inv_key_mapping_pos_emb(key):
        return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)

    state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
    word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")

    word_embeddings = word_embeddings[:, : config.vocab_size]
    state_dict["transformer.wte.weight"] = word_embeddings
    state_dict["lm_head.weight"] = word_embeddings

    # LayerNorm
    def inv_key_mapping_ln(key):
        key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
        key = re.sub(
            r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
            r"transformer.h.\1.ln_\2.\3",
            key,
        )
        return key

    state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())

    # MLPs
    def inv_key_mapping_mlp(key):
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.fc1.weight",
            r"transformer.h.\1.mlp.c_fc.weight",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.fc2.weight",
            r"transformer.h.\1.mlp.c_proj.weight",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.fc1.bias",
            r"transformer.h.\1.mlp.c_fc.bias",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mlp.fc2.bias",
            r"transformer.h.\1.mlp.c_proj.bias",
            key,
        )
        return key

    state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())

    # Attention
    for d in range(config.num_hidden_layers):
        embed_dim = config.n_embd
        head_dim = embed_dim // config.n_head

        Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
        q, k, v = torch.split(
            Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
        )
        c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
        state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight

        # Same deal with the bias
        Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
        q, k, v = torch.split(
            Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
        )
        c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
        state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias

    def inv_key_mapping_attn(key):
        key = re.sub(
            r"^transformer.layers.(\d+).mixer.out_proj.weight",
            r"transformer.h.\1.attn.c_proj.weight",
            key,
        )
        key = re.sub(
            r"^transformer.layers.(\d+).mixer.out_proj.bias",
            r"transformer.h.\1.attn.c_proj.bias",
            key,
        )
        return key

    state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())

    return state_dict


def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
    return GPT2Config(
        activation_function=bigcode_config.activation_function,
        attn_pdrop=bigcode_config.attn_pdrop,
        bos_token_id=bigcode_config.bos_token_id,
        embd_pdrop=bigcode_config.embd_pdrop,
        eos_token_id=bigcode_config.eos_token_id,
        initializer_range=bigcode_config.initializer_range,
        layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
        max_batch_size=bigcode_config.max_batch_size,
        max_sequence_length=bigcode_config.max_sequence_length,
        model_type=bigcode_config.model_type,
        multi_query=bigcode_config.multi_query,
        n_embd=bigcode_config.n_embd,
        n_head=bigcode_config.n_head,
        n_inner=bigcode_config.n_inner,
        n_layer=bigcode_config.n_layer,
        n_positions=bigcode_config.n_positions,
        resid_pdrop=bigcode_config.resid_pdrop,
        scale_attn_weights=bigcode_config.scale_attn_weights,
        summary_activation=bigcode_config.summary_activation,
        summary_first_dropout=bigcode_config.summary_first_dropout,
        summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
        summary_type=bigcode_config.summary_type,
        summary_use_proj=bigcode_config.summary_use_proj,
        use_cache=bigcode_config.use_cache,
        vocab_size=bigcode_config.vocab_size,
    )