pictero / config.py
nickyreinert-vml
fix wrong quotes
a0e2e28
raw
history blame
10.8 kB
import gradio as gr
import base64
import json
import torch
from gradio import Request
from gradio.context import Context
def persist(component):
sessions = {}
def resume_session(value, request: Request):
return sessions.get(request.username, value)
def update_session(value, request: Request):
sessions[request.username] = value
Context.root_block.load(resume_session, inputs=[component], outputs=component)
component.change(update_session, inputs=[component])
return component
def get_initial_config():
device = None
data_type = 'float16'
allow_tensorfloat32 = "False"
if torch.cuda.is_available():
device = "cuda"
data_type = 'bfloat16'
allow_tensorfloat32 = "True"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
config = {
"device": device,
"model": None,
"cpu_offload": "False",
"scheduler": None,
"variant": None,
"attention_slicing": "False",
"pre_compile_unet": "False",
"allow_tensorfloat32": allow_tensorfloat32,
"use_safetensors": "False",
"data_type": data_type,
"refiner": "none",
"safety_checker": "False",
"requires_safety_checker": "False",
"auto_encoder": None,
"enable_vae_slicing": "True",
"enable_vae_tiling": "True",
"manual_seed": 42,
"inference_steps": 10,
"guidance_scale": 5,
"adapter_textual_inversion": None,
"adapter_textual_inversion_token": None,
"adapter_lora": [],
"adapter_lora_token": [],
"adapter_lora_weight": [],
"adapter_lora_balancing": {},
"lora_scale": 0.5,
"prompt": 'A white rabbit',
"trigger_token": '',
"negative_prompt": 'lowres, cropped, worst quality, low quality',
}
return config
def get_config_from_url(initial_config, request: Request):
encoded_params = request.request.query_params.get('config')
return_config = {}
# get configuration from URL if GET parameter `share` is set
if encoded_params is not None:
decoded_params = base64.b64decode(encoded_params)
decoded_params = decoded_params.decode('utf-8')
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False')
dict_params = json.loads(decoded_params)
return_config = dict_params
# otherwise use default initial config
else:
# check if a cookie exists for our initial parameters
for key in initial_config.keys():
if key in request.cookies:
value = request.cookies[key]
# transform empty values to a "Python-like" None
if value == 'null' or value == '': value = None
# if value expected to be a list, transform the string to list
if type(initial_config[key]) == list: value = json.loads(value)
initial_config[key] = value
return_config = initial_config
return [return_config['model'],
return_config['device'],
return_config['cpu_offload'],
return_config['use_safetensors'],
return_config['data_type'],
return_config['refiner'],
return_config['variant'],
return_config['attention_slicing'],
return_config['pre_compile_unet'],
return_config['safety_checker'],
return_config['requires_safety_checker'],
return_config['auto_encoder'],
return_config['enable_vae_slicing'],
return_config['enable_vae_tiling'],
return_config['scheduler'],
return_config['prompt'],
return_config['trigger_token'],
return_config['negative_prompt'],
return_config['inference_steps'],
return_config['manual_seed'],
return_config['guidance_scale'],
return_config['adapter_textual_inversion'],
return_config['adapter_textual_inversion_token'],
return_config['adapter_lora'],
return_config['adapter_lora_token'],
return_config['adapter_lora_weight'],
return_config['adapter_lora_balancing'],
return_config['lora_scale']
]
def load_app_config():
try:
with open('appConfig.json', 'r') as f:
appConfig = json.load(f)
except FileNotFoundError:
print("App config file not found.")
except json.JSONDecodeError:
print("Error decoding JSON in app config file.")
except Exception as e:
print("An error occurred while loading app config:", str(e))
return appConfig
def set_config(config, key, value):
if str(value).lower() == 'null' or str(value).lower() == 'none': value = ''
config[key] = value
return config
def assemble_code(str_config):
config = str_config
code = []
code.append(f'''device = "{config['device']}"''')
if config['data_type'] == "bfloat16":
code.append('data_type = torch.bfloat16')
else:
code.append('data_type = torch.float16')
code.append(f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}')
if str(config["variant"]) == 'None':
code.append(f'variant = {config["variant"]}')
else:
code.append(f'variant = "{config["variant"]}"')
code.append(f'''use_safetensors = {config["use_safetensors"]}''')
# INIT PIPELINE
code.append(f'''pipeline = DiffusionPipeline.from_pretrained(
"{config['model']}",
use_safetensors=use_safetensors,
torch_dtype=data_type,
variant=variant).to(device)''')
if str(config["attention_slicing"]).lower() != 'false': code.append("pipeline.enable_attention_slicing()")
if str(config["pre_compile_unet"]).lower() != 'false': code.append("pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead', fullgraph=True)")
if str(config["cpu_offload"]).lower() != 'false': code.append("pipeline.enable_model_cpu_offload()")
# AUTO ENCODER
if str(config["auto_encoder"]).lower() != 'none' and str(config["auto_encoder"]).lower() != 'null' and str(config["auto_encoder"]).lower() != '':
code.append(f'pipeline.vae = AutoencoderKL.from_pretrained("{config["auto_encoder"]}", torch_dtype=data_type).to(device)')
if str(config["enable_vae_slicing"]).lower() != 'false': code.append("pipeline.enable_vae_slicing()")
if str(config["enable_vae_tiling"]).lower() != 'false': code.append("pipeline.enable_vae_tiling()")
# INIT REFINER
if str(config['refiner']).lower() != 'none':
code.append(f'''refiner = DiffusionPipeline.from_pretrained(
"{config['refiner']}",
text_encoder_2 = base.text_encoder_2,
vae = base.vae,
torch_dtype = data_t ype,
use_safetensors = use_safetensors,
variant=variant,
).to(device)''')
if str(config["cpu_offload"]).lower() != 'false': code.append("refiner.enable_model_cpu_offload()")
if str(config["enable_vae_slicing"]).lower() != 'false': code.append("refiner.enable_vae_slicing()")
if str(config["enable_vae_tiling"]).lower() != 'false': code.append("refiner.enable_vae_tiling()")
# SAFETY CHECKER
code.append(f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}')
if str(config["safety_checker"]).lower() == 'false':
code.append(f'pipeline.safety_checker = None')
# SCHEDULER/SOLVER
if str(config["scheduler"]).lower() != 'none':
code.append(f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)')
# MANUAL SEED/GENERATOR
if config['manual_seed'] is None or config['manual_seed'] == '' or int(config['manual_seed']) < 0:
code.append(f'# manual_seed = {config["manual_seed"]}')
code.append(f'generator = None')
else:
code.append(f'manual_seed = {config["manual_seed"]}')
code.append(f'generator = torch.manual_seed(manual_seed)')
# ADAPTER
if str(config["adapter_textual_inversion"]).lower() != 'none' and str(config["adapter_textual_inversion"]).lower() != 'null' and str(config["adapter_textual_inversion"]).lower() != '':
code.append(f'pipeline.load_textual_inversion("{config["adapter_textual_inversion"]}", token="{config["adapter_textual_inversion_token"]}")')
if len(config["adapter_lora"]) > 0 and len(config["adapter_lora"]) == len(config["adapter_lora_weight"]):
adapter_lora_balancing = []
for adapter_lora_index, adapter_lora in enumerate(config["adapter_lora"]):
if str(config["adapter_lora_weight"][adapter_lora_index]).lower() != 'none':
code.append(f'pipeline.load_lora_weights("{adapter_lora}", weight_name="{config["adapter_lora_weight"][adapter_lora_index]}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")')
else:
code.append(f'pipeline.load_lora_weights("{adapter_lora}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")')
adapter_lora_balancing.append(config["adapter_lora_balancing"][adapter_lora])
code.append(f'adapter_weights = {adapter_lora_balancing}')
code.append(f'pipeline.set_adapters({config["adapter_lora_token"]}, adapter_weights=adapter_weights)')
cross_attention_kwargs = '{"scale": ' + config["lora_scale"] + '}'
else:
cross_attention_kwargs = 'None'
code.append(f'prompt = "{config["prompt"]} {config["trigger_token"]} {config["adapter_textual_inversion_token"]} {", ".join(config["adapter_lora_token"])}"')
code.append(f'negative_prompt = "{config["negative_prompt"]}"')
code.append(f'inference_steps = {config["inference_steps"]}')
code.append(f'guidance_scale = {config["guidance_scale"]}')
code.append(f'''image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
num_inference_steps=inference_steps,
cross_attention_kwargs={cross_attention_kwargs},
guidance_scale=guidance_scale).images
''')
if str(config['refiner']).lower() != 'none':
code.append(f'''image = refiner(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=inference_steps,
image=image
).images[0]''')
code.append('image[0]')
return '\r\n'.join(code)