from typing import List, Optional, Union from torchvision import transforms from PIL import Image from transformers.image_processing_utils import BaseImageProcessor from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoImageProcessor, AutoModel import os from huggingface_hub import hf_hub_download import torch import torch.nn as nn from transformers.pipelines import PIPELINE_REGISTRY from transformers.utils import add_end_docstrings from transformers.pipelines.base import Pipeline, build_pipeline_init_args class SscdImageProcessor(BaseImageProcessor): def __init__( self, do_resize: bool = True, size: int = 288, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) self.size = size self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] self.do_convert_rgb = do_convert_rgb self.do_resize = do_resize def preprocess( self, image: Image, do_resize: bool = None, **kwargs, ): size_transforms = [ transforms.ToTensor(), transforms.Normalize( mean=self.image_mean, std=self.image_std, ), ] if do_resize is None: do_resize = self.do_resize if do_resize: size_transforms.append(transforms.Resize(self.size)) preprocess = transforms.Compose([ transforms.Resize(self.size), transforms.ToTensor(), transforms.Normalize( mean=self.image_mean, std=self.image_std, ), ]) if self.do_convert_rgb: image = image.convert('RGB') return preprocess(image).unsqueeze(0) class SscdConfig(PretrainedConfig): model_type = 'sscd-copy-detection' def __init__(self, model_path: str = None, **kwargs): if model_path is None: model_path = 'sscd_disc_mixup.torchscript.pt' super().__init__(model_path=model_path, **kwargs) class SscdModel(PreTrainedModel): config_class = SscdConfig def __init__(self, config, model_path: str = None): super().__init__(config) self.dummy_param = nn.Parameter(torch.zeros(0)) if model_path is None: model_path = config.model_path is_local = os.path.isdir(config.name_or_path) if is_local: config.base_path = config.name_or_path else: file_path = hf_hub_download(repo_id=config.name_or_path, filename=model_path) config.base_path = os.path.dirname(file_path) model_path = config.base_path + '/' + model_path if model_path is not None: self.model = torch.jit.load(model_path) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return cls(AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)) def forward(self, inputs): return self.model(inputs)[0, :] @add_end_docstrings(build_pipeline_init_args(has_image_processor=True)) class SscdPipeline(Pipeline): def __init__(self, model, **kwargs): self.device_id = kwargs['device'] super().__init__(model=model, **kwargs) def _sanitize_parameters(self, **kwargs): return {}, {}, {} def preprocess(self, input): return self.image_processor.preprocess(input) def _forward(self, inputs): return self.model(inputs) def postprocess(self, model_outputs): return model_outputs AutoConfig.register('sscd-copy-detection', SscdConfig) AutoModel.register(SscdConfig, SscdModel) AutoImageProcessor.register(SscdConfig, slow_image_processor_class=SscdImageProcessor) models = AutoModel.from_pretrained('m3/sscd-copy-detection') PIPELINE_REGISTRY.register_pipeline( task='sscd-copy-detection', pipeline_class=SscdPipeline, pt_model=SscdModel )