pictero / config.py
nickyreinert-vml
adding pre compile feature
e4a20eb
raw
history blame
10.8 kB
import gradio as gr
import base64
import json
import torch
from gradio import Request
from gradio.context import Context
def persist(component):
sessions = {}
def resume_session(value, request: Request):
return sessions.get(request.username, value)
def update_session(value, request: Request):
sessions[request.username] = value
Context.root_block.load(resume_session, inputs=[component], outputs=component)
component.change(update_session, inputs=[component])
return component
def get_initial_config():
device = None
data_type = 'float16'
allow_tensorfloat32 = "False"
if torch.cuda.is_available():
device = "cuda"
data_type = 'bfloat16'
allow_tensorfloat32 = "True"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
config = {
"device": device,
"model": None,
"cpu_offload": "False",
"scheduler": None,
"variant": None,
"attention_slicing": "False",
"pre_compile_unet": "False",
"allow_tensorfloat32": allow_tensorfloat32,
"use_safetensors": "False",
"data_type": data_type,
"refiner": "none",
"safety_checker": "False",
"requires_safety_checker": "False",
"auto_encoder": None,
"enable_vae_slicing": "True",
"enable_vae_tiling": "True",
"manual_seed": 42,
"inference_steps": 10,
"guidance_scale": 5,
"adapter_textual_inversion": None,
"adapter_textual_inversion_token": None,
"adapter_lora": [],
"adapter_lora_token": [],
"adapter_lora_weight": [],
"adapter_lora_balancing": {},
"lora_scale": 0.5,
"prompt": 'A white rabbit',
"trigger_token": '',
"negative_prompt": 'lowres, cropped, worst quality, low quality',
}
return config
def get_config_from_url(initial_config, request: Request):
encoded_params = request.request.query_params.get('config')
return_config = {}
# get configuration from URL if GET parameter `share` is set
if encoded_params is not None:
decoded_params = base64.b64decode(encoded_params)
decoded_params = decoded_params.decode('utf-8')
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False')
dict_params = json.loads(decoded_params)
return_config = dict_params
# otherwise use default initial config
else:
# check if a cookie exists for our initial parameters
for key in initial_config.keys():
if key in request.cookies:
value = request.cookies[key]
# transform empty values to a "Python-like" None
if value == 'null' or value == '': value = None
# if value expected to be a list, transform the string to list
if type(initial_config[key]) == list: value = json.loads(value)
initial_config[key] = value
return_config = initial_config
return [return_config['model'],
return_config['device'],
return_config['cpu_offload'],
return_config['use_safetensors'],
return_config['data_type'],
return_config['refiner'],
return_config['variant'],
return_config['attention_slicing'],
return_config['pre_compile_unet'],
return_config['safety_checker'],
return_config['requires_safety_checker'],
return_config['auto_encoder'],
return_config['enable_vae_slicing'],
return_config['enable_vae_tiling'],
return_config['scheduler'],
return_config['prompt'],
return_config['trigger_token'],
return_config['negative_prompt'],
return_config['inference_steps'],
return_config['manual_seed'],
return_config['guidance_scale'],
return_config['adapter_textual_inversion'],
return_config['adapter_textual_inversion_token'],
return_config['adapter_lora'],
return_config['adapter_lora_token'],
return_config['adapter_lora_weight'],
return_config['adapter_lora_balancing'],
return_config['lora_scale']
]
def load_app_config():
try:
with open('appConfig.json', 'r') as f:
appConfig = json.load(f)
except FileNotFoundError:
print("App config file not found.")
except json.JSONDecodeError:
print("Error decoding JSON in app config file.")
except Exception as e:
print("An error occurred while loading app config:", str(e))
return appConfig
def set_config(config, key, value):
if str(value).lower() == 'null' or str(value).lower() == 'none': value = ''
config[key] = value
return config
def assemble_code(str_config):
config = str_config
code = []
code.append(f'''device = "{config['device']}"''')
if config['data_type'] == "bfloat16":
code.append('data_type = torch.bfloat16')
else:
code.append('data_type = torch.float16')
code.append(f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}')
if str(config["variant"]) == 'None':
code.append(f'variant = {config["variant"]}')
else:
code.append(f'variant = "{config["variant"]}"')
code.append(f'''use_safetensors = {config["use_safetensors"]}''')
# INIT PIPELINE
code.append(f'''pipeline = DiffusionPipeline.from_pretrained(
"{config['model']}",
use_safetensors=use_safetensors,
torch_dtype=data_type,
variant=variant).to(device)''')
if str(config["attention_slicing"]).lower() != 'false': code.append("pipeline.enable_attention_slicing()")
if str(config["pre_compile_unet"]).lower() != 'false': code.append("pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)")
if str(config["cpu_offload"]).lower() != 'false': code.append("pipeline.enable_model_cpu_offload()")
# AUTO ENCODER
if str(config["auto_encoder"]).lower() != 'none' and str(config["auto_encoder"]).lower() != 'null' and str(config["auto_encoder"]).lower() != '':
code.append(f'pipeline.vae = AutoencoderKL.from_pretrained("{config["auto_encoder"]}", torch_dtype=data_type).to(device)')
if str(config["enable_vae_slicing"]).lower() != 'false': code.append("pipeline.enable_vae_slicing()")
if str(config["enable_vae_tiling"]).lower() != 'false': code.append("pipeline.enable_vae_tiling()")
# INIT REFINER
if str(config['refiner']).lower() != 'none':
code.append(f'''refiner = DiffusionPipeline.from_pretrained(
"{config['refiner']}",
text_encoder_2 = base.text_encoder_2,
vae = base.vae,
torch_dtype = data_t ype,
use_safetensors = use_safetensors,
variant=variant,
).to(device)''')
if str(config["cpu_offload"]).lower() != 'false': code.append("refiner.enable_model_cpu_offload()")
if str(config["enable_vae_slicing"]).lower() != 'false': code.append("refiner.enable_vae_slicing()")
if str(config["enable_vae_tiling"]).lower() != 'false': code.append("refiner.enable_vae_tiling()")
# SAFETY CHECKER
code.append(f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}')
if str(config["safety_checker"]).lower() == 'false':
code.append(f'pipeline.safety_checker = None')
# SCHEDULER/SOLVER
if str(config["scheduler"]).lower() != 'none':
code.append(f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)')
# MANUAL SEED/GENERATOR
if config['manual_seed'] is None or config['manual_seed'] == '' or int(config['manual_seed']) < 0:
code.append(f'# manual_seed = {config["manual_seed"]}')
code.append(f'generator = None')
else:
code.append(f'manual_seed = {config["manual_seed"]}')
code.append(f'generator = torch.manual_seed(manual_seed)')
# ADAPTER
if str(config["adapter_textual_inversion"]).lower() != 'none' and str(config["adapter_textual_inversion"]).lower() != 'null' and str(config["adapter_textual_inversion"]).lower() != '':
code.append(f'pipeline.load_textual_inversion("{config["adapter_textual_inversion"]}", token="{config["adapter_textual_inversion_token"]}")')
if len(config["adapter_lora"]) > 0 and len(config["adapter_lora"]) == len(config["adapter_lora_weight"]):
adapter_lora_balancing = []
for adapter_lora_index, adapter_lora in enumerate(config["adapter_lora"]):
if str(config["adapter_lora_weight"][adapter_lora_index]).lower() != 'none':
code.append(f'pipeline.load_lora_weights("{adapter_lora}", weight_name="{config["adapter_lora_weight"][adapter_lora_index]}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")')
else:
code.append(f'pipeline.load_lora_weights("{adapter_lora}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")')
adapter_lora_balancing.append(config["adapter_lora_balancing"][adapter_lora])
code.append(f'adapter_weights = {adapter_lora_balancing}')
code.append(f'pipeline.set_adapters({config["adapter_lora_token"]}, adapter_weights=adapter_weights)')
cross_attention_kwargs = '{"scale": ' + config["lora_scale"] + '}'
else:
cross_attention_kwargs = 'None'
code.append(f'prompt = "{config["prompt"]} {config["trigger_token"]} {config["adapter_textual_inversion_token"]} {", ".join(config["adapter_lora_token"])}"')
code.append(f'negative_prompt = "{config["negative_prompt"]}"')
code.append(f'inference_steps = {config["inference_steps"]}')
code.append(f'guidance_scale = {config["guidance_scale"]}')
code.append(f'''image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
num_inference_steps=inference_steps,
cross_attention_kwargs={cross_attention_kwargs},
guidance_scale=guidance_scale).images
''')
if str(config['refiner']).lower() != 'none':
code.append(f'''image = refiner(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=inference_steps,
image=image
).images[0]''')
code.append('image[0]')
return '\r\n'.join(code)