File size: 5,504 Bytes
074c857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import re
def sanitize(prompt):
whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')
tmp = ''.join(filter(whitelist.__contains__, prompt))
return tmp.replace(' ', '_')
def check_is_number(value):
float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$'
return re.match(float_pattern, value)
# prompt weighting with colons and number coefficients (like 'bacon:0.75 eggs:0.25')
# borrowed from https://github.com/kylewlacy/stable-diffusion/blob/0a4397094eb6e875f98f9d71193e350d859c4220/ldm/dream/conditioning.py
# and https://github.com/raefu/stable-diffusion-automatic/blob/unstablediffusion/modules/processing.py
def get_uc_and_c(prompts, model, args, frame = 0):
prompt = prompts[0] # they are the same in a batch anyway
# get weighted sub-prompts
negative_subprompts, positive_subprompts = split_weighted_subprompts(
prompt, frame, not args.normalize_prompt_weights
)
uc = get_learned_conditioning(model, negative_subprompts, "", args, -1)
c = get_learned_conditioning(model, positive_subprompts, prompt, args, 1)
return (uc, c)
def get_learned_conditioning(model, weighted_subprompts, text, args, sign = 1):
if len(weighted_subprompts) < 1:
log_tokenization(text, model, args.log_weighted_subprompts, sign)
c = model.get_learned_conditioning(args.n_samples * [text])
else:
c = None
for subtext, subweight in weighted_subprompts:
log_tokenization(subtext, model, args.log_weighted_subprompts, sign * subweight)
if c is None:
c = model.get_learned_conditioning(args.n_samples * [subtext])
c *= subweight
else:
c.add_(model.get_learned_conditioning(args.n_samples * [subtext]), alpha=subweight)
return c
def parse_weight(match, frame = 0)->float:
import numexpr
w_raw = match.group("weight")
if w_raw == None:
return 1
if check_is_number(w_raw):
return float(w_raw)
else:
t = frame
if len(w_raw) < 3:
print('the value inside `-characters cannot represent a math function')
return 1
return float(numexpr.evaluate(w_raw[1:-1]))
def normalize_prompt_weights(parsed_prompts):
if len(parsed_prompts) == 0:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
def split_weighted_subprompts(text, frame = 0, skip_normalize=False):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight>(( # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)|( # or
`[\S\s]*?`# a math function
)))? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
negative_prompts = []
positive_prompts = []
for match in re.finditer(prompt_parser, text):
w = parse_weight(match, frame)
if w < 0:
# negating the sign as we'll feed this to uc
negative_prompts.append((match.group("prompt").replace("\\:", ":"), -w))
elif w > 0:
positive_prompts.append((match.group("prompt").replace("\\:", ":"), w))
if skip_normalize:
return (negative_prompts, positive_prompts)
return (normalize_prompt_weights(negative_prompts), normalize_prompt_weights(positive_prompts))
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
) |