import gradio as gr import base64 import json import torch code_pos_device = '001_code' code_pos_data_type = '002_data_type' code_pos_tf32 = '003_tf32' code_pos_variant = '004_variant' code_pos_init_pipeline = '050_init_pipe' code_pos_requires_safety_checker = '054_requires_safety_checker' code_pos_safety_checker = '055_safety_checker' code_pos_scheduler = '060_scheduler' code_pos_generator = '070_generator' code_pos_prompt = '080_prompt' code_pos_negative_prompt = '085_negative_prompt' code_pos_inference_steps = '090_inference_steps' code_pos_guidance_scale = '095_guidance_scale' code_pos_run_inference = '100_run_inference' def load_app_config(): try: with open('appConfig.json', 'r') as f: appConfig = json.load(f) except FileNotFoundError: print("App config file not found.") except json.JSONDecodeError: print("Error decoding JSON in app config file.") except Exception as e: print("An error occurred while loading app config:", str(e)) return appConfig def get_inital_config(): appConfig = load_app_config() # default model is None model_configs = appConfig.get("models", {}) # default model is None models = list(model_configs.keys()) model = None # default scheduler is None scheduler_configs = appConfig.get("schedulers", {}) schedulers = list(scheduler_configs.keys()) scheduler = None # default device devices = appConfig.get("devices", []) device = None data_type = 'float16' allow_tensorfloat32 = False if torch.cuda.is_available(): device = "cuda" data_type = 'bfloat16' allow_tensorfloat32 = True elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" initial_config = { "device": device, "model": None, "scheduler": None, "variant": None, "allow_tensorflow32": allow_tensorfloat32, "use_safetensors": False, "data_type": data_type, "safety_checker": False, "requires_safety_checker": False, "manual_seed": 42, "inference_steps": 10, "guidance_scale": 0.5, "prompt": 'A white rabbit', "negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly', } # code output order code = {} code[code_pos_device] = f'device = "{device}"' code[code_pos_variant] = f'variant = {initial_config['variant']}' code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config['allow_tensorfloat32']}' code[code_pos_data_type] = 'data_type = torch.bfloat16' code[code_pos_init_pipeline] = 'sys.exit("No model selected!")' code[code_pos_safety_checker] = 'pipeline.safety_checker = None' code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config['requires_safety_checker']}' code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")' code[code_pos_generator] = f'generator = torch.Generator("{device}")' code[code_pos_prompt] = f'prompt = "{initial_config['prompt']}"' code[code_pos_negative_prompt] = f'negative_prompt = "{initial_config['negative_prompt']}"' code[code_pos_inference_steps] = f'inference_steps = {initial_config['inference_steps']}' code[code_pos_guidance_scale] = f'guidance_scale = {initial_config['guidance_scale']}' code[code_pos_run_inference] = f'''image = pipeline( prompt=prompt, negative_prompt=negative_prompt, generator=generator.manual_seed(manual_seed), num_inference_steps=inference_steps, guidance_scale=guidance_scale).images[0]''' return initial_config, devices, models, schedulers, code def init_config(request: gr.Request, inital_config): encoded_params = request.request.query_params.get('config') return_config = {} # get configuration from URL if GET parameter `share` is set if encoded_params is not None: decoded_params = base64.b64decode(encoded_params) decoded_params = decoded_params.decode('utf-8') decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false') dict_params = json.loads(decoded_params) return_config = dict_params # otherwise use default initial config else: inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false') dict_inital_config = json.loads(inital_config) return_config = dict_inital_config return [return_config['model'], return_config['device'], return_config['use_safetensors'], return_config['data_type'], return_config['variant'], return_config['safety_checker'], return_config['requires_safety_checker'], return_config['scheduler'], return_config['prompt'], return_config['negative_prompt'], return_config['inference_steps'], return_config['manual_seed'], return_config['guidance_scale'] ] def get_sorted_code(): return '\r\n'.join(value[1] for value in sorted(code.items()))