RohitGandikota commited on
Commit
500d414
1 Parent(s): 376e5cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -4,7 +4,34 @@ from finetuning import FineTunedModel
4
  from StableDiffuser import StableDiffuser
5
  from train import train
6
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import os
9
  model_map = {'Van Gogh' : 'models/vangogh.pt',
10
  'Pablo Picasso': 'models/pablopicasso.pt',
 
4
  from StableDiffuser import StableDiffuser
5
  from train import train
6
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
7
+ from transformers import CLIPFeatureExtractor
8
+
9
+
10
+ def numpy_to_pil(images):
11
+ """
12
+ Convert a numpy image or a batch of images to a PIL image.
13
+ """
14
+ if images.ndim == 3:
15
+ images = images[None, ...]
16
+ images = (images * 255).round().astype("uint8")
17
+ if images.shape[-1] == 1:
18
+ # special case for grayscale (single channel) images
19
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
20
+ else:
21
+ pil_images = [Image.fromarray(image) for image in images]
22
+
23
+ return pil_images
24
+
25
+ def run_safety_checker(self, image, device, dtype):
26
+
27
+ feature_extractor = CLIPFeatureExtractor()
28
+ safety_checker = StableDiffusionSafetyChecker()
29
+ safety_checker_input = feature_extractor(numpy_to_pil(image), return_tensors="pt").to('cuda')
30
+ image, has_nsfw_concept = safety_checker(
31
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
32
+ )
33
+ return image, has_nsfw_concept
34
+
35
  import os
36
  model_map = {'Van Gogh' : 'models/vangogh.pt',
37
  'Pablo Picasso': 'models/pablopicasso.pt',