alessandro trinca tornidor
feat: rename io package to 'io_package' to avoid errors on test execution. Add pytest dependency
00f8875
"""functions using machine learning instance model(s)""" | |
from samgis_lisa_on_zero import app_logger, MODEL_FOLDER | |
from samgis_lisa_on_zero.io_package.geo_helpers import get_vectorized_raster_as_geojson | |
from samgis_lisa_on_zero.io_package.raster_helpers import get_raster_terrain_rgb_like, get_rgb_prediction_image | |
from samgis_lisa_on_zero.io_package.tms2geotiff import download_extent | |
from samgis_lisa_on_zero.io_package.wrappers_helpers import check_source_type_is_terrain | |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict, embedding_dict | |
from samgis_lisa_on_zero.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE | |
from samgis_core.prediction_api import sam_onnx2, sam_onnx_inference | |
from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE | |
from samgis_core.utilities.type_hints import LlistFloat, DictStrInt, ListDict | |
def samexporter_predict( | |
bbox: LlistFloat, | |
prompt: ListDict, | |
zoom: float, | |
model_name_key: str = "fastsam", | |
source: str = DEFAULT_URL_TILES, | |
source_name: str = None | |
) -> DictStrInt: | |
""" | |
Return predictions as a geojson from a geo-referenced image using the given input prompt. | |
1. if necessary instantiate a segment anything machine learning instance model | |
2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox) | |
3. get a prediction image from the segment anything instance model using the input prompt | |
4. get a geo-referenced geojson from the prediction image | |
Args: | |
bbox: coordinates bounding box | |
prompt: machine learning input prompt | |
zoom: Level of detail | |
model_name_key: machine learning model name | |
source: xyz | |
source_name: name of tile provider | |
Returns: | |
Affine transform | |
""" | |
if models_dict[model_name_key]["instance"] is None: | |
app_logger.info(f"missing instance model {model_name_key}, instantiating it now!") | |
model_instance = sam_onnx2.SegmentAnythingONNX2( | |
encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME, | |
decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME | |
) | |
models_dict[model_name_key]["instance"] = model_instance | |
app_logger.debug(f"using a {model_name_key} instance model...") | |
models_instance = models_dict[model_name_key]["instance"] | |
pt0, pt1 = bbox | |
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.") | |
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source) | |
if check_source_type_is_terrain(source): | |
app_logger.info("terrain-rgb like raster: transforms it into a DEM") | |
dem = get_raster_terrain_rgb_like(img, source.name) | |
# set a slope cell size proportional to the image width | |
slope_cellsize = int(img.shape[1] * SLOPE_CELLSIZE / DEFAULT_INPUT_SHAPE[1]) | |
app_logger.info(f"terrain-rgb like raster: compute slope, curvature using {slope_cellsize} as cell size.") | |
img = get_rgb_prediction_image(dem, slope_cellsize) | |
app_logger.info( | |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.") | |
app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.") | |
embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}" | |
mask, n_predictions = sam_onnx_inference.get_raster_inference_with_embedding_from_dict( | |
img, prompt, models_instance, model_name_key, embedding_key, embedding_dict) | |
app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...") | |
return { | |
"n_predictions": n_predictions, | |
**get_vectorized_raster_as_geojson(mask, transform) | |
} | |