Spaces:
Running
on
T4
Running
on
T4
yizhangliu
commited on
Commit
·
5128046
1
Parent(s):
d829f40
update app.py
Browse files
app.py
CHANGED
@@ -774,10 +774,13 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
774 |
use_sam_predictor = True
|
775 |
if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
776 |
image = np.array(input_img)
|
|
|
777 |
if task_type == 'remove' and remove_use_segment == False:
|
|
|
778 |
use_sam_predictor = False
|
779 |
|
780 |
if sam_predictor and use_sam_predictor:
|
|
|
781 |
sam_predictor.set_image(image)
|
782 |
|
783 |
for i in range(boxes_filt.size(0)):
|
@@ -786,6 +789,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
786 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
787 |
|
788 |
if sam_predictor and use_sam_predictor:
|
|
|
789 |
boxes_filt = boxes_filt.to(sam_device)
|
790 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
|
791 |
|
@@ -798,6 +802,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
798 |
# masks: [9, 1, 512, 512]
|
799 |
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
800 |
else:
|
|
|
801 |
masks = torch.zeros(len(boxes_filt), 1, H, W)
|
802 |
mask_count = 0
|
803 |
for box in boxes_filt:
|
@@ -806,6 +811,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
806 |
masks = torch.where(masks > 0, True, False)
|
807 |
run_mode = "rectangle"
|
808 |
|
|
|
809 |
# draw output image
|
810 |
plt.figure(figsize=(10, 10))
|
811 |
plt.imshow(image)
|
|
|
774 |
use_sam_predictor = True
|
775 |
if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
776 |
image = np.array(input_img)
|
777 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_1_')
|
778 |
if task_type == 'remove' and remove_use_segment == False:
|
779 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_2_')
|
780 |
use_sam_predictor = False
|
781 |
|
782 |
if sam_predictor and use_sam_predictor:
|
783 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_3_')
|
784 |
sam_predictor.set_image(image)
|
785 |
|
786 |
for i in range(boxes_filt.size(0)):
|
|
|
789 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
790 |
|
791 |
if sam_predictor and use_sam_predictor:
|
792 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_4_')
|
793 |
boxes_filt = boxes_filt.to(sam_device)
|
794 |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
|
795 |
|
|
|
802 |
# masks: [9, 1, 512, 512]
|
803 |
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
804 |
else:
|
805 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_5_')
|
806 |
masks = torch.zeros(len(boxes_filt), 1, H, W)
|
807 |
mask_count = 0
|
808 |
for box in boxes_filt:
|
|
|
811 |
masks = torch.where(masks > 0, True, False)
|
812 |
run_mode = "rectangle"
|
813 |
|
814 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_6_')
|
815 |
# draw output image
|
816 |
plt.figure(figsize=(10, 10))
|
817 |
plt.imshow(image)
|