fix launch
Browse files- sagemaker/fixing/webui/styles.py +240 -0
sagemaker/fixing/webui/styles.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FIX ERROR LUNCH by ANXETY for SAGEMAKER
|
2 |
+
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
from modules import errors
|
6 |
+
import csv
|
7 |
+
import os
|
8 |
+
import typing
|
9 |
+
import shutil
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Optional, Union, List
|
13 |
+
|
14 |
+
|
15 |
+
class PromptStyle(typing.NamedTuple): # FIXING
|
16 |
+
name: str
|
17 |
+
prompt: Optional[str]
|
18 |
+
negative_prompt: Optional[str]
|
19 |
+
path: Optional[str] = None
|
20 |
+
|
21 |
+
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
22 |
+
if "{prompt}" in style_prompt:
|
23 |
+
res = style_prompt.replace("{prompt}", prompt)
|
24 |
+
else:
|
25 |
+
parts = filter(None, (prompt.strip(), style_prompt.strip()))
|
26 |
+
res = ", ".join(parts)
|
27 |
+
|
28 |
+
return res
|
29 |
+
|
30 |
+
|
31 |
+
def apply_styles_to_prompt(prompt, styles):
|
32 |
+
for style in styles:
|
33 |
+
prompt = merge_prompts(style, prompt)
|
34 |
+
|
35 |
+
return prompt
|
36 |
+
|
37 |
+
|
38 |
+
def extract_style_text_from_prompt(style_text, prompt):
|
39 |
+
"""This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
|
40 |
+
|
41 |
+
extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
|
42 |
+
extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
|
43 |
+
extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
|
44 |
+
"""
|
45 |
+
|
46 |
+
stripped_prompt = prompt.strip()
|
47 |
+
stripped_style_text = style_text.strip()
|
48 |
+
|
49 |
+
if "{prompt}" in stripped_style_text:
|
50 |
+
left, right = stripped_style_text.split("{prompt}", 2)
|
51 |
+
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
52 |
+
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
53 |
+
return True, prompt
|
54 |
+
else:
|
55 |
+
if stripped_prompt.endswith(stripped_style_text):
|
56 |
+
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
57 |
+
|
58 |
+
if prompt.endswith(', '):
|
59 |
+
prompt = prompt[:-2]
|
60 |
+
|
61 |
+
return True, prompt
|
62 |
+
|
63 |
+
return False, prompt
|
64 |
+
|
65 |
+
|
66 |
+
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
67 |
+
"""
|
68 |
+
Takes a style and compares it to the prompt and negative prompt. If the style
|
69 |
+
matches, returns True plus the prompt and negative prompt with the style text
|
70 |
+
removed. Otherwise, returns False with the original prompt and negative prompt.
|
71 |
+
"""
|
72 |
+
if not style.prompt and not style.negative_prompt:
|
73 |
+
return False, prompt, negative_prompt
|
74 |
+
|
75 |
+
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
76 |
+
if not match_positive:
|
77 |
+
return False, prompt, negative_prompt
|
78 |
+
|
79 |
+
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
80 |
+
if not match_negative:
|
81 |
+
return False, prompt, negative_prompt
|
82 |
+
|
83 |
+
return True, extracted_positive, extracted_negative
|
84 |
+
|
85 |
+
|
86 |
+
class StyleDatabase:
|
87 |
+
# def __init__(self, paths: list[str | Path]):
|
88 |
+
def __init__(self, paths: List[Union[str, Path]]): # FIXING
|
89 |
+
self.no_style = PromptStyle("None", "", "", None)
|
90 |
+
self.styles = {}
|
91 |
+
self.paths = paths
|
92 |
+
self.all_styles_files: list[Path] = []
|
93 |
+
|
94 |
+
folder, file = os.path.split(self.paths[0])
|
95 |
+
if '*' in file or '?' in file:
|
96 |
+
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
|
97 |
+
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
|
98 |
+
self.paths.insert(0, self.default_path)
|
99 |
+
else:
|
100 |
+
self.default_path = Path(self.paths[0])
|
101 |
+
|
102 |
+
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
103 |
+
|
104 |
+
self.reload()
|
105 |
+
|
106 |
+
def reload(self):
|
107 |
+
"""
|
108 |
+
Clears the style database and reloads the styles from the CSV file(s)
|
109 |
+
matching the path used to initialize the database.
|
110 |
+
"""
|
111 |
+
self.styles.clear()
|
112 |
+
|
113 |
+
# scans for all styles files
|
114 |
+
all_styles_files = []
|
115 |
+
for pattern in self.paths:
|
116 |
+
folder, file = os.path.split(pattern)
|
117 |
+
if '*' in file or '?' in file:
|
118 |
+
found_files = Path(folder).glob(file)
|
119 |
+
[all_styles_files.append(file) for file in found_files]
|
120 |
+
else:
|
121 |
+
# if os.path.exists(pattern):
|
122 |
+
all_styles_files.append(Path(pattern))
|
123 |
+
|
124 |
+
# Remove any duplicate entries
|
125 |
+
seen = set()
|
126 |
+
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
|
127 |
+
|
128 |
+
for styles_file in self.all_styles_files:
|
129 |
+
if len(all_styles_files) > 1:
|
130 |
+
# add divider when more than styles file
|
131 |
+
# '---------------- STYLES ----------------'
|
132 |
+
divider = f' {styles_file.stem.upper()} '.center(40, '-')
|
133 |
+
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
|
134 |
+
if styles_file.is_file():
|
135 |
+
self.load_from_csv(styles_file)
|
136 |
+
|
137 |
+
# def load_from_csv(self, path: str | Path):
|
138 |
+
def load_from_csv(self, path: Union[str, Path]): # FIXING
|
139 |
+
try:
|
140 |
+
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
141 |
+
reader = csv.DictReader(file, skipinitialspace=True)
|
142 |
+
for row in reader:
|
143 |
+
# Ignore empty rows or rows starting with a comment
|
144 |
+
if not row or row["name"].startswith("#"):
|
145 |
+
continue
|
146 |
+
# Support loading old CSV format with "name, text"-columns
|
147 |
+
prompt = row["prompt"] if "prompt" in row else row["text"]
|
148 |
+
negative_prompt = row.get("negative_prompt", "")
|
149 |
+
# Add style to database
|
150 |
+
self.styles[row["name"]] = PromptStyle(
|
151 |
+
row["name"], prompt, negative_prompt, str(path)
|
152 |
+
)
|
153 |
+
except Exception:
|
154 |
+
errors.report(f'Error loading styles from {path}: ', exc_info=True)
|
155 |
+
|
156 |
+
def get_style_paths(self) -> set:
|
157 |
+
"""Returns a set of all distinct paths of files that styles are loaded from."""
|
158 |
+
# Update any styles without a path to the default path
|
159 |
+
for style in list(self.styles.values()):
|
160 |
+
if not style.path:
|
161 |
+
self.styles[style.name] = style._replace(path=str(self.default_path))
|
162 |
+
|
163 |
+
# Create a list of all distinct paths, including the default path
|
164 |
+
style_paths = set()
|
165 |
+
style_paths.add(str(self.default_path))
|
166 |
+
for _, style in self.styles.items():
|
167 |
+
if style.path:
|
168 |
+
style_paths.add(style.path)
|
169 |
+
|
170 |
+
# Remove any paths for styles that are just list dividers
|
171 |
+
style_paths.discard("do_not_save")
|
172 |
+
|
173 |
+
return style_paths
|
174 |
+
|
175 |
+
def get_style_prompts(self, styles):
|
176 |
+
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
177 |
+
|
178 |
+
def get_negative_style_prompts(self, styles):
|
179 |
+
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
180 |
+
|
181 |
+
def apply_styles_to_prompt(self, prompt, styles):
|
182 |
+
return apply_styles_to_prompt(
|
183 |
+
prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
|
184 |
+
)
|
185 |
+
|
186 |
+
def apply_negative_styles_to_prompt(self, prompt, styles):
|
187 |
+
return apply_styles_to_prompt(
|
188 |
+
prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
189 |
+
)
|
190 |
+
|
191 |
+
def save_styles(self, path: str = None) -> None:
|
192 |
+
# The path argument is deprecated, but kept for backwards compatibility
|
193 |
+
|
194 |
+
style_paths = self.get_style_paths()
|
195 |
+
|
196 |
+
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
197 |
+
|
198 |
+
for style_path in style_paths:
|
199 |
+
# Always keep a backup file around
|
200 |
+
if os.path.exists(style_path):
|
201 |
+
shutil.copy(style_path, f"{style_path}.bak")
|
202 |
+
|
203 |
+
# Write the styles to the CSV file
|
204 |
+
with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
|
205 |
+
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
|
206 |
+
writer.writeheader()
|
207 |
+
for style in (s for s in self.styles.values() if s.path == style_path):
|
208 |
+
# Skip style list dividers, e.g. "STYLES.CSV"
|
209 |
+
if style.name.lower().strip("# ") in csv_names:
|
210 |
+
continue
|
211 |
+
# Write style fields, ignoring the path field
|
212 |
+
writer.writerow(
|
213 |
+
{k: v for k, v in style._asdict().items() if k != "path"}
|
214 |
+
)
|
215 |
+
|
216 |
+
def extract_styles_from_prompt(self, prompt, negative_prompt):
|
217 |
+
extracted = []
|
218 |
+
|
219 |
+
applicable_styles = list(self.styles.values())
|
220 |
+
|
221 |
+
while True:
|
222 |
+
found_style = None
|
223 |
+
|
224 |
+
for style in applicable_styles:
|
225 |
+
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
|
226 |
+
style, prompt, negative_prompt
|
227 |
+
)
|
228 |
+
if is_match:
|
229 |
+
found_style = style
|
230 |
+
prompt = new_prompt
|
231 |
+
negative_prompt = new_neg_prompt
|
232 |
+
break
|
233 |
+
|
234 |
+
if not found_style:
|
235 |
+
break
|
236 |
+
|
237 |
+
applicable_styles.remove(found_style)
|
238 |
+
extracted.append(found_style.name)
|
239 |
+
|
240 |
+
return list(reversed(extracted)), prompt, negative_prompt
|