|
import gradio as gr |
|
import base64 |
|
import json |
|
import torch |
|
from gradio import Request |
|
from gradio.context import Context |
|
|
|
|
|
def persist(component): |
|
sessions = {} |
|
|
|
def resume_session(value, request: Request): |
|
return sessions.get(request.username, value) |
|
|
|
def update_session(value, request: Request): |
|
sessions[request.username] = value |
|
|
|
Context.root_block.load(resume_session, inputs=[component], outputs=component) |
|
component.change(update_session, inputs=[component]) |
|
|
|
return component |
|
|
|
def get_initial_config(): |
|
|
|
device = None |
|
data_type = 'float16' |
|
allow_tensorfloat32 = "False" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
data_type = 'bfloat16' |
|
allow_tensorfloat32 = "True" |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
|
|
config = { |
|
"device": device, |
|
"model": None, |
|
"scheduler": None, |
|
"variant": None, |
|
"allow_tensorfloat32": allow_tensorfloat32, |
|
"use_safetensors": "False", |
|
"data_type": data_type, |
|
"safety_checker": "False", |
|
"requires_safety_checker": "False", |
|
"manual_seed": 42, |
|
"inference_steps": 10, |
|
"guidance_scale": 0.5, |
|
"prompt": 'A white rabbit', |
|
"negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly', |
|
} |
|
|
|
return config |
|
|
|
def get_config_from_url(initial_config, request: Request): |
|
|
|
encoded_params = request.request.query_params.get('config') |
|
return_config = {} |
|
|
|
|
|
if encoded_params is not None: |
|
decoded_params = base64.b64decode(encoded_params) |
|
decoded_params = decoded_params.decode('utf-8') |
|
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False') |
|
dict_params = json.loads(decoded_params) |
|
|
|
return_config = dict_params |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return_config = initial_config |
|
|
|
return [return_config['model'], |
|
return_config['device'], |
|
return_config['use_safetensors'], |
|
return_config['data_type'], |
|
return_config['variant'], |
|
return_config['safety_checker'], |
|
return_config['requires_safety_checker'], |
|
return_config['scheduler'], |
|
return_config['prompt'], |
|
return_config['negative_prompt'], |
|
return_config['inference_steps'], |
|
return_config['manual_seed'], |
|
return_config['guidance_scale'] |
|
] |
|
|
|
def load_app_config(): |
|
try: |
|
with open('appConfig.json', 'r') as f: |
|
appConfig = json.load(f) |
|
except FileNotFoundError: |
|
print("App config file not found.") |
|
except json.JSONDecodeError: |
|
print("Error decoding JSON in app config file.") |
|
except Exception as e: |
|
print("An error occurred while loading app config:", str(e)) |
|
|
|
return appConfig |
|
|
|
def set_config(config, key, value): |
|
|
|
|
|
|
|
config[key] = value |
|
|
|
return config |
|
|
|
def assemble_code(str_config): |
|
|
|
|
|
|
|
config = str_config |
|
|
|
code = {} |
|
|
|
code['001_code'] = f'''device = "{config['device']}"''' |
|
if config['data_type'] == "bfloat16": |
|
code['002_data_type'] = 'data_type = torch.bfloat16' |
|
else: |
|
code['002_data_type'] = 'data_type = torch.float16' |
|
code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}' |
|
|
|
if str(config["variant"]) == 'None': |
|
code['004_variant'] = f'variant = {config["variant"]}' |
|
else: |
|
code['004_variant'] = f'variant = "{config["variant"]}"' |
|
|
|
|
|
code['005_use_safetensors'] = f'''use_safetensors = {config["use_safetensors"]}''' |
|
|
|
code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained( |
|
"{config['model']}", |
|
use_safetensors=use_safetensors, |
|
torch_dtype=data_type, |
|
variant=variant).to(device)''' |
|
|
|
code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}' |
|
|
|
if str(config["safety_checker"]).lower() == 'false': |
|
code['055_safety_checker'] = f'pipeline.safety_checker = None' |
|
else: |
|
code['055_safety_checker'] = '' |
|
|
|
code['060_scheduler'] = f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)' |
|
|
|
if config['manual_seed'] < 0 or config['manual_seed'] is None or config['manual_seed'] == '': |
|
code['091_manual_seed'] = f'# manual_seed = {config["manual_seed"]}' |
|
code['092_generator'] = f'generator = torch.Generator("{config["device"]}")' |
|
else: |
|
code['091_manual_seed'] = f'manual_seed = {config["manual_seed"]}' |
|
code['092_generator'] = f'generator = torch.manual_seed(manual_seed)' |
|
|
|
code["080_prompt"] = f'prompt = "{config["prompt"]}"' |
|
code["085_negative_prompt"] = f'negative_prompt = "{config["negative_prompt"]}"' |
|
code["090_inference_steps"] = f'inference_steps = {config["inference_steps"]}' |
|
code["095_guidance_scale"] = f'guidance_scale = {config["guidance_scale"]}' |
|
|
|
code["100_run_inference"] = f'''image = pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
generator=generator, |
|
num_inference_steps=inference_steps, |
|
guidance_scale=guidance_scale).images[0]''' |
|
|
|
return '\r\n'.join(value[1] for value in sorted(code.items())) |