File size: 8,402 Bytes
77771e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import os
import re
import subprocess
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
DIFFUSERS_PATH = "src/diffusers"
REPO_PATH = "."
def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_diffusers(object_name):
"""Find and return the code source code of `object_name`."""
parts = object_name.split(".")
i = 0
# First let's find the module where our object lives.
module = parts[i]
while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")):
i += 1
if i < len(parts):
module = os.path.join(module, parts[i])
if i >= len(parts):
raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.")
with open(
os.path.join(DIFFUSERS_PATH, f"{module}.py"),
"r",
encoding="utf-8",
newline="\n",
) as f:
lines = f.readlines()
# Now let's find the class / func in the code!
indent = ""
line_index = 0
for name in parts[i + 1 :]:
while (
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
):
line_index += 1
indent += " "
line_index += 1
if line_index >= len(lines):
raise ValueError(f" {object_name} does not match any function or class in {module}.")
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
start_index = line_index
while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
code_lines = lines[start_index:line_index]
return "".join(code_lines)
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
def get_indent(code):
lines = code.split("\n")
idx = 0
while idx < len(lines) and len(lines[idx]) == 0:
idx += 1
if idx < len(lines):
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
return ""
def run_ruff(code):
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()
def stylify(code: str) -> str:
"""
Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`.
As `ruff` does not provide a python api this cannot be done on the fly.
Args:
code (`str`): The code to format.
Returns:
`str`: The formatted code.
"""
has_indent = len(get_indent(code)) > 0
if has_indent:
code = f"class Bla:\n{code}"
formatted_code = run_ruff(code)
return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
def is_copy_consistent(filename, overwrite=False):
"""
Check if the code commented as a copy in `filename` matches the original.
Return the differences or overwrites the content depending on `overwrite`.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
diffs = []
line_index = 0
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines):
search = _re_copy_warning.search(lines[line_index])
if search is None:
line_index += 1
continue
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
theoretical_code = find_code_in_diffusers(object_name)
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
indent = theoretical_indent
line_index = start_index
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
should_continue = True
while line_index < len(lines) and should_continue:
line_index += 1
if line_index >= len(lines):
break
line = lines[line_index]
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
observed_code_lines = lines[start_index:line_index]
observed_code = "".join(observed_code_lines)
# Remove any nested `Copied from` comments to avoid circular copies
theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None]
theoretical_code = "\n".join(theoretical_code)
# Before comparing, use the `replace_pattern` on the original code.
if len(replace_pattern) > 0:
patterns = replace_pattern.replace("with", "").split(",")
patterns = [_re_replace_pattern.search(p) for p in patterns]
for pattern in patterns:
if pattern is None:
continue
obj1, obj2, option = pattern.groups()
theoretical_code = re.sub(obj1, obj2, theoretical_code)
if option.strip() == "all-casing":
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
# stylify after replacement. To be able to do that, we need the header (class or function definition)
# from the previous line
theoretical_code = stylify(lines[start_index - 1] + theoretical_code)
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
# Test for a diff and act accordingly.
if observed_code != theoretical_code:
diffs.append([object_name, start_index])
if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1
if overwrite and len(diffs) > 0:
# Warn the user a file has been modified.
print(f"Detected changes, rewriting {filename}.")
with open(filename, "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines)
return diffs
def check_copies(overwrite: bool = False):
all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True)
diffs = []
for filename in all_files:
new_diffs = is_copy_consistent(filename, overwrite)
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs)
raise Exception(
"Found the following copy inconsistencies:\n"
+ diff
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--fix_and_overwrite",
action="store_true",
help="Whether to fix inconsistencies.",
)
args = parser.parse_args()
check_copies(args.fix_and_overwrite)
|