|
from typing import List, Optional, Union |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
from transformers.image_processing_utils import BaseImageProcessor |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
import torch.nn as nn |
|
class SscdImageProcessor(BaseImageProcessor): |
|
def __init__( |
|
self, |
|
do_resize: bool = True, |
|
size: int = 288, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
do_convert_rgb: bool = True, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.size = size |
|
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] |
|
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] |
|
self.do_convert_rgb = do_convert_rgb |
|
self.do_resize = do_resize |
|
|
|
def preprocess( |
|
self, |
|
image: Image, |
|
do_resize: bool = None, |
|
**kwargs, |
|
): |
|
size_transforms = [ |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=self.image_mean, std=self.image_std, |
|
), |
|
] |
|
if do_resize is None: |
|
do_resize = self.do_resize |
|
if do_resize: |
|
size_transforms.append(transforms.Resize(self.size)) |
|
preprocess = transforms.Compose([ |
|
transforms.Resize(self.size), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=self.image_mean, std=self.image_std, |
|
), |
|
]) |
|
if self.do_convert_rgb: |
|
image = image.convert('RGB') |
|
return preprocess(image).unsqueeze(0) |
|
|
|
class SscdConfig(PretrainedConfig): |
|
model_type = 'sscd-copy-detection' |
|
def __init__(self, model_path: str = None, **kwargs): |
|
if model_path is None: |
|
model_path = 'sscd_disc_mixup.torchscript.pt' |
|
super().__init__(model_path=model_path, **kwargs) |
|
|
|
class SscdModel(PreTrainedModel): |
|
config_class = SscdConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.dummy_param = nn.Parameter(torch.zeros(0)) |
|
|
|
print("______", config.name_or_path) |
|
|
|
is_local = os.path.isdir(config.name_or_path) |
|
if is_local: |
|
config.base_path = config.name_or_path |
|
else: |
|
config_path = hf_hub_download(repo_id=config.name_or_path, filename='config.json') |
|
config.base_path = os.path.dirname(config_path) |
|
model_path = config.base_path + '/' + config.model_path |
|
print("___model_path___", model_path) |
|
|
|
def forward(self, inputs): |
|
return self.model(inputs) |
|
|
|
sscd_processor = SscdImageProcessor() |
|
sscd_processor.save_pretrained('new_model') |
|
sscd_config = SscdConfig(model_path='sscd_disc_mixup.torchscript.pt') |
|
sscd_config.save_pretrained('new_model') |
|
|
|
model = SscdModel.from_pretrained('new_model') |
|
|
|
|
|
|
|
|