File size: 3,045 Bytes
e97054d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import List, Optional, Union
from torchvision import transforms
from PIL import Image

from transformers.image_processing_utils import BaseImageProcessor
from transformers import PreTrainedModel, PretrainedConfig
import os
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
class SscdImageProcessor(BaseImageProcessor):
    def __init__(
            self,
            do_resize: bool = True,
            size: int = 288,
            image_mean: Optional[Union[float, List[float]]] = None,
            image_std: Optional[Union[float, List[float]]] = None,
            do_convert_rgb: bool = True,
            **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.size = size
        self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406]
        self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225]
        self.do_convert_rgb = do_convert_rgb
        self.do_resize = do_resize

    def preprocess(
            self,
            image: Image,
            do_resize: bool = None,
            **kwargs,
    ):
        size_transforms = [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=self.image_mean, std=self.image_std,
            ),
        ]
        if do_resize is None:
            do_resize = self.do_resize
        if do_resize:
            size_transforms.append(transforms.Resize(self.size))
        preprocess = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=self.image_mean, std=self.image_std,
            ),
        ])
        if self.do_convert_rgb:
            image = image.convert('RGB')
        return preprocess(image).unsqueeze(0)

class SscdConfig(PretrainedConfig):
    model_type = 'sscd-copy-detection'
    def __init__(self, model_path: str = None, **kwargs):
        if model_path is None:
            model_path = 'sscd_disc_mixup.torchscript.pt'
        super().__init__(model_path=model_path, **kwargs)

class SscdModel(PreTrainedModel):
    config_class = SscdConfig

    def __init__(self, config):
        super().__init__(config)
        self.dummy_param = nn.Parameter(torch.zeros(0))

        print("______", config.name_or_path)

        is_local = os.path.isdir(config.name_or_path)
        if is_local:
            config.base_path = config.name_or_path
        else:
            config_path = hf_hub_download(repo_id=config.name_or_path, filename='config.json')
            config.base_path = os.path.dirname(config_path)
        model_path =  config.base_path + '/' + config.model_path
        print("___model_path___", model_path)

    def forward(self, inputs):
        return self.model(inputs)

sscd_processor = SscdImageProcessor()
sscd_processor.save_pretrained('new_model')
sscd_config = SscdConfig(model_path='sscd_disc_mixup.torchscript.pt')
sscd_config.save_pretrained('new_model')

model = SscdModel.from_pretrained('new_model')