Gordonkl commited on
Commit
8c72aec
1 Parent(s): 5d83894

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +7 -7
  2. app.py +381 -0
  3. gitattributes +38 -0
  4. requirements.txt +18 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: TEXT
3
- emoji: 🏢
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.3.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: CSGO
3
+ emoji: 🏔️
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: Content-Style Composition (GoGoGo)
12
  ---
 
 
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ import spaces
4
+ import gradio as gr
5
+ import torch
6
+ from ip_adapter.utils import BLOCKS as BLOCKS
7
+ from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
8
+ from ip_adapter.utils import resize_content
9
+ import cv2
10
+ import numpy as np
11
+ import random
12
+ from PIL import Image
13
+ from transformers import AutoImageProcessor, AutoModel
14
+ from diffusers import (
15
+ AutoencoderKL,
16
+ ControlNetModel,
17
+ StableDiffusionXLControlNetPipeline,
18
+
19
+ )
20
+ from ip_adapter import CSGO
21
+ from transformers import BlipProcessor, BlipForConditionalGeneration
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
25
+ import os
26
+ os.system("git lfs install")
27
+ os.system("git clone https://huggingface.co/h94/IP-Adapter")
28
+ os.system("mv IP-Adapter/sdxl_models sdxl_models")
29
+
30
+ from huggingface_hub import hf_hub_download
31
+
32
+ # hf_hub_download(repo_id="h94/IP-Adapter", filename="sdxl_models/image_encoder", local_dir="./sdxl_models/image_encoder")
33
+ hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/")
34
+ os.system('rm -rf IP-Adapter/models')
35
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
36
+ image_encoder_path = "sdxl_models/image_encoder"
37
+ csgo_ckpt ='./CSGO/csgo_4_32.bin'
38
+ pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix'
39
+ controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic"
40
+ weight_dtype = torch.float16
41
+
42
+
43
+ os.system("git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic")
44
+ os.system("mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors")
45
+ os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors')
46
+ os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors')
47
+ controlnet_path = "./TTPLanet_SDXL_Controlnet_Tile_Realistic"
48
+
49
+
50
+ # os.system('git clone https://huggingface.co/InstantX/CSGO')
51
+ # os.system('rm -rf CSGO/csgo.bin')
52
+
53
+
54
+
55
+ vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
56
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
57
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
58
+ base_model_path,
59
+ controlnet=controlnet,
60
+ torch_dtype=torch.float16,
61
+ add_watermarker=False,
62
+ vae=vae
63
+ )
64
+ pipe.enable_vae_tiling()
65
+
66
+
67
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
68
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
69
+
70
+ target_content_blocks = BLOCKS['content']
71
+ target_style_blocks = BLOCKS['style']
72
+ controlnet_target_content_blocks = controlnet_BLOCKS['content']
73
+ controlnet_target_style_blocks = controlnet_BLOCKS['style']
74
+
75
+ csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4, num_style_tokens=32,
76
+ target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,
77
+ controlnet_adapter=True,
78
+ controlnet_target_content_blocks=controlnet_target_content_blocks,
79
+ controlnet_target_style_blocks=controlnet_target_style_blocks,
80
+ content_model_resampler=True,
81
+ style_model_resampler=True,
82
+ )
83
+
84
+ MAX_SEED = np.iinfo(np.int32).max
85
+
86
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
87
+ if randomize_seed:
88
+ seed = random.randint(0, MAX_SEED)
89
+ return seed
90
+
91
+
92
+
93
+
94
+
95
+ def get_example():
96
+ case = [
97
+ [
98
+ "./assets/img_0.png",
99
+ './assets/img_1.png',
100
+ "Image-Driven Style Transfer",
101
+ "there is a small house with a sheep statue on top of it",
102
+ 0.6,
103
+ 1.0,
104
+ 7.0,
105
+ 42
106
+ ],
107
+ [
108
+ None,
109
+ './assets/img_1.png',
110
+ "Text-Driven Style Synthesis",
111
+ "a cat",
112
+ 0.01,
113
+ 1.0,
114
+ 7.0,
115
+ 42
116
+ ],
117
+ [
118
+ None,
119
+ './assets/img_2.png',
120
+ "Text-Driven Style Synthesis",
121
+ "a cat",
122
+ 0.01,
123
+ 1.0,
124
+ 7.0,
125
+ 42,
126
+ ],
127
+ [
128
+ "./assets/img_0.png",
129
+ './assets/img_1.png',
130
+ "Text Edit-Driven Style Synthesis",
131
+ "there is a small house",
132
+ 0.4,
133
+ 1.0,
134
+ 7.0,
135
+ 42,
136
+ ],
137
+ ]
138
+ return case
139
+
140
+
141
+ def run_for_examples(content_image_pil,style_image_pil,target, prompt, scale_c, scale_s,guidance_scale,seed):
142
+ return create_image(
143
+ content_image_pil=content_image_pil,
144
+ style_image_pil=style_image_pil,
145
+ prompt=prompt,
146
+ scale_c=scale_c,
147
+ scale_s=scale_s,
148
+ guidance_scale=guidance_scale,
149
+ num_samples=2,
150
+ num_inference_steps=50,
151
+ seed=seed,
152
+ target=target,
153
+ )
154
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
155
+ if randomize_seed:
156
+ seed = random.randint(0, MAX_SEED)
157
+ return seed
158
+
159
+ def image_grid(imgs, rows, cols):
160
+ assert len(imgs) == rows * cols
161
+
162
+ w, h = imgs[0].size
163
+ grid = Image.new('RGB', size=(cols * w, rows * h))
164
+ grid_w, grid_h = grid.size
165
+
166
+ for i, img in enumerate(imgs):
167
+ grid.paste(img, box=(i % cols * w, i // cols * h))
168
+ return grid
169
+ @spaces.GPU
170
+ def create_image(content_image_pil,
171
+ style_image_pil,
172
+ prompt,
173
+ scale_c,
174
+ scale_s,
175
+ guidance_scale,
176
+ num_samples,
177
+ num_inference_steps,
178
+ seed,
179
+ target="Image-Driven Style Transfer",
180
+ ):
181
+
182
+
183
+ if content_image_pil is None:
184
+ content_image_pil = Image.fromarray(
185
+ np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
186
+
187
+ if prompt == '':
188
+
189
+ inputs = blip_processor(content_image_pil, return_tensors="pt").to(device)
190
+ out = blip_model.generate(**inputs)
191
+ prompt = blip_processor.decode(out[0], skip_special_tokens=True)
192
+ width, height, content_image = resize_content(content_image_pil)
193
+ style_image = style_image_pil
194
+ neg_content_prompt='text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry'
195
+ if target =="Image-Driven Style Transfer":
196
+
197
+ images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
198
+ prompt=prompt,
199
+ negative_prompt=neg_content_prompt,
200
+ height=height,
201
+ width=width,
202
+ content_scale=1.0,
203
+ style_scale=scale_s,
204
+ guidance_scale=guidance_scale,
205
+ num_images_per_prompt=num_samples,
206
+ num_inference_steps=num_inference_steps,
207
+ num_samples=1,
208
+ seed=seed,
209
+ image=content_image.convert('RGB'),
210
+ controlnet_conditioning_scale=scale_c,
211
+ )
212
+
213
+ elif target =="Text-Driven Style Synthesis":
214
+ content_image = Image.fromarray(
215
+ np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
216
+
217
+ images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
218
+ prompt=prompt,
219
+ negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
220
+ height=height,
221
+ width=width,
222
+ content_scale=0.5,
223
+ style_scale=scale_s,
224
+ guidance_scale=7,
225
+ num_images_per_prompt=num_samples,
226
+ num_inference_steps=num_inference_steps,
227
+ num_samples=1,
228
+ seed=42,
229
+ image=content_image.convert('RGB'),
230
+ controlnet_conditioning_scale=scale_c,
231
+ )
232
+ elif target =="Text Edit-Driven Style Synthesis":
233
+
234
+
235
+ images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
236
+ prompt=prompt,
237
+ negative_prompt=neg_content_prompt,
238
+ height=height,
239
+ width=width,
240
+ content_scale=1.0,
241
+ style_scale=scale_s,
242
+ guidance_scale=guidance_scale,
243
+ num_images_per_prompt=num_samples,
244
+ num_inference_steps=num_inference_steps,
245
+ num_samples=1,
246
+ seed=seed,
247
+ image=content_image.convert('RGB'),
248
+ controlnet_conditioning_scale=scale_c,
249
+ )
250
+
251
+ return [image_grid(images, 1, num_samples)]
252
+
253
+
254
+ def pil_to_cv2(image_pil):
255
+ image_np = np.array(image_pil)
256
+ image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
257
+ return image_cv2
258
+
259
+
260
+ # Description
261
+ title = r"""
262
+ <h1 align="center">CSGO: Content-Style Composition in Text-to-Image Generation</h1>
263
+ """
264
+
265
+ description = r"""
266
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br>
267
+ How to use:<br>
268
+ 1. Upload a content image if you want to use image-driven style transfer.
269
+ 2. Upload a style image.
270
+ 3. Sets the type of task to perform, by default image-driven style transfer is performed. Options are <b>Image-driven style transfer, Text-driven style synthesis, and Text editing-driven style synthesis<b>.
271
+ 4. <b>If you choose a text-driven task, enter your desired prompt<b>.
272
+ 5. If you don't provide a prompt, the default is to use the BLIP model to generate the caption. We suggest that by providing detailed prompts for Content images, CSGO is able to effectively guarantee content.
273
+ 6. Click the <b>Submit</b> button to begin customization.
274
+ 7. Share your stylized photo with your friends and enjoy! 😊
275
+
276
+ Advanced usage:<br>
277
+ 1. Click advanced options.
278
+ 2. Choose different guidance and steps.
279
+ """
280
+
281
+ article = r"""
282
+ ---
283
+ 📝 **Tips**
284
+ In CSGO, the more accurate the text prompts for content images, the better the content retention.
285
+ Text-driven style synthesis and text-edit-driven style synthesis are expected to be more stable in the next release.
286
+ ---
287
+ 📝 **Citation**
288
+ <br>
289
+ If our work is helpful for your research or applications, please cite us via:
290
+ ```bibtex
291
+ @article{xing2024csgo,
292
+ title={CSGO: Content-Style Composition in Text-to-Image Generation},
293
+ author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
294
+ year={2024},
295
+ journal = {arXiv 2408.16766},
296
+ }
297
+ ```
298
+ 📧 **Contact**
299
+ <br>
300
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
301
+ """
302
+
303
+ block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
304
+ with block:
305
+ # description
306
+ gr.Markdown(title)
307
+ gr.Markdown(description)
308
+
309
+ with gr.Tabs():
310
+ with gr.Row():
311
+ with gr.Column():
312
+ with gr.Row():
313
+ with gr.Column():
314
+ content_image_pil = gr.Image(label="Content Image (optional)", type='pil')
315
+ style_image_pil = gr.Image(label="Style Image", type='pil')
316
+
317
+ target = gr.Radio(["Image-Driven Style Transfer", "Text-Driven Style Synthesis", "Text Edit-Driven Style Synthesis"],
318
+ value="Image-Driven Style Transfer",
319
+ label="task")
320
+
321
+ # prompt_type = gr.Radio(["caption of Blip", "user input"],
322
+ # value="caption of Blip",
323
+ # label="prompt type")
324
+
325
+ prompt = gr.Textbox(label="Prompt",
326
+ value="there is a small house with a sheep statue on top of it")
327
+ prompt_type = gr.CheckboxGroup(
328
+ ["caption of Blip", "user input"], label="prompt_type", value=["caption of Blip"],
329
+ info="Choose to enter more detailed prompts yourself or use the blip model to describe content images."
330
+ )
331
+ if prompt_type == "caption of Blip" and target == "Image-Driven Style Transfer":
332
+ prompt =''
333
+
334
+ scale_c = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6, label="Content Scale")
335
+ scale_s = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=1.0, label="Style Scale")
336
+ with gr.Accordion(open=False, label="Advanced Options"):
337
+
338
+ guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label="guidance scale")
339
+ num_samples = gr.Slider(minimum=1, maximum=4.0, step=1.0, value=1.0, label="num samples")
340
+ num_inference_steps = gr.Slider(minimum=5, maximum=100.0, step=1.0, value=50,
341
+ label="num inference steps")
342
+ seed = gr.Slider(minimum=-1000000, maximum=1000000, value=1, step=1, label="Seed Value")
343
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
344
+
345
+ generate_button = gr.Button("Generate Image")
346
+
347
+ with gr.Column():
348
+ generated_image = gr.Gallery(label="Generated Image")
349
+
350
+ generate_button.click(
351
+ fn=randomize_seed_fn,
352
+ inputs=[seed, randomize_seed],
353
+ outputs=seed,
354
+ queue=False,
355
+ api_name=False,
356
+ ).then(
357
+ fn=create_image,
358
+ inputs=[content_image_pil,
359
+ style_image_pil,
360
+ prompt,
361
+ scale_c,
362
+ scale_s,
363
+ guidance_scale,
364
+ num_samples,
365
+ num_inference_steps,
366
+ seed,
367
+ target,],
368
+ outputs=[generated_image])
369
+
370
+ gr.Examples(
371
+ examples=get_example(),
372
+ inputs=[content_image_pil,style_image_pil,target, prompt, scale_c, scale_s,guidance_scale,seed],
373
+ fn=run_for_examples,
374
+ outputs=[generated_image],
375
+ cache_examples=False,
376
+ )
377
+
378
+ gr.Markdown(article)
379
+
380
+
381
+ block.launch()
gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.webp filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.25.1
2
+ torch==2.0.1
3
+ torchaudio==2.0.2
4
+ torchvision==0.15.2
5
+ transformers==4.40.2
6
+ accelerate
7
+ safetensors
8
+ einops
9
+ spaces==0.19.4
10
+ omegaconf
11
+ peft
12
+ huggingface-hub==0.24.5
13
+ opencv-python
14
+ insightface
15
+ gradio
16
+ controlnet_aux
17
+ gdown
18
+ peft