File size: 3,554 Bytes
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
04c5fb7
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04c5fb7
 
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04c5fb7
 
19b3da3
 
 
 
 
 
 
 
 
 
04c5fb7
19b3da3
 
 
 
 
04c5fb7
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import traceback
from typing import Dict, List, Optional

import requests
from pydash import includes

from internals.data.task import Task
from internals.util.config import api_endpoint, api_headers
from internals.util.slack import Slack


def updateSource(sourceId, userId, state):
    print("update source is called")
    url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
    headers = {
        "Content-Type": "application/json",
        "user-id": str(userId),
        **api_headers(),
    }

    data = {"state": state}

    try:
        response = requests.patch(url, headers=headers, json=data, timeout=10)
        print("update source response", response)
    except requests.exceptions.Timeout:
        print("Request timed out while updating source")
    except requests.exceptions.RequestException as e:
        print(f"Error while updating source: {e}")

    return


def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
    print("save generation called")
    url = (
        api_endpoint()
        + "/autodraft-crecoai/source/"
        + str(sourceId)
        + "/generatedImages"
    )
    headers = {
        "Content-Type": "application/json",
        "user-id": str(userId),
        **api_headers(),
    }
    data = {"state": "ACTIVE", "has_nsfw": has_nsfw}

    try:
        requests.patch(url, headers=headers, json=data)
        # print("save generation response", response)
    except requests.exceptions.Timeout:
        print("Request timed out while saving image")
    except requests.exceptions.RequestException as e:
        print("Failed to mark source as active: ", e)
        return
    return


def getStyles() -> Optional[Dict]:
    url = api_endpoint() + "/autodraft-crecoai/style"
    print(url)
    try:
        response = requests.get(
            url,
            timeout=10,
            headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
        )
        return response.json()
    except requests.exceptions.Timeout:
        print("Request timed out while fetching styles")
    except requests.exceptions.RequestException as e:
        raise e
        print(f"Error while fetching styles: {e}")
    return None


def getCharacters(model_id: str) -> Optional[List]:
    url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
    try:
        response = requests.get(url, timeout=10, headers=api_headers())
        response = response.json()
        response = response["data"]["characters"]
        return response
    except requests.exceptions.Timeout:
        print("Request timed out while fetching characters")
    except Exception as e:
        print(f"Error while fetching characters: {e}")
    return None


def update_db(func):
    def caller(*args, **kwargs):
        if type(args[0]) is not Task:
            raise Exception("First argument must be a Task object")
        task = args[0]
        try:
            updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
            rargs = func(*args, **kwargs)
            has_nsfw = rargs.get("has_nsfw", False)
            updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
            saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
            return rargs
        except Exception as e:
            print("Error processing image: {}".format(str(e)))
            traceback.print_exc()
            slack = Slack()
            slack.error_alert(task, e)
            updateSource(task.get_sourceId(), task.get_userId(), "FAILED")

    return caller