File size: 7,087 Bytes
19b3da3 ae524a9 19b3da3 b71808f 19b3da3 b71808f 19b3da3 b71808f 19b3da3 ae524a9 22df957 b71808f 19b3da3 ae524a9 19b3da3 ae524a9 19b3da3 ae524a9 19b3da3 ae524a9 19b3da3 ae524a9 19b3da3 ae524a9 19b3da3 ae524a9 19b3da3 ae524a9 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from re import L
import cv2
import numpy as np
import torch
import torch.nn as nn
from scipy.ndimage.filters import gaussian_filter
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from internals.pipelines.commons import AbstractPipeline
from internals.util.config import get_nsfw_access, get_nsfw_threshold
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class SafetyChecker:
__loaded = False
def load(self):
if self.__loaded:
return
self.model = StableDiffusionSafetyCheckerV2.from_pretrained(
"CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16
).to("cuda")
self.__loaded = True
def apply(self, pipeline: AbstractPipeline):
model = self.model if not get_nsfw_access() else None
if model:
self.load()
if not pipeline:
return
if hasattr(pipeline, "pipe"):
pipeline.pipe.safety_checker = model
if hasattr(pipeline, "pipe2"):
pipeline.pipe2.safety_checker = model
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class StableDiffusionSafetyCheckerV2(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(
config.vision_config.hidden_size, config.projection_dim, bias=False
)
self.concept_embeds = nn.Parameter(
torch.ones(17, config.projection_dim), requires_grad=False
)
self.special_care_embeds = nn.Parameter(
torch.ones(3, config.projection_dim), requires_grad=False
)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(
torch.ones(3), requires_grad=False
)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
special_cos_dist = (
cosine_distance(image_embeds, self.special_care_embeds)
.cpu()
.float()
.numpy()
)
cos_dist = (
cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
)
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {
"special_scores": {},
"special_care": [],
"concept_scores": {},
"bad_concepts": [],
}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(
concept_cos - concept_threshold + adjustment, 3
)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append(
{concept_idx, result_img["special_scores"][concept_idx]}
)
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(
concept_cos - concept_threshold + adjustment, 3
)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
# Blur images based on NSFW score
# -------------------------------
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if any(has_nsfw_concepts) and not get_nsfw_access():
if torch.is_tensor(images) or torch.is_tensor(images[0]):
image = images[idx].cpu().numpy().astype(np.float32)
image = gaussian_filter(image, sigma=7)
# image = cv2.blur(image, (30, 30))
image = torch.from_numpy(image)
images[idx] = image
else:
images[idx] = gaussian_filter(images[idx], sigma=7)
if any(has_nsfw_concepts):
print("NSFW")
return images, has_nsfw_concepts
@torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
special_scores = (
special_cos_dist - self.special_care_embeds_weights + adjustment
)
# special_scores = special_scores.round(decimals=3)
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(
-1, cos_dist.shape[1]
)
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
# images[has_nsfw_concepts] = 0.0 # black image
# Blur images based on NSFW score
# -------------------------------
if not get_nsfw_access():
image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
image = gaussian_filter(image, sigma=7)
image = torch.from_numpy(image)
images[has_nsfw_concepts] = image
return images, has_nsfw_concepts
|