|
import torch |
|
from torch import nn |
|
from transformers.trainer_pt_utils import LabelSmoother |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index |
|
DEFAULT_SPEECH_TOKEN = "<speech>" |
|
|
|
|
|
class SPEECH_LLM(nn.Module): |
|
""" |
|
The Speech-to-Text model. It consists of an encoder, a language model and an encoder projector. |
|
The encoder is used to extract speech features from the input speech signal. |
|
The encoder projector is used to project the encoder outputs to the same dimension as the language model. |
|
The language model is used to generate the text from the speech features. |
|
Args: |
|
encoder (:obj:`nn.Module`): The encoder module. |
|
llm (:obj:`nn.Module`): The language model module. |
|
encoder_projector (:obj:`nn.Module`): The encoder projector module. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
encoder: nn.Module = None, |
|
llm: nn.Module = None, |
|
encoder_projector: nn.Module = None, |
|
): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.llm = llm |
|
self.encoder_projector = encoder_projector |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
speech_encoder_dim = 1280 |
|
encoder_projector_ds_rate = 8 |
|
llm_config_hidden_size = 1536 |
|
|
|
adapter_dir="/home/scratch.yuekaiz_wwfo_1/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt" |
|
llm_dir="/home/scratch.yuekaiz_wwfo_1/Qwen2-1.5B-Instruct" |
|
target_dir="/home/scratch.yuekaiz_wwfo_1/Qwen2_1.5B_merged" |
|
|
|
llm = AutoModelForCausalLM.from_pretrained( |
|
llm_dir, |
|
torch_dtype=torch.float16, |
|
) |
|
lora_config = LoraConfig( |
|
r=64, |
|
lora_alpha=16, |
|
target_modules=[ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"up_proj", |
|
"gate_proj", |
|
"down_proj", |
|
], |
|
task_type="CAUSAL_LM", |
|
) |
|
llm = get_peft_model(llm, lora_config) |
|
model = SPEECH_LLM( |
|
llm = llm, |
|
) |
|
|
|
checkpoint = torch.load( |
|
adapter_dir, map_location="cpu" |
|
) |
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) |
|
|
|
print(missing_keys, unexpected_keys) |
|
|
|
llm_merged = model.llm.merge_and_unload() |
|
|
|
llm_merged.save_pretrained(target_dir) |