File size: 2,141 Bytes
9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 4e8d679 9ed3b66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification, FlaxResNetModel, ResNetModel
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core.frozen_dict import unfreeze
import re
import jax.numpy as jnp
import torch
pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
flax_resnet = FlaxResNetForImageClassification(pt_resnet.config)
pt_state = pt_resnet.state_dict()
flax_state = flatten_dict(unfreeze(flax_resnet.params))
new_pt_state = {}
for key, tensor in pt_state.items():
key_parts = set(key.split("."))
tensor = tensor.numpy()
if "convolution.weight" in key:
key = key.replace("weight", "kernel")
tensor = tensor.transpose((2, 3, 1, 0))
key = "params."+key
new_pt_state[key] = tensor
elif "normalization.weight" in key:
key = key.replace("weight", "scale")
key = "params."+key
new_pt_state[key] = tensor
elif "normalization.bias" in key:
key = key.replace("bias", "bias")
key = "params."+key
new_pt_state[key] = tensor
elif "classifier.1.weight" in key:
key = "params.classifier.1.kernel"
new_pt_state[key] = tensor.transpose()
elif "classifier.1.bias" in key:
key = "params.classifier.1.bias"
new_pt_state[key] = tensor
elif "normalization.running_mean" in key:
key = key.replace("running_mean", "mean")
key = "batch_stats."+key
new_pt_state[key] = tensor
elif "normalization.running_var" in key:
key = key.replace("running_var", "var")
key = "batch_stats."+key
new_pt_state[key] = tensor
else:
continue
for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()):
orig_flax_tensor = flax_state.get(tuple(new_key.split(".")))
assert orig_flax_tensor is not None
assert orig_flax_tensor.shape == new_tensor.shape
flax_state[tuple(new_key.split("."))] = new_tensor
flax_state = unflatten_dict(flax_state)
flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)
|