Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from peft.tuners.prompt_tuning import PromptEmbedding | |
from peft.utils import TaskType | |
from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit | |
# This code is adapted for the paper: https://arxiv.org/abs/2303.02861 and | |
# constitutes the work done at MIT-IBM Watson Research Lab. | |
class MultitaskPromptEmbedding(PromptEmbedding): | |
def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings): | |
super().__init__(config, word_embeddings) | |
self.num_tasks = config.num_tasks | |
self.num_ranks = config.num_ranks | |
self.num_virtual_tokens = config.num_virtual_tokens | |
self.num_transformer_submodules = config.num_transformer_submodules | |
if self.num_transformer_submodules is None: | |
self.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1 | |
self.token_dim = config.token_dim | |
total_virtual_tokens = self.num_virtual_tokens * self.num_transformer_submodules | |
self.prefix_task_cols = torch.nn.Parameter( | |
torch.normal( | |
mean=0, | |
std=0.02, | |
size=(self.num_tasks, total_virtual_tokens, self.num_ranks), | |
) | |
) | |
self.prefix_task_rows = torch.nn.Parameter( | |
torch.normal( | |
mean=0, | |
std=0.02, | |
size=(self.num_tasks, self.num_ranks, self.token_dim), | |
) | |
) | |
if config.prompt_tuning_init in [ | |
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, | |
MultitaskPromptTuningInit.EXACT_SOURCE_TASK, | |
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, | |
]: | |
if config.prompt_tuning_init_state_dict_path is None: | |
raise ValueError( | |
f"prompt_tuning_init_state_dict_path needs to be specified with {config.prompt_tuning_init} " | |
"init method" | |
) | |
if config.prompt_tuning_init_state_dict_path.endswith(".safetensors"): | |
from safetensors.torch import load_file | |
state_dict: dict = load_file(config.prompt_tuning_init_state_dict_path) | |
else: | |
state_dict: dict = torch.load( | |
config.prompt_tuning_init_state_dict_path, | |
map_location=word_embeddings.weight.device, | |
) | |
if config.prompt_tuning_init in [ | |
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, | |
MultitaskPromptTuningInit.EXACT_SOURCE_TASK, | |
]: | |
prefix_task_cols_: torch.Tensor = state_dict["prefix_task_cols"] | |
prefix_task_rows_: torch.Tensor = state_dict["prefix_task_rows"] | |
if config.prompt_tuning_init == MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS: | |
prefix_task_cols_ = prefix_task_cols_.mean(0, keepdim=True) | |
prefix_task_rows_ = prefix_task_rows_.mean(0, keepdim=True) | |
elif config.prompt_tuning_init == MultitaskPromptTuningInit.EXACT_SOURCE_TASK: | |
prefix_task_cols_ = prefix_task_cols_[config.prompt_tuning_init_task, ...].unsqueeze(0) | |
prefix_task_rows_ = prefix_task_rows_[config.prompt_tuning_init_task, ...].unsqueeze(0) | |
state_dict = { | |
"embedding.weight": state_dict["prompt_embeddings"], | |
"prefix_task_cols": prefix_task_cols_, | |
"prefix_task_rows": prefix_task_rows_, | |
} | |
self.load_state_dict(state_dict, strict=True) | |
elif config.prompt_tuning_init == MultitaskPromptTuningInit.ONLY_SOURCE_SHARED: | |
state_dict = { | |
"embedding.weight": state_dict["prompt_embeddings"], | |
} | |
self.load_state_dict(state_dict, strict=False) | |
def forward(self, indices, task_ids): | |
if task_ids is None: | |
raise ValueError("task_ids cannot be None") | |
prompt_embeddings = self.embedding(indices) | |
task_cols = torch.index_select(self.prefix_task_cols, 0, task_ids) | |
task_rows = torch.index_select(self.prefix_task_rows, 0, task_ids) | |
task_prompts = torch.matmul(task_cols, task_rows) | |
prompt_embeddings *= task_prompts | |
return prompt_embeddings | |