m3's picture
chore: add sscd models
e97054d
raw
history blame
No virus
3.05 kB
from typing import List, Optional, Union
from torchvision import transforms
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor
from transformers import PreTrainedModel, PretrainedConfig
import os
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
class SscdImageProcessor(BaseImageProcessor):
def __init__(
self,
do_resize: bool = True,
size: int = 288,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.size = size
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406]
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225]
self.do_convert_rgb = do_convert_rgb
self.do_resize = do_resize
def preprocess(
self,
image: Image,
do_resize: bool = None,
**kwargs,
):
size_transforms = [
transforms.ToTensor(),
transforms.Normalize(
mean=self.image_mean, std=self.image_std,
),
]
if do_resize is None:
do_resize = self.do_resize
if do_resize:
size_transforms.append(transforms.Resize(self.size))
preprocess = transforms.Compose([
transforms.Resize(self.size),
transforms.ToTensor(),
transforms.Normalize(
mean=self.image_mean, std=self.image_std,
),
])
if self.do_convert_rgb:
image = image.convert('RGB')
return preprocess(image).unsqueeze(0)
class SscdConfig(PretrainedConfig):
model_type = 'sscd-copy-detection'
def __init__(self, model_path: str = None, **kwargs):
if model_path is None:
model_path = 'sscd_disc_mixup.torchscript.pt'
super().__init__(model_path=model_path, **kwargs)
class SscdModel(PreTrainedModel):
config_class = SscdConfig
def __init__(self, config):
super().__init__(config)
self.dummy_param = nn.Parameter(torch.zeros(0))
print("______", config.name_or_path)
is_local = os.path.isdir(config.name_or_path)
if is_local:
config.base_path = config.name_or_path
else:
config_path = hf_hub_download(repo_id=config.name_or_path, filename='config.json')
config.base_path = os.path.dirname(config_path)
model_path = config.base_path + '/' + config.model_path
print("___model_path___", model_path)
def forward(self, inputs):
return self.model(inputs)
sscd_processor = SscdImageProcessor()
sscd_processor.save_pretrained('new_model')
sscd_config = SscdConfig(model_path='sscd_disc_mixup.torchscript.pt')
sscd_config.save_pretrained('new_model')
model = SscdModel.from_pretrained('new_model')