AINxtGen's picture
Init
c22961b
to_basicsr_dict = {
'model.0.weight': 'conv_first.weight',
'model.0.bias': 'conv_first.bias',
'model.1.sub.23.weight': 'conv_body.weight',
'model.1.sub.23.bias': 'conv_body.bias',
'model.3.weight': 'conv_up1.weight',
'model.3.bias': 'conv_up1.bias',
'model.6.weight': 'conv_up2.weight',
'model.6.bias': 'conv_up2.bias',
'model.8.weight': 'conv_hr.weight',
'model.8.bias': 'conv_hr.bias',
'model.10.bias': 'conv_last.bias',
'model.10.weight': 'conv_last.weight',
# 'model.1.sub.0.RDB1.conv1.0.weight': 'body.0.rdb1.conv1.weight'
}
def convert_state_dict_to_basicsr(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k in to_basicsr_dict:
new_state_dict[to_basicsr_dict[k]] = v
elif k.startswith('model.1.sub.'):
bsr_name = k.replace('model.1.sub.', 'body.').lower()
bsr_name = bsr_name.replace('.0.weight', '.weight')
bsr_name = bsr_name.replace('.0.bias', '.bias')
new_state_dict[bsr_name] = v
else:
new_state_dict[k] = v
return new_state_dict
# just matching a commonly used format
def convert_basicsr_state_dict_to_save_format(state_dict):
new_state_dict = {}
to_basicsr_dict_values = list(to_basicsr_dict.values())
for k, v in state_dict.items():
if k in to_basicsr_dict_values:
for key, value in to_basicsr_dict.items():
if value == k:
new_state_dict[key] = v
elif k.startswith('body.'):
bsr_name = k.replace('body.', 'model.1.sub.').lower()
bsr_name = bsr_name.replace('rdb', 'RDB')
bsr_name = bsr_name.replace('.weight', '.0.weight')
bsr_name = bsr_name.replace('.bias', '.0.bias')
new_state_dict[bsr_name] = v
else:
new_state_dict[k] = v
return new_state_dict