Olive_Whisper_ASR / utils /model_utils.py
sam2ai's picture
Synced repo using 'sync_with_huggingface' Github Action
6de3e11
raw
history blame contribute delete
622 Bytes
import bitsandbytes as bnb
import torch
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
def find_all_linear_names(use_8bit, model):
cls = bnb.nn.Linear8bitLt if use_8bit else torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
target_modules = list(lora_module_names)
return target_modules
def load_from_checkpoint(resume_from_checkpoint, model=None):
pass