- app.py +37 -0
- load_checkpoint.py +0 -35
app.py
CHANGED
@@ -9,6 +9,43 @@ import gradio as gr
|
|
9 |
from gradio_helper import *
|
10 |
from load_checkpoint import *
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
AUDIO_EXAMPLES = glob.glob('/content/examples/*.*', recursive=True)
|
13 |
YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c"]
|
14 |
|
|
|
9 |
from gradio_helper import *
|
10 |
from load_checkpoint import *
|
11 |
|
12 |
+
|
13 |
+
# @title Load Checkpoint
|
14 |
+
model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
|
15 |
+
precision = '16' # @param ["32", "bf16-mixed", "16"]
|
16 |
+
project = '2024'
|
17 |
+
|
18 |
+
if model_name == "YMT3+":
|
19 |
+
checkpoint = "[email protected]"
|
20 |
+
args = [checkpoint, '-p', project, '-pr', precision]
|
21 |
+
elif model_name == "YPTF+Single (noPS)":
|
22 |
+
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
|
23 |
+
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
|
24 |
+
'-hop', '300', '-atc', '1', '-pr', precision]
|
25 |
+
elif model_name == "YPTF+Multi (PS)":
|
26 |
+
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
|
27 |
+
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
|
28 |
+
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
|
29 |
+
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
30 |
+
elif model_name == "YPTF.MoE+Multi (noPS)":
|
31 |
+
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
|
32 |
+
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
|
33 |
+
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
|
34 |
+
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
|
35 |
+
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
36 |
+
elif model_name == "YPTF.MoE+Multi (PS)":
|
37 |
+
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
|
38 |
+
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
|
39 |
+
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
|
40 |
+
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
|
41 |
+
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
42 |
+
else:
|
43 |
+
raise ValueError(model_name)
|
44 |
+
|
45 |
+
model = load_model_checkpoint(args=args)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
AUDIO_EXAMPLES = glob.glob('/content/examples/*.*', recursive=True)
|
50 |
YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c"]
|
51 |
|
load_checkpoint.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
# @title Load Checkpoint
|
2 |
-
from model_helper import laod_model_checkpoint
|
3 |
-
|
4 |
-
model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
|
5 |
-
precision = '16' # @param ["32", "bf16-mixed", "16"]
|
6 |
-
project = '2024'
|
7 |
-
|
8 |
-
if model_name == "YMT3+":
|
9 |
-
checkpoint = "[email protected]"
|
10 |
-
args = [checkpoint, '-p', project, '-pr', precision]
|
11 |
-
elif model_name == "YPTF+Single (noPS)":
|
12 |
-
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
|
13 |
-
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
|
14 |
-
'-hop', '300', '-atc', '1', '-pr', precision]
|
15 |
-
elif model_name == "YPTF+Multi (PS)":
|
16 |
-
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
|
17 |
-
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
|
18 |
-
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
|
19 |
-
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
20 |
-
elif model_name == "YPTF.MoE+Multi (noPS)":
|
21 |
-
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
|
22 |
-
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
|
23 |
-
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
|
24 |
-
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
|
25 |
-
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
26 |
-
elif model_name == "YPTF.MoE+Multi (PS)":
|
27 |
-
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
|
28 |
-
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
|
29 |
-
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
|
30 |
-
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
|
31 |
-
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
|
32 |
-
else:
|
33 |
-
raise ValueError(model_name)
|
34 |
-
|
35 |
-
model = load_model_checkpoint(args=args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|