Spaces:
Paused
Paused
import numpy as np | |
import torch | |
import torch.nn as nn | |
from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel | |
from ...utils import logging | |
logger = logging.get_logger(__name__) | |
class IFSafetyChecker(PreTrainedModel): | |
config_class = CLIPConfig | |
_no_split_modules = ["CLIPEncoderLayer"] | |
def __init__(self, config: CLIPConfig): | |
super().__init__(config) | |
self.vision_model = CLIPVisionModelWithProjection(config.vision_config) | |
self.p_head = nn.Linear(config.vision_config.projection_dim, 1) | |
self.w_head = nn.Linear(config.vision_config.projection_dim, 1) | |
def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): | |
image_embeds = self.vision_model(clip_input)[0] | |
nsfw_detected = self.p_head(image_embeds) | |
nsfw_detected = nsfw_detected.flatten() | |
nsfw_detected = nsfw_detected > p_threshold | |
nsfw_detected = nsfw_detected.tolist() | |
if any(nsfw_detected): | |
logger.warning( | |
"Potential NSFW content was detected in one or more images. A black image will be returned instead." | |
" Try again with a different prompt and/or seed." | |
) | |
for idx, nsfw_detected_ in enumerate(nsfw_detected): | |
if nsfw_detected_: | |
images[idx] = np.zeros(images[idx].shape) | |
watermark_detected = self.w_head(image_embeds) | |
watermark_detected = watermark_detected.flatten() | |
watermark_detected = watermark_detected > w_threshold | |
watermark_detected = watermark_detected.tolist() | |
if any(watermark_detected): | |
logger.warning( | |
"Potential watermarked content was detected in one or more images. A black image will be returned instead." | |
" Try again with a different prompt and/or seed." | |
) | |
for idx, watermark_detected_ in enumerate(watermark_detected): | |
if watermark_detected_: | |
images[idx] = np.zeros(images[idx].shape) | |
return images, nsfw_detected, watermark_detected | |