|
"""This lobe enables the integration of huggingface pretrained Llama2 Model model plus the expanding embedding layer for additional PAD tokens . |
|
|
|
Transformer from HuggingFace needs to be installed: |
|
https://huggingface.co/transformers/installation.html |
|
|
|
Authors |
|
* Pooneh Mousavi 2023 |
|
""" |
|
|
|
import logging |
|
from torch import Tensor |
|
import torch |
|
import torch.nn as nn |
|
from speechbrain.lobes.models.huggingface_transformers.llama2 import LLAMA2 |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LLAMA2_expanded(LLAMA2): |
|
"""This lobe enables the integration of HuggingFace pretrained LLAMA2 model. |
|
Source paper LLAMA2: |
|
https://arxiv.org/abs/2307.09288 |
|
Transformer from HuggingFace needs to be installed: |
|
https://huggingface.co/transformers/installation.html |
|
|
|
The model can be finetuned. It will download automatically the model from |
|
HuggingFace or use a local path. |
|
|
|
Arguments |
|
--------- |
|
source : str |
|
HuggingFace hub name: e.g "meta-llama/Llama-2-7b-chat-hf" |
|
save_path : str |
|
Path (dir) of the downloaded model. |
|
freeze : bool (default: False) |
|
If True, the model is frozen. If False, the model will be trained |
|
alongside with the rest of the pipeline. |
|
Example |
|
------- |
|
>>> model_hub = "meta-llama/Llama-2-7b-chat-hf" |
|
>>> save_path = "savedir" |
|
>>> model = LLAMA2(model_hub, save_path) |
|
>>> tokens = torch.tensor([[1, 1]]) |
|
>>> attention_mask = torch.tensor([[1, 1]]) |
|
>>> outputs = model(tokens, attention_mask) |
|
""" |
|
def __init__( |
|
self, *args, **kwrds |
|
) -> None: |
|
super().__init__( *args, **kwrds) |
|
|
|
|
|
|
|
|
|
self.add_special_tokens_( |
|
{"pad_token": "<pad>"} |
|
) |
|
|
|
def add_special_tokens_(self, attr_to_special_token,) -> None: |
|
orig_num_tokens = len(self.tokenizer) |
|
num_added_tokens = self.tokenizer.add_special_tokens( |
|
attr_to_special_token |
|
) |
|
if num_added_tokens > 0: |
|
self.model.resize_token_embeddings( |
|
new_num_tokens=orig_num_tokens + num_added_tokens |
|
) |
|
|
|
|