jennysun commited on
Commit
5037c92
β€’
1 Parent(s): 81ba850

Automating grounding instruction step so users type in input and model automatically parses subjects for them

Browse files
Files changed (1) hide show
  1. app.py +233 -204
app.py CHANGED
@@ -1,118 +1,91 @@
1
  import gradio as gr
2
  import torch
 
3
  from omegaconf import OmegaConf
4
- from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
 
5
 
6
  import json
7
  import numpy as np
8
  from PIL import Image, ImageDraw, ImageFont
9
  from functools import partial
10
- from collections import Counter
11
  import math
12
- import gc
13
 
14
  from gradio import processing_utils
15
  from typing import Optional
16
 
17
- import warnings
18
-
19
- from datetime import datetime
20
-
21
  from huggingface_hub import hf_hub_download
22
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
23
 
24
- import sys
25
- sys.tracebacklimit = 0
26
-
27
-
28
- def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
29
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return torch.load(cache_file, map_location='cpu')
31
 
32
  def load_ckpt_config_from_hf(modality):
33
- ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
34
- config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
35
  return ckpt, config
36
 
37
 
38
- def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
39
- pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
40
  config = OmegaConf.create( config["_content"] ) # config used in training
41
- config.alpha_scale = 1.0
42
- config.model['params']['is_inpaint'] = is_inpaint
43
- config.model['params']['is_style'] = is_style
44
-
45
- if common_instances is None:
46
- common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
47
- common_instances = load_common_ckpt(config, common_ckpt)
48
-
49
- loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
50
 
51
- return loaded_model_list, common_instances
52
 
 
 
 
 
 
 
 
53
 
54
- class Instance:
55
- def __init__(self, capacity = 2):
56
- self.model_type = 'base'
57
- self.loaded_model_list = {}
58
- self.counter = Counter()
59
- self.global_counter = Counter()
60
- self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
61
- 'gligen-generation-text-box',
62
- is_inpaint=False, is_style=False, common_instances=None
63
- )
64
- self.capacity = capacity
65
-
66
- def _log(self, model_type, batch_size, instruction, phrase_list):
67
- self.counter[model_type] += 1
68
- self.global_counter[model_type] += 1
69
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
70
- print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
71
- current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
72
- ))
73
-
74
- def get_model(self, model_type, batch_size, instruction, phrase_list):
75
- if model_type in self.loaded_model_list:
76
- self._log(model_type, batch_size, instruction, phrase_list)
77
- return self.loaded_model_list[model_type]
78
-
79
- if self.capacity == len(self.loaded_model_list):
80
- least_used_type = self.counter.most_common()[-1][0]
81
- del self.loaded_model_list[least_used_type]
82
- del self.counter[least_used_type]
83
- gc.collect()
84
- torch.cuda.empty_cache()
85
-
86
- self.loaded_model_list[model_type] = self._get_model(model_type)
87
- self._log(model_type, batch_size, instruction, phrase_list)
88
- return self.loaded_model_list[model_type]
89
-
90
- def _get_model(self, model_type):
91
- if model_type == 'base':
92
- return ckpt_load_helper(
93
- 'gligen-generation-text-box',
94
- is_inpaint=False, is_style=False, common_instances=self.common_instances
95
- )[0]
96
- elif model_type == 'inpaint':
97
- return ckpt_load_helper(
98
- 'gligen-inpainting-text-box',
99
- is_inpaint=True, is_style=False, common_instances=self.common_instances
100
- )[0]
101
- elif model_type == 'style':
102
- return ckpt_load_helper(
103
- 'gligen-generation-text-image-box',
104
- is_inpaint=False, is_style=True, common_instances=self.common_instances
105
- )[0]
106
-
107
- assert False
108
 
109
- instance = Instance()
 
 
 
 
 
 
110
 
111
 
112
  def load_clip_model():
113
  from transformers import CLIPProcessor, CLIPModel
114
  version = "openai/clip-vit-large-patch14"
115
- model = CLIPModel.from_pretrained(version).cuda()
116
  processor = CLIPProcessor.from_pretrained(version)
117
 
118
  return {
@@ -162,11 +135,10 @@ class Blocks(gr.Blocks):
162
  self.extra_configs = {
163
  'thumbnail': kwargs.pop('thumbnail', ''),
164
  'url': kwargs.pop('url', 'https://gradio.app/'),
165
- 'creator': kwargs.pop('creator', '@teamGradio'),
166
  }
167
 
168
  super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
169
- warnings.filterwarnings("ignore")
170
 
171
  def get_config_file(self):
172
  config = super(Blocks, self).get_config_file()
@@ -232,21 +204,17 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
232
  inpainting_boxes_nodrop = inpainting_boxes_nodrop,
233
  )
234
 
235
- get_model = partial(instance.get_model,
236
- batch_size=batch_size,
237
- instruction=language_instruction,
238
- phrase_list=phrase_list)
239
-
240
- with torch.autocast(device_type='cuda', dtype=torch.float16):
241
  if task == 'Grounded Generation':
242
  if style_image == None:
243
- return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
244
  else:
245
- return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
246
  elif task == 'Grounded Inpainting':
247
  assert image is not None
248
  instruction['input_image'] = image.convert("RGB")
249
- return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
250
 
251
 
252
  def draw_box(boxes=[], texts=[], img=None):
@@ -283,11 +251,9 @@ def auto_append_grounding(language_instruction, grounding_texts):
283
  for grounding_text in grounding_texts:
284
  if grounding_text not in language_instruction and grounding_text != 'auto':
285
  language_instruction += "; " + grounding_text
 
286
  return language_instruction
287
 
288
-
289
-
290
-
291
  def generate(task, language_instruction, grounding_texts, sketch_pad,
292
  alpha_sample, guidance_scale, batch_size,
293
  fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
@@ -297,14 +263,7 @@ def generate(task, language_instruction, grounding_texts, sketch_pad,
297
 
298
  boxes = state['boxes']
299
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
300
- # assert len(boxes) == len(grounding_texts)
301
- if len(boxes) != len(grounding_texts):
302
- if len(boxes) < len(grounding_texts):
303
- raise ValueError("""The number of boxes should be equal to the number of grounding objects.
304
- Number of boxes drawn: {}, number of grounding tokens: {}.
305
- Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
306
- grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
307
-
308
  boxes = (np.asarray(boxes) / 512).tolist()
309
  grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
310
 
@@ -488,19 +447,30 @@ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
488
  return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
489
 
490
  css = """
491
- #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
 
 
 
 
 
 
 
 
 
 
492
  {
493
  height: var(--height) !important;
494
  max-height: var(--height) !important;
495
  min-height: var(--height) !important;
496
  }
 
 
 
497
  #paper-info a {
498
  color:#008AD7;
499
- text-decoration: none;
500
  }
501
  #paper-info a:hover {
502
  cursor: pointer;
503
- text-decoration: none;
504
  }
505
  """
506
 
@@ -517,28 +487,95 @@ function(x) {
517
  }
518
  """
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  with Blocks(
521
  css=css,
522
  analytics_enabled=False,
523
  title="GLIGen demo",
524
  ) as main:
525
- description = """<p style="text-align: center; font-weight: bold;">
526
- <span style="font-size: 28px">GLIGen: Open-Set Grounded Text-to-Image Generation</span>
527
- <br>
528
- <span style="font-size: 18px" id="paper-info">
529
- [<a href="https://gligen.github.io" target="_blank">Project Page</a>]
530
- [<a href="https://arxiv.org/abs/2301.07093" target="_blank">Paper</a>]
531
- [<a href="https://github.com/gligen/GLIGEN" target="_blank">GitHub</a>]
532
- </span>
533
- </p>
534
- <p>
535
- To ground concepts of interest with desired spatial specification, please (1) &#9000;&#65039; enter the concept names in <em> Grounding Instruction</em>, and (2) &#128433;&#65039; draw their corresponding bounding boxes one by one using <em> Sketch Pad</em> -- the parsed boxes will be displayed automatically.
536
- <br>
537
- For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/gligen/demo?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>
538
- </p>
539
- """
540
- gr.HTML(description)
541
-
542
  with gr.Row():
543
  with gr.Column(scale=4):
544
  sketch_pad_trigger = gr.Number(value=0, visible=False)
@@ -547,45 +584,89 @@ with Blocks(
547
  image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
548
  new_image_trigger = gr.Number(value=0, visible=False)
549
 
 
550
  task = gr.Radio(
551
  choices=["Grounded Generation", 'Grounded Inpainting'],
552
  type="value",
553
  value="Grounded Generation",
554
  label="Task",
555
  )
556
- language_instruction = gr.Textbox(
557
- label="Language instruction",
558
- )
559
- grounding_instruction = gr.Textbox(
560
- label="Grounding instruction (Separated by semicolon)",
561
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  with gr.Row():
563
  sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
564
  out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
565
  with gr.Row():
566
  clear_btn = gr.Button(value='Clear')
567
- gen_btn = gr.Button(value='Generate')
568
  with gr.Accordion("Advanced Options", open=False):
569
  with gr.Column():
570
- alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (Ο„)")
571
- guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
572
- batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=2, label="Number of Samples")
573
- append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
574
  use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
575
  with gr.Row():
576
- fix_seed = gr.Checkbox(value=True, label="Fixed seed")
577
- rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
578
  with gr.Row():
579
- use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition")
580
- style_cond_image = gr.Image(type="pil", label="Style Condition", visible=False, interactive=True)
581
  with gr.Column(scale=4):
582
- gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
583
  with gr.Row():
584
  out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
585
  out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
586
  with gr.Row():
587
- out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
588
- out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
589
 
590
  state = gr.State({})
591
 
@@ -687,6 +768,14 @@ with Blocks(
687
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
688
  queue=True
689
  )
 
 
 
 
 
 
 
 
690
  sketch_pad_resize_trigger.change(
691
  None,
692
  None,
@@ -710,65 +799,5 @@ with Blocks(
710
  outputs=[use_style_cond, style_cond_image, alpha_sample, use_actual_mask],
711
  queue=False)
712
 
713
- with gr.Column():
714
- gr.Examples(
715
- examples=[
716
- [
717
- "images/blank.png",
718
- "Grounded Generation",
719
- "a dog and an apple",
720
- "a dog;an apple",
721
- ],
722
- [
723
- "images/blank.png",
724
- "Grounded Generation",
725
- "John Lennon is using a pc",
726
- "John Lennon;a pc",
727
- [
728
- "images/blank.png",
729
- "Grounded Generation",
730
- "a painting of a fox sitting in a field at sunrise in the style of Claude Mone",
731
- "fox;sunrise",
732
- ],
733
- ],
734
- [
735
- "images/blank.png",
736
- "Grounded Generation",
737
- "a beautiful painting of hot dog by studio ghibli, octane render, brilliantly coloured",
738
- "hot dog",
739
- ],
740
- [
741
- "images/blank.png",
742
- "Grounded Generation",
743
- "a sport car, unreal engine, global illumination, ray tracing",
744
- "a sport car",
745
- ],
746
- [
747
- "images/flower_beach.jpg",
748
- "Grounded Inpainting",
749
- "a squirrel and the space needle",
750
- "a squirrel;the space needle",
751
- ],
752
- [
753
- "images/arg_corgis.jpeg",
754
- "Grounded Inpainting",
755
- "a dog and a birthday cake",
756
- "a dog; a birthday cake",
757
- ],
758
- [
759
- "images/teddy.jpg",
760
- "Grounded Inpainting",
761
- "a teddy bear wearing a santa claus red shirt; holding a Christmas gift box on hand",
762
- "a santa claus shirt; a Christmas gift box",
763
- ],
764
- ],
765
- inputs=[sketch_pad, task, language_instruction, grounding_instruction],
766
- outputs=None,
767
- fn=None,
768
- cache_examples=False,
769
- )
770
-
771
  main.queue(concurrency_count=1, api_open=False)
772
- main.launch(share=False, show_api=False, show_error=True)
773
-
774
-
 
1
  import gradio as gr
2
  import torch
3
+ import argparse
4
  from omegaconf import OmegaConf
5
+ from gligen.task_grounded_generation import grounded_generation_box, load_ckpt
6
+ from ldm.util import default_device
7
 
8
  import json
9
  import numpy as np
10
  from PIL import Image, ImageDraw, ImageFont
11
  from functools import partial
 
12
  import math
13
+ from contextlib import nullcontext
14
 
15
  from gradio import processing_utils
16
  from typing import Optional
17
 
 
 
 
 
18
  from huggingface_hub import hf_hub_download
19
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
20
 
21
+ import openai
22
+ from gradio.components import Textbox, Text
23
+ import os
24
+
25
+ arg_bool = lambda x: x.lower() == 'true'
26
+ device = default_device()
27
+
28
+ print(f"GLIGEN uses {device.upper()} device.")
29
+ if device == "cpu":
30
+ print("It will be sloooow. Consider using GPU support with CUDA or (in case of M1/M2 Apple Silicon) MPS.")
31
+ elif device == "mps":
32
+ print("The fastest you can get on M1/2 Apple Silicon. Yet, still many opimizations are switched off and it will is much slower than CUDA.")
33
+
34
+ def parse_option():
35
+ parser = argparse.ArgumentParser('GLIGen Demo', add_help=False)
36
+ parser.add_argument("--folder", type=str, default="create_samples", help="path to OUTPUT")
37
+ parser.add_argument("--official_ckpt", type=str, default='ckpts/sd-v1-4.ckpt', help="")
38
+ parser.add_argument("--guidance_scale", type=float, default=5, help="")
39
+ parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
40
+ parser.add_argument("--load-text-box-generation", type=arg_bool, default=True, help="Load text-box generation pipeline.")
41
+ parser.add_argument("--load-text-box-inpainting", type=arg_bool, default=False, help="Load text-box inpainting pipeline.")
42
+ parser.add_argument("--load-text-image-box-generation", type=arg_bool, default=False, help="Load text-image-box generation pipeline.")
43
+ args = parser.parse_args()
44
+ return args
45
+ args = parse_option()
46
+
47
+
48
+ def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin'):
49
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
50
  return torch.load(cache_file, map_location='cpu')
51
 
52
  def load_ckpt_config_from_hf(modality):
53
+ ckpt = load_from_hf(f'gligen/{modality}')
54
+ config = load_from_hf('gligen/demo_config_legacy', filename=f'{modality}.pth')
55
  return ckpt, config
56
 
57
 
58
+ if args.load_text_box_generation:
59
+ pretrained_ckpt_gligen, config = load_ckpt_config_from_hf('gligen-generation-text-box')
60
  config = OmegaConf.create( config["_content"] ) # config used in training
61
+ config.update( vars(args) )
62
+ config.model['params']['is_inpaint'] = False
63
+ config.model['params']['is_style'] = False
64
+ loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen)
 
 
 
 
 
65
 
 
66
 
67
+ if args.load_text_box_inpainting:
68
+ pretrained_ckpt_gligen_inpaint, config = load_ckpt_config_from_hf('gligen-inpainting-text-box')
69
+ config = OmegaConf.create( config["_content"] ) # config used in training
70
+ config.update( vars(args) )
71
+ config.model['params']['is_inpaint'] = True
72
+ config.model['params']['is_style'] = False
73
+ loaded_model_list_inpaint = load_ckpt(config, pretrained_ckpt_gligen_inpaint)
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ if args.load_text_image_box_generation:
77
+ pretrained_ckpt_gligen_style, config = load_ckpt_config_from_hf('gligen-generation-text-image-box')
78
+ config = OmegaConf.create( config["_content"] ) # config used in training
79
+ config.update( vars(args) )
80
+ config.model['params']['is_inpaint'] = False
81
+ config.model['params']['is_style'] = True
82
+ loaded_model_list_style = load_ckpt(config, pretrained_ckpt_gligen_style)
83
 
84
 
85
  def load_clip_model():
86
  from transformers import CLIPProcessor, CLIPModel
87
  version = "openai/clip-vit-large-patch14"
88
+ model = CLIPModel.from_pretrained(version).to(device)
89
  processor = CLIPProcessor.from_pretrained(version)
90
 
91
  return {
 
135
  self.extra_configs = {
136
  'thumbnail': kwargs.pop('thumbnail', ''),
137
  'url': kwargs.pop('url', 'https://gradio.app/'),
138
+ 'creator': kwargs.pop('creator', 'Jenny Sun'),
139
  }
140
 
141
  super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
 
142
 
143
  def get_config_file(self):
144
  config = super(Blocks, self).get_config_file()
 
204
  inpainting_boxes_nodrop = inpainting_boxes_nodrop,
205
  )
206
 
207
+ # float16 autocasting only CUDA device
208
+ with torch.autocast(device_type='cuda', dtype=torch.float16) if device == "cuda" else nullcontext():
 
 
 
 
209
  if task == 'Grounded Generation':
210
  if style_image == None:
211
+ return grounded_generation_box(loaded_model_list, instruction, *args, **kwargs)
212
  else:
213
+ return grounded_generation_box(loaded_model_list_style, instruction, *args, **kwargs)
214
  elif task == 'Grounded Inpainting':
215
  assert image is not None
216
  instruction['input_image'] = image.convert("RGB")
217
+ return grounded_generation_box(loaded_model_list_inpaint, instruction, *args, **kwargs)
218
 
219
 
220
  def draw_box(boxes=[], texts=[], img=None):
 
251
  for grounding_text in grounding_texts:
252
  if grounding_text not in language_instruction and grounding_text != 'auto':
253
  language_instruction += "; " + grounding_text
254
+ print(language_instruction)
255
  return language_instruction
256
 
 
 
 
257
  def generate(task, language_instruction, grounding_texts, sketch_pad,
258
  alpha_sample, guidance_scale, batch_size,
259
  fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
 
263
 
264
  boxes = state['boxes']
265
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
266
+ assert len(boxes) == len(grounding_texts)
 
 
 
 
 
 
 
267
  boxes = (np.asarray(boxes) / 512).tolist()
268
  grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
269
 
 
447
  return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
448
 
449
  css = """
450
+ #generate-btn {
451
+ --tw-border-opacity: 1;
452
+ border-color: rgb(255 216 180 / var(--tw-border-opacity));
453
+ --tw-gradient-from: rgb(255 216 180 / .7);
454
+ --tw-gradient-to: rgb(255 216 180 / 0);
455
+ --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
456
+ --tw-gradient-to: rgb(255 176 102 / .8);
457
+ --tw-text-opacity: 1;
458
+ color: rgb(238 116 0 / var(--tw-text-opacity));
459
+ }
460
+ #img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img
461
  {
462
  height: var(--height) !important;
463
  max-height: var(--height) !important;
464
  min-height: var(--height) !important;
465
  }
466
+ #mirrors a:hover {
467
+ cursor:pointer;
468
+ }
469
  #paper-info a {
470
  color:#008AD7;
 
471
  }
472
  #paper-info a:hover {
473
  cursor: pointer;
 
474
  }
475
  """
476
 
 
487
  }
488
  """
489
 
490
+ mirror_js = """
491
+ function () {
492
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
493
+ const mirrors_div = root.querySelector('#mirrors');
494
+ const current_url = window.location.href;
495
+ const mirrors = [
496
+ 'https://dev.hliu.cc/gligen_mirror1/',
497
+ 'https://dev.hliu.cc/gligen_mirror2/',
498
+ ];
499
+
500
+ let mirror_html = '';
501
+ mirror_html += '[<a href="https://gligen.github.io" target="_blank" style="">Project Page</a>]';
502
+ mirror_html += '[<a href="https://arxiv.org/abs/2301.07093" target="_blank" style="">Paper</a>]';
503
+ mirror_html += '[<a href="https://github.com/gligen/GLIGEN" target="_blank" style="">GitHub Repo</a>]';
504
+ mirror_html += '&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;';
505
+ mirror_html += 'Mirrors: ';
506
+
507
+ mirrors.forEach((e, index) => {
508
+ let cur_index = index + 1;
509
+ if (current_url.includes(e)) {
510
+ mirror_html += `[Mirror ${cur_index}] `;
511
+ } else {
512
+ mirror_html += `[<a onclick="window.location.href = '${e}'">Mirror ${cur_index}</a>] `;
513
+ }
514
+ });
515
+
516
+ mirror_html = `<div class="output-markdown gr-prose" style="max-width: 100%;"><h3 style="text-align: center" id="paper-info">${mirror_html}</h3></div>`;
517
+
518
+ mirrors_div.innerHTML = mirror_html;
519
+ }
520
+ """
521
+
522
+ # Set up OpenAI API key
523
+ openai.api_key = os.environ['OPENAI_API_KEY']
524
+
525
+ prompt_base = 'Separate the subjects in this sentence by semicolons. For example, the sentence "a tiger and a horse running in a greenland" should output "tiger; horse". If there are numbers, make each subject unique. For example, "2 dogs and 1 duck" would be "dog; dog; duck." Do the same for the following sentence: \n'
526
+
527
+ original_input = ""
528
+ separated_subjects = ""
529
+
530
+ language_instruction = gr.Textbox(
531
+ label="Language Instruction by User",
532
+ value="2 horses running",
533
+ visible=False
534
+ )
535
+ grounding_instruction = gr.Textbox(
536
+ label="Subjects in image (Separated by semicolon)",
537
+ value="horse; horse",
538
+ visible=False
539
+ )
540
+
541
+ def separate_subjects(input_text):
542
+ prompt = prompt_base + input_text
543
+ response = openai.Completion.create(
544
+ engine="text-davinci-002",
545
+ prompt=prompt,
546
+ max_tokens=1024,
547
+ n=1,
548
+ stop=None,
549
+ temperature=0.7,
550
+ )
551
+ output_text = response.choices[0].text.strip()
552
+ return output_text
553
+
554
+ # def update_original_input():
555
+ # print("start update_original_input")
556
+ # global original_input
557
+ # original_input = language_instruction.value
558
+ # print("original_input in update:", original_input)
559
+
560
+ # def update_grounding_instruction():
561
+ # print("start update_grounding_instruction")
562
+ # # global original_input # declare you want to use the outer variable
563
+ # global separated_subjects
564
+ # update_original_input()
565
+ # separated_subjects = separate_subjects(language_instruction.value)
566
+ # # separated_subjects = separate_subjects(original_input)
567
+ # grounding_instruction.value = separated_subjects
568
+ # print("original_input:", original_input)
569
+ # print("separated_subjects", separated_subjects)
570
+
571
  with Blocks(
572
  css=css,
573
  analytics_enabled=False,
574
  title="GLIGen demo",
575
  ) as main:
576
+ gr.Markdown('<h1 style="text-align: center;">MSR: MultiSubject Render</h1>')
577
+ gr.Markdown('<h3 style="text-align: center;">Using NLP and Grounding Processing Techniques to improve image generation of multiple subjects with base Stable Diffusion Model</h3>')
578
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  with gr.Row():
580
  with gr.Column(scale=4):
581
  sketch_pad_trigger = gr.Number(value=0, visible=False)
 
584
  image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
585
  new_image_trigger = gr.Number(value=0, visible=False)
586
 
587
+ # UNCOMMENT THIS WHEN YOU WANT TO TOGGLE INPAINTING OPTION
588
  task = gr.Radio(
589
  choices=["Grounded Generation", 'Grounded Inpainting'],
590
  type="value",
591
  value="Grounded Generation",
592
  label="Task",
593
  )
594
+
595
+ # language_instruction = gr.Textbox(
596
+ # label="Enter your prompt here",
597
+ # )
598
+ # grounding_instruction = gr.Textbox(
599
+ # label="Grounding instruction (Separated by semicolon)",
600
+ # )
601
+ # grounding_instruction = separate_subjects(language_instruction.value)
602
+ # print(f"The user entered: {language_instruction}")
603
+ # print(f"Our function gave: {grounding_instruction}")
604
+
605
+ # EXPERIMENTING:
606
+ with gr.Column():
607
+ seed = gr.Text(label="Enter your prompt here:")
608
+ gr.Examples(["2 horses running", "A cowboy and ninja fighting", "An apple and an orange on a table"], inputs=[seed])
609
+ with gr.Column():
610
+ btn = gr.Button("Gen")
611
+ with gr.Column():
612
+ separated_text = gr.Text(label="Subjects Separated by Semicolon")
613
+ btn.click(separate_subjects, inputs=[seed], outputs=[separated_text])
614
+
615
+ language_instruction.value = seed
616
+ grounding_instruction.value = separated_text
617
+ ####################
618
+ # language_instruction = gr.Textbox(
619
+ # label="Enter your prompt here",
620
+ # )
621
+ # original_input = language_instruction.value
622
+ # start_btn = gr.Button('Start')
623
+ # start_btn.click(update_grounding_instruction)
624
+ # print("separated subjects 2:", separated_subjects)
625
+
626
+ # language_instruction = gr.Textbox(
627
+ # label="just needs to be here",
628
+ # value=seed,
629
+ # visible=False
630
+ # )
631
+ # grounding_instruction = gr.Textbox(
632
+ # label="Subjects in image (Separated by semicolon)",
633
+ # value=separated_text,
634
+ # visible=False
635
+ # )
636
+
637
+ print("Language instruction:", language_instruction.value)
638
+ print("Grounding instruction:", grounding_instruction.value)
639
+
640
+
641
+ ####################
642
+
643
  with gr.Row():
644
  sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
645
  out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
646
  with gr.Row():
647
  clear_btn = gr.Button(value='Clear')
648
+ gen_btn = gr.Button(value='Generate', elem_id="generate-btn")
649
  with gr.Accordion("Advanced Options", open=False):
650
  with gr.Column():
651
+ alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (Ο„)", visible=False)
652
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=20, label="Guidance Scale (how closely it adheres to your prompt)")
653
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images")
654
+ append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption", visible=False)
655
  use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
656
  with gr.Row():
657
+ fix_seed = gr.Checkbox(value=False, label="Fixed seed", visible=False)
658
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed", visible=False)
659
  with gr.Row():
660
+ use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition", visible=False)
661
+ style_cond_image = gr.Image(type="pil", label="Style Condition", interactive=True, visible=False)
662
  with gr.Column(scale=4):
663
+ gr.Markdown("### Generated Images")
664
  with gr.Row():
665
  out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
666
  out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
667
  with gr.Row():
668
+ out_gen_3 = gr.Image(type="pil", visible=True, show_label=False)
669
+ out_gen_4 = gr.Image(type="pil", visible=True, show_label=False)
670
 
671
  state = gr.State({})
672
 
 
768
  outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
769
  queue=True
770
  )
771
+ # start_btn.click(
772
+ # update_grounding_instruction,
773
+ # # inputs=[
774
+ # # original_input,
775
+ # # ],
776
+ # # outputs=[separated_subjects],
777
+ # # queue=True
778
+ # )
779
  sketch_pad_resize_trigger.change(
780
  None,
781
  None,
 
799
  outputs=[use_style_cond, style_cond_image, alpha_sample, use_actual_mask],
800
  queue=False)
801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
  main.queue(concurrency_count=1, api_open=False)
803
+ main.launch(share=False, show_api=False)