BertChristiaens commited on
Commit
32a644b
·
1 Parent(s): ef697d2
Files changed (2) hide show
  1. app.py +6 -5
  2. models.py +10 -3
app.py CHANGED
@@ -182,11 +182,6 @@ def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, pain
182
 
183
 
184
  elif generation_mode == "Re-generate objects":
185
- st.write("This mode allows you to choose which objects you want to re-generate in the image. "
186
- "Use the selection dropdown to add or remove objects. If you are ready, press the generate button"
187
- " to generate the image, which can take up to 30 seconds. If you want to improve the generated image, click"
188
- " the 'move image to input' button."
189
- )
190
  canvas = st_canvas(
191
  **canvas_dict,
192
  )
@@ -209,6 +204,12 @@ def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, pain
209
  default=st.session_state['unique_colors'],
210
  format_func=map_colors_rgb,
211
  )
 
 
 
 
 
 
212
 
213
  if st.button("generate image", key='generate_button'):
214
  image = get_image()
 
182
 
183
 
184
  elif generation_mode == "Re-generate objects":
 
 
 
 
 
185
  canvas = st_canvas(
186
  **canvas_dict,
187
  )
 
204
  default=st.session_state['unique_colors'],
205
  format_func=map_colors_rgb,
206
  )
207
+ with st.expander("Explanation", expanded=False):
208
+ st.write("This mode allows you to choose which objects you want to re-generate in the image. "
209
+ "Use the selection dropdown to add or remove objects. If you are ready, press the generate button"
210
+ " to generate the image, which can take up to 30 seconds. If you want to improve the generated image, click"
211
+ " the 'move image to input' button."
212
+ )
213
 
214
  if st.button("generate image", key='generate_button'):
215
  image = get_image()
models.py CHANGED
@@ -4,6 +4,7 @@ from typing import List, Tuple, Dict
4
 
5
  import streamlit as st
6
  import torch
 
7
  import time
8
  import numpy as np
9
  from PIL import Image
@@ -23,22 +24,26 @@ from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNe
23
 
24
  LOGGING = logging.getLogger(__name__)
25
 
 
 
 
26
 
27
  class ControlNetPipeline:
28
  def __init__(self):
29
  self.in_use = False
30
  self.controlnet = ControlNetModel.from_pretrained(
31
- "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
32
 
33
  self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
34
  "runwayml/stable-diffusion-inpainting",
35
  controlnet=self.controlnet,
36
  safety_checker=None,
37
- torch_dtype=torch.float16
38
  )
39
 
40
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
41
  self.pipe.enable_xformers_memory_efficient_attention()
 
42
  self.pipe = self.pipe.to("cuda")
43
 
44
  self.waiting_queue = []
@@ -59,8 +64,10 @@ class ControlNetPipeline:
59
  # it's your turn, so remove the number from the queue
60
  # and call the function
61
  print("It's the turn of", self.count)
 
62
  self.waiting_queue.pop(0)
63
- return self.pipe(**kwargs)
 
64
 
65
 
66
  @contextmanager
 
4
 
5
  import streamlit as st
6
  import torch
7
+ import gc
8
  import time
9
  import numpy as np
10
  from PIL import Image
 
24
 
25
  LOGGING = logging.getLogger(__name__)
26
 
27
+ def flush():
28
+ gc.collect()
29
+ torch.cuda.empty_cache()
30
 
31
  class ControlNetPipeline:
32
  def __init__(self):
33
  self.in_use = False
34
  self.controlnet = ControlNetModel.from_pretrained(
35
+ "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float32)
36
 
37
  self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
38
  "runwayml/stable-diffusion-inpainting",
39
  controlnet=self.controlnet,
40
  safety_checker=None,
41
+ torch_dtype=torch.float32
42
  )
43
 
44
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
45
  self.pipe.enable_xformers_memory_efficient_attention()
46
+ self.pipe.enable_attention_slicing("max")
47
  self.pipe = self.pipe.to("cuda")
48
 
49
  self.waiting_queue = []
 
64
  # it's your turn, so remove the number from the queue
65
  # and call the function
66
  print("It's the turn of", self.count)
67
+ results = self.pipe(**kwargs)
68
  self.waiting_queue.pop(0)
69
+ flush()
70
+ return results
71
 
72
 
73
  @contextmanager