Spaces:
Runtime error
Runtime error
import torch | |
def load_state_dict(model, sd, ignore_errors=[], log_name=None, ignore_start=None): | |
missing, unexpected = model.load_state_dict(sd, strict=False) | |
missing = [x for x in missing if x not in ignore_errors] | |
unexpected = [x for x in unexpected if x not in ignore_errors] | |
if isinstance(ignore_start, str): | |
missing = [x for x in missing if not x.startswith(ignore_start)] | |
unexpected = [x for x in unexpected if not x.startswith(ignore_start)] | |
log_name = log_name or type(model).__name__ | |
if len(missing) > 0: | |
print(f'{log_name} Missing: {missing}') | |
if len(unexpected) > 0: | |
print(f'{log_name} Unexpected: {unexpected}') | |
return | |
def state_dict_has(sd, prefix): | |
return any(x.startswith(prefix) for x in sd.keys()) | |
def filter_state_dict_with_prefix(sd, prefix, new_prefix=''): | |
new_sd = {} | |
for k, v in list(sd.items()): | |
if k.startswith(prefix): | |
new_sd[new_prefix + k[len(prefix):]] = v | |
del sd[k] | |
return new_sd | |
def try_filter_state_dict(sd, prefix_list, new_prefix=''): | |
for prefix in prefix_list: | |
if state_dict_has(sd, prefix): | |
return filter_state_dict_with_prefix(sd, prefix, new_prefix) | |
return {} | |
def transformers_convert(sd, prefix_from, prefix_to, number): | |
keys_to_replace = { | |
"{}positional_embedding": "{}embeddings.position_embedding.weight", | |
"{}token_embedding.weight": "{}embeddings.token_embedding.weight", | |
"{}ln_final.weight": "{}final_layer_norm.weight", | |
"{}ln_final.bias": "{}final_layer_norm.bias", | |
} | |
for k in keys_to_replace: | |
x = k.format(prefix_from) | |
if x in sd: | |
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) | |
resblock_to_replace = { | |
"ln_1": "layer_norm1", | |
"ln_2": "layer_norm2", | |
"mlp.c_fc": "mlp.fc1", | |
"mlp.c_proj": "mlp.fc2", | |
"attn.out_proj": "self_attn.out_proj", | |
} | |
for resblock in range(number): | |
for x in resblock_to_replace: | |
for y in ["weight", "bias"]: | |
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) | |
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) | |
if k in sd: | |
sd[k_to] = sd.pop(k) | |
for y in ["weight", "bias"]: | |
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) | |
if k_from in sd: | |
weights = sd.pop(k_from) | |
shape_from = weights.shape[0] // 3 | |
for x in range(3): | |
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] | |
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) | |
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] | |
return sd | |
def state_dict_key_replace(state_dict, keys_to_replace): | |
for x in keys_to_replace: | |
if x in state_dict: | |
state_dict[keys_to_replace[x]] = state_dict.pop(x) | |
return state_dict | |
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): | |
if filter_keys: | |
out = {} | |
else: | |
out = state_dict | |
for rp in replace_prefix: | |
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) | |
for x in replace: | |
w = state_dict.pop(x[0]) | |
out[x[1]] = w | |
return out | |