yoinked commited on
Commit
839c246
1 Parent(s): fdb1a74

primitive anti nsfw using wdtagger

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -2,6 +2,7 @@ import spaces
2
  import os
3
  import gc
4
  import gradio as gr
 
5
  import numpy as np
6
  import torch
7
  import json
@@ -12,7 +13,7 @@ from PIL import Image, PngImagePlugin
12
  from datetime import datetime
13
  from diffusers.models import AutoencoderKL
14
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
-
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
@@ -33,7 +34,7 @@ MODEL = os.getenv(
33
  "OnomaAIResearch/Illustrious-xl-early-release-v0",
34
  )
35
 
36
- torch.backends.cudnn.deterministic = True
37
  torch.backends.cudnn.benchmark = False
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -192,7 +193,19 @@ def generate(
192
  pipe.scheduler = backup_scheduler
193
  utils.free_memory()
194
 
195
-
 
 
 
 
 
 
 
 
 
 
 
 
196
  if torch.cuda.is_available():
197
  pipe = load_pipeline(MODEL)
198
  logger.info("Loaded on Device!")
@@ -369,7 +382,7 @@ with gr.Blocks(css="style.css", theme="NoCrypt/[email protected]") as demo:
369
  queue=False,
370
  api_name=False,
371
  ).then(
372
- fn=generate,
373
  inputs=[
374
  prompt,
375
  negative_prompt,
 
2
  import os
3
  import gc
4
  import gradio as gr
5
+ import gradio_client as grcl
6
  import numpy as np
7
  import torch
8
  import json
 
13
  from datetime import datetime
14
  from diffusers.models import AutoencoderKL
15
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
16
+ GRAD_CLIENT = grcl.Client("https://yoinked-da-nsfw-checker.hf.space/")
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
 
34
  "OnomaAIResearch/Illustrious-xl-early-release-v0",
35
  )
36
 
37
+ torch.backends.cudnn.deterministic = True # maybe disable this? seems
38
  torch.backends.cudnn.benchmark = False
39
 
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
193
  pipe.scheduler = backup_scheduler
194
  utils.free_memory()
195
 
196
+ def genwrap(*args, **kwargs):
197
+ ipth, mtd = generate(*args, **kwargs)
198
+ r = GRAD_CLIENT(image=grcl.file(ipth), "chen-evangelion", 0.4, False, False, api_name="/classify"))
199
+ ratings = val[0]
200
+ rating = rating['confidences']
201
+ highestval, classtype = -1, "aa"
202
+ for o in rating:
203
+ if o['confidence'] > highestval:
204
+ highestval = o['confidence']
205
+ classtype = o['label']
206
+ if classtype not in ["general", "sensitive"]: #add "questionable" and "explicit" to enable nsfw, or just delete this func
207
+ return "https://upload.wikimedia.org/wikipedia/commons/b/bf/Bucephala-albeola-010.jpg", mtd
208
+ return ipth, mtd
209
  if torch.cuda.is_available():
210
  pipe = load_pipeline(MODEL)
211
  logger.info("Loaded on Device!")
 
382
  queue=False,
383
  api_name=False,
384
  ).then(
385
+ fn=genwrap,
386
  inputs=[
387
  prompt,
388
  negative_prompt,