|
import base64 |
|
import os |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
from internals.data.task import Task |
|
from internals.util.model_loader import ModelConfig |
|
|
|
env = "prod" |
|
nsfw_threshold = 0.0 |
|
nsfw_access = False |
|
access_token = "" |
|
root_dir = "" |
|
model_config = None |
|
hf_token = base64.b64decode( |
|
b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA==" |
|
).decode() |
|
hf_cache_dir = "/tmp/hf_hub" |
|
|
|
base_dimension = 512 |
|
|
|
num_return_sequences = 4 |
|
|
|
os.makedirs(hf_cache_dir, exist_ok=True) |
|
|
|
|
|
def set_hf_cache_dir(dir: Union[str, Path]): |
|
global hf_cache_dir |
|
hf_cache_dir = str(dir) |
|
|
|
|
|
def get_hf_cache_dir(): |
|
global hf_cache_dir |
|
return hf_cache_dir |
|
|
|
|
|
def set_root_dir(main_file: str): |
|
global root_dir |
|
root_dir = os.path.dirname(os.path.abspath(main_file)) |
|
|
|
|
|
def set_model_config(config: ModelConfig): |
|
global model_config |
|
model_config = config |
|
|
|
|
|
def set_configs_from_task(task: Task): |
|
global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences |
|
name = task.get_queue_name() |
|
if name.startswith("gamma"): |
|
env = "gamma" |
|
else: |
|
env = "prod" |
|
nsfw_threshold = task.get_nsfw_threshold() |
|
nsfw_access = task.can_access_nsfw() |
|
access_token = task.get_access_token() |
|
base_dimension = task.get_base_dimension() |
|
num_return_sequences = task.get_num_return_sequences() |
|
|
|
|
|
def get_model_dir(): |
|
global model_config |
|
return model_config.base_model_path |
|
|
|
|
|
def get_inpaint_model_path(): |
|
global model_config |
|
return model_config.base_inpaint_model_path |
|
|
|
|
|
def get_base_dimension(): |
|
global global_base_dimension, base_dimension |
|
if base_dimension: |
|
return base_dimension |
|
return model_config.base_dimension |
|
|
|
|
|
def get_is_sdxl(): |
|
global model_config |
|
return model_config.is_sdxl |
|
|
|
|
|
def get_root_dir(): |
|
global root_dir |
|
return root_dir |
|
|
|
|
|
def get_num_return_sequences(): |
|
global num_return_sequences |
|
return num_return_sequences |
|
|
|
|
|
def get_environment(): |
|
global env |
|
return env |
|
|
|
|
|
def get_nsfw_threshold(): |
|
global nsfw_threshold |
|
return nsfw_threshold |
|
|
|
|
|
def get_nsfw_access(): |
|
global nsfw_access |
|
return nsfw_access |
|
|
|
|
|
def get_hf_token(): |
|
global hf_token |
|
return hf_token |
|
|
|
|
|
def get_low_gpu_mem(): |
|
global model_config |
|
return model_config.low_gpu_mem |
|
|
|
|
|
def get_base_model_variant(): |
|
global model_config |
|
return model_config.base_model_variant |
|
|
|
|
|
def get_base_inpaint_model_variant(): |
|
global model_config |
|
return model_config.base_inpaint_model_variant |
|
|
|
|
|
def api_headers(): |
|
return { |
|
"Access-Token": access_token, |
|
} |
|
|
|
|
|
def api_endpoint(): |
|
if env == "prod": |
|
return "https://api.autodraft.in" |
|
else: |
|
return "https://gamma-api.autodraft.in" |
|
|
|
|
|
def comic_url(): |
|
if env == "prod": |
|
return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80" |
|
else: |
|
return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80" |
|
|