alessandro trinca tornidor
refactor: remove unuseful app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU), fix logs, initialize gpu within infer_lisa_gradio()
a5e4002
from datetime import datetime | |
from spaces import GPU as SPACES_GPU | |
from samgis_core.utilities.type_hints import LlistFloat, DictStrInt | |
from samgis_lisa_on_zero.io_package.geo_helpers import get_vectorized_raster_as_geojson | |
from samgis_lisa_on_zero.io_package.raster_helpers import write_raster_png, write_raster_tiff | |
from samgis_lisa_on_zero.io_package.tms2geotiff import download_extent | |
from samgis_lisa_on_zero.utilities.constants import DEFAULT_URL_TILES, LISA_INFERENCE_FN | |
msg_write_tmp_on_disk = "found option to write images and geojson output..." | |
def load_model_and_inference_fn(inference_function_name_key: str): | |
from samgis_lisa_on_zero import app_logger | |
from lisa_on_cuda.utils import app_helpers | |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict | |
if models_dict[inference_function_name_key]["inference"] is None: | |
msg = f"missing inference function {inference_function_name_key}, " | |
msg += f"instantiating it now using inference_decorator {SPACES_GPU}!" | |
app_logger.info(msg) | |
parsed_args = app_helpers.parse_args([]) | |
inference_fn = app_helpers.get_inference_model_by_args( | |
parsed_args, | |
internal_logger0=app_logger, | |
inference_decorator=SPACES_GPU | |
) | |
models_dict[inference_function_name_key]["inference"] = inference_fn | |
def lisa_predict( | |
bbox: LlistFloat, | |
prompt: str, | |
zoom: float, | |
inference_function_name_key: str = LISA_INFERENCE_FN, | |
source: str = DEFAULT_URL_TILES, | |
source_name: str = None | |
) -> DictStrInt: | |
""" | |
Return predictions as a geojson from a geo-referenced image using the given input prompt. | |
1. if necessary instantiate a segment anything machine learning instance model | |
2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox) | |
3. get a prediction image from the segment anything instance model using the input prompt | |
4. get a geo-referenced geojson from the prediction image | |
Args: | |
bbox: coordinates bounding box | |
prompt: machine learning input prompt | |
zoom: Level of detail | |
inference_function_name_key: machine learning model name | |
source: xyz | |
source_name: name of tile provider | |
Returns: | |
Affine transform | |
""" | |
from os import getenv | |
from samgis_lisa_on_zero import app_logger | |
from samgis_lisa_on_zero.prediction_api.global_models import models_dict | |
if source_name is None: | |
source_name = str(source) | |
app_logger.info("start lisa inference...") | |
app_logger.debug(f"type(source):{type(source)}, source:{source},") | |
app_logger.debug(f"type(source_name):{type(source_name)}, source_name:{source_name}.") | |
load_model_and_inference_fn(inference_function_name_key) | |
app_logger.debug(f"using a '{inference_function_name_key}' instance model...") | |
inference_fn = models_dict[inference_function_name_key]["inference"] | |
app_logger.info(f"loaded inference function '{inference_fn.__name__}'.") | |
pt0, pt1 = bbox | |
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.") | |
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source) | |
app_logger.info( | |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.") | |
folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "") | |
prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_" | |
if bool(folder_write_tmp_on_disk): | |
now = datetime.now().strftime('%Y%m%d_%H%M%S') | |
app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.") | |
if img.shape and len(img.shape) == 2: | |
write_raster_tiff(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk) | |
if img.shape and len(img.shape) == 3 and img.shape[2] == 3: | |
write_raster_png(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_img", folder_write_tmp_on_disk) | |
else: | |
app_logger.info("keep all temp data in memory...") | |
app_logger.info(f"lisa_zero, source_name:{source_name}, source_name type:{type(source_name)}.") | |
app_logger.info(f"lisa_zero, prompt type:{type(prompt)}.") | |
app_logger.info(f"lisa_zero, prompt:{prompt}.") | |
prompt_str = str(prompt) | |
app_logger.info(f"lisa_zero, img type:{type(img)}.") | |
embedding_key = f"{source_name}_z{zoom}_{prefix}" | |
_, mask, output_string = inference_fn(input_str=prompt_str, input_image=img, embedding_key=embedding_key) | |
app_logger.info(f"lisa_zero, output_string type:{type(output_string)}.") | |
app_logger.info(f"lisa_zero, mask_output type:{type(mask)}.") | |
app_logger.info(f"created output_string '{output_string}', preparing conversion to geojson...") | |
return { | |
"output_string": output_string, | |
**get_vectorized_raster_as_geojson(mask, transform) | |
} | |