|
import gradio as gr |
|
import base64 |
|
import json |
|
import torch |
|
from gradio import Request |
|
from gradio.context import Context |
|
|
|
|
|
def persist(component): |
|
sessions = {} |
|
|
|
def resume_session(value, request: Request): |
|
return sessions.get(request.username, value) |
|
|
|
def update_session(value, request: Request): |
|
sessions[request.username] = value |
|
|
|
Context.root_block.load(resume_session, inputs=[component], outputs=component) |
|
component.change(update_session, inputs=[component]) |
|
|
|
return component |
|
|
|
def get_initial_config(): |
|
|
|
device = None |
|
data_type = 'float16' |
|
allow_tensorfloat32 = "False" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
data_type = 'bfloat16' |
|
allow_tensorfloat32 = "True" |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
|
|
config = { |
|
"device": device, |
|
"model": None, |
|
"cpu_offload": "False", |
|
"scheduler": None, |
|
"variant": None, |
|
"allow_tensorfloat32": allow_tensorfloat32, |
|
"use_safetensors": "False", |
|
"data_type": data_type, |
|
"refiner": "none", |
|
"safety_checker": "False", |
|
"requires_safety_checker": "False", |
|
"manual_seed": 42, |
|
"inference_steps": 10, |
|
"guidance_scale": 0.5, |
|
"prompt": 'A white rabbit', |
|
"trigger_token": '', |
|
"negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly', |
|
} |
|
|
|
return config |
|
|
|
def get_config_from_url(initial_config, request: Request): |
|
|
|
encoded_params = request.request.query_params.get('config') |
|
return_config = {} |
|
|
|
|
|
if encoded_params is not None: |
|
decoded_params = base64.b64decode(encoded_params) |
|
decoded_params = decoded_params.decode('utf-8') |
|
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False') |
|
dict_params = json.loads(decoded_params) |
|
|
|
return_config = dict_params |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return_config = initial_config |
|
|
|
return [return_config['model'], |
|
return_config['device'], |
|
return_config['cpu_offload'], |
|
return_config['use_safetensors'], |
|
return_config['data_type'], |
|
return_config['refiner'], |
|
return_config['variant'], |
|
return_config['safety_checker'], |
|
return_config['requires_safety_checker'], |
|
return_config['scheduler'], |
|
return_config['prompt'], |
|
return_config['trigger_token'], |
|
return_config['negative_prompt'], |
|
return_config['inference_steps'], |
|
return_config['manual_seed'], |
|
return_config['guidance_scale'] |
|
] |
|
|
|
def load_app_config(): |
|
try: |
|
with open('appConfig.json', 'r') as f: |
|
appConfig = json.load(f) |
|
except FileNotFoundError: |
|
print("App config file not found.") |
|
except json.JSONDecodeError: |
|
print("Error decoding JSON in app config file.") |
|
except Exception as e: |
|
print("An error occurred while loading app config:", str(e)) |
|
|
|
return appConfig |
|
|
|
def set_config(config, key, value): |
|
|
|
|
|
|
|
config[key] = value |
|
|
|
|
|
|
|
return config |
|
|
|
def assemble_code(str_config): |
|
|
|
|
|
|
|
config = str_config |
|
|
|
code = [] |
|
|
|
code.append(f'''device = "{config['device']}"''') |
|
if config['data_type'] == "bfloat16": |
|
code.append('data_type = torch.bfloat16') |
|
else: |
|
code.append('data_type = torch.float16') |
|
code.append(f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}') |
|
|
|
if str(config["variant"]) == 'None': |
|
code.append(f'variant = {config["variant"]}') |
|
else: |
|
code.append(f'variant = "{config["variant"]}"') |
|
|
|
|
|
code.append(f'''use_safetensors = {config["use_safetensors"]}''') |
|
|
|
code.append(f'''pipeline = DiffusionPipeline.from_pretrained( |
|
"{config['model']}", |
|
use_safetensors=use_safetensors, |
|
torch_dtype=data_type, |
|
variant=variant).to(device)''') |
|
|
|
if str(config["cpu_offload"]).lower() == 'false': code.append("pipeline.enable_model_cpu_offload()") |
|
|
|
if config['refiner'].lower() != 'none': |
|
code.append(f'''refiner = DiffusionPipeline.from_pretrained( |
|
"{config['refiner']}", |
|
text_encoder_2 = base.text_encoder_2, |
|
vae = base.vae, |
|
torch_dtype = data_type, |
|
use_safetensors = use_safetensors, |
|
variant=variant, |
|
).to(device)''') |
|
|
|
if str(config["use_safetensors"]).lower() == 'false': code.append("refiner.enable_model_cpu_offload()") |
|
|
|
code.append(f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}') |
|
|
|
if str(config["safety_checker"]).lower() == 'false': |
|
code.append(f'pipeline.safety_checker = None') |
|
|
|
code.append(f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)') |
|
|
|
if config['manual_seed'] < 0 or config['manual_seed'] is None or config['manual_seed'] == '': |
|
code.append(f'# manual_seed = {config["manual_seed"]}') |
|
code.append(f'generator = torch.Generator("{config["device"]}")') |
|
else: |
|
code.append(f'manual_seed = {config["manual_seed"]}') |
|
code.append(f'generator = torch.manual_seed(manual_seed)') |
|
|
|
code.append(f'prompt = "{config["prompt"]} {config["trigger_token"]}"') |
|
code.append(f'negative_prompt = "{config["negative_prompt"]}"') |
|
code.append(f'inference_steps = {config["inference_steps"]}') |
|
code.append(f'guidance_scale = {config["guidance_scale"]}') |
|
|
|
code.append(f'''image = pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
generator=generator, |
|
num_inference_steps=inference_steps, |
|
guidance_scale=guidance_scale).images |
|
''') |
|
|
|
if config['refiner'].lower() != 'none': |
|
code.append(f'''image = refiner( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=inference_steps, |
|
image=image |
|
).images[0]''') |
|
|
|
code.append('image[0]') |
|
|
|
return '\r\n'.join(code) |