import gradio as gr import base64 import json import torch class Config: def __init__(self): self.code = {} self.history = [] self.devices = [] def load_app_config(self): try: with open('appConfig.json', 'r') as f: appConfig = json.load(f) except FileNotFoundError: print("App config file not found.") except json.JSONDecodeError: print("Error decoding JSON in app config file.") except Exception as e: print("An error occurred while loading app config:", str(e)) return appConfig def set_inital_config(self): appConfig = self.load_app_config() self.model_configs = appConfig.get("models", {}) self.scheduler_configs = appConfig.get("schedulers", {}) # default device self.devices = appConfig.get("devices", []) device = None data_type = 'float16' allow_tensorfloat32 = False if torch.cuda.is_available(): device = "cuda" data_type = 'bfloat16' allow_tensorfloat32 = True elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" self.current = { "device": device, "model": None, "scheduler": None, "variant": None, "allow_tensorfloat32": allow_tensorfloat32, "use_safetensors": False, "data_type": data_type, "safety_checker": False, "requires_safety_checker": False, "manual_seed": 42, "inference_steps": 10, "guidance_scale": 0.5, "prompt": 'A white rabbit', "negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly', } self.assemble_code() # code output order # self.code[self.code_pos_device] = f'device = "{device}"' # self.code[self.code_pos_variant] = f'variant = {initial_config["variant"]}' # self.code[self.code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config["allow_tensorfloat32"]}' # self.code[self.code_pos_data_type] = 'data_type = torch.bfloat16' # self.code[self.code_pos_init_pipeline] = 'sys.exit("No model selected!")' # self.code[self.code_pos_safety_checker] = 'pipeline.safety_checker = None' # self.code[self.code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config["requires_safety_checker"]}' # self.code[self.code_pos_scheduler] = 'sys.exit("No scheduler selected!")' # self.code[self.code_pos_generator] = f'generator = torch.Generator("{device}")' # self.code[self.code_pos_prompt] = f'prompt = "{initial_config["prompt"]}"' # self.code[self.code_pos_negative_prompt] = f'negative_prompt = "{initial_config["negative_prompt"]}"' # self.code[self.code_pos_inference_steps] = f'inference_steps = {initial_config["inference_steps"]}' # self.code[self.code_pos_manual_seed] = f'manual_seed = {initial_config["inference_steps"]}' # self.code[self.code_pos_guidance_scale] = f'guidance_scale = {initial_config["guidance_scale"]}' # self.code[self.code_pos_run_inference] = f'''image = pipeline( # prompt=prompt, # negative_prompt=negative_prompt, # generator=generator.manual_seed(manual_seed), # num_inference_steps=inference_steps, # guidance_scale=guidance_scale).images[0]''' # return initial_config, devices, model_configs, scheduler_configs, self.code def init_config(self, request: gr.Request, inital_config): encoded_params = request.request.query_params.get('config') return_config = {} # get configuration from URL if GET parameter `share` is set if encoded_params is not None: decoded_params = base64.b64decode(encoded_params) decoded_params = decoded_params.decode('utf-8') decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false') dict_params = json.loads(decoded_params) return_config = dict_params # otherwise use default initial config else: inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false') dict_inital_config = json.loads(inital_config) return_config = dict_inital_config return [return_config['model'], return_config['device'], return_config['use_safetensors'], return_config['data_type'], return_config['variant'], return_config['safety_checker'], return_config['requires_safety_checker'], return_config['scheduler'], return_config['prompt'], return_config['negative_prompt'], return_config['inference_steps'], return_config['manual_seed'], return_config['guidance_scale'] ] def set_config(self, key, value): self.current[key] = value return str(self.current) def get_scheduler_description(self, scheduler): if type(scheduler) != list and scheduler is not None: return self.scheduler_configs[scheduler] else: return '' def assemble_code(self): self.code['001_code'] = f'''device = "{self.current['device']}"''' if self.current['data_type'] == "bfloat16": self.code['002_data_type'] = 'data_type = torch.bfloat16' else: self.code['002_data_type'] = 'data_type = torch.float16' self.code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {self.current["allow_tensorfloat32"]}' if str(self.current["variant"]) == 'None': self.code['004_variant'] = f'variant = {self.current["variant"]}' else: self.code['004_variant'] = f'variant = "{self.current["variant"]}"' self.code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained( "{self.current['model']}", use_safetensors=use_safetensors, torch_dtype=data_type, variant=variant).to(device)''' self.code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {self.current["requires_safety_checker"]}' if not self.current["safety_checker"] or str(self.current["safety_checker"]).lower == 'false': self.code['055_safety_checker'] = f'pipeline.safety_checker = None' else: self.code['055_safety_checker'] = '' self.code['060_scheduler'] = f'pipeline.scheduler = {self.current["scheduler"]}.from_config(pipeline.scheduler.config)' if self.current['manual_seed'] < 0 or self.current['manual_seed'] is None or self.current['manual_seed'] == '': self.code['070_generator'] = f'generator = torch.Generator("{self.current["device"]}")' self.code['091_manual_seed'] = f'# manual_seed = {self.current["manual_seed"]}' else: self.code['070_generator'] = f'generator = torch.manual_seed(manual_seed)' self.code['091_manual_seed'] = f'manual_seed = {self.current["manual_seed"]}' self.code["080_prompt"] = f'prompt = {self.current["prompt"]}' self.code["085_negative_prompt"] = f'negative_prompt = {self.current["negative_prompt"]}' self.code["090_inference_steps"] = f'inference_steps = {self.current["inference_steps"]}' self.code["095_guidance_scale"] = f'guidance_scale = {self.current["guidance_scale"]}' self.code["100_run_inference"] = f'''image = pipeline( prompt=prompt, negative_prompt=negative_prompt, generator={self.code["070_generator"]}, num_inference_steps=inference_steps, guidance_scale=guidance_scale).images[0]''' return '\r\n'.join(value[1] for value in sorted(self.code.items()))