|
import torch |
|
from diffusers import ( |
|
DDPMScheduler, |
|
DDIMScheduler, |
|
PNDMScheduler, |
|
LMSDiscreteScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
DPMSolverMultistepScheduler, |
|
) |
|
import base64 |
|
|
|
def get_variant(str_variant): |
|
|
|
if str(str_variant).lower() == 'none': |
|
return None |
|
else: |
|
return str_variant |
|
|
|
def get_bool(str_bool): |
|
|
|
if str(str_bool).lower() == 'false': |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def get_data_type(str_data_type): |
|
|
|
if str_data_type == "bfloat16": |
|
return torch.bfloat16 |
|
else: |
|
return torch.float16 |
|
|
|
def get_tensorfloat32(allow_tensorfloat32): |
|
|
|
return True if str(allow_tensorfloat32).lower() == 'true' else False |
|
|
|
def get_scheduler(scheduler, pipeline_config): |
|
|
|
if scheduler == "DDPMScheduler": |
|
return DDPMScheduler.from_config(pipeline_config) |
|
elif scheduler == "DDIMScheduler": |
|
return DDIMScheduler.from_config(pipeline_config) |
|
elif scheduler == "PNDMScheduler": |
|
return PNDMScheduler.from_config(pipeline_config) |
|
elif scheduler == "LMSDiscreteScheduler": |
|
return LMSDiscreteScheduler.from_config(pipeline_config) |
|
elif scheduler == "EulerAncestralDiscreteScheduler": |
|
return EulerAncestralDiscreteScheduler.from_config(pipeline_config) |
|
elif scheduler == "EulerDiscreteScheduler": |
|
return EulerDiscreteScheduler.from_config(pipeline_config) |
|
elif scheduler == "DPMSolverMultistepScheduler": |
|
return DPMSolverMultistepScheduler.from_config(pipeline_config) |
|
else: |
|
return DPMSolverMultistepScheduler.from_config(pipeline_config) |
|
|
|
def dict_list_to_markdown_table(config_history): |
|
|
|
if not config_history: |
|
return "" |
|
|
|
headers = list(config_history[0].keys()) |
|
markdown_table = "| share | " + " | ".join(headers) + " |\n" |
|
markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n" |
|
|
|
for index, config in enumerate(config_history): |
|
|
|
encoded_config = base64.b64encode(str(config).encode()).decode() |
|
share_link = f'<a target="_blank" href="?config={encoded_config}">π</a>' |
|
markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n" |
|
|
|
markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>' |
|
|
|
return markdown_table |
|
|
|
|