|
import gradio as gr |
|
import base64 |
|
import json |
|
import torch |
|
from gradio import Request |
|
from gradio.context import Context |
|
|
|
|
|
def persist(component): |
|
sessions = {} |
|
|
|
def resume_session(value, request: Request): |
|
return sessions.get(request.username, value) |
|
|
|
def update_session(value, request: Request): |
|
sessions[request.username] = value |
|
|
|
Context.root_block.load(resume_session, inputs=[component], outputs=component) |
|
component.change(update_session, inputs=[component]) |
|
|
|
return component |
|
|
|
def get_initial_config(): |
|
|
|
device = None |
|
data_type = 'float16' |
|
allow_tensorfloat32 = "False" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
data_type = 'bfloat16' |
|
allow_tensorfloat32 = "True" |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
|
|
config = { |
|
"device": device, |
|
"model": None, |
|
"cpu_offload": "False", |
|
"scheduler": None, |
|
"variant": None, |
|
"attention_slicing": "False", |
|
"pre_compile_unet": "False", |
|
"allow_tensorfloat32": allow_tensorfloat32, |
|
"use_safetensors": "False", |
|
"data_type": data_type, |
|
"refiner": "none", |
|
"safety_checker": "False", |
|
"requires_safety_checker": "False", |
|
"auto_encoder": None, |
|
"enable_vae_slicing": "True", |
|
"enable_vae_tiling": "True", |
|
"manual_seed": 42, |
|
"inference_steps": 10, |
|
"guidance_scale": 5, |
|
"adapter_textual_inversion": None, |
|
"adapter_textual_inversion_token": None, |
|
"adapter_lora": [], |
|
"adapter_lora_token": [], |
|
"adapter_lora_weight": [], |
|
"adapter_lora_balancing": {}, |
|
"lora_scale": 0.5, |
|
"prompt": 'A white rabbit', |
|
"trigger_token": '', |
|
"negative_prompt": 'lowres, cropped, worst quality, low quality', |
|
} |
|
|
|
return config |
|
|
|
def get_config_from_url(initial_config, request: Request): |
|
|
|
encoded_params = request.request.query_params.get('config') |
|
return_config = {} |
|
|
|
|
|
if encoded_params is not None: |
|
decoded_params = base64.b64decode(encoded_params) |
|
decoded_params = decoded_params.decode('utf-8') |
|
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False') |
|
dict_params = json.loads(decoded_params) |
|
|
|
return_config = dict_params |
|
|
|
|
|
else: |
|
|
|
|
|
for key in initial_config.keys(): |
|
if key in request.cookies: |
|
value = request.cookies[key] |
|
|
|
if value == 'null' or value == '': value = None |
|
|
|
if type(initial_config[key]) == list: value = json.loads(value) |
|
initial_config[key] = value |
|
|
|
return_config = initial_config |
|
|
|
return [return_config['model'], |
|
return_config['device'], |
|
return_config['cpu_offload'], |
|
return_config['use_safetensors'], |
|
return_config['data_type'], |
|
return_config['refiner'], |
|
return_config['variant'], |
|
return_config['attention_slicing'], |
|
return_config['pre_compile_unet'], |
|
return_config['safety_checker'], |
|
return_config['requires_safety_checker'], |
|
return_config['auto_encoder'], |
|
return_config['enable_vae_slicing'], |
|
return_config['enable_vae_tiling'], |
|
return_config['scheduler'], |
|
return_config['prompt'], |
|
return_config['trigger_token'], |
|
return_config['negative_prompt'], |
|
return_config['inference_steps'], |
|
return_config['manual_seed'], |
|
return_config['guidance_scale'], |
|
return_config['adapter_textual_inversion'], |
|
return_config['adapter_textual_inversion_token'], |
|
return_config['adapter_lora'], |
|
return_config['adapter_lora_token'], |
|
return_config['adapter_lora_weight'], |
|
return_config['adapter_lora_balancing'], |
|
return_config['lora_scale'] |
|
] |
|
|
|
def load_app_config(): |
|
try: |
|
with open('appConfig.json', 'r') as f: |
|
appConfig = json.load(f) |
|
except FileNotFoundError: |
|
print("App config file not found.") |
|
except json.JSONDecodeError: |
|
print("Error decoding JSON in app config file.") |
|
except Exception as e: |
|
print("An error occurred while loading app config:", str(e)) |
|
|
|
return appConfig |
|
|
|
def set_config(config, key, value): |
|
|
|
if str(value).lower() == 'null' or str(value).lower() == 'none': value = '' |
|
|
|
config[key] = value |
|
|
|
return config |
|
|
|
def assemble_code(str_config): |
|
|
|
config = str_config |
|
|
|
code = [] |
|
|
|
code.append(f'''device = "{config['device']}"''') |
|
if config['data_type'] == "bfloat16": |
|
code.append('data_type = torch.bfloat16') |
|
else: |
|
code.append('data_type = torch.float16') |
|
|
|
code.append(f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}') |
|
|
|
if str(config["variant"]) == 'None': |
|
code.append(f'variant = {config["variant"]}') |
|
else: |
|
code.append(f'variant = "{config["variant"]}"') |
|
|
|
code.append(f'''use_safetensors = {config["use_safetensors"]}''') |
|
|
|
|
|
code.append(f'''pipeline = DiffusionPipeline.from_pretrained( |
|
"{config['model']}", |
|
use_safetensors=use_safetensors, |
|
torch_dtype=data_type, |
|
variant=variant).to(device)''') |
|
|
|
if str(config["attention_slicing"]).lower() != 'false': code.append("pipeline.enable_attention_slicing()") |
|
if str(config["pre_compile_unet"]).lower() != 'false': code.append("pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead', fullgraph=True)") |
|
|
|
if str(config["cpu_offload"]).lower() != 'false': code.append("pipeline.enable_model_cpu_offload()") |
|
|
|
|
|
if str(config["auto_encoder"]).lower() != 'none' and str(config["auto_encoder"]).lower() != 'null' and str(config["auto_encoder"]).lower() != '': |
|
code.append(f'pipeline.vae = AutoencoderKL.from_pretrained("{config["auto_encoder"]}", torch_dtype=data_type).to(device)') |
|
|
|
if str(config["enable_vae_slicing"]).lower() != 'false': code.append("pipeline.enable_vae_slicing()") |
|
if str(config["enable_vae_tiling"]).lower() != 'false': code.append("pipeline.enable_vae_tiling()") |
|
|
|
|
|
if str(config['refiner']).lower() != 'none': |
|
code.append(f'''refiner = DiffusionPipeline.from_pretrained( |
|
"{config['refiner']}", |
|
text_encoder_2 = base.text_encoder_2, |
|
vae = base.vae, |
|
torch_dtype = data_t ype, |
|
use_safetensors = use_safetensors, |
|
variant=variant, |
|
).to(device)''') |
|
|
|
if str(config["cpu_offload"]).lower() != 'false': code.append("refiner.enable_model_cpu_offload()") |
|
if str(config["enable_vae_slicing"]).lower() != 'false': code.append("refiner.enable_vae_slicing()") |
|
if str(config["enable_vae_tiling"]).lower() != 'false': code.append("refiner.enable_vae_tiling()") |
|
|
|
|
|
code.append(f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}') |
|
if str(config["safety_checker"]).lower() == 'false': |
|
code.append(f'pipeline.safety_checker = None') |
|
|
|
|
|
if str(config["scheduler"]).lower() != 'none': |
|
code.append(f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)') |
|
|
|
|
|
if config['manual_seed'] is None or config['manual_seed'] == '' or int(config['manual_seed']) < 0: |
|
code.append(f'# manual_seed = {config["manual_seed"]}') |
|
code.append(f'generator = None') |
|
else: |
|
code.append(f'manual_seed = {config["manual_seed"]}') |
|
code.append(f'generator = torch.manual_seed(manual_seed)') |
|
|
|
|
|
if str(config["adapter_textual_inversion"]).lower() != 'none' and str(config["adapter_textual_inversion"]).lower() != 'null' and str(config["adapter_textual_inversion"]).lower() != '': |
|
code.append(f'pipeline.load_textual_inversion("{config["adapter_textual_inversion"]}", token="{config["adapter_textual_inversion_token"]}")') |
|
|
|
if len(config["adapter_lora"]) > 0 and len(config["adapter_lora"]) == len(config["adapter_lora_weight"]): |
|
adapter_lora_balancing = [] |
|
for adapter_lora_index, adapter_lora in enumerate(config["adapter_lora"]): |
|
if str(config["adapter_lora_weight"][adapter_lora_index]).lower() != 'none': |
|
code.append(f'pipeline.load_lora_weights("{adapter_lora}", weight_name="{config["adapter_lora_weight"][adapter_lora_index]}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")') |
|
else: |
|
code.append(f'pipeline.load_lora_weights("{adapter_lora}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")') |
|
adapter_lora_balancing.append(config["adapter_lora_balancing"][adapter_lora]) |
|
|
|
code.append(f'adapter_weights = {adapter_lora_balancing}') |
|
code.append(f'pipeline.set_adapters({config["adapter_lora_token"]}, adapter_weights=adapter_weights)') |
|
|
|
cross_attention_kwargs = '{"scale": ' + config["lora_scale"] + '}' |
|
|
|
else: |
|
|
|
cross_attention_kwargs = 'None' |
|
|
|
code.append(f'prompt = "{config["prompt"]} {config["trigger_token"]} {config["adapter_textual_inversion_token"]} {", ".join(config["adapter_lora_token"])}"') |
|
code.append(f'negative_prompt = "{config["negative_prompt"]}"') |
|
code.append(f'inference_steps = {config["inference_steps"]}') |
|
code.append(f'guidance_scale = {config["guidance_scale"]}') |
|
|
|
code.append(f'''image = pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
generator=generator, |
|
num_inference_steps=inference_steps, |
|
cross_attention_kwargs={cross_attention_kwargs}, |
|
guidance_scale=guidance_scale).images |
|
''') |
|
|
|
if str(config['refiner']).lower() != 'none': |
|
code.append(f'''image = refiner( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=inference_steps, |
|
image=image |
|
).images[0]''') |
|
|
|
code.append('image[0]') |
|
|
|
return '\r\n'.join(code) |