JustinLin610
update
8437114
raw
history blame
873 Bytes
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
src_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2.pt"
ref_ckpt = "/checkpoint/wnhsu/w2v/hubert_icassp_oss_v3/iter2_km100-400k-grp-L6/oss.km500_p0_1_s334.pmw1_0.puw0_0.grpnorm.ml10.mp0_8.untie.mxsz250000.ufreq1.maxtok1400000.MU100k.s1337.ngpu32/checkpoint_last.pt"
new_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2_updated.pt"
def update_state(state):
state["model"]["label_embs_concat"] = state["model"].pop("label_embs")
state["args"].task = "hubert_pretraining"
state["args"].labels = f"['{state['args'].labels}']"
return state
src_state = torch.load(src_ckpt)
src_state = update_state(src_state)
torch.save(src_state, new_ckpt)