File size: 2,755 Bytes
02a3b66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from copy import deepcopy
from typing import Optional
import torch
from transformers import AutoConfig, VisionTextDualEncoderConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class CustomCLIPPooler(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0, :]
return first_token_tensor
def get_text_model_pooler(text_model_pooler: str) -> torch.nn.Module:
if text_model_pooler == "CustomCLIPPooler":
return CustomCLIPPooler
else:
raise ValueError(f"Unrecognized text model pooler type {text_model_pooler!r}.")
def is_valid_text_model_pooler(
text_model_pooler: str, suppress_error: bool = False
) -> bool:
try:
get_text_model_pooler(text_model_pooler)
except ValueError:
if not suppress_error:
raise
return False
else:
return True
class CustomCLIPConfig(VisionTextDualEncoderConfig):
model_type = "custom-clip-model"
DEFAULT_TEXT_MODEL_POOLER_STR: str = "CustomCLIPPooler"
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = {}
def __init__(
self,
*args,
text_model_pooler: Optional[str] = None,
text_model_pooler_kwargs: Optional[dict] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.text_model_pooler = (
self.DEFAULT_TEXT_MODEL_POOLER_STR
if text_model_pooler is None
else text_model_pooler
)
is_valid_text_model_pooler(self.text_model_pooler, suppress_error=False)
self.text_model_pooler_kwargs = (
self.DEFAULT_TEXT_MODEL_POOLER_KWARGS
if text_model_pooler_kwargs is None
else text_model_pooler_kwargs
)
@classmethod
def from_base(cls, obj: VisionTextDualEncoderConfig):
if not isinstance(obj, cls):
base = VisionTextDualEncoderConfig
if not isinstance(obj, base):
raise TypeError(f"obj must be of type {cls!r} or {base!r}.")
obj = deepcopy(obj)
logger.warning(f"Changing config class from {obj.__class__!r} to {cls!r}.")
obj.__class__ = cls
def setattr_with_warning(object, name, value):
logger.warning(f"Setting {name!r} to {value!r}.")
setattr(object, name, value)
setattr_with_warning(
obj, "text_model_pooler", cls.DEFAULT_TEXT_MODEL_POOLER_STR
)
setattr_with_warning(
obj, "text_model_pooler_kwargs", cls.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
return obj
AutoConfig.register(CustomCLIPConfig.model_type, CustomCLIPConfig)
|