FredZhang7's picture
Upload model
75208a0
raw
history blame
2.05 kB
classes = { '0': 'nsfw_gore', '1': 'nsfw_suggestive', '2': 'safe' }
model_path = "safesearch_mini_v2.bin"
from transformers import PretrainedConfig, PreTrainedModel
class SafeSearchConfig(PretrainedConfig):
model_type = "safesearch_mini_v2"
def __init__(self,
model_name: str = "safesearch_mini_v2",
input_channels: int = 3,
num_classes: int = 3,
input_size: list = [3, 299, 299],
pool_size: list = [8, 8],
crop_pct: float = 0.875,
interpolation: str = "bicubic",
mean: list = [0.5, 0.5, 0.5],
std: list = [0.5, 0.5, 0.5],
first_conv: str = "conv2d_1a.conv",
classifier: str = "default",
has_aux: bool = False,
label_offset: int = 0,
classes: object = classes,
output_channels: int = 1536,
device: str = "cpu",
**kwargs):
self.model_name = model_name
self.input_channels = input_channels
self.num_classes = num_classes
self.input_size = input_size
self.pool_size = pool_size
self.crop_pct = crop_pct
self.interpolation = interpolation
self.mean = mean
self.std = std
self.first_conv = first_conv
self.classifier = classifier
self.has_aux = has_aux
self.label_offset = label_offset
self.classes = classes
self.output_channels = output_channels
self.device = device
super().__init__(**kwargs)
"""
safesearch_config = SafeSearchConfig()
safesearch_config.save_pretrained("safesearch_config")
"""
import torch, os, timm
class SafeSearchModel(PreTrainedModel):
config_class = SafeSearchConfig
def __init__(self, config: SafeSearchConfig):
super().__init__(config)
if not os.path.exists(model_path):
from urllib.request import urlretrieve
urlretrieve(f"https://huggingface.co/FredZhang7/google-safesearch-mini-v2/resolve/main/pytorch_model.bin", model_path)
self.model = timm.create_model("inception_resnet_v2", pretrained=False, num_classes=3)
self.model.load_state_dict(torch.load(model_path, map_location=torch.device(config.device)))
def forward(self, input_ids: torch.Tensor):
return self.model(input_ids)