pictero / config.py
n42's picture
adding trigger token config param
a447492
raw
history blame
7.03 kB
import gradio as gr
import base64
import json
import torch
from gradio import Request
from gradio.context import Context
def persist(component):
sessions = {}
def resume_session(value, request: Request):
return sessions.get(request.username, value)
def update_session(value, request: Request):
sessions[request.username] = value
Context.root_block.load(resume_session, inputs=[component], outputs=component)
component.change(update_session, inputs=[component])
return component
def get_initial_config():
device = None
data_type = 'float16'
allow_tensorfloat32 = "False"
if torch.cuda.is_available():
device = "cuda"
data_type = 'bfloat16'
allow_tensorfloat32 = "True"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
config = {
"device": device,
"model": None,
"cpu_offload": "False",
"scheduler": None,
"variant": None,
"allow_tensorfloat32": allow_tensorfloat32,
"use_safetensors": "False",
"data_type": data_type,
"refiner": "none",
"safety_checker": "False",
"requires_safety_checker": "False",
"manual_seed": 42,
"inference_steps": 10,
"guidance_scale": 0.5,
"prompt": 'A white rabbit',
"trigger_token": '',
"negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
}
return config
def get_config_from_url(initial_config, request: Request):
encoded_params = request.request.query_params.get('config')
return_config = {}
# get configuration from URL if GET parameter `share` is set
if encoded_params is not None:
decoded_params = base64.b64decode(encoded_params)
decoded_params = decoded_params.decode('utf-8')
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False')
dict_params = json.loads(decoded_params)
return_config = dict_params
# otherwise use default initial config
else:
# initial_config = initial_config.replace("'", '"').replace('None', 'null').replace('False', 'False')
# return_config = json.loads(initial_config)
return_config = initial_config
return [return_config['model'],
return_config['device'],
return_config['cpu_offload'],
return_config['use_safetensors'],
return_config['data_type'],
return_config['refiner'],
return_config['variant'],
return_config['safety_checker'],
return_config['requires_safety_checker'],
return_config['scheduler'],
return_config['prompt'],
return_config['trigger_token'],
return_config['negative_prompt'],
return_config['inference_steps'],
return_config['manual_seed'],
return_config['guidance_scale']
]
def load_app_config():
try:
with open('appConfig.json', 'r') as f:
appConfig = json.load(f)
except FileNotFoundError:
print("App config file not found.")
except json.JSONDecodeError:
print("Error decoding JSON in app config file.")
except Exception as e:
print("An error occurred while loading app config:", str(e))
return appConfig
def set_config(config, key, value):
# str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
# config = json.loads(str_config)
config[key] = value
# encoded_config = base64.b64encode(str(config).encode()).decode()
# share_link = f'<br /><a target="_blank" href="?config={encoded_config}">share</a>'
return config
def assemble_code(str_config):
# str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
# config = json.loads(str_config)
config = str_config
code = []
code.append(f'''device = "{config['device']}"''')
if config['data_type'] == "bfloat16":
code.append('data_type = torch.bfloat16')
else:
code.append('data_type = torch.float16')
code.append(f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}')
if str(config["variant"]) == 'None':
code.append(f'variant = {config["variant"]}')
else:
code.append(f'variant = "{config["variant"]}"')
code.append(f'''use_safetensors = {config["use_safetensors"]}''')
code.append(f'''pipeline = DiffusionPipeline.from_pretrained(
"{config['model']}",
use_safetensors=use_safetensors,
torch_dtype=data_type,
variant=variant).to(device)''')
if str(config["cpu_offload"]).lower() == 'false': code.append("pipeline.enable_model_cpu_offload()")
if config['refiner'].lower() != 'none':
code.append(f'''refiner = DiffusionPipeline.from_pretrained(
"{config['refiner']}",
text_encoder_2 = base.text_encoder_2,
vae = base.vae,
torch_dtype = data_type,
use_safetensors = use_safetensors,
variant=variant,
).to(device)''')
if str(config["use_safetensors"]).lower() == 'false': code.append("refiner.enable_model_cpu_offload()")
code.append(f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}')
if str(config["safety_checker"]).lower() == 'false':
code.append(f'pipeline.safety_checker = None')
code.append(f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)')
if config['manual_seed'] < 0 or config['manual_seed'] is None or config['manual_seed'] == '':
code.append(f'# manual_seed = {config["manual_seed"]}')
code.append(f'generator = torch.Generator("{config["device"]}")')
else:
code.append(f'manual_seed = {config["manual_seed"]}')
code.append(f'generator = torch.manual_seed(manual_seed)')
code.append(f'prompt = "{config["prompt"]} {config["trigger_token"]}"')
code.append(f'negative_prompt = "{config["negative_prompt"]}"')
code.append(f'inference_steps = {config["inference_steps"]}')
code.append(f'guidance_scale = {config["guidance_scale"]}')
code.append(f'''image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
num_inference_steps=inference_steps,
guidance_scale=guidance_scale).images
''')
if config['refiner'].lower() != 'none':
code.append(f'''image = refiner(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=inference_steps,
image=image
).images[0]''')
code.append('image[0]')
return '\r\n'.join(code)