from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from transformers import CLIPFeatureExtractor import numpy as np import torch from PIL import Image from typing import Optional, Tuple, Union device = None torch_device = None torch_dtype = None safety_checker = None feature_extractor = None def load_model(): global device, torch_device, torch_dtype, safety_checker, feature_extractor device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_device = device torch_dtype = torch.float16 safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ).to(device) feature_extractor = CLIPFeatureExtractor.from_pretrained( "openai/clip-vit-base-patch32" ) def check(image): images = [image] safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) images_np = [np.array(img) for img in images] _, has_nsfw_concepts = safety_checker( images=images_np, clip_input=safety_checker_input.pixel_values.to(torch_device), ) return has_nsfw_concepts[0]