File size: 1,722 Bytes
5734b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# Convert whisper into smaller model using layer pruning
import torch
from transformers import WhisperProcessor, GenerationConfig, WhisperForConditionalGeneration, WhisperTokenizer
TEACHER_CKPT = "large-v2"
DECODER_LAYERS = 8
SAVE_DIR = "."
CACHE_DIR = "."
teacher_model = WhisperForConditionalGeneration.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
teacher_config = teacher_model.config
teacher_layers = teacher_config.decoder_layers
student_config = teacher_config
student_config.decoder_layers = DECODER_LAYERS
mapping = [0, 1, 4, 8, 16, 24, 30, 31] # mapping 8 teacher decoder layers to student model
assert DECODER_LAYERS == len(mapping)
student_model = WhisperForConditionalGeneration(student_config)
# copy layers
info = student_model.load_state_dict(teacher_model.state_dict(), strict=False)
# make sure entire encoder is copied
for s,t in zip(student_model.model.encoder.parameters(), teacher_model.model.encoder.parameters()):
assert torch.equal(s.data, t.data)
# copy decoder layers
# has to be strict match: <All keys matched successfully>
layers_to_copy = torch.nn.ModuleList([teacher_model.model.decoder.layers[i] for i in mapping])
student_model.model.decoder.layers.load_state_dict(layers_to_copy.state_dict())
# save model
student_model.save_pretrained(SAVE_DIR)
# also save processor, generation config and tokenizer
processor = WhisperProcessor.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
processor.save_pretrained(SAVE_DIR)
generation_config = GenerationConfig.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
generation_config.save_pretrained(SAVE_DIR)
tokenizer = WhisperTokenizer.from_pretrained(TEACHER_CKPT, cache_dir=CACHE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
|