jadechoghari's picture
Update audioldm_train/config/mos_as_token/qa_mdt.yaml
15fc9cd verified
log_directory: "./log/latent_diffusion"
project: "audioldm"
precision: "high"
# TODO: change this with your project path
base_root: "./qa_mdt"
# TODO: change this with your pretrained path
# TODO: pretrained path is also needed in "base_root/offset_pretrained_checkpoints.json"
pretrained:
clap_music: "./qa_mdt/checkpoints/clap_music"
flan_t5: "./qa_mdt/checkpoints/flant5"
hifi-gan: "./qa_mdt/checkpoints/hifi-gan/checkpoints"
roberta-base: "./qa_mdt/checkpoints/robertabase"
# TODO: lmdb dataset that stores pMOS of the training dataset
# while in inference, we don't need it !!!
# while in inference, we don't need it !!!
# while in inference, we don't need it !!!
mos_path: ""
train_path:
train_lmdb_path: [] # path list of training lmdb folders
val_path:
val_lmdb_path: [] # path list of training lmdb folders
val_key_path: [] # path list of training lmdb key files
variables:
sampling_rate: &sampling_rate 16000
mel_bins: &mel_bins 64
latent_embed_dim: &latent_embed_dim 8
latent_t_size: &latent_t_size 256 # TODO might need to change
latent_f_size: &latent_f_size 16 # TODO might need to change
in_channels: &unet_in_channels 8 # TODO might need to change
optimize_ddpm_parameter: &optimize_ddpm_parameter true
optimize_gpt: &optimize_gpt true
warmup_steps: &warmup_steps 2000
# we rewrite the dataset so it may not be needed
data:
train: ["audiocaps"]
val: "audiocaps"
test: "audiocaps"
class_label_indices: "audioset_eval_subset"
dataloader_add_ons: ["waveform_rs_48k"]
step:
validation_every_n_epochs: 10000
save_checkpoint_every_n_steps: 1000
# limit_val_batches: 2
max_steps: 8000000
save_top_k: 1000
preprocessing:
audio:
sampling_rate: *sampling_rate
max_wav_value: 32768.0
duration: 10.24
stft:
filter_length: 1024
hop_length: 160
win_length: 1024
mel:
n_mel_channels: *mel_bins
mel_fmin: 0
mel_fmax: 8000
augmentation:
mixup: 0.0
model:
target: qa_mdt.audioldm_train.modules.latent_diffusion.ddpm.LatentDiffusion
params:
# Autoencoder
first_stage_config:
base_learning_rate: 8.0e-06
target: qa_mdt.audioldm_train.modules.latent_encoder.autoencoder.AutoencoderKL
params:
# TODO: change it with your VAE checkpoint
reload_from_ckpt: "./qa_mdt/checkpoints/hifi-gan/checkpoints/vae_mel_16k_64bins.ckpt"
sampling_rate: *sampling_rate
batchsize: 1
monitor: val/rec_loss
image_key: fbank
subband: 1
embed_dim: *latent_embed_dim
time_shuffle: 1
lossconfig:
target: qa_mdt.audioldm_train.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 1000.0
disc_weight: 0.5
disc_in_channels: 1
ddconfig:
double_z: true
mel_bins: *mel_bins
z_channels: 8
resolution: 256
downsample_time: false
in_channels: 1
out_ch: 1
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
# Other parameters
base_learning_rate: 8.0e-5
warmup_steps: *warmup_steps
optimize_ddpm_parameter: *optimize_ddpm_parameter
sampling_rate: *sampling_rate
batchsize: 16
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
unconditional_prob_cfg: 0.1
parameterization: eps # [eps, x0, v]
first_stage_key: fbank
latent_t_size: *latent_t_size
latent_f_size: *latent_f_size
channels: *latent_embed_dim
monitor: val/loss_simple_ema
scale_by_std: true
unet_config:
# TODO: choose your class, Default: MDT_MOS_AS_TOKEN
# (Noted: the 2D-Rope, SwiGLU and the MDT are in two classes, when training with all of them, they should be changed and merged)
target: qa_mdt.audioldm_train.modules.diffusionmodules.PixArt.PixArt_MDT_MOS_AS_TOKEN
params:
input_size : [256, 16]
# patch_size: [16,4]
patch_size : [4, 1]
overlap_size: [0, 0]
in_channels : 8
hidden_size : 1152
depth : 28
num_heads : 16
mlp_ratio : 4.0
class_dropout_prob : 0.1
pred_sigma : True
drop_path : 0.
window_size : 0
window_block_indexes : None
use_rel_pos : False
cond_dim : 1024
lewei_scale : 1.0
overlap: [0, 0]
use_cfg: true
mask_ratio: 0.30
decode_layer: 8
cond_stage_config:
crossattn_flan_t5:
cond_stage_key: text
conditioning_key: crossattn
target: qa_mdt.audioldm_train.conditional_models.FlanT5HiddenState
evaluation_params:
unconditional_guidance_scale: 3.5
ddim_sampling_steps: 200
n_candidates_per_samples: 3