File size: 2,069 Bytes
70a3c4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
classes = { '0': 'nsfw_gore', '1': 'nsfw_suggestive', '2': 'safe' }
model_path = "safesearch_mini_v2.bin"

from transformers import PretrainedConfig, PreTrainedModel

class SafeSearchConfig(PretrainedConfig):
	model_type = "safesearch_mini_v2"
	def __init__(self,
				 model_name: str = "safesearch_mini_v2",
				 input_channels: int = 3,
				 num_classes: int = 3,
				 input_size: list[int] = [3, 299, 299],
				 pool_size: list[int] = [8, 8],
				 crop_pct: float = 0.875,
				 interpolation: str = "bicubic",
				 mean: list[float] = [0.5, 0.5, 0.5],
				 std: list[float] = [0.5, 0.5, 0.5],
				 first_conv: str = "conv2d",
				 classifier: str = "default",
				 has_aux: bool = False,
				 label_offset: int = 0,
				 classes: list[str] = classes,
				 output_channels: int = 1536,
				 device: str = "cpu",
				 **kwargs):
		self.model_name = model_name
		self.input_channels = input_channels
		self.num_classes = num_classes
		self.input_size = input_size
		self.pool_size = pool_size
		self.crop_pct = crop_pct
		self.interpolation = interpolation
		self.mean = mean
		self.std = std
		self.first_conv = first_conv
		self.classifier = classifier
		self.has_aux = has_aux
		self.label_offset = label_offset
		self.classes = classes
		self.output_channels = output_channels
		self.device = device
		super().__init__(**kwargs)

"""
safesearch_config = SafeSearchConfig()
safesearch_config.save_pretrained("safesearch_config")
"""

import torch, os, timm

class SafeSearchModel(PreTrainedModel):
	config_class = SafeSearchConfig
	def __init__(self, config: SafeSearchConfig):
		super().__init__(config)
		if not os.path.exists(model_path):
			from urllib.request import urlretrieve
			urlretrieve(f"https://huggingface.co/FredZhang7/google-safesearch-mini-v2/resolve/main/pytorch_model.bin", model_path)
		self.model = timm.create_model("inception_resnet_v2", pretrained=False, num_classes=3)
		self.model.load_state_dict(torch.load(model_path, map_location=torch.device(config.device)))

	def forward(self, input_ids: torch.Tensor):
		return self.model(input_ids)