File size: 3,207 Bytes
e6a4021
19b3da3
1bc457e
 
19b3da3
 
fd5252e
19b3da3
1bc457e
19b3da3
 
 
 
fd5252e
e6a4021
 
 
1bc457e
9387217
a3f5c82
19b3da3
86248f3
 
1bc457e
 
 
 
 
 
 
 
 
 
 
 
19b3da3
 
 
 
 
 
fd5252e
 
 
 
 
19b3da3
22df957
19b3da3
1bc457e
19b3da3
1bc457e
 
19b3da3
 
 
a3f5c82
22df957
19b3da3
 
b71808f
fd5252e
 
 
 
 
 
 
b71808f
 
a3f5c82
9387217
 
 
 
a3f5c82
 
10230ea
 
 
 
 
19b3da3
 
 
 
 
22df957
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b71808f
 
 
 
 
22df957
 
 
 
 
 
 
bcaef47
22df957
 
 
 
 
 
 
19b3da3
 
 
 
 
 
 
 
f1235a4
19b3da3
5e62aa8
19b3da3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import base64
import os
from pathlib import Path
from typing import Union

from internals.data.task import Task
from internals.util.model_loader import ModelConfig

env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_config = None
hf_token = base64.b64decode(
    b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
).decode()
hf_cache_dir = "/tmp/hf_hub"

base_dimension = 512  # needed for high res

num_return_sequences = 4  # the number of results to generate

os.makedirs(hf_cache_dir, exist_ok=True)


def set_hf_cache_dir(dir: Union[str, Path]):
    global hf_cache_dir
    hf_cache_dir = str(dir)


def get_hf_cache_dir():
    global hf_cache_dir
    return hf_cache_dir


def set_root_dir(main_file: str):
    global root_dir
    root_dir = os.path.dirname(os.path.abspath(main_file))


def set_model_config(config: ModelConfig):
    global model_config
    model_config = config


def set_configs_from_task(task: Task):
    global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences
    name = task.get_queue_name()
    if name.startswith("gamma"):
        env = "gamma"
    else:
        env = "prod"
    nsfw_threshold = task.get_nsfw_threshold()
    nsfw_access = task.can_access_nsfw()
    access_token = task.get_access_token()
    base_dimension = task.get_base_dimension()
    num_return_sequences = task.get_num_return_sequences()


def get_model_dir():
    global model_config
    return model_config.base_model_path  # pyright: ignore


def get_inpaint_model_path():
    global model_config
    return model_config.base_inpaint_model_path  # pyright: ignore


def get_base_dimension():
    global global_base_dimension, base_dimension
    if base_dimension:
        return base_dimension
    return model_config.base_dimension  # pyright: ignore


def get_is_sdxl():
    global model_config
    return model_config.is_sdxl  # pyright: ignore


def get_root_dir():
    global root_dir
    return root_dir


def get_num_return_sequences():
    global num_return_sequences
    return num_return_sequences


def get_environment():
    global env
    return env


def get_nsfw_threshold():
    global nsfw_threshold
    return nsfw_threshold


def get_nsfw_access():
    global nsfw_access
    return nsfw_access


def get_hf_token():
    global hf_token
    return hf_token


def get_low_gpu_mem():
    global model_config
    return model_config.low_gpu_mem  # pyright: ignore


def get_base_model_variant():
    global model_config
    return model_config.base_model_variant  # pyright: ignore


def get_base_inpaint_model_variant():
    global model_config
    return model_config.base_inpaint_model_variant  # pyright: ignore


def api_headers():
    return {
        "Access-Token": access_token,
    }


def api_endpoint():
    if env == "prod":
        return "https://api.autodraft.in"
    else:
        return "https://gamma-api.autodraft.in"


def comic_url():
    if env == "prod":
        return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
    else:
        return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"