shichen1231 hysts HF staff commited on
Commit
dc6b9bd
β€’
0 Parent(s):

Duplicate from AttendAndExcite/Attend-and-Excite

Browse files

Co-authored-by: hysts <[email protected]>

Files changed (11) hide show
  1. .gitattributes +34 -0
  2. .gitignore +163 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +37 -0
  5. .style.yapf +5 -0
  6. Attend-and-Excite +1 -0
  7. README.md +15 -0
  8. app.py +191 -0
  9. model.py +85 -0
  10. requirements.txt +7 -0
  11. style.css +3 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "Attend-and-Excite"]
2
+ path = Attend-and-Excite
3
+ url = https://github.com/AttendAndExcite/Attend-and-Excite
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Attend-and-Excite ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 1b67cfc19cd3952e390dbb8047ccd126471567f2
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Attend And Excite
3
+ emoji: πŸ’»
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.17.0
8
+ python_version: 3.10.9
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ duplicated_from: AttendAndExcite/Attend-and-Excite
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+ import PIL.Image
7
+
8
+ from model import Model
9
+
10
+ DESCRIPTION = '''# Attend-and-Excite
11
+ This is a demo for [Attend-and-Excite](https://arxiv.org/abs/2301.13826).
12
+ Attend-and-Excite performs attention-based generative semantic guidance to mitigate subject neglect in Stable Diffusion.
13
+ Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
14
+ '''
15
+
16
+ model = Model()
17
+
18
+
19
+ def process_example(
20
+ prompt: str,
21
+ indices_to_alter_str: str,
22
+ seed: int,
23
+ apply_attend_and_excite: bool,
24
+ ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
25
+ model_id = 'CompVis/stable-diffusion-v1-4'
26
+ num_steps = 50
27
+ guidance_scale = 7.5
28
+ return model.run(model_id, prompt, indices_to_alter_str, seed,
29
+ apply_attend_and_excite, num_steps, guidance_scale)
30
+
31
+
32
+ with gr.Blocks(css='style.css') as demo:
33
+ gr.Markdown(DESCRIPTION)
34
+
35
+ with gr.Row():
36
+ with gr.Column():
37
+ model_id = gr.Text(label='Model ID',
38
+ value='CompVis/stable-diffusion-v1-4',
39
+ visible=False)
40
+ prompt = gr.Text(
41
+ label='Prompt',
42
+ max_lines=1,
43
+ placeholder=
44
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background'
45
+ )
46
+ with gr.Accordion(label='Check token indices', open=False):
47
+ show_token_indices_button = gr.Button('Show token indices')
48
+ token_indices_table = gr.Dataframe(label='Token indices',
49
+ headers=['Index', 'Token'],
50
+ col_count=2)
51
+ token_indices_str = gr.Text(
52
+ label=
53
+ 'Token indices (a comma-separated list indices of the tokens you wish to alter)',
54
+ max_lines=1,
55
+ placeholder='4,16')
56
+ seed = gr.Slider(label='Seed',
57
+ minimum=0,
58
+ maximum=100000,
59
+ value=0,
60
+ step=1)
61
+ apply_attend_and_excite = gr.Checkbox(
62
+ label='Apply Attend-and-Excite', value=True)
63
+ num_steps = gr.Slider(label='Number of steps',
64
+ minimum=0,
65
+ maximum=100,
66
+ step=1,
67
+ value=50)
68
+ guidance_scale = gr.Slider(label='CFG scale',
69
+ minimum=0,
70
+ maximum=50,
71
+ step=0.1,
72
+ value=7.5)
73
+ run_button = gr.Button('Generate')
74
+ with gr.Column():
75
+ result = gr.Image(label='Result')
76
+
77
+ with gr.Row():
78
+ examples = [
79
+ [
80
+ 'A mouse and a red car',
81
+ '2,6',
82
+ 2098,
83
+ True,
84
+ ],
85
+ [
86
+ 'A mouse and a red car',
87
+ '2,6',
88
+ 2098,
89
+ False,
90
+ ],
91
+ [
92
+ 'A horse and a dog',
93
+ '2,5',
94
+ 123,
95
+ True,
96
+ ],
97
+ [
98
+ 'A horse and a dog',
99
+ '2,5',
100
+ 123,
101
+ False,
102
+ ],
103
+ [
104
+ 'A painting of an elephant with glasses',
105
+ '5,7',
106
+ 123,
107
+ True,
108
+ ],
109
+ [
110
+ 'A painting of an elephant with glasses',
111
+ '5,7',
112
+ 123,
113
+ False,
114
+ ],
115
+ [
116
+ 'A playful kitten chasing a butterfly in a wildflower meadow',
117
+ '3,6,10',
118
+ 123,
119
+ True,
120
+ ],
121
+ [
122
+ 'A playful kitten chasing a butterfly in a wildflower meadow',
123
+ '3,6,10',
124
+ 123,
125
+ False,
126
+ ],
127
+ [
128
+ 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
129
+ '2,6,15',
130
+ 123,
131
+ True,
132
+ ],
133
+ [
134
+ 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
135
+ '2,6,15',
136
+ 123,
137
+ False,
138
+ ],
139
+ [
140
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
141
+ '4,16',
142
+ 123,
143
+ True,
144
+ ],
145
+ [
146
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
147
+ '4,16',
148
+ 123,
149
+ False,
150
+ ],
151
+ ]
152
+ gr.Examples(examples=examples,
153
+ inputs=[
154
+ prompt,
155
+ token_indices_str,
156
+ seed,
157
+ apply_attend_and_excite,
158
+ ],
159
+ outputs=[
160
+ token_indices_table,
161
+ result,
162
+ ],
163
+ fn=process_example,
164
+ cache_examples=True,
165
+ examples_per_page=20)
166
+
167
+ show_token_indices_button.click(fn=model.get_token_table,
168
+ inputs=[
169
+ model_id,
170
+ prompt,
171
+ ],
172
+ outputs=token_indices_table)
173
+
174
+ inputs = [
175
+ model_id,
176
+ prompt,
177
+ token_indices_str,
178
+ seed,
179
+ apply_attend_and_excite,
180
+ num_steps,
181
+ guidance_scale,
182
+ ]
183
+ outputs = [
184
+ token_indices_table,
185
+ result,
186
+ ]
187
+ prompt.submit(fn=model.run, inputs=inputs, outputs=outputs)
188
+ token_indices_str.submit(fn=model.run, inputs=inputs, outputs=outputs)
189
+ run_button.click(fn=model.run, inputs=inputs, outputs=outputs)
190
+
191
+ demo.queue(max_size=50).launch(share=False)
model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+
5
+ import gradio as gr
6
+ import PIL.Image
7
+ import torch
8
+
9
+ sys.path.append('Attend-and-Excite')
10
+
11
+ from config import RunConfig
12
+ from pipeline_attend_and_excite import AttendAndExcitePipeline
13
+ from run import run_on_prompt
14
+ from utils.ptp_utils import AttentionStore
15
+
16
+
17
+ class Model:
18
+ def __init__(self):
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.model_id = ''
22
+ self.model = None
23
+ self.tokenizer = None
24
+
25
+ self.load_model('CompVis/stable-diffusion-v1-4')
26
+
27
+ def load_model(self, model_id: str) -> None:
28
+ if model_id == self.model_id:
29
+ return
30
+ self.model = AttendAndExcitePipeline.from_pretrained(model_id).to(
31
+ self.device)
32
+ self.tokenizer = self.model.tokenizer
33
+ self.model_id = model_id
34
+
35
+ def get_token_table(self, model_id: str, prompt: str):
36
+ self.load_model(model_id)
37
+ tokens = [
38
+ self.tokenizer.decode(t)
39
+ for t in self.tokenizer(prompt)['input_ids']
40
+ ]
41
+ tokens = tokens[1:-1]
42
+ return list(enumerate(tokens, start=1))
43
+
44
+ def run(
45
+ self,
46
+ model_id: str,
47
+ prompt: str,
48
+ indices_to_alter_str: str,
49
+ seed: int,
50
+ apply_attend_and_excite: bool,
51
+ num_steps: int,
52
+ guidance_scale: float,
53
+ scale_factor: int = 20,
54
+ thresholds: dict[int, float] = {
55
+ 10: 0.5,
56
+ 20: 0.8
57
+ },
58
+ max_iter_to_alter: int = 25,
59
+ ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
60
+ generator = torch.Generator(device=self.device).manual_seed(seed)
61
+ try:
62
+ indices_to_alter = list(map(int, indices_to_alter_str.split(',')))
63
+ except:
64
+ raise gr.Error('Invalid token indices.')
65
+
66
+ self.load_model(model_id)
67
+
68
+ token_table = self.get_token_table(model_id, prompt)
69
+
70
+ controller = AttentionStore()
71
+ config = RunConfig(prompt=prompt,
72
+ n_inference_steps=num_steps,
73
+ guidance_scale=guidance_scale,
74
+ run_standard_sd=not apply_attend_and_excite,
75
+ scale_factor=scale_factor,
76
+ thresholds=thresholds,
77
+ max_iter_to_alter=max_iter_to_alter)
78
+ image = run_on_prompt(model=self.model,
79
+ prompt=[prompt],
80
+ controller=controller,
81
+ token_indices=indices_to_alter,
82
+ seed=generator,
83
+ config=config)
84
+
85
+ return token_table, image
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers==0.3.0
2
+ ftfy==6.1.1
3
+ jupyter
4
+ opencv-python-headless==4.7.0.68
5
+ pyrallis==0.3.1
6
+ torch==1.13.1
7
+ transformers==4.23.1
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }