File size: 4,101 Bytes
19b3da3
 
 
 
 
7fbdac4
19b3da3
 
 
 
 
 
10230ea
 
 
 
 
 
 
 
19b3da3
f1235a4
19b3da3
 
 
 
 
 
 
 
 
10230ea
 
19b3da3
 
 
 
 
 
 
 
 
f1235a4
 
 
 
 
 
19b3da3
 
 
 
 
 
 
 
10230ea
 
19b3da3
 
 
 
 
 
 
 
 
 
f1235a4
19b3da3
10230ea
 
 
 
 
 
19b3da3
 
 
 
 
42ef134
19b3da3
 
 
 
f1235a4
19b3da3
10230ea
 
 
 
19b3da3
 
 
 
 
 
 
 
10230ea
 
 
 
19b3da3
 
7fbdac4
 
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import traceback
from typing import Dict, List, Optional

import requests
from pydash import includes
from requests.adapters import HTTPAdapter, Retry

from internals.data.task import Task
from internals.util.config import api_endpoint, api_headers
from internals.util.slack import Slack


class RetryRequest:
    def __new__(cls):
        obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
        session = requests.Session()
        session.mount("https://", HTTPAdapter(max_retries=obj))
        return session


def updateSource(sourceId, userId, state):
    url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
    headers = {
        "Content-Type": "application/json",
        "user-id": str(userId),
        **api_headers(),
    }

    data = {"state": state}

    try:
        with RetryRequest() as session:
            response = session.patch(url, headers=headers, json=data, timeout=10)
    except requests.exceptions.Timeout:
        print("Request timed out while updating source")
    except requests.exceptions.RequestException as e:
        print(f"Error while updating source: {e}")

    return


def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
    url = (
        api_endpoint()
        + "/autodraft-crecoai/source/"
        + str(sourceId)
        + "/generatedImages"
    )
    headers = {
        "Content-Type": "application/json",
        "user-id": str(userId),
        **api_headers(),
    }
    data = {"state": "ACTIVE", "has_nsfw": has_nsfw}

    try:
        with RetryRequest() as session:
            session.patch(url, headers=headers, json=data)
        # print("save generation response", response)
    except requests.exceptions.Timeout:
        print("Request timed out while saving image")
    except requests.exceptions.RequestException as e:
        print("Failed to mark source as active: ", e)
        return
    return


def getStyles() -> Optional[Dict]:
    url = api_endpoint() + "/autodraft-crecoai/style"
    try:
        with RetryRequest() as session:
            response = session.get(
                url,
                timeout=10,
                headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
            )
        return response.json()
    except requests.exceptions.Timeout:
        print("Request timed out while fetching styles")
    except requests.exceptions.RequestException as e:
        print(f"Error while fetching styles: {e}")
        raise e
    return None


def getCharacters(model_id: str) -> Optional[List]:
    url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
    try:
        with RetryRequest() as session:
            response = session.get(url, timeout=10, headers=api_headers())
            response = response.json()
            response = response["data"]["characters"]
        return response
    except requests.exceptions.Timeout:
        print("Request timed out while fetching characters")
    except Exception as e:
        print(f"Error while fetching characters: {e}")
    return None


def update_db_source_failed(sourceId, userId):
    updateSource(sourceId, userId, "FAILED")


def update_db(func):
    def caller(*args, **kwargs):
        task = None
        for arg in args:
            if type(arg) is Task:
                task = arg
                break
        if task is None:
            raise Exception("First argument must be a Task object")
        try:
            updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
            rargs = func(*args, **kwargs)
            has_nsfw = rargs.get("has_nsfw", False)
            updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
            saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
            return rargs
        except Exception as e:
            print("Error processing image: {}".format(str(e)))
            traceback.print_exc()
            slack = Slack()
            slack.error_alert(task, e)
            updateSource(task.get_sourceId(), task.get_userId(), "FAILED")

    return caller