import torch def load_state_dict(model, sd, ignore_errors=[], log_name=None, ignore_start=None): missing, unexpected = model.load_state_dict(sd, strict=False) missing = [x for x in missing if x not in ignore_errors] unexpected = [x for x in unexpected if x not in ignore_errors] if isinstance(ignore_start, str): missing = [x for x in missing if not x.startswith(ignore_start)] unexpected = [x for x in unexpected if not x.startswith(ignore_start)] log_name = log_name or type(model).__name__ if len(missing) > 0: print(f'{log_name} Missing: {missing}') if len(unexpected) > 0: print(f'{log_name} Unexpected: {unexpected}') return def state_dict_has(sd, prefix): return any(x.startswith(prefix) for x in sd.keys()) def filter_state_dict_with_prefix(sd, prefix, new_prefix=''): new_sd = {} for k, v in list(sd.items()): if k.startswith(prefix): new_sd[new_prefix + k[len(prefix):]] = v del sd[k] return new_sd def try_filter_state_dict(sd, prefix_list, new_prefix=''): for prefix in prefix_list: if state_dict_has(sd, prefix): return filter_state_dict_with_prefix(sd, prefix, new_prefix) return {} def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", "{}token_embedding.weight": "{}embeddings.token_embedding.weight", "{}ln_final.weight": "{}final_layer_norm.weight", "{}ln_final.bias": "{}final_layer_norm.bias", } for k in keys_to_replace: x = k.format(prefix_from) if x in sd: sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) resblock_to_replace = { "ln_1": "layer_norm1", "ln_2": "layer_norm2", "mlp.c_fc": "mlp.fc1", "mlp.c_proj": "mlp.fc2", "attn.out_proj": "self_attn.out_proj", } for resblock in range(number): for x in resblock_to_replace: for y in ["weight", "bias"]: k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) if k in sd: sd[k_to] = sd.pop(k) for y in ["weight", "bias"]: k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) if k_from in sd: weights = sd.pop(k_from) shape_from = weights.shape[0] // 3 for x in range(3): p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: if x in state_dict: state_dict[keys_to_replace[x]] = state_dict.pop(x) return state_dict def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): if filter_keys: out = {} else: out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: w = state_dict.pop(x[0]) out[x[1]] = w return out