|
|
|
from model_helper import laod_model_checkpoint |
|
|
|
model_name = 'YPTF.MoE+Multi (noPS)' |
|
precision = '16' |
|
project = '2024' |
|
|
|
if model_name == "YMT3+": |
|
checkpoint = "[email protected]" |
|
args = [checkpoint, '-p', project, '-pr', precision] |
|
elif model_name == "YPTF+Single (noPS)": |
|
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt" |
|
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec', |
|
'-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF+Multi (PS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', |
|
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf', |
|
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF.MoE+Multi (noPS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', |
|
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', |
|
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', |
|
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
elif model_name == "YPTF.MoE+Multi (PS)": |
|
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt" |
|
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', |
|
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', |
|
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', |
|
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision] |
|
else: |
|
raise ValueError(model_name) |
|
|
|
model = load_model_checkpoint(args=args) |
|
|