Spaces:
Runtime error
Runtime error
add sdxl google
Browse files- app/api/route.ts +7 -5
- app/api/uploader.ts +3 -4
- components/button/index.tsx +1 -1
- components/main/index.tsx +42 -25
- prisma/schema.prisma +1 -1
- utils/checker/image.ts +4 -2
- utils/checker/prompt.ts +1 -1
app/api/route.ts
CHANGED
@@ -22,19 +22,21 @@ export async function POST(
|
|
22 |
const textIsNSFW = await isTextNSFW(inputs, global_headers)
|
23 |
if (textIsNSFW) return Response.json({ status: 401, ok: false, message: "Prompt is not safe for work." });
|
24 |
|
25 |
-
const response = await fetch(`${process.env.
|
26 |
method: 'POST',
|
27 |
body: JSON.stringify({
|
28 |
-
|
29 |
negative_prompt: findStyle?.negative_prompt ?? "",
|
30 |
}),
|
31 |
headers: global_headers,
|
32 |
})
|
33 |
|
|
|
34 |
const res = await response.clone().json().catch(() => ({}));
|
35 |
if (res?.error) return Response.json({ status: response.status, ok: false, message: res.error });
|
36 |
-
|
37 |
-
const
|
|
|
38 |
|
39 |
const imageIsNSFW = await isImageNSFW(blob, global_headers)
|
40 |
if (imageIsNSFW) return Response.json({ status: 401, ok: false, message: "Image is not safe for work." });
|
@@ -51,7 +53,7 @@ export async function POST(
|
|
51 |
userId: userId ?? "",
|
52 |
},
|
53 |
})
|
54 |
-
|
55 |
return Response.json({ image: new_image, status: 200, ok: true });
|
56 |
|
57 |
}
|
|
|
22 |
const textIsNSFW = await isTextNSFW(inputs, global_headers)
|
23 |
if (textIsNSFW) return Response.json({ status: 401, ok: false, message: "Prompt is not safe for work." });
|
24 |
|
25 |
+
const response = await fetch(`${process.env.API_SDXL_URL}`, {
|
26 |
method: 'POST',
|
27 |
body: JSON.stringify({
|
28 |
+
prompt: findStyle?.prompt.replace("{prompt}", inputs) ?? inputs,
|
29 |
negative_prompt: findStyle?.negative_prompt ?? "",
|
30 |
}),
|
31 |
headers: global_headers,
|
32 |
})
|
33 |
|
34 |
+
|
35 |
const res = await response.clone().json().catch(() => ({}));
|
36 |
if (res?.error) return Response.json({ status: response.status, ok: false, message: res.error });
|
37 |
+
|
38 |
+
const base64Image = res.images[0];
|
39 |
+
const blob = await fetch(`data:image/png;base64,${base64Image}`).then((r) => r.blob());
|
40 |
|
41 |
const imageIsNSFW = await isImageNSFW(blob, global_headers)
|
42 |
if (imageIsNSFW) return Response.json({ status: 401, ok: false, message: "Image is not safe for work." });
|
|
|
53 |
userId: userId ?? "",
|
54 |
},
|
55 |
})
|
56 |
+
|
57 |
return Response.json({ image: new_image, status: 200, ok: true });
|
58 |
|
59 |
}
|
app/api/uploader.ts
CHANGED
@@ -1,19 +1,18 @@
|
|
1 |
-
import {
|
2 |
import type { RepoDesignation, Credentials } from "./../../node_modules/@huggingface/hub/dist";
|
3 |
|
4 |
export const UploaderDataset = async (blob: Blob, name: string) => {
|
5 |
const repo: RepoDesignation = { type: "dataset", name: "enzostvs/stable-diffusion-tpu-generations" };
|
6 |
const credentials: Credentials = { accessToken: process.env.NEXT_APP_HF_TOKEN as string };
|
7 |
|
8 |
-
const res: any = await
|
9 |
repo,
|
10 |
credentials,
|
11 |
-
|
12 |
{
|
13 |
path: `images/${name}.png`,
|
14 |
content: blob,
|
15 |
},
|
16 |
-
],
|
17 |
});
|
18 |
|
19 |
if (res?.error) return {
|
|
|
1 |
+
import { uploadFile } from "./../../node_modules/@huggingface/hub/dist";
|
2 |
import type { RepoDesignation, Credentials } from "./../../node_modules/@huggingface/hub/dist";
|
3 |
|
4 |
export const UploaderDataset = async (blob: Blob, name: string) => {
|
5 |
const repo: RepoDesignation = { type: "dataset", name: "enzostvs/stable-diffusion-tpu-generations" };
|
6 |
const credentials: Credentials = { accessToken: process.env.NEXT_APP_HF_TOKEN as string };
|
7 |
|
8 |
+
const res: any = await uploadFile({
|
9 |
repo,
|
10 |
credentials,
|
11 |
+
file:
|
12 |
{
|
13 |
path: `images/${name}.png`,
|
14 |
content: blob,
|
15 |
},
|
|
|
16 |
});
|
17 |
|
18 |
if (res?.error) return {
|
components/button/index.tsx
CHANGED
@@ -15,7 +15,7 @@ export const Button: React.FC<Props> = ({
|
|
15 |
return (
|
16 |
<button
|
17 |
className={classNames(
|
18 |
-
"rounded-full px-6 py-2.5 font-semibold flex items-center justify-center gap-2.5 border-[2px] transition-all duration-200 max-w-max",
|
19 |
{
|
20 |
"bg-primary text-white border-primary": theme === "primary",
|
21 |
"bg-white text-gray-900 border-white": theme === "white",
|
|
|
15 |
return (
|
16 |
<button
|
17 |
className={classNames(
|
18 |
+
"rounded-full px-4 py-1.5 lg:px-6 lg:py-2.5 text-sm lg:text-base font-semibold flex items-center justify-center gap-2.5 border-[2px] transition-all duration-200 max-w-max",
|
19 |
{
|
20 |
"bg-primary text-white border-primary": theme === "primary",
|
21 |
"bg-white text-gray-900 border-white": theme === "white",
|
components/main/index.tsx
CHANGED
@@ -4,6 +4,8 @@ import { useState } from "react";
|
|
4 |
import { HiUserGroup, HiHeart, HiAdjustmentsHorizontal } from "react-icons/hi2";
|
5 |
import Link from "next/link";
|
6 |
import Image from "next/image";
|
|
|
|
|
7 |
|
8 |
import { InputGeneration } from "@/components/input-generation";
|
9 |
import { Button } from "@/components/button";
|
@@ -17,35 +19,38 @@ const categories = [
|
|
17 |
{
|
18 |
key: "community",
|
19 |
label: "Community",
|
20 |
-
icon: <HiUserGroup className="text-2xl" />,
|
21 |
},
|
22 |
{
|
23 |
key: "my-own",
|
24 |
label: "My generations",
|
25 |
isLogged: true,
|
26 |
-
icon: <HiHeart className="text-2xl" />,
|
27 |
},
|
28 |
];
|
29 |
|
|
|
|
|
30 |
export const Main = () => {
|
31 |
const { openWindowLogin, user } = useUser();
|
|
|
32 |
const { list_styles, style, setStyle, loading } = useInputGeneration();
|
33 |
const [category, setCategory] = useState<string>("community");
|
34 |
const [advancedSettings, setAdvancedSettings] = useState<boolean>(false);
|
35 |
|
36 |
-
console.log("user", user);
|
37 |
-
|
38 |
return (
|
39 |
<main className="px-6 z-[2] relative max-w-[1722px] mx-auto">
|
40 |
<div className="py-2 pl-2 pr-2 lg:pr-4 bg-black bg-opacity-30 backdrop-blur-sm lg:sticky lg:top-4 z-10 rounded-full">
|
41 |
<div className="flex flex-col lg:flex-row items-center justify-between w-full">
|
42 |
<InputGeneration />
|
43 |
-
<div className="items-center justify-center lg:justify-end gap-5 w-full mt-6 lg:mt-0 flex">
|
44 |
{categories.map(({ key, label, icon, isLogged }) =>
|
45 |
isLogged && !user ? (
|
46 |
<img
|
47 |
key={key}
|
48 |
-
src=
|
|
|
|
|
49 |
className="cursor-pointer hover:-translate-y-1 transition-all duration-200"
|
50 |
onClick={openWindowLogin}
|
51 |
/>
|
@@ -63,25 +68,37 @@ export const Main = () => {
|
|
63 |
)}
|
64 |
</div>
|
65 |
</div>
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
</div>
|
86 |
<p
|
87 |
className="text-white/70 font-medium text-sm flex items-center justify-center lg:justify-start gap-2 hover:text-white cursor-pointer mt-3"
|
|
|
4 |
import { HiUserGroup, HiHeart, HiAdjustmentsHorizontal } from "react-icons/hi2";
|
5 |
import Link from "next/link";
|
6 |
import Image from "next/image";
|
7 |
+
import classNames from "classnames";
|
8 |
+
import { createBreakpoint } from "react-use";
|
9 |
|
10 |
import { InputGeneration } from "@/components/input-generation";
|
11 |
import { Button } from "@/components/button";
|
|
|
19 |
{
|
20 |
key: "community",
|
21 |
label: "Community",
|
22 |
+
icon: <HiUserGroup className="text-lg lg:text-2xl" />,
|
23 |
},
|
24 |
{
|
25 |
key: "my-own",
|
26 |
label: "My generations",
|
27 |
isLogged: true,
|
28 |
+
icon: <HiHeart className="text-lg lg:text-2xl" />,
|
29 |
},
|
30 |
];
|
31 |
|
32 |
+
const useBreakpoint = createBreakpoint({ L: 1024, XS: 640 });
|
33 |
+
|
34 |
export const Main = () => {
|
35 |
const { openWindowLogin, user } = useUser();
|
36 |
+
const breakpoint = useBreakpoint();
|
37 |
const { list_styles, style, setStyle, loading } = useInputGeneration();
|
38 |
const [category, setCategory] = useState<string>("community");
|
39 |
const [advancedSettings, setAdvancedSettings] = useState<boolean>(false);
|
40 |
|
|
|
|
|
41 |
return (
|
42 |
<main className="px-6 z-[2] relative max-w-[1722px] mx-auto">
|
43 |
<div className="py-2 pl-2 pr-2 lg:pr-4 bg-black bg-opacity-30 backdrop-blur-sm lg:sticky lg:top-4 z-10 rounded-full">
|
44 |
<div className="flex flex-col lg:flex-row items-center justify-between w-full">
|
45 |
<InputGeneration />
|
46 |
+
<div className="items-center justify-center flex-col lg:flex-row lg:justify-end gap-5 w-full mt-6 lg:mt-0 flex">
|
47 |
{categories.map(({ key, label, icon, isLogged }) =>
|
48 |
isLogged && !user ? (
|
49 |
<img
|
50 |
key={key}
|
51 |
+
src={`https://huggingface.co/datasets/huggingface/badges/resolve/main/sign-in-with-huggingface-${
|
52 |
+
breakpoint === "XS" ? "lg" : "xl"
|
53 |
+
}.svg`}
|
54 |
className="cursor-pointer hover:-translate-y-1 transition-all duration-200"
|
55 |
onClick={openWindowLogin}
|
56 |
/>
|
|
|
68 |
)}
|
69 |
</div>
|
70 |
</div>
|
71 |
+
<div
|
72 |
+
className={classNames(
|
73 |
+
"flex items-center justify-center lg:justify-end text-right gap-1 mt-4 lg:mt-0",
|
74 |
+
{
|
75 |
+
"text-gray-300 text-xs": !user?.sub,
|
76 |
+
"text-white text-sm": user?.sub,
|
77 |
+
}
|
78 |
+
)}
|
79 |
+
>
|
80 |
+
{user?.sub ? (
|
81 |
+
<>
|
82 |
+
Logged as
|
83 |
+
<Link
|
84 |
+
href={user?.profile}
|
85 |
+
target="_blank"
|
86 |
+
className="hover:text-blue-500 flex items-center justify-end gap-2"
|
87 |
+
>
|
88 |
+
@{user?.preferred_username}
|
89 |
+
<Image
|
90 |
+
src={user?.picture}
|
91 |
+
width={24}
|
92 |
+
height={24}
|
93 |
+
className="rounded-full ring-1 ring-white/60 border border-black"
|
94 |
+
alt={user?.preferred_username}
|
95 |
+
/>
|
96 |
+
</Link>
|
97 |
+
</>
|
98 |
+
) : (
|
99 |
+
"to save your generations in your own gallery"
|
100 |
+
)}
|
101 |
+
</div>
|
102 |
</div>
|
103 |
<p
|
104 |
className="text-white/70 font-medium text-sm flex items-center justify-center lg:justify-start gap-2 hover:text-white cursor-pointer mt-3"
|
prisma/schema.prisma
CHANGED
@@ -4,7 +4,7 @@ generator client {
|
|
4 |
|
5 |
datasource db {
|
6 |
provider = "sqlite"
|
7 |
-
url = "file
|
8 |
}
|
9 |
|
10 |
model Collection {
|
|
|
4 |
|
5 |
datasource db {
|
6 |
provider = "sqlite"
|
7 |
+
url = "file:../data/dev.db"
|
8 |
}
|
9 |
|
10 |
model Collection {
|
utils/checker/image.ts
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
export const isImageNSFW = async (blob: Blob, global_headers: any) => {
|
2 |
return new Promise(async (resolve, reject) => {
|
|
|
3 |
const headers = new Headers();
|
4 |
headers.set("Content-Type", "image/*");
|
5 |
-
const request = await fetch(`${process.env.
|
6 |
method: 'POST',
|
7 |
headers: {
|
8 |
...global_headers,
|
@@ -15,10 +16,11 @@ export const isImageNSFW = async (blob: Blob, global_headers: any) => {
|
|
15 |
if (res?.error && res?.estimated_time) {
|
16 |
setTimeout(() => {
|
17 |
isImageNSFW(blob, global_headers)
|
18 |
-
}, res?.estimated_time
|
19 |
} else {
|
20 |
if (res?.error) return Response.json({ status: 500, ok: false, message: res?.error });
|
21 |
if (res?.length) {
|
|
|
22 |
const isNSFW = res?.find((n: { label: string }) => n.label === "no_safe")?.score > 0.85 ?? false;
|
23 |
resolve(isNSFW)
|
24 |
}
|
|
|
1 |
export const isImageNSFW = async (blob: Blob, global_headers: any) => {
|
2 |
return new Promise(async (resolve, reject) => {
|
3 |
+
console.log("isImageNSFW running...")
|
4 |
const headers = new Headers();
|
5 |
headers.set("Content-Type", "image/*");
|
6 |
+
const request = await fetch(`${process.env.INFERENCE_API_URL}/models/DamarJati/NSFW-Filterization-DecentScan`, {
|
7 |
method: 'POST',
|
8 |
headers: {
|
9 |
...global_headers,
|
|
|
16 |
if (res?.error && res?.estimated_time) {
|
17 |
setTimeout(() => {
|
18 |
isImageNSFW(blob, global_headers)
|
19 |
+
}, res?.estimated_time);
|
20 |
} else {
|
21 |
if (res?.error) return Response.json({ status: 500, ok: false, message: res?.error });
|
22 |
if (res?.length) {
|
23 |
+
console.log(res)
|
24 |
const isNSFW = res?.find((n: { label: string }) => n.label === "no_safe")?.score > 0.85 ?? false;
|
25 |
resolve(isNSFW)
|
26 |
}
|
utils/checker/prompt.ts
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
export const isTextNSFW = async (prompt: string, headers: any) => {
|
2 |
return new Promise(async (resolve, reject) => {
|
3 |
-
const request = await fetch(`${process.env.
|
4 |
method: 'POST',
|
5 |
body: JSON.stringify({
|
6 |
inputs: prompt,
|
|
|
1 |
export const isTextNSFW = async (prompt: string, headers: any) => {
|
2 |
return new Promise(async (resolve, reject) => {
|
3 |
+
const request = await fetch(`${process.env.INFERENCE_API_URL}/models/michellejieli/NSFW_text_classifier`, {
|
4 |
method: 'POST',
|
5 |
body: JSON.stringify({
|
6 |
inputs: prompt,
|