File size: 3,207 Bytes
e6a4021 19b3da3 1bc457e 19b3da3 fd5252e 19b3da3 1bc457e 19b3da3 fd5252e e6a4021 1bc457e 9387217 a3f5c82 19b3da3 86248f3 1bc457e 19b3da3 fd5252e 19b3da3 22df957 19b3da3 1bc457e 19b3da3 1bc457e 19b3da3 a3f5c82 22df957 19b3da3 b71808f fd5252e b71808f a3f5c82 9387217 a3f5c82 10230ea 19b3da3 22df957 19b3da3 b71808f 22df957 bcaef47 22df957 19b3da3 f1235a4 19b3da3 5e62aa8 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import base64
import os
from pathlib import Path
from typing import Union
from internals.data.task import Task
from internals.util.model_loader import ModelConfig
env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_config = None
hf_token = base64.b64decode(
b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
).decode()
hf_cache_dir = "/tmp/hf_hub"
base_dimension = 512 # needed for high res
num_return_sequences = 4 # the number of results to generate
os.makedirs(hf_cache_dir, exist_ok=True)
def set_hf_cache_dir(dir: Union[str, Path]):
global hf_cache_dir
hf_cache_dir = str(dir)
def get_hf_cache_dir():
global hf_cache_dir
return hf_cache_dir
def set_root_dir(main_file: str):
global root_dir
root_dir = os.path.dirname(os.path.abspath(main_file))
def set_model_config(config: ModelConfig):
global model_config
model_config = config
def set_configs_from_task(task: Task):
global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences
name = task.get_queue_name()
if name.startswith("gamma"):
env = "gamma"
else:
env = "prod"
nsfw_threshold = task.get_nsfw_threshold()
nsfw_access = task.can_access_nsfw()
access_token = task.get_access_token()
base_dimension = task.get_base_dimension()
num_return_sequences = task.get_num_return_sequences()
def get_model_dir():
global model_config
return model_config.base_model_path # pyright: ignore
def get_inpaint_model_path():
global model_config
return model_config.base_inpaint_model_path # pyright: ignore
def get_base_dimension():
global global_base_dimension, base_dimension
if base_dimension:
return base_dimension
return model_config.base_dimension # pyright: ignore
def get_is_sdxl():
global model_config
return model_config.is_sdxl # pyright: ignore
def get_root_dir():
global root_dir
return root_dir
def get_num_return_sequences():
global num_return_sequences
return num_return_sequences
def get_environment():
global env
return env
def get_nsfw_threshold():
global nsfw_threshold
return nsfw_threshold
def get_nsfw_access():
global nsfw_access
return nsfw_access
def get_hf_token():
global hf_token
return hf_token
def get_low_gpu_mem():
global model_config
return model_config.low_gpu_mem # pyright: ignore
def get_base_model_variant():
global model_config
return model_config.base_model_variant # pyright: ignore
def get_base_inpaint_model_variant():
global model_config
return model_config.base_inpaint_model_variant # pyright: ignore
def api_headers():
return {
"Access-Token": access_token,
}
def api_endpoint():
if env == "prod":
return "https://api.autodraft.in"
else:
return "https://gamma-api.autodraft.in"
def comic_url():
if env == "prod":
return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
else:
return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"
|