OwenElliott commited on
Commit
63b4882
1 Parent(s): d3e40ba

Delete modeling_nsfw_image_detection.py

Browse files
Files changed (1) hide show
  1. modeling_nsfw_image_detection.py +0 -38
modeling_nsfw_image_detection.py DELETED
@@ -1,38 +0,0 @@
1
- import torch
2
- from PIL import Image
3
- import timm
4
- from timm import create_model
5
- from transformers import PreTrainedModel
6
- from typing import List
7
- from .configuration_nsfw_image_detection import NSFWImageDetectionConfig
8
-
9
-
10
- class NSFWImageDetector(PreTrainedModel):
11
- config_class = NSFWImageDetectionConfig
12
-
13
- def __init__(self, config: NSFWImageDetectionConfig):
14
- super().__init__(config)
15
- self.model = create_model(
16
- 'hf-hub:Marqo/nsfw-image-detection-384',
17
- pretrained=True,
18
- ).eval()
19
- self.data_config = timm.data.resolve_model_data_config(self.model)
20
- self.transforms = timm.data.create_transform(**self.data_config, is_training=False)
21
- self.model = self.model.to(self.device)
22
-
23
- def forward(
24
- self,
25
- images: List[Image.Image]
26
- ) -> torch.Tensor:
27
- image_inputs = [self.transforms(image) for image in images]
28
- image_tensor = torch.stack(image_inputs, dim=0).to(self.device)
29
- outputs = self.model(image_tensor).softmax(dim=-1).cpu()
30
- return outputs
31
-
32
- def __call__(self, images: List[Image.Image]):
33
- return self.forward(images)
34
-
35
- @classmethod
36
- def from_pretrained(self, *args, **kwargs):
37
- config = NSFWImageDetectionConfig()
38
- return NSFWImageDetector(config)