Spaces:
Paused
Paused
import copy | |
import os | |
import random | |
from logging import WARNING | |
from typing import Any, List, Optional, Union | |
import torch | |
import torch.nn as nn | |
import transformers | |
# from mmagic.utils import try_import | |
from mmengine import print_log | |
class TokenizerWrapper: | |
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer | |
currently. This wrapper is modified from https://github.com/huggingface/dif | |
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders. | |
py#L358 # noqa. | |
Args: | |
from_pretrained (Union[str, os.PathLike], optional): The *model id* | |
of a pretrained model or a path to a *directory* containing | |
model weights and config. Defaults to None. | |
from_config (Union[str, os.PathLike], optional): The *model id* | |
of a pretrained model or a path to a *directory* containing | |
model weights and config. Defaults to None. | |
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs | |
will be passed to `from_pretrained` function. Otherwise, *args | |
and **kwargs will be used to initialize the model by | |
`self._module_cls(*args, **kwargs)`. | |
""" | |
def __init__( | |
self, | |
from_pretrained: Optional[Union[str, os.PathLike]] = None, | |
from_config: Optional[Union[str, os.PathLike]] = None, | |
*args, | |
**kwargs, | |
): | |
# transformers = try_import("transformers") | |
module_cls = transformers.CLIPTokenizer | |
assert not (from_pretrained and from_config), ( | |
"'from_pretrained' and 'from_config' should not be passed " "at the same time." | |
) | |
if from_config: | |
print_log( | |
"Tokenizers from Huggingface transformers do not support " | |
"'from_config'. Will call 'from_pretrained' instead " | |
"with the same argument.", | |
"current", | |
WARNING, | |
) | |
from_pretrained = from_config | |
if from_pretrained: | |
self.wrapped = module_cls.from_pretrained(from_pretrained, *args, **kwargs) | |
else: | |
self.wrapper = module_cls(*args, **kwargs) | |
self._from_pretrained = from_pretrained | |
self.token_map = {} | |
def __getattr__(self, name: str) -> Any: | |
if name == "wrapped": | |
return super().__getattr__("wrapped") | |
try: | |
return getattr(self.wrapped, name) | |
except AttributeError: | |
try: | |
return super().__getattr__(name) | |
except AttributeError: | |
raise AttributeError( | |
"'name' cannot be found in both " | |
f"'{self.__class__.__name__}' and " | |
f"'{self.__class__.__name__}.tokenizer'." | |
) | |
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs): | |
"""Attempt to add tokens to the tokenizer. | |
Args: | |
tokens (Union[str, List[str]]): The tokens to be added. | |
""" | |
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs) | |
assert num_added_tokens != 0, ( | |
f"The tokenizer already contains the token {tokens}. Please pass " | |
"a different `placeholder_token` that is not already in the " | |
"tokenizer." | |
) | |
def get_token_info(self, token: str) -> dict: | |
"""Get the information of a token, including its start and end index in | |
the current tokenizer. | |
Args: | |
token (str): The token to be queried. | |
Returns: | |
dict: The information of the token, including its start and end | |
index in current tokenizer. | |
""" | |
token_ids = self.__call__(token).input_ids | |
start, end = token_ids[1], token_ids[-2] + 1 | |
return {"name": token, "start": start, "end": end} | |
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs): | |
"""Add placeholder tokens to the tokenizer. | |
Args: | |
placeholder_token (str): The placeholder token to be added. | |
num_vec_per_token (int, optional): The number of vectors of | |
the added placeholder token. | |
*args, **kwargs: The arguments for `self.wrapped.add_tokens`. | |
""" | |
output = [] | |
if num_vec_per_token == 1: | |
self.try_adding_tokens(placeholder_token, *args, **kwargs) | |
output.append(placeholder_token) | |
else: | |
output = [] | |
for i in range(num_vec_per_token): | |
ith_token = placeholder_token + f"_{i}" | |
self.try_adding_tokens(ith_token, *args, **kwargs) | |
output.append(ith_token) | |
for token in self.token_map: | |
if token in placeholder_token: | |
raise ValueError( | |
f"The tokenizer already has placeholder token {token} " | |
f"that can get confused with {placeholder_token} " | |
"keep placeholder tokens independent" | |
) | |
self.token_map[placeholder_token] = output | |
def replace_placeholder_tokens_in_text( | |
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0 | |
) -> Union[str, List[str]]: | |
"""Replace the keywords in text with placeholder tokens. This function | |
will be called in `self.__call__` and `self.encode`. | |
Args: | |
text (Union[str, List[str]]): The text to be processed. | |
vector_shuffle (bool, optional): Whether to shuffle the vectors. | |
Defaults to False. | |
prop_tokens_to_load (float, optional): The proportion of tokens to | |
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0. | |
Returns: | |
Union[str, List[str]]: The processed text. | |
""" | |
if isinstance(text, list): | |
output = [] | |
for i in range(len(text)): | |
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle)) | |
return output | |
for placeholder_token in self.token_map: | |
if placeholder_token in text: | |
tokens = self.token_map[placeholder_token] | |
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)] | |
if vector_shuffle: | |
tokens = copy.copy(tokens) | |
random.shuffle(tokens) | |
text = text.replace(placeholder_token, " ".join(tokens)) | |
return text | |
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]: | |
"""Replace the placeholder tokens in text with the original keywords. | |
This function will be called in `self.decode`. | |
Args: | |
text (Union[str, List[str]]): The text to be processed. | |
Returns: | |
Union[str, List[str]]: The processed text. | |
""" | |
if isinstance(text, list): | |
output = [] | |
for i in range(len(text)): | |
output.append(self.replace_text_with_placeholder_tokens(text[i])) | |
return output | |
for placeholder_token, tokens in self.token_map.items(): | |
merged_tokens = " ".join(tokens) | |
if merged_tokens in text: | |
text = text.replace(merged_tokens, placeholder_token) | |
return text | |
def __call__( | |
self, | |
text: Union[str, List[str]], | |
*args, | |
vector_shuffle: bool = False, | |
prop_tokens_to_load: float = 1.0, | |
**kwargs, | |
): | |
"""The call function of the wrapper. | |
Args: | |
text (Union[str, List[str]]): The text to be tokenized. | |
vector_shuffle (bool, optional): Whether to shuffle the vectors. | |
Defaults to False. | |
prop_tokens_to_load (float, optional): The proportion of tokens to | |
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0 | |
*args, **kwargs: The arguments for `self.wrapped.__call__`. | |
""" | |
replaced_text = self.replace_placeholder_tokens_in_text( | |
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load | |
) | |
return self.wrapped.__call__(replaced_text, *args, **kwargs) | |
def encode(self, text: Union[str, List[str]], *args, **kwargs): | |
"""Encode the passed text to token index. | |
Args: | |
text (Union[str, List[str]]): The text to be encode. | |
*args, **kwargs: The arguments for `self.wrapped.__call__`. | |
""" | |
replaced_text = self.replace_placeholder_tokens_in_text(text) | |
return self.wrapped(replaced_text, *args, **kwargs) | |
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]: | |
"""Decode the token index to text. | |
Args: | |
token_ids: The token index to be decoded. | |
return_raw: Whether keep the placeholder token in the text. | |
Defaults to False. | |
*args, **kwargs: The arguments for `self.wrapped.decode`. | |
Returns: | |
Union[str, List[str]]: The decoded text. | |
""" | |
text = self.wrapped.decode(token_ids, *args, **kwargs) | |
if return_raw: | |
return text | |
replaced_text = self.replace_text_with_placeholder_tokens(text) | |
return replaced_text | |
def __repr__(self): | |
"""The representation of the wrapper.""" | |
s = super().__repr__() | |
prefix = f"Wrapped Module Class: {self._module_cls}\n" | |
prefix += f"Wrapped Module Name: {self._module_name}\n" | |
if self._from_pretrained: | |
prefix += f"From Pretrained: {self._from_pretrained}\n" | |
s = prefix + s | |
return s | |
class EmbeddingLayerWithFixes(nn.Module): | |
"""The revised embedding layer to support external embeddings. This design | |
of this class is inspired by https://github.com/AUTOMATIC1111/stable- | |
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi | |
jack.py#L224 # noqa. | |
Args: | |
wrapped (nn.Emebdding): The embedding layer to be wrapped. | |
external_embeddings (Union[dict, List[dict]], optional): The external | |
embeddings added to this layer. Defaults to None. | |
""" | |
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None): | |
super().__init__() | |
self.wrapped = wrapped | |
self.num_embeddings = wrapped.weight.shape[0] | |
self.external_embeddings = [] | |
if external_embeddings: | |
self.add_embeddings(external_embeddings) | |
self.trainable_embeddings = nn.ParameterDict() | |
def weight(self): | |
"""Get the weight of wrapped embedding layer.""" | |
return self.wrapped.weight | |
def check_duplicate_names(self, embeddings: List[dict]): | |
"""Check whether duplicate names exist in list of 'external | |
embeddings'. | |
Args: | |
embeddings (List[dict]): A list of embedding to be check. | |
""" | |
names = [emb["name"] for emb in embeddings] | |
assert len(names) == len(set(names)), ( | |
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'" | |
) | |
def check_ids_overlap(self, embeddings): | |
"""Check whether overlap exist in token ids of 'external_embeddings'. | |
Args: | |
embeddings (List[dict]): A list of embedding to be check. | |
""" | |
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings] | |
ids_range.sort() # sort by 'start' | |
# check if 'end' has overlapping | |
for idx in range(len(ids_range) - 1): | |
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1] | |
assert ids_range[idx][1] <= ids_range[idx + 1][0], ( | |
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'." | |
) | |
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]): | |
"""Add external embeddings to this layer. | |
Use case: | |
>>> 1. Add token to tokenizer and get the token id. | |
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32') | |
>>> # 'how much' in kiswahili | |
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4) | |
>>> | |
>>> 2. Add external embeddings to the model. | |
>>> new_embedding = { | |
>>> 'name': 'ngapi', # 'how much' in kiswahili | |
>>> 'embedding': torch.ones(1, 15) * 4, | |
>>> 'start': tokenizer.get_token_info('kwaheri')['start'], | |
>>> 'end': tokenizer.get_token_info('kwaheri')['end'], | |
>>> 'trainable': False # if True, will registry as a parameter | |
>>> } | |
>>> embedding_layer = nn.Embedding(10, 15) | |
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer) | |
>>> embedding_layer_wrapper.add_embeddings(new_embedding) | |
>>> | |
>>> 3. Forward tokenizer and embedding layer! | |
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?'] | |
>>> input_ids = tokenizer( | |
>>> input_text, padding='max_length', truncation=True, | |
>>> return_tensors='pt')['input_ids'] | |
>>> out_feat = embedding_layer_wrapper(input_ids) | |
>>> | |
>>> 4. Let's validate the result! | |
>>> assert (out_feat[0, 3: 7] == 2.3).all() | |
>>> assert (out_feat[2, 5: 9] == 2.3).all() | |
Args: | |
embeddings (Union[dict, list[dict]]): The external embeddings to | |
be added. Each dict must contain the following 4 fields: 'name' | |
(the name of this embedding), 'embedding' (the embedding | |
tensor), 'start' (the start token id of this embedding), 'end' | |
(the end token id of this embedding). For example: | |
`{name: NAME, start: START, end: END, embedding: torch.Tensor}` | |
""" | |
if isinstance(embeddings, dict): | |
embeddings = [embeddings] | |
self.external_embeddings += embeddings | |
self.check_duplicate_names(self.external_embeddings) | |
self.check_ids_overlap(self.external_embeddings) | |
# set for trainable | |
added_trainable_emb_info = [] | |
for embedding in embeddings: | |
trainable = embedding.get("trainable", False) | |
if trainable: | |
name = embedding["name"] | |
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"]) | |
self.trainable_embeddings[name] = embedding["embedding"] | |
added_trainable_emb_info.append(name) | |
added_emb_info = [emb["name"] for emb in embeddings] | |
added_emb_info = ", ".join(added_emb_info) | |
print_log(f"Successfully add external embeddings: {added_emb_info}.", "current") | |
if added_trainable_emb_info: | |
added_trainable_emb_info = ", ".join(added_trainable_emb_info) | |
print_log("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current") | |
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: | |
"""Replace external input ids to 0. | |
Args: | |
input_ids (torch.Tensor): The input ids to be replaced. | |
Returns: | |
torch.Tensor: The replaced input ids. | |
""" | |
input_ids_fwd = input_ids.clone() | |
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0 | |
return input_ids_fwd | |
def replace_embeddings( | |
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict | |
) -> torch.Tensor: | |
"""Replace external embedding to the embedding layer. Noted that, in | |
this function we use `torch.cat` to avoid inplace modification. | |
Args: | |
input_ids (torch.Tensor): The original token ids. Shape like | |
[LENGTH, ]. | |
embedding (torch.Tensor): The embedding of token ids after | |
`replace_input_ids` function. | |
external_embedding (dict): The external embedding to be replaced. | |
Returns: | |
torch.Tensor: The replaced embedding. | |
""" | |
new_embedding = [] | |
name = external_embedding["name"] | |
start = external_embedding["start"] | |
end = external_embedding["end"] | |
target_ids_to_replace = [i for i in range(start, end)] | |
ext_emb = external_embedding["embedding"] | |
# do not need to replace | |
if not (input_ids == start).any(): | |
return embedding | |
# start replace | |
s_idx, e_idx = 0, 0 | |
while e_idx < len(input_ids): | |
if input_ids[e_idx] == start: | |
if e_idx != 0: | |
# add embedding do not need to replace | |
new_embedding.append(embedding[s_idx:e_idx]) | |
# check if the next embedding need to replace is valid | |
actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]] | |
assert actually_ids_to_replace == target_ids_to_replace, ( | |
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. " | |
f"Expect '{target_ids_to_replace}' for embedding " | |
f"'{name}' but found '{actually_ids_to_replace}'." | |
) | |
new_embedding.append(ext_emb) | |
s_idx = e_idx + end - start | |
e_idx = s_idx + 1 | |
else: | |
e_idx += 1 | |
if e_idx == len(input_ids): | |
new_embedding.append(embedding[s_idx:e_idx]) | |
return torch.cat(new_embedding, dim=0) | |
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None): | |
"""The forward function. | |
Args: | |
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or | |
[LENGTH, ]. | |
external_embeddings (Optional[List[dict]]): The external | |
embeddings. If not passed, only `self.external_embeddings` | |
will be used. Defaults to None. | |
input_ids: shape like [bz, LENGTH] or [LENGTH]. | |
""" | |
assert input_ids.ndim in [1, 2] | |
if input_ids.ndim == 1: | |
input_ids = input_ids.unsqueeze(0) | |
if external_embeddings is None and not self.external_embeddings: | |
return self.wrapped(input_ids) | |
input_ids_fwd = self.replace_input_ids(input_ids) | |
inputs_embeds = self.wrapped(input_ids_fwd) | |
vecs = [] | |
if external_embeddings is None: | |
external_embeddings = [] | |
elif isinstance(external_embeddings, dict): | |
external_embeddings = [external_embeddings] | |
embeddings = self.external_embeddings + external_embeddings | |
for input_id, embedding in zip(input_ids, inputs_embeds): | |
new_embedding = embedding | |
for external_embedding in embeddings: | |
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding) | |
vecs.append(new_embedding) | |
return torch.stack(vecs) | |
def add_tokens( | |
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1 | |
): | |
"""Add token for training. | |
# TODO: support add tokens as dict, then we can load pretrained tokens. | |
""" | |
if initialize_tokens is not None: | |
assert len(initialize_tokens) == len( | |
placeholder_tokens | |
), "placeholder_token should be the same length as initialize_token" | |
for ii in range(len(placeholder_tokens)): | |
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token) | |
# text_encoder.set_embedding_layer() | |
embedding_layer = text_encoder.text_model.embeddings.token_embedding | |
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer) | |
embedding_layer = text_encoder.text_model.embeddings.token_embedding | |
assert embedding_layer is not None, ( | |
"Do not support get embedding layer for current text encoder. " "Please check your configuration." | |
) | |
initialize_embedding = [] | |
if initialize_tokens is not None: | |
for ii in range(len(placeholder_tokens)): | |
init_id = tokenizer(initialize_tokens[ii]).input_ids[1] | |
temp_embedding = embedding_layer.weight[init_id] | |
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1)) | |
else: | |
for ii in range(len(placeholder_tokens)): | |
init_id = tokenizer("a").input_ids[1] | |
temp_embedding = embedding_layer.weight[init_id] | |
len_emb = temp_embedding.shape[0] | |
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0 | |
initialize_embedding.append(init_weight) | |
# initialize_embedding = torch.cat(initialize_embedding,dim=0) | |
token_info_all = [] | |
for ii in range(len(placeholder_tokens)): | |
token_info = tokenizer.get_token_info(placeholder_tokens[ii]) | |
token_info["embedding"] = initialize_embedding[ii] | |
token_info["trainable"] = True | |
token_info_all.append(token_info) | |
embedding_layer.add_embeddings(token_info_all) | |