import base64 import os from pathlib import Path from typing import Union from internals.data.task import Task from internals.util.model_loader import ModelConfig env = "prod" nsfw_threshold = 0.0 nsfw_access = False access_token = "" root_dir = "" model_config = None hf_token = base64.b64decode( b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA==" ).decode() hf_cache_dir = "/tmp/hf_hub" base_dimension = 512 # needed for high res num_return_sequences = 4 # the number of results to generate os.makedirs(hf_cache_dir, exist_ok=True) def set_hf_cache_dir(dir: Union[str, Path]): global hf_cache_dir hf_cache_dir = str(dir) def get_hf_cache_dir(): global hf_cache_dir return hf_cache_dir def set_root_dir(main_file: str): global root_dir root_dir = os.path.dirname(os.path.abspath(main_file)) def set_model_config(config: ModelConfig): global model_config model_config = config def set_configs_from_task(task: Task): global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences name = task.get_queue_name() if name.startswith("gamma"): env = "gamma" else: env = "prod" nsfw_threshold = task.get_nsfw_threshold() nsfw_access = task.can_access_nsfw() access_token = task.get_access_token() base_dimension = task.get_base_dimension() num_return_sequences = task.get_num_return_sequences() def get_model_dir(): global model_config return model_config.base_model_path # pyright: ignore def get_inpaint_model_path(): global model_config return model_config.base_inpaint_model_path # pyright: ignore def get_base_dimension(): global global_base_dimension, base_dimension if base_dimension: return base_dimension return model_config.base_dimension # pyright: ignore def get_is_sdxl(): global model_config return model_config.is_sdxl # pyright: ignore def get_root_dir(): global root_dir return root_dir def get_num_return_sequences(): global num_return_sequences return num_return_sequences def get_environment(): global env return env def get_nsfw_threshold(): global nsfw_threshold return nsfw_threshold def get_nsfw_access(): global nsfw_access return nsfw_access def get_hf_token(): global hf_token return hf_token def get_low_gpu_mem(): global model_config return model_config.low_gpu_mem # pyright: ignore def get_base_model_variant(): global model_config return model_config.base_model_variant # pyright: ignore def get_base_inpaint_model_variant(): global model_config return model_config.base_inpaint_model_variant # pyright: ignore def api_headers(): return { "Access-Token": access_token, } def api_endpoint(): if env == "prod": return "https://api.autodraft.in" else: return "https://gamma-api.autodraft.in" def comic_url(): if env == "prod": return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80" else: return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"