|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
src_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2.pt" |
|
ref_ckpt = "/checkpoint/wnhsu/w2v/hubert_icassp_oss_v3/iter2_km100-400k-grp-L6/oss.km500_p0_1_s334.pmw1_0.puw0_0.grpnorm.ml10.mp0_8.untie.mxsz250000.ufreq1.maxtok1400000.MU100k.s1337.ngpu32/checkpoint_last.pt" |
|
new_ckpt = "/checkpoint/wnhsu/w2v/archived/hubert_base_ls960_it2_updated.pt" |
|
|
|
|
|
def update_state(state): |
|
state["model"]["label_embs_concat"] = state["model"].pop("label_embs") |
|
state["args"].task = "hubert_pretraining" |
|
state["args"].labels = f"['{state['args'].labels}']" |
|
return state |
|
|
|
|
|
src_state = torch.load(src_ckpt) |
|
src_state = update_state(src_state) |
|
torch.save(src_state, new_ckpt) |
|
|