File size: 2,701 Bytes
b96c8c5 56914a9 b96c8c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
from diffusers import (
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
DPMSolverMultistepScheduler,
)
import base64
def get_variant(str_variant):
if str(str_variant).lower() == 'none':
return None
else:
return str_variant
def get_bool(str_bool):
if str(str_bool).lower() == 'false':
return False
else:
return True
def get_data_type(str_data_type):
if str_data_type == "bfloat16":
return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
if str_data_type == "float32":
return torch.float32 # BFloat16 is not supported on MPS as of 01/2024
else:
return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
def get_tensorfloat32(allow_tensorfloat32):
return True if str(allow_tensorfloat32).lower() == 'true' else False
def get_scheduler(scheduler, pipeline_config):
if scheduler == "DDPMScheduler":
return DDPMScheduler.from_config(pipeline_config)
elif scheduler == "DDIMScheduler":
return DDIMScheduler.from_config(pipeline_config)
elif scheduler == "PNDMScheduler":
return PNDMScheduler.from_config(pipeline_config)
elif scheduler == "LMSDiscreteScheduler":
return LMSDiscreteScheduler.from_config(pipeline_config)
elif scheduler == "EulerAncestralDiscreteScheduler":
return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
elif scheduler == "EulerDiscreteScheduler":
return EulerDiscreteScheduler.from_config(pipeline_config)
elif scheduler == "DPMSolverMultistepScheduler":
return DPMSolverMultistepScheduler.from_config(pipeline_config)
else:
return DPMSolverMultistepScheduler.from_config(pipeline_config)
def dict_list_to_markdown_table(config_history):
if not config_history:
return ""
headers = list(config_history[0].keys())
markdown_table = "| share | " + " | ".join(headers) + " |\n"
markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
for index, config in enumerate(config_history):
encoded_config = base64.b64encode(str(config).encode()).decode()
share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>'
markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"
markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
return markdown_table
|