feizhengcong's picture
Upload 198 files
074c857
raw
history blame
5.5 kB
import re
def sanitize(prompt):
whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')
tmp = ''.join(filter(whitelist.__contains__, prompt))
return tmp.replace(' ', '_')
def check_is_number(value):
float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$'
return re.match(float_pattern, value)
# prompt weighting with colons and number coefficients (like 'bacon:0.75 eggs:0.25')
# borrowed from https://github.com/kylewlacy/stable-diffusion/blob/0a4397094eb6e875f98f9d71193e350d859c4220/ldm/dream/conditioning.py
# and https://github.com/raefu/stable-diffusion-automatic/blob/unstablediffusion/modules/processing.py
def get_uc_and_c(prompts, model, args, frame = 0):
prompt = prompts[0] # they are the same in a batch anyway
# get weighted sub-prompts
negative_subprompts, positive_subprompts = split_weighted_subprompts(
prompt, frame, not args.normalize_prompt_weights
)
uc = get_learned_conditioning(model, negative_subprompts, "", args, -1)
c = get_learned_conditioning(model, positive_subprompts, prompt, args, 1)
return (uc, c)
def get_learned_conditioning(model, weighted_subprompts, text, args, sign = 1):
if len(weighted_subprompts) < 1:
log_tokenization(text, model, args.log_weighted_subprompts, sign)
c = model.get_learned_conditioning(args.n_samples * [text])
else:
c = None
for subtext, subweight in weighted_subprompts:
log_tokenization(subtext, model, args.log_weighted_subprompts, sign * subweight)
if c is None:
c = model.get_learned_conditioning(args.n_samples * [subtext])
c *= subweight
else:
c.add_(model.get_learned_conditioning(args.n_samples * [subtext]), alpha=subweight)
return c
def parse_weight(match, frame = 0)->float:
import numexpr
w_raw = match.group("weight")
if w_raw == None:
return 1
if check_is_number(w_raw):
return float(w_raw)
else:
t = frame
if len(w_raw) < 3:
print('the value inside `-characters cannot represent a math function')
return 1
return float(numexpr.evaluate(w_raw[1:-1]))
def normalize_prompt_weights(parsed_prompts):
if len(parsed_prompts) == 0:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
def split_weighted_subprompts(text, frame = 0, skip_normalize=False):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight>(( # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)|( # or
`[\S\s]*?`# a math function
)))? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
negative_prompts = []
positive_prompts = []
for match in re.finditer(prompt_parser, text):
w = parse_weight(match, frame)
if w < 0:
# negating the sign as we'll feed this to uc
negative_prompts.append((match.group("prompt").replace("\\:", ":"), -w))
elif w > 0:
positive_prompts.append((match.group("prompt").replace("\\:", ":"), w))
if skip_normalize:
return (negative_prompts, positive_prompts)
return (normalize_prompt_weights(negative_prompts), normalize_prompt_weights(positive_prompts))
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)