Spaces:
Running
Running
import streamlit.components.v1 as components | |
import streamlit as st | |
import numpy | |
import sahi.predict | |
import sahi.utils | |
from PIL import Image | |
import pathlib | |
import os | |
import uuid | |
STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / "static" | |
def sahi_mmdet_inference( | |
image, | |
detection_model, | |
slice_height=512, | |
slice_width=512, | |
overlap_height_ratio=0.2, | |
overlap_width_ratio=0.2, | |
image_size=640, | |
postprocess_type="UNIONMERGE", | |
postprocess_match_metric="IOS", | |
postprocess_match_threshold=0.5, | |
postprocess_class_agnostic=False, | |
): | |
# standard inference | |
prediction_result_1 = sahi.predict.get_prediction( | |
image=image, detection_model=detection_model, image_size=image_size | |
) | |
visual_result_1 = sahi.utils.cv.visualize_object_predictions( | |
image=numpy.array(image), | |
object_prediction_list=prediction_result_1.object_prediction_list, | |
) | |
output_1 = Image.fromarray(visual_result_1["image"]) | |
# sliced inference | |
prediction_result_2 = sahi.predict.get_sliced_prediction( | |
image=image, | |
detection_model=detection_model, | |
image_size=image_size, | |
slice_height=slice_height, | |
slice_width=slice_width, | |
overlap_height_ratio=overlap_height_ratio, | |
overlap_width_ratio=overlap_width_ratio, | |
postprocess_type=postprocess_type, | |
postprocess_match_metric=postprocess_match_metric, | |
postprocess_match_threshold=postprocess_match_threshold, | |
postprocess_class_agnostic=postprocess_class_agnostic, | |
) | |
visual_result_2 = sahi.utils.cv.visualize_object_predictions( | |
image=numpy.array(image), | |
object_prediction_list=prediction_result_2.object_prediction_list, | |
) | |
output_2 = Image.fromarray(visual_result_2["image"]) | |
return output_1, output_2 | |
def imagecompare( | |
img1: str, | |
img2: str, | |
label1: str = "1", | |
label2: str = "2", | |
width: int = 700, | |
show_labels: bool = True, | |
starting_position: int = 50, | |
make_responsive: bool = True, | |
): | |
"""Create a new juxtapose component. | |
Parameters | |
---------- | |
img1: str, PosixPath, PIL.Image or URL | |
Input image to compare | |
img2: str, PosixPath, PIL.Image or URL | |
Input image to compare | |
label1: str or None | |
Label for image 1 | |
label2: str or None | |
Label for image 2 | |
width: int or None | |
Width of the component in px | |
show_labels: bool or None | |
Show given labels on images | |
starting_position: int or None | |
Starting position of the slider as percent (0-100) | |
make_responsive: bool or None | |
Enable responsive mode | |
Returns | |
------- | |
static_component: Boolean | |
Returns a static component with a timeline | |
""" | |
# prepare images | |
for file_ in os.listdir(STREAMLIT_STATIC_PATH): | |
if file_.endswith(".png") and "favicon" not in file_: | |
os.remove(str(STREAMLIT_STATIC_PATH / file_)) | |
image_1_name = str(uuid.uuid4()) + ".png" | |
image_1_path = STREAMLIT_STATIC_PATH / image_1_name | |
image_1_path = str(image_1_path.resolve()) | |
sahi.utils.cv.read_image_as_pil(img1).save(image_1_path) | |
image_2_name = str(uuid.uuid4()) + ".png" | |
image_2_path = STREAMLIT_STATIC_PATH / image_2_name | |
image_2_path = str(image_2_path.resolve()) | |
sahi.utils.cv.read_image_as_pil(img2).save(image_2_path) | |
img_width, img_height = img1.size | |
h_to_w = img_height / img_width | |
height = width * h_to_w - 20 | |
# load css + js | |
cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest" | |
css_block = f'<link rel="stylesheet" href="{cdn_path}/css/juxtapose.css">' | |
js_block = f'<script src="{cdn_path}/js/juxtapose.min.js"></script>' | |
# write html block | |
htmlcode = f""" | |
{css_block} | |
{js_block} | |
<div id="foo"style="height: '%100'; width: {width or '%100'};"></div> | |
<script> | |
slider = new juxtapose.JXSlider('#foo', | |
[ | |
{{ | |
src: '{image_1_name}', | |
label: '{label1}', | |
}}, | |
{{ | |
src: '{image_2_name}', | |
label: '{label2}', | |
}} | |
], | |
{{ | |
animate: true, | |
showLabels: {'true' if show_labels else 'false'}, | |
showCredits: true, | |
startingPosition: "{starting_position}%", | |
makeResponsive: {'true' if make_responsive else 'false'}, | |
}}); | |
</script> | |
""" | |
static_component = components.html(htmlcode, height=height, width=width) | |
return static_component, image_1_path, image_2_path | |