File size: 1,342 Bytes
3665b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3

# from transformers.utils.dummy_pt_objects import SpeechEncoderDecoderModel
# from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
# from transformers.models.auto.tokenization_auto import AutoTokenizer
# from transformers.models.wav2vec2.processing_wav2vec2 import Wav2Vec2Processor

from transformers import SpeechEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2Processor


# checkpoints to leverage
encoder_id = "facebook/wav2vec2-base"
decoder_id = "facebook/bart-base"

# load and save speech-encoder-decoder model
# set some hyper-parameters for training and evaluation
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_id,
    decoder_id,
    encoder_add_adapter=True,
    encoder_feat_proj_dropout=0.0,
    encoder_layerdrop=0.0,
    max_length=200,
    num_beams=5,
)
model.config.decoder_start_token_id = model.decoder.config.bos_token_id
model.config.pad_token_id = model.decoder.config.pad_token_id
model.config.eos_token_id = model.decoder.config.eos_token_id
model.save_pretrained("./")

# load and save processor
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
processor.save_pretrained("./")