classes = { '0': 'nsfw_gore', '1': 'nsfw_suggestive', '2': 'safe' } model_path = "safesearch_mini_v2.bin" from transformers import PretrainedConfig, PreTrainedModel class SafeSearchConfig(PretrainedConfig): model_type = "safesearch_mini_v2" def __init__(self, model_name: str = "safesearch_mini_v2", input_channels: int = 3, num_classes: int = 3, input_size: list = [3, 299, 299], pool_size: list = [8, 8], crop_pct: float = 0.875, interpolation: str = "bicubic", mean: list = [0.5, 0.5, 0.5], std: list = [0.5, 0.5, 0.5], first_conv: str = "conv2d", classifier: str = "default", has_aux: bool = False, label_offset: int = 0, classes: object = classes, output_channels: int = 1536, device: str = "cpu", **kwargs): self.model_name = model_name self.input_channels = input_channels self.num_classes = num_classes self.input_size = input_size self.pool_size = pool_size self.crop_pct = crop_pct self.interpolation = interpolation self.mean = mean self.std = std self.first_conv = first_conv self.classifier = classifier self.has_aux = has_aux self.label_offset = label_offset self.classes = classes self.output_channels = output_channels self.device = device super().__init__(**kwargs) """ safesearch_config = SafeSearchConfig() safesearch_config.save_pretrained("safesearch_config") """ import torch, os, timm class SafeSearchModel(PreTrainedModel): config_class = SafeSearchConfig def __init__(self, config: SafeSearchConfig): super().__init__(config) if not os.path.exists(model_path): from urllib.request import urlretrieve urlretrieve(f"https://huggingface.co/FredZhang7/google-safesearch-mini-v2/resolve/main/pytorch_model.bin", model_path) self.model = timm.create_model("inception_resnet_v2", pretrained=False, num_classes=3) self.model.load_state_dict(torch.load(model_path, map_location=torch.device(config.device))) def forward(self, input_ids: torch.Tensor): return self.model(input_ids)