File size: 2,491 Bytes
2a63a7e
5881efa
cca515d
2a63a7e
cca515d
 
02b9873
2a63a7e
cca515d
02b9873
cca515d
02b9873
 
 
 
 
 
 
 
cca515d
02b9873
5881efa
 
 
 
 
4a320f9
5881efa
 
 
 
 
 
 
 
 
6b92c38
5881efa
02b9873
 
5881efa
 
 
02b9873
 
5881efa
2a63a7e
5881efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cca515d
 
 
5881efa
 
2a63a7e
6f0b822
 
 
 
 
 
 
 
cca515d
4a320f9
2a63a7e
 
 
 
 
02b9873
 
 
 
2a63a7e
 
4f2c36e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import { useState } from "react"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
import { useLocalStorage } from 'react-use';

import { Collection } from "@/type"
import list_styles from "@/assets/list_styles.json"

export const useInputGeneration = () => {
  const [myGenerationsId, setGenerationsId] = useLocalStorage<any>('my-own-generations', []);
  const [style, setStyle] = useState<string>(list_styles[0].name)

  const { data: prompt } = useQuery(["prompt"], () => {
    return ''
  }, {
    refetchOnWindowFocus: false,
    refetchOnMount: false,
    refetchOnReconnect: false,
    initialData: ''
  })
  const setPrompt = (str:string) => client.setQueryData(["prompt"], () => str)

  const client = useQueryClient()

  const { mutate: submit, isLoading: loading } = useMutation(
    ["generation"],
    async () => {
      if (!hasMadeFirstGeneration) setFirstGenerationDone()
      client.setQueryData(["collections"], (old) => {
        return [{
          id: -1,
          blob: {
            type: "image/png",
            data: new ArrayBuffer(0),
          },
          prompt
        }, ...old as Collection[]]
      })

      const findStyle = list_styles.find((item) => item.name === style)

      const response = await fetch("/api", {
        method: "POST",
        body: JSON.stringify({
          inputs: findStyle?.prompt.replace("{prompt}", prompt) ?? prompt,
          negative_prompt: findStyle?.negative_prompt ?? "",
        }),
      })
      const data = await response.json()

      if (!response.ok) {
        throw new Error(data.message)
      }

      client.setQueryData(["collections"], (old) => {
        const newArray = [...old as Collection[]]
        const index = newArray.findIndex((item: Collection) => item.id === -1)

        newArray[index] = data?.blob as Collection

        return newArray
      })

      setGenerationsId(myGenerationsId?.length ? [...myGenerationsId, data?.blob?.id] : [data?.blob?.id])

      return data ?? {}
    },
  )

  const { data: hasMadeFirstGeneration } = useQuery(["firstGenerationDone"], () => {
    return false
  }, {
    refetchOnWindowFocus: false,
    refetchOnMount: false,
    refetchOnReconnect: false,
    initialData: false
  })
  const setFirstGenerationDone = () => client.setQueryData(["firstGenerationDone"], () => true)

  return {
    prompt,
    setPrompt,
    loading,
    submit,
    hasMadeFirstGeneration,
    list_styles,
    style,
    setStyle
  }

}