from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from transformers import AutoConfig, AutoModel, AutoProcessor class VisionTransformer(nn.Module): """Huggingface AutoModel to generate token embeddings. Loads the correct class, e.g. BERT / RoBERTa etc. Args: model_name_or_path: Huggingface models name (https://huggingface.co/models) model_args: Keyword arguments passed to the Huggingface Transformers model tokenizer_args: Keyword arguments passed to the Huggingface Transformers tokenizer config_args: Keyword arguments passed to the Huggingface Transformers config cache_dir: Cache dir for Huggingface Transformers to store/load models """ def __init__( self, model_name_or_path: str, model_args: Optional[Dict[str, Any]] = None, tokenizer_args: Optional[Dict[str, Any]] = None, config_args: Optional[Dict[str, Any]] = None, cache_dir: Optional[str] = None, ) -> None: super(VisionTransformer, self).__init__() if model_args is None: model_args = {} if tokenizer_args is None: tokenizer_args = {} if config_args is None: config_args = {} self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir) self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, **model_args, cache_dir=cache_dir) self.processor = AutoProcessor.from_pretrained(model_name_or_path, config=self.config, **tokenizer_args, cache_dir=cache_dir) def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Returns token_embeddings, cls_token""" output_states = self.model(pixel_values=features["pixel_values"], return_dict=False)[0] features.update({"token_embeddings": output_states}) return features def get_word_embedding_dimension(self) -> int: return self.config.hidden_size def tokenize( self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True ) -> Dict[str, torch.Tensor]: return self.processor(texts, return_tensors="pt") def get_config_dict(self) -> Dict[str, Any]: return {key: self.__dict__[key] for key in self.config_keys} def save(self, output_path: str, safe_serialization: bool = True) -> None: self.model.save_pretrained(output_path, safe_serialization=safe_serialization) self.processor.save_pretrained(output_path)