enzostvs HF Staff commited on
Commit
7dc9f9f
·
1 Parent(s): 6def2c4

add sdxl google

Browse files
app/api/route.ts CHANGED
@@ -22,19 +22,21 @@ export async function POST(
22
  const textIsNSFW = await isTextNSFW(inputs, global_headers)
23
  if (textIsNSFW) return Response.json({ status: 401, ok: false, message: "Prompt is not safe for work." });
24
 
25
- const response = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/stabilityai/stable-diffusion-xl-base-1.0`, {
26
  method: 'POST',
27
  body: JSON.stringify({
28
- inputs: findStyle?.prompt.replace("{prompt}", inputs) ?? inputs,
29
  negative_prompt: findStyle?.negative_prompt ?? "",
30
  }),
31
  headers: global_headers,
32
  })
33
 
 
34
  const res = await response.clone().json().catch(() => ({}));
35
  if (res?.error) return Response.json({ status: response.status, ok: false, message: res.error });
36
-
37
- const blob = await response.blob()
 
38
 
39
  const imageIsNSFW = await isImageNSFW(blob, global_headers)
40
  if (imageIsNSFW) return Response.json({ status: 401, ok: false, message: "Image is not safe for work." });
@@ -51,7 +53,7 @@ export async function POST(
51
  userId: userId ?? "",
52
  },
53
  })
54
-
55
  return Response.json({ image: new_image, status: 200, ok: true });
56
 
57
  }
 
22
  const textIsNSFW = await isTextNSFW(inputs, global_headers)
23
  if (textIsNSFW) return Response.json({ status: 401, ok: false, message: "Prompt is not safe for work." });
24
 
25
+ const response = await fetch(`${process.env.API_SDXL_URL}`, {
26
  method: 'POST',
27
  body: JSON.stringify({
28
+ prompt: findStyle?.prompt.replace("{prompt}", inputs) ?? inputs,
29
  negative_prompt: findStyle?.negative_prompt ?? "",
30
  }),
31
  headers: global_headers,
32
  })
33
 
34
+
35
  const res = await response.clone().json().catch(() => ({}));
36
  if (res?.error) return Response.json({ status: response.status, ok: false, message: res.error });
37
+
38
+ const base64Image = res.images[0];
39
+ const blob = await fetch(`data:image/png;base64,${base64Image}`).then((r) => r.blob());
40
 
41
  const imageIsNSFW = await isImageNSFW(blob, global_headers)
42
  if (imageIsNSFW) return Response.json({ status: 401, ok: false, message: "Image is not safe for work." });
 
53
  userId: userId ?? "",
54
  },
55
  })
56
+
57
  return Response.json({ image: new_image, status: 200, ok: true });
58
 
59
  }
app/api/uploader.ts CHANGED
@@ -1,19 +1,18 @@
1
- import { uploadFiles } from "./../../node_modules/@huggingface/hub/dist";
2
  import type { RepoDesignation, Credentials } from "./../../node_modules/@huggingface/hub/dist";
3
 
4
  export const UploaderDataset = async (blob: Blob, name: string) => {
5
  const repo: RepoDesignation = { type: "dataset", name: "enzostvs/stable-diffusion-tpu-generations" };
6
  const credentials: Credentials = { accessToken: process.env.NEXT_APP_HF_TOKEN as string };
7
 
8
- const res: any = await uploadFiles({
9
  repo,
10
  credentials,
11
- files: [
12
  {
13
  path: `images/${name}.png`,
14
  content: blob,
15
  },
16
- ],
17
  });
18
 
19
  if (res?.error) return {
 
1
+ import { uploadFile } from "./../../node_modules/@huggingface/hub/dist";
2
  import type { RepoDesignation, Credentials } from "./../../node_modules/@huggingface/hub/dist";
3
 
4
  export const UploaderDataset = async (blob: Blob, name: string) => {
5
  const repo: RepoDesignation = { type: "dataset", name: "enzostvs/stable-diffusion-tpu-generations" };
6
  const credentials: Credentials = { accessToken: process.env.NEXT_APP_HF_TOKEN as string };
7
 
8
+ const res: any = await uploadFile({
9
  repo,
10
  credentials,
11
+ file:
12
  {
13
  path: `images/${name}.png`,
14
  content: blob,
15
  },
 
16
  });
17
 
18
  if (res?.error) return {
components/button/index.tsx CHANGED
@@ -15,7 +15,7 @@ export const Button: React.FC<Props> = ({
15
  return (
16
  <button
17
  className={classNames(
18
- "rounded-full px-6 py-2.5 font-semibold flex items-center justify-center gap-2.5 border-[2px] transition-all duration-200 max-w-max",
19
  {
20
  "bg-primary text-white border-primary": theme === "primary",
21
  "bg-white text-gray-900 border-white": theme === "white",
 
15
  return (
16
  <button
17
  className={classNames(
18
+ "rounded-full px-4 py-1.5 lg:px-6 lg:py-2.5 text-sm lg:text-base font-semibold flex items-center justify-center gap-2.5 border-[2px] transition-all duration-200 max-w-max",
19
  {
20
  "bg-primary text-white border-primary": theme === "primary",
21
  "bg-white text-gray-900 border-white": theme === "white",
components/main/index.tsx CHANGED
@@ -4,6 +4,8 @@ import { useState } from "react";
4
  import { HiUserGroup, HiHeart, HiAdjustmentsHorizontal } from "react-icons/hi2";
5
  import Link from "next/link";
6
  import Image from "next/image";
 
 
7
 
8
  import { InputGeneration } from "@/components/input-generation";
9
  import { Button } from "@/components/button";
@@ -17,35 +19,38 @@ const categories = [
17
  {
18
  key: "community",
19
  label: "Community",
20
- icon: <HiUserGroup className="text-2xl" />,
21
  },
22
  {
23
  key: "my-own",
24
  label: "My generations",
25
  isLogged: true,
26
- icon: <HiHeart className="text-2xl" />,
27
  },
28
  ];
29
 
 
 
30
  export const Main = () => {
31
  const { openWindowLogin, user } = useUser();
 
32
  const { list_styles, style, setStyle, loading } = useInputGeneration();
33
  const [category, setCategory] = useState<string>("community");
34
  const [advancedSettings, setAdvancedSettings] = useState<boolean>(false);
35
 
36
- console.log("user", user);
37
-
38
  return (
39
  <main className="px-6 z-[2] relative max-w-[1722px] mx-auto">
40
  <div className="py-2 pl-2 pr-2 lg:pr-4 bg-black bg-opacity-30 backdrop-blur-sm lg:sticky lg:top-4 z-10 rounded-full">
41
  <div className="flex flex-col lg:flex-row items-center justify-between w-full">
42
  <InputGeneration />
43
- <div className="items-center justify-center lg:justify-end gap-5 w-full mt-6 lg:mt-0 flex">
44
  {categories.map(({ key, label, icon, isLogged }) =>
45
  isLogged && !user ? (
46
  <img
47
  key={key}
48
- src="https://huggingface.co/datasets/huggingface/badges/resolve/main/sign-in-with-huggingface-xl.svg"
 
 
49
  className="cursor-pointer hover:-translate-y-1 transition-all duration-200"
50
  onClick={openWindowLogin}
51
  />
@@ -63,25 +68,37 @@ export const Main = () => {
63
  )}
64
  </div>
65
  </div>
66
- {user && (
67
- <div className="flex items-center justify-center lg:justify-end text-white text-right text-sm gap-1 mt-4 lg:mt-0">
68
- Logged as
69
- <Link
70
- href={user?.profile}
71
- target="_blank"
72
- className="hover:text-blue-500 flex items-center justify-end gap-2"
73
- >
74
- @{user?.preferred_username}
75
- <Image
76
- src={user?.picture}
77
- width={24}
78
- height={24}
79
- className="rounded-full ring-1 ring-white/60 border border-black"
80
- alt={user?.preferred_username}
81
- />
82
- </Link>
83
- </div>
84
- )}
 
 
 
 
 
 
 
 
 
 
 
 
85
  </div>
86
  <p
87
  className="text-white/70 font-medium text-sm flex items-center justify-center lg:justify-start gap-2 hover:text-white cursor-pointer mt-3"
 
4
  import { HiUserGroup, HiHeart, HiAdjustmentsHorizontal } from "react-icons/hi2";
5
  import Link from "next/link";
6
  import Image from "next/image";
7
+ import classNames from "classnames";
8
+ import { createBreakpoint } from "react-use";
9
 
10
  import { InputGeneration } from "@/components/input-generation";
11
  import { Button } from "@/components/button";
 
19
  {
20
  key: "community",
21
  label: "Community",
22
+ icon: <HiUserGroup className="text-lg lg:text-2xl" />,
23
  },
24
  {
25
  key: "my-own",
26
  label: "My generations",
27
  isLogged: true,
28
+ icon: <HiHeart className="text-lg lg:text-2xl" />,
29
  },
30
  ];
31
 
32
+ const useBreakpoint = createBreakpoint({ L: 1024, XS: 640 });
33
+
34
  export const Main = () => {
35
  const { openWindowLogin, user } = useUser();
36
+ const breakpoint = useBreakpoint();
37
  const { list_styles, style, setStyle, loading } = useInputGeneration();
38
  const [category, setCategory] = useState<string>("community");
39
  const [advancedSettings, setAdvancedSettings] = useState<boolean>(false);
40
 
 
 
41
  return (
42
  <main className="px-6 z-[2] relative max-w-[1722px] mx-auto">
43
  <div className="py-2 pl-2 pr-2 lg:pr-4 bg-black bg-opacity-30 backdrop-blur-sm lg:sticky lg:top-4 z-10 rounded-full">
44
  <div className="flex flex-col lg:flex-row items-center justify-between w-full">
45
  <InputGeneration />
46
+ <div className="items-center justify-center flex-col lg:flex-row lg:justify-end gap-5 w-full mt-6 lg:mt-0 flex">
47
  {categories.map(({ key, label, icon, isLogged }) =>
48
  isLogged && !user ? (
49
  <img
50
  key={key}
51
+ src={`https://huggingface.co/datasets/huggingface/badges/resolve/main/sign-in-with-huggingface-${
52
+ breakpoint === "XS" ? "lg" : "xl"
53
+ }.svg`}
54
  className="cursor-pointer hover:-translate-y-1 transition-all duration-200"
55
  onClick={openWindowLogin}
56
  />
 
68
  )}
69
  </div>
70
  </div>
71
+ <div
72
+ className={classNames(
73
+ "flex items-center justify-center lg:justify-end text-right gap-1 mt-4 lg:mt-0",
74
+ {
75
+ "text-gray-300 text-xs": !user?.sub,
76
+ "text-white text-sm": user?.sub,
77
+ }
78
+ )}
79
+ >
80
+ {user?.sub ? (
81
+ <>
82
+ Logged as
83
+ <Link
84
+ href={user?.profile}
85
+ target="_blank"
86
+ className="hover:text-blue-500 flex items-center justify-end gap-2"
87
+ >
88
+ @{user?.preferred_username}
89
+ <Image
90
+ src={user?.picture}
91
+ width={24}
92
+ height={24}
93
+ className="rounded-full ring-1 ring-white/60 border border-black"
94
+ alt={user?.preferred_username}
95
+ />
96
+ </Link>
97
+ </>
98
+ ) : (
99
+ "to save your generations in your own gallery"
100
+ )}
101
+ </div>
102
  </div>
103
  <p
104
  className="text-white/70 font-medium text-sm flex items-center justify-center lg:justify-start gap-2 hover:text-white cursor-pointer mt-3"
prisma/schema.prisma CHANGED
@@ -4,7 +4,7 @@ generator client {
4
 
5
  datasource db {
6
  provider = "sqlite"
7
- url = "file://data/dev.db"
8
  }
9
 
10
  model Collection {
 
4
 
5
  datasource db {
6
  provider = "sqlite"
7
+ url = "file:../data/dev.db"
8
  }
9
 
10
  model Collection {
utils/checker/image.ts CHANGED
@@ -1,8 +1,9 @@
1
  export const isImageNSFW = async (blob: Blob, global_headers: any) => {
2
  return new Promise(async (resolve, reject) => {
 
3
  const headers = new Headers();
4
  headers.set("Content-Type", "image/*");
5
- const request = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/DamarJati/NSFW-Filterization-DecentScan`, {
6
  method: 'POST',
7
  headers: {
8
  ...global_headers,
@@ -15,10 +16,11 @@ export const isImageNSFW = async (blob: Blob, global_headers: any) => {
15
  if (res?.error && res?.estimated_time) {
16
  setTimeout(() => {
17
  isImageNSFW(blob, global_headers)
18
- }, res?.estimated_time * 100);
19
  } else {
20
  if (res?.error) return Response.json({ status: 500, ok: false, message: res?.error });
21
  if (res?.length) {
 
22
  const isNSFW = res?.find((n: { label: string }) => n.label === "no_safe")?.score > 0.85 ?? false;
23
  resolve(isNSFW)
24
  }
 
1
  export const isImageNSFW = async (blob: Blob, global_headers: any) => {
2
  return new Promise(async (resolve, reject) => {
3
+ console.log("isImageNSFW running...")
4
  const headers = new Headers();
5
  headers.set("Content-Type", "image/*");
6
+ const request = await fetch(`${process.env.INFERENCE_API_URL}/models/DamarJati/NSFW-Filterization-DecentScan`, {
7
  method: 'POST',
8
  headers: {
9
  ...global_headers,
 
16
  if (res?.error && res?.estimated_time) {
17
  setTimeout(() => {
18
  isImageNSFW(blob, global_headers)
19
+ }, res?.estimated_time);
20
  } else {
21
  if (res?.error) return Response.json({ status: 500, ok: false, message: res?.error });
22
  if (res?.length) {
23
+ console.log(res)
24
  const isNSFW = res?.find((n: { label: string }) => n.label === "no_safe")?.score > 0.85 ?? false;
25
  resolve(isNSFW)
26
  }
utils/checker/prompt.ts CHANGED
@@ -1,6 +1,6 @@
1
  export const isTextNSFW = async (prompt: string, headers: any) => {
2
  return new Promise(async (resolve, reject) => {
3
- const request = await fetch(`${process.env.NEXT_PUBLIC_API_URL}/models/michellejieli/NSFW_text_classifier`, {
4
  method: 'POST',
5
  body: JSON.stringify({
6
  inputs: prompt,
 
1
  export const isTextNSFW = async (prompt: string, headers: any) => {
2
  return new Promise(async (resolve, reject) => {
3
+ const request = await fetch(`${process.env.INFERENCE_API_URL}/models/michellejieli/NSFW_text_classifier`, {
4
  method: 'POST',
5
  body: JSON.stringify({
6
  inputs: prompt,