NagisaNao commited on
Commit
b23f7ed
·
verified ·
1 Parent(s): 6014396

fix launch

Browse files
Files changed (1) hide show
  1. sagemaker/fixing/webui/styles.py +240 -0
sagemaker/fixing/webui/styles.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FIX ERROR LUNCH by ANXETY for SAGEMAKER
2
+
3
+
4
+ from pathlib import Path
5
+ from modules import errors
6
+ import csv
7
+ import os
8
+ import typing
9
+ import shutil
10
+
11
+
12
+ from typing import Optional, Union, List
13
+
14
+
15
+ class PromptStyle(typing.NamedTuple): # FIXING
16
+ name: str
17
+ prompt: Optional[str]
18
+ negative_prompt: Optional[str]
19
+ path: Optional[str] = None
20
+
21
+ def merge_prompts(style_prompt: str, prompt: str) -> str:
22
+ if "{prompt}" in style_prompt:
23
+ res = style_prompt.replace("{prompt}", prompt)
24
+ else:
25
+ parts = filter(None, (prompt.strip(), style_prompt.strip()))
26
+ res = ", ".join(parts)
27
+
28
+ return res
29
+
30
+
31
+ def apply_styles_to_prompt(prompt, styles):
32
+ for style in styles:
33
+ prompt = merge_prompts(style, prompt)
34
+
35
+ return prompt
36
+
37
+
38
+ def extract_style_text_from_prompt(style_text, prompt):
39
+ """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
40
+
41
+ extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
42
+ extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
43
+ extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
44
+ """
45
+
46
+ stripped_prompt = prompt.strip()
47
+ stripped_style_text = style_text.strip()
48
+
49
+ if "{prompt}" in stripped_style_text:
50
+ left, right = stripped_style_text.split("{prompt}", 2)
51
+ if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
52
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
53
+ return True, prompt
54
+ else:
55
+ if stripped_prompt.endswith(stripped_style_text):
56
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
57
+
58
+ if prompt.endswith(', '):
59
+ prompt = prompt[:-2]
60
+
61
+ return True, prompt
62
+
63
+ return False, prompt
64
+
65
+
66
+ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
67
+ """
68
+ Takes a style and compares it to the prompt and negative prompt. If the style
69
+ matches, returns True plus the prompt and negative prompt with the style text
70
+ removed. Otherwise, returns False with the original prompt and negative prompt.
71
+ """
72
+ if not style.prompt and not style.negative_prompt:
73
+ return False, prompt, negative_prompt
74
+
75
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
76
+ if not match_positive:
77
+ return False, prompt, negative_prompt
78
+
79
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
80
+ if not match_negative:
81
+ return False, prompt, negative_prompt
82
+
83
+ return True, extracted_positive, extracted_negative
84
+
85
+
86
+ class StyleDatabase:
87
+ # def __init__(self, paths: list[str | Path]):
88
+ def __init__(self, paths: List[Union[str, Path]]): # FIXING
89
+ self.no_style = PromptStyle("None", "", "", None)
90
+ self.styles = {}
91
+ self.paths = paths
92
+ self.all_styles_files: list[Path] = []
93
+
94
+ folder, file = os.path.split(self.paths[0])
95
+ if '*' in file or '?' in file:
96
+ # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
97
+ self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
98
+ self.paths.insert(0, self.default_path)
99
+ else:
100
+ self.default_path = Path(self.paths[0])
101
+
102
+ self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
103
+
104
+ self.reload()
105
+
106
+ def reload(self):
107
+ """
108
+ Clears the style database and reloads the styles from the CSV file(s)
109
+ matching the path used to initialize the database.
110
+ """
111
+ self.styles.clear()
112
+
113
+ # scans for all styles files
114
+ all_styles_files = []
115
+ for pattern in self.paths:
116
+ folder, file = os.path.split(pattern)
117
+ if '*' in file or '?' in file:
118
+ found_files = Path(folder).glob(file)
119
+ [all_styles_files.append(file) for file in found_files]
120
+ else:
121
+ # if os.path.exists(pattern):
122
+ all_styles_files.append(Path(pattern))
123
+
124
+ # Remove any duplicate entries
125
+ seen = set()
126
+ self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
127
+
128
+ for styles_file in self.all_styles_files:
129
+ if len(all_styles_files) > 1:
130
+ # add divider when more than styles file
131
+ # '---------------- STYLES ----------------'
132
+ divider = f' {styles_file.stem.upper()} '.center(40, '-')
133
+ self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
134
+ if styles_file.is_file():
135
+ self.load_from_csv(styles_file)
136
+
137
+ # def load_from_csv(self, path: str | Path):
138
+ def load_from_csv(self, path: Union[str, Path]): # FIXING
139
+ try:
140
+ with open(path, "r", encoding="utf-8-sig", newline="") as file:
141
+ reader = csv.DictReader(file, skipinitialspace=True)
142
+ for row in reader:
143
+ # Ignore empty rows or rows starting with a comment
144
+ if not row or row["name"].startswith("#"):
145
+ continue
146
+ # Support loading old CSV format with "name, text"-columns
147
+ prompt = row["prompt"] if "prompt" in row else row["text"]
148
+ negative_prompt = row.get("negative_prompt", "")
149
+ # Add style to database
150
+ self.styles[row["name"]] = PromptStyle(
151
+ row["name"], prompt, negative_prompt, str(path)
152
+ )
153
+ except Exception:
154
+ errors.report(f'Error loading styles from {path}: ', exc_info=True)
155
+
156
+ def get_style_paths(self) -> set:
157
+ """Returns a set of all distinct paths of files that styles are loaded from."""
158
+ # Update any styles without a path to the default path
159
+ for style in list(self.styles.values()):
160
+ if not style.path:
161
+ self.styles[style.name] = style._replace(path=str(self.default_path))
162
+
163
+ # Create a list of all distinct paths, including the default path
164
+ style_paths = set()
165
+ style_paths.add(str(self.default_path))
166
+ for _, style in self.styles.items():
167
+ if style.path:
168
+ style_paths.add(style.path)
169
+
170
+ # Remove any paths for styles that are just list dividers
171
+ style_paths.discard("do_not_save")
172
+
173
+ return style_paths
174
+
175
+ def get_style_prompts(self, styles):
176
+ return [self.styles.get(x, self.no_style).prompt for x in styles]
177
+
178
+ def get_negative_style_prompts(self, styles):
179
+ return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
180
+
181
+ def apply_styles_to_prompt(self, prompt, styles):
182
+ return apply_styles_to_prompt(
183
+ prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
184
+ )
185
+
186
+ def apply_negative_styles_to_prompt(self, prompt, styles):
187
+ return apply_styles_to_prompt(
188
+ prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
189
+ )
190
+
191
+ def save_styles(self, path: str = None) -> None:
192
+ # The path argument is deprecated, but kept for backwards compatibility
193
+
194
+ style_paths = self.get_style_paths()
195
+
196
+ csv_names = [os.path.split(path)[1].lower() for path in style_paths]
197
+
198
+ for style_path in style_paths:
199
+ # Always keep a backup file around
200
+ if os.path.exists(style_path):
201
+ shutil.copy(style_path, f"{style_path}.bak")
202
+
203
+ # Write the styles to the CSV file
204
+ with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
205
+ writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
206
+ writer.writeheader()
207
+ for style in (s for s in self.styles.values() if s.path == style_path):
208
+ # Skip style list dividers, e.g. "STYLES.CSV"
209
+ if style.name.lower().strip("# ") in csv_names:
210
+ continue
211
+ # Write style fields, ignoring the path field
212
+ writer.writerow(
213
+ {k: v for k, v in style._asdict().items() if k != "path"}
214
+ )
215
+
216
+ def extract_styles_from_prompt(self, prompt, negative_prompt):
217
+ extracted = []
218
+
219
+ applicable_styles = list(self.styles.values())
220
+
221
+ while True:
222
+ found_style = None
223
+
224
+ for style in applicable_styles:
225
+ is_match, new_prompt, new_neg_prompt = extract_original_prompts(
226
+ style, prompt, negative_prompt
227
+ )
228
+ if is_match:
229
+ found_style = style
230
+ prompt = new_prompt
231
+ negative_prompt = new_neg_prompt
232
+ break
233
+
234
+ if not found_style:
235
+ break
236
+
237
+ applicable_styles.remove(found_style)
238
+ extracted.append(found_style.name)
239
+
240
+ return list(reversed(extracted)), prompt, negative_prompt