OwenElliott
commited on
Commit
•
63b4882
1
Parent(s):
d3e40ba
Delete modeling_nsfw_image_detection.py
Browse files
modeling_nsfw_image_detection.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from PIL import Image
|
3 |
-
import timm
|
4 |
-
from timm import create_model
|
5 |
-
from transformers import PreTrainedModel
|
6 |
-
from typing import List
|
7 |
-
from .configuration_nsfw_image_detection import NSFWImageDetectionConfig
|
8 |
-
|
9 |
-
|
10 |
-
class NSFWImageDetector(PreTrainedModel):
|
11 |
-
config_class = NSFWImageDetectionConfig
|
12 |
-
|
13 |
-
def __init__(self, config: NSFWImageDetectionConfig):
|
14 |
-
super().__init__(config)
|
15 |
-
self.model = create_model(
|
16 |
-
'hf-hub:Marqo/nsfw-image-detection-384',
|
17 |
-
pretrained=True,
|
18 |
-
).eval()
|
19 |
-
self.data_config = timm.data.resolve_model_data_config(self.model)
|
20 |
-
self.transforms = timm.data.create_transform(**self.data_config, is_training=False)
|
21 |
-
self.model = self.model.to(self.device)
|
22 |
-
|
23 |
-
def forward(
|
24 |
-
self,
|
25 |
-
images: List[Image.Image]
|
26 |
-
) -> torch.Tensor:
|
27 |
-
image_inputs = [self.transforms(image) for image in images]
|
28 |
-
image_tensor = torch.stack(image_inputs, dim=0).to(self.device)
|
29 |
-
outputs = self.model(image_tensor).softmax(dim=-1).cpu()
|
30 |
-
return outputs
|
31 |
-
|
32 |
-
def __call__(self, images: List[Image.Image]):
|
33 |
-
return self.forward(images)
|
34 |
-
|
35 |
-
@classmethod
|
36 |
-
def from_pretrained(self, *args, **kwargs):
|
37 |
-
config = NSFWImageDetectionConfig()
|
38 |
-
return NSFWImageDetector(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|