liuyizhang commited on
Commit
634639d
Β·
1 Parent(s): ed4b6a1

update app.py: add outpainting, remove relate anything

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -53,7 +53,7 @@ plt = matplotlib.pyplot
53
  groundingdino_enable = True
54
  sam_enable = True
55
  inpainting_enable = True
56
- ram_enable = True
57
 
58
  lama_cleaner_enable = True
59
 
@@ -620,7 +620,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
620
  return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
621
 
622
  text_prompt = text_prompt.strip()
623
- if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
624
  if text_prompt == '':
625
  return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
626
 
@@ -652,7 +652,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
652
  H, W = size[1], size[0]
653
 
654
  # run grounding dino model
655
- if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
656
  pass
657
  else:
658
  groundingdino_device = 'cpu'
@@ -682,7 +682,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
682
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
683
 
684
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
685
- if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
686
  image = np.array(input_img)
687
  if sam_predictor:
688
  sam_predictor.set_image(image)
@@ -734,7 +734,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
734
  if task_type == 'detection' or task_type == 'segment':
735
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
736
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
737
- elif task_type == 'inpainting' or task_type == 'remove':
738
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
739
  task_type = 'remove'
740
 
@@ -752,10 +752,17 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
752
  output_images.append(mask_pil.convert("RGB"))
753
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
754
 
755
- if task_type == 'inpainting':
756
  # inpainting pipeline
757
  image_source_for_inpaint = image_pil.resize((512, 512))
758
  image_mask_for_inpaint = mask_pil.resize((512, 512))
 
 
 
 
 
 
 
759
  image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
760
  else:
761
  # remove from mask
@@ -828,9 +835,9 @@ def change_radio_display(task_type, mask_source_radio):
828
  kosmos_output_visible = True
829
  kosmos_text_output_visible = True
830
 
831
- if task_type == "inpainting":
832
  inpaint_prompt_visible = True
833
- if task_type == "inpainting" or task_type == "remove":
834
  mask_source_radio_visible = True
835
  if mask_source_radio == mask_source_draw:
836
  text_prompt_visible = False
@@ -872,6 +879,7 @@ def main_gradio(args):
872
  task_types.append("segment")
873
  if inpainting_enable:
874
  task_types.append("inpainting")
 
875
  if lama_cleaner_enable:
876
  task_types.append("remove")
877
  if ram_enable:
@@ -887,7 +895,7 @@ def main_gradio(args):
887
  value=mask_source_segment, label="Mask from",
888
  visible=False)
889
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
890
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
891
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
892
 
893
  kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
 
53
  groundingdino_enable = True
54
  sam_enable = True
55
  inpainting_enable = True
56
+ ram_enable = False
57
 
58
  lama_cleaner_enable = True
59
 
 
620
  return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
621
 
622
  text_prompt = text_prompt.strip()
623
+ if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
624
  if text_prompt == '':
625
  return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
626
 
 
652
  H, W = size[1], size[0]
653
 
654
  # run grounding dino model
655
+ if (task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw:
656
  pass
657
  else:
658
  groundingdino_device = 'cpu'
 
682
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
683
 
684
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
685
+ if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
686
  image = np.array(input_img)
687
  if sam_predictor:
688
  sam_predictor.set_image(image)
 
734
  if task_type == 'detection' or task_type == 'segment':
735
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
736
  return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
737
+ elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
738
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
739
  task_type = 'remove'
740
 
 
752
  output_images.append(mask_pil.convert("RGB"))
753
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
754
 
755
+ if task_type in ['inpainting', 'outpainting']:
756
  # inpainting pipeline
757
  image_source_for_inpaint = image_pil.resize((512, 512))
758
  image_mask_for_inpaint = mask_pil.resize((512, 512))
759
+ if task_type in ['outpainting']:
760
+ # reverse mask
761
+ img_arr = np.array(image_mask_for_inpaint)
762
+ img_arr = 1 - img_arr
763
+ image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
764
+ output_images.append(image_mask_for_inpaint.convert("RGB"))
765
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
766
  image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
767
  else:
768
  # remove from mask
 
835
  kosmos_output_visible = True
836
  kosmos_text_output_visible = True
837
 
838
+ if task_type in ['inpainting', 'outpainting']:
839
  inpaint_prompt_visible = True
840
+ if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
841
  mask_source_radio_visible = True
842
  if mask_source_radio == mask_source_draw:
843
  text_prompt_visible = False
 
879
  task_types.append("segment")
880
  if inpainting_enable:
881
  task_types.append("inpainting")
882
+ task_types.append("outpainting")
883
  if lama_cleaner_enable:
884
  task_types.append("remove")
885
  if ram_enable:
 
895
  value=mask_source_segment, label="Mask from",
896
  visible=False)
897
  text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
898
+ inpaint_prompt = gr.Textbox(label="Inpaint/Outpaint Prompt (if this is empty, then remove)", visible=False)
899
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
900
 
901
  kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)