File size: 2,588 Bytes
b96c8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from diffusers import (
    DDPMScheduler,
    DDIMScheduler,
    PNDMScheduler,
    LMSDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    DPMSolverMultistepScheduler,
)
import base64

def get_variant(str_variant):
 
    if str(str_variant).lower() == 'none':
        return None
    else:
        return str_variant

def get_bool(str_bool):
    
    if str(str_bool).lower() == 'false':
        return False
    else:
        return True
            

def get_data_type(str_data_type):
    
    if str_data_type == "bfloat16":
        return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
    else:
        return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory

def get_tensorfloat32(allow_tensorfloat32):
        
    return True if str(allow_tensorfloat32).lower() == 'true' else False

def get_scheduler(scheduler, pipeline_config):
    
    if scheduler == "DDPMScheduler":
        return DDPMScheduler.from_config(pipeline_config)
    elif scheduler == "DDIMScheduler":
        return DDIMScheduler.from_config(pipeline_config)
    elif scheduler == "PNDMScheduler":
        return PNDMScheduler.from_config(pipeline_config)
    elif scheduler == "LMSDiscreteScheduler":
        return LMSDiscreteScheduler.from_config(pipeline_config)
    elif scheduler == "EulerAncestralDiscreteScheduler":
        return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
    elif scheduler == "EulerDiscreteScheduler":
        return EulerDiscreteScheduler.from_config(pipeline_config)
    elif scheduler == "DPMSolverMultistepScheduler":
        return DPMSolverMultistepScheduler.from_config(pipeline_config)
    else:
        return DPMSolverMultistepScheduler.from_config(pipeline_config)

def dict_list_to_markdown_table(config_history):

    if not config_history:
        return ""

    headers = list(config_history[0].keys())
    markdown_table = "| share | " + " | ".join(headers) + " |\n"
    markdown_table += "| ---  | " + " | ".join(["---"] * len(headers)) + " |\n"

    for index, config in enumerate(config_history):
            
        encoded_config = base64.b64encode(str(config).encode()).decode()
        share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>'
        markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"

    markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'

    return markdown_table