File size: 3,651 Bytes
76a12b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import functools
import random
from pathlib import Path

import yaml

from modules import shared
from modules.loaders import loaders_samplers
from modules.logging_colors import logger


def default_preset():
    return {
        'temperature': 1,
        'temperature_last': False,
        'dynamic_temperature': False,
        'dynatemp_low': 1,
        'dynatemp_high': 1,
        'dynatemp_exponent': 1,
        'top_p': 1,
        'min_p': 0,
        'top_k': 0,
        'repetition_penalty': 1,
        'presence_penalty': 0,
        'frequency_penalty': 0,
        'repetition_penalty_range': 1024,
        'typical_p': 1,
        'tfs': 1,
        'top_a': 0,
        'epsilon_cutoff': 0,
        'eta_cutoff': 0,
        'guidance_scale': 1,
        'penalty_alpha': 0,
        'mirostat_mode': 0,
        'mirostat_tau': 5,
        'mirostat_eta': 0.1,
        'do_sample': True,
        'encoder_repetition_penalty': 1,
        'no_repeat_ngram_size': 0,
        'min_length': 0,
        'num_beams': 1,
        'length_penalty': 1,
        'early_stopping': False,
    }


def presets_params():
    return [k for k in default_preset()]


def load_preset(name):
    generate_params = default_preset()
    if name not in ['None', None, '']:
        path = Path(f'presets/{name}.yaml')
        if path.exists():
            with open(path, 'r') as infile:
                preset = yaml.safe_load(infile)

            for k in preset:
                generate_params[k] = preset[k]
        else:
            logger.error(f"The preset \"{name}\" does not exist under \"{path}\". Using the default parameters.")

    return generate_params


@functools.cache
def load_preset_memoized(name):
    return load_preset(name)


def load_preset_for_ui(name, state):
    generate_params = load_preset(name)
    state.update(generate_params)
    return state, *[generate_params[k] for k in presets_params()]


def random_preset(state):
    params_and_values = {
        'remove_tail_tokens': {
            'top_p': [0.5, 0.8, 0.9, 0.95, 0.99],
            'min_p': [0.5, 0.2, 0.1, 0.05, 0.01],
            'top_k': [3, 5, 10, 20, 30, 40],
            'typical_p': [0.2, 0.575, 0.95],
            'tfs': [0.5, 0.8, 0.9, 0.95, 0.99],
            'top_a': [0.5, 0.2, 0.1, 0.05, 0.01],
            'epsilon_cutoff': [1, 3, 5, 7, 9],
            'eta_cutoff': [3, 6, 9, 12, 15, 18],
        },
        'flatten_distribution': {
            'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0],
        },
        'repetition': {
            'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
            'presence_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
            'frequency_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
        },
        'other': {
            'temperature_last': [True, False],
        }
    }

    generate_params = default_preset()
    for cat in params_and_values:
        choices = list(params_and_values[cat].keys())
        if shared.args.loader is not None:
            choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]

        if len(choices) > 0:
            choice = random.choice(choices)
            generate_params[choice] = random.choice(params_and_values[cat][choice])

    state.update(generate_params)
    return state, *[generate_params[k] for k in presets_params()]


def generate_preset_yaml(state):
    defaults = default_preset()
    data = {k: state[k] for k in presets_params()}

    # Remove entries that are identical to the defaults
    for k in list(data.keys()):
        if data[k] == defaults[k]:
            del data[k]

    return yaml.dump(data, sort_keys=False)