enzostvs HF Staff commited on
Commit
6757563
·
1 Parent(s): 9efd7bc

add prompt checker

Browse files
app/api/route.ts CHANGED
@@ -1,6 +1,9 @@
1
  import { PrismaClient } from '@prisma/client'
2
 
 
3
  import { UploaderDataset } from './uploader'
 
 
4
 
5
  const prisma = new PrismaClient()
6
 
@@ -13,12 +16,17 @@ export async function POST(
13
  ['x-use-cache']: "0"
14
  }
15
 
16
- const { inputs, negative_prompt } = await request.json()
 
 
 
 
 
17
  const response = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/stabilityai/stable-diffusion-xl-base-1.0`, {
18
  method: 'POST',
19
  body: JSON.stringify({
20
- inputs,
21
- negative_prompt
22
  }),
23
  headers: global_headers,
24
  })
@@ -27,37 +35,9 @@ export async function POST(
27
  if (res?.error) return Response.json({ status: response.status, ok: false, message: res.error });
28
 
29
  const blob = await response.blob()
30
- const headers = new Headers();
31
- headers.set("Content-Type", "image/*");
32
-
33
- const checkIfIsNSFW = (blob: Blob) => {
34
- return new Promise(async (resolve, reject) => {
35
- const response_nsfw = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/DamarJati/NSFW-Filterization-DecentScan`, {
36
- method: 'POST',
37
- headers: {
38
- ...global_headers,
39
- ...headers,
40
- },
41
- body: blob,
42
- })
43
- const isNSFW = await response_nsfw.clone().json().catch(() => ({}));
44
-
45
- if (isNSFW?.error && isNSFW?.estimated_time) {
46
- setTimeout(() => {
47
- checkIfIsNSFW(blob)
48
- }, isNSFW?.estimated_time * 100);
49
- } else resolve(isNSFW)
50
- })
51
- }
52
 
53
- const isNSFW: any = await checkIfIsNSFW(blob)
54
- if (isNSFW?.error) return Response.json({ status: 500, ok: false, message: isNSFW?.error });
55
- if (isNSFW?.length) {
56
- const scoreNotSafe = isNSFW?.find((n: { label: string }) => n.label === "no_safe");
57
- if (scoreNotSafe?.score > 0.85) {
58
- return Response.json({ status: 401, ok: false, message: "Image is not safe for work." });
59
- }
60
- }
61
 
62
  const name = Date.now() + `-${inputs.replace(/[^a-zA-Z0-9]/g, '-').slice(0, 10).toLowerCase()}`
63
  const { ok, message } = await UploaderDataset(blob, name)
@@ -71,6 +51,6 @@ export async function POST(
71
  },
72
  })
73
 
74
- return Response.json({ image: new_image, status: 200, ok: true, headers });
75
 
76
  }
 
1
  import { PrismaClient } from '@prisma/client'
2
 
3
+ import list_styles from "@/assets/list_styles.json"
4
  import { UploaderDataset } from './uploader'
5
+ import { isTextNSFW } from '@/utils/checker/prompt'
6
+ import { isImageNSFW } from '@/utils/checker/image'
7
 
8
  const prisma = new PrismaClient()
9
 
 
16
  ['x-use-cache']: "0"
17
  }
18
 
19
+ const { inputs, style } = await request.json()
20
+ const findStyle = list_styles.find((item) => item.name === style)
21
+
22
+ const textIsNSFW = await isTextNSFW(inputs, global_headers)
23
+ if (textIsNSFW) return Response.json({ status: 401, ok: false, message: "Prompt is not safe for work." });
24
+
25
  const response = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/stabilityai/stable-diffusion-xl-base-1.0`, {
26
  method: 'POST',
27
  body: JSON.stringify({
28
+ inputs: findStyle?.prompt.replace("{prompt}", inputs) ?? inputs,
29
+ negative_prompt: findStyle?.negative_prompt ?? "",
30
  }),
31
  headers: global_headers,
32
  })
 
35
  if (res?.error) return Response.json({ status: response.status, ok: false, message: res.error });
36
 
37
  const blob = await response.blob()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ const imageIsNSFW = await isImageNSFW(blob, global_headers)
40
+ if (imageIsNSFW) return Response.json({ status: 401, ok: false, message: "Image is not safe for work." });
 
 
 
 
 
 
41
 
42
  const name = Date.now() + `-${inputs.replace(/[^a-zA-Z0-9]/g, '-').slice(0, 10).toLowerCase()}`
43
  const { ok, message } = await UploaderDataset(blob, name)
 
51
  },
52
  })
53
 
54
+ return Response.json({ image: new_image, status: 200, ok: true });
55
 
56
  }
components/main/hooks/useInputGeneration.ts CHANGED
@@ -50,13 +50,11 @@ export const useInputGeneration = () => {
50
  }
51
  })
52
 
53
- const findStyle = list_styles.find((item) => item.name === style)
54
-
55
  const response = await fetch("/api", {
56
  method: "POST",
57
  body: JSON.stringify({
58
- inputs: findStyle?.prompt.replace("{prompt}", prompt) ?? prompt,
59
- negative_prompt: findStyle?.negative_prompt ?? "",
60
  }),
61
  })
62
  const data = await response.json()
 
50
  }
51
  })
52
 
 
 
53
  const response = await fetch("/api", {
54
  method: "POST",
55
  body: JSON.stringify({
56
+ inputs: prompt,
57
+ style,
58
  }),
59
  })
60
  const data = await response.json()
utils/checker/image.ts ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const isImageNSFW = async (blob: Blob, global_headers: any) => {
2
+ return new Promise(async (resolve, reject) => {
3
+ const headers = new Headers();
4
+ headers.set("Content-Type", "image/*");
5
+ const request = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/DamarJati/NSFW-Filterization-DecentScan`, {
6
+ method: 'POST',
7
+ headers: {
8
+ ...global_headers,
9
+ ...headers,
10
+ },
11
+ body: blob,
12
+ })
13
+ const res = await request.clone().json().catch(() => ({}));
14
+
15
+ if (res?.error && res?.estimated_time) {
16
+ setTimeout(() => {
17
+ isImageNSFW(blob, global_headers)
18
+ }, res?.estimated_time * 100);
19
+ } else {
20
+ if (res?.error) return Response.json({ status: 500, ok: false, message: res?.error });
21
+ if (res?.length) {
22
+ const isNSFW = res?.find((n: { label: string }) => n.label === "no_safe")?.score > 0.85 ?? false;
23
+ resolve(isNSFW)
24
+ }
25
+ resolve(true)
26
+ }
27
+ })
28
+ }
utils/checker/prompt.ts ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const isTextNSFW = async (prompt: string, headers: any) => {
2
+ return new Promise(async (resolve, reject) => {
3
+ const request = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/michellejieli/NSFW_text_classifier`, {
4
+ method: 'POST',
5
+ body: JSON.stringify({
6
+ inputs: prompt,
7
+ }),
8
+ headers: headers,
9
+ })
10
+ const res = await request.clone().json().catch(() => ({}));
11
+ const isNSFW = res?.[0]?.find((item: { label: string, score: number }) => item?.label === "NSFW")?.score > 0.92 ?? false
12
+ resolve(isNSFW)
13
+ })
14
+ }