jayparmr's picture
Upload folder using huggingface_hub
bcaef47 verified
raw
history blame
3.21 kB
import base64
import os
from pathlib import Path
from typing import Union
from internals.data.task import Task
from internals.util.model_loader import ModelConfig
env = "prod"
nsfw_threshold = 0.0
nsfw_access = False
access_token = ""
root_dir = ""
model_config = None
hf_token = base64.b64decode(
b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
).decode()
hf_cache_dir = "/tmp/hf_hub"
base_dimension = 512 # needed for high res
num_return_sequences = 4 # the number of results to generate
os.makedirs(hf_cache_dir, exist_ok=True)
def set_hf_cache_dir(dir: Union[str, Path]):
global hf_cache_dir
hf_cache_dir = str(dir)
def get_hf_cache_dir():
global hf_cache_dir
return hf_cache_dir
def set_root_dir(main_file: str):
global root_dir
root_dir = os.path.dirname(os.path.abspath(main_file))
def set_model_config(config: ModelConfig):
global model_config
model_config = config
def set_configs_from_task(task: Task):
global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences
name = task.get_queue_name()
if name.startswith("gamma"):
env = "gamma"
else:
env = "prod"
nsfw_threshold = task.get_nsfw_threshold()
nsfw_access = task.can_access_nsfw()
access_token = task.get_access_token()
base_dimension = task.get_base_dimension()
num_return_sequences = task.get_num_return_sequences()
def get_model_dir():
global model_config
return model_config.base_model_path # pyright: ignore
def get_inpaint_model_path():
global model_config
return model_config.base_inpaint_model_path # pyright: ignore
def get_base_dimension():
global global_base_dimension, base_dimension
if base_dimension:
return base_dimension
return model_config.base_dimension # pyright: ignore
def get_is_sdxl():
global model_config
return model_config.is_sdxl # pyright: ignore
def get_root_dir():
global root_dir
return root_dir
def get_num_return_sequences():
global num_return_sequences
return num_return_sequences
def get_environment():
global env
return env
def get_nsfw_threshold():
global nsfw_threshold
return nsfw_threshold
def get_nsfw_access():
global nsfw_access
return nsfw_access
def get_hf_token():
global hf_token
return hf_token
def get_low_gpu_mem():
global model_config
return model_config.low_gpu_mem # pyright: ignore
def get_base_model_variant():
global model_config
return model_config.base_model_variant # pyright: ignore
def get_base_inpaint_model_variant():
global model_config
return model_config.base_inpaint_model_variant # pyright: ignore
def api_headers():
return {
"Access-Token": access_token,
}
def api_endpoint():
if env == "prod":
return "https://api.autodraft.in"
else:
return "https://gamma-api.autodraft.in"
def comic_url():
if env == "prod":
return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
else:
return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"