alessandro trinca tornidor
refactor: remove unuseful app_helpers.get_inference_model_by_args(args, inference_decorator=SPACES_GPU), fix logs, initialize gpu within infer_lisa_gradio()
a5e4002
import json | |
import os | |
import pathlib | |
import uuid | |
from typing import Callable, NoReturn | |
import gradio as gr | |
import spaces | |
import uvicorn | |
from fastapi import FastAPI, HTTPException, Request, status | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from lisa_on_cuda.utils import app_helpers, frontend_builder, create_folders_and_variables_if_not_exists | |
from pydantic import ValidationError | |
from samgis_core.utilities.fastapi_logger import setup_logging | |
from samgis_lisa_on_zero import PROJECT_ROOT_FOLDER, WORKDIR | |
from samgis_lisa_on_zero.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody | |
loglevel = os.getenv('LOGLEVEL', 'INFO').upper() | |
app_logger = setup_logging(debug=loglevel == "DEBUG") | |
VITE_INDEX_URL = os.getenv("VITE_INDEX_URL", "/") | |
VITE_SAMGIS_URL = os.getenv("VITE_SAMGIS_URL", "/samgis") | |
VITE_LISA_URL = os.getenv("VITE_LISA_URL", "/lisa") | |
VITE_GRADIO_URL = os.getenv("VITE_GRADIO_URL", "/gradio") | |
FASTAPI_TITLE = "samgis-lisa-on-zero" | |
app = FastAPI(title=FASTAPI_TITLE, version="1.0") | |
def gpu_initialization() -> None: | |
app_logger.info("GPU initialization...") | |
def get_gradio_interface_geojson( | |
fn_inference: Callable | |
): | |
return gr.Interface( | |
fn_inference, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder=None, label="Payload input"), | |
], | |
outputs=[ | |
gr.Textbox(lines=1, placeholder=None, label="Geojson Output") | |
] | |
) | |
def handle_exception_response(exception: Exception) -> NoReturn: | |
import subprocess | |
project_root_folder_content = subprocess.run( | |
f"ls -l {PROJECT_ROOT_FOLDER}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE | |
) | |
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.") | |
workdir_folder_content = subprocess.run( | |
f"ls -l {WORKDIR}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE | |
) | |
app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.") | |
app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.") | |
app_logger.error(f"inference error:{exception}.") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference" | |
) | |
async def request_middleware(request, call_next): | |
request_id = str(uuid.uuid4()) | |
with app_logger.contextualize(request_id=request_id): | |
app_logger.info("Request started") | |
try: | |
response = await call_next(request) | |
except Exception as ex_middleware_http: | |
app_logger.error(f"Request failed, ex_middleware_http: {ex_middleware_http}") | |
response = JSONResponse(content={"success": False}, status_code=500) | |
finally: | |
response.headers["X-Request-ID"] = request_id | |
app_logger.info("Request ended") | |
return response | |
def post_test_dictlist2(request_input: ApiRequestBody) -> JSONResponse: | |
from samgis_lisa_on_zero.io_package.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt | |
request_body = get_parsed_bbox_points_with_dictlist_prompt(request_input) | |
app_logger.info(f"request_body:{request_body}.") | |
return JSONResponse( | |
status_code=200, | |
content=request_body | |
) | |
async def health() -> JSONResponse: | |
import importlib.metadata | |
from importlib.metadata import PackageNotFoundError | |
core_version = lisa_on_cuda_version = samgis_lisa_on_cuda_version = "" | |
try: | |
core_version = importlib.metadata.version('samgis_core') | |
lisa_on_cuda_version = importlib.metadata.version('lisa-on-cuda') | |
samgis_lisa_on_cuda_version = importlib.metadata.version('samgis-lisa-on-zero') | |
except PackageNotFoundError as pe: | |
app_logger.error(f"pe:{pe}.") | |
msg = "still alive, " | |
msg += f"""version:{samgis_lisa_on_cuda_version}, core version:{core_version},""" | |
msg += f"""lisa-on-cuda version:{lisa_on_cuda_version},""" | |
app_logger.info(msg) | |
return JSONResponse(status_code=200, content={"msg": "still alive..."}) | |
def post_test_string(request_input: StringPromptApiRequestBody) -> JSONResponse: | |
from lisa_on_cuda.utils import app_helpers | |
from samgis_lisa_on_zero.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt | |
request_body = get_parsed_bbox_points_with_string_prompt(request_input) | |
app_logger.info(f"request_body:{request_body}.") | |
custom_args = app_helpers.parse_args([]) | |
request_body["content"] = {**request_body, "precision": str(custom_args.precision)} | |
return JSONResponse( | |
status_code=200, | |
content=request_body | |
) | |
def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> JSONResponse: | |
from samgis_lisa_on_zero.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt, get_source_name | |
from samgis_lisa_on_zero.prediction_api import lisa | |
from samgis_lisa_on_zero.utilities.constants import LISA_INFERENCE_FN | |
app_logger.info("starting lisa inference request...") | |
try: | |
import time | |
time_start_run = time.time() | |
body_request = get_parsed_bbox_points_with_string_prompt(request_input) | |
app_logger.info(f"lisa body_request:{body_request}.") | |
try: | |
source = body_request["source"] | |
source_name = body_request["source_name"] | |
app_logger.debug(f"body_request:type(source):{type(source)}, source:{source}.") | |
app_logger.debug(f"body_request:type(source_name):{type(source_name)}, source_name:{source_name}.") | |
app_logger.debug(f"lisa module:{lisa}.") | |
gpu_initialization() | |
output = lisa.lisa_predict( | |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], | |
source=source, source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN | |
) | |
duration_run = time.time() - time_start_run | |
app_logger.info(f"duration_run:{duration_run}.") | |
body = { | |
"duration_run": duration_run, | |
"output": output | |
} | |
dumped = json.dumps(body) | |
app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") | |
app_logger.debug(f"complete json.dumps(body):{dumped}.") | |
return dumped | |
except Exception as inference_exception: | |
handle_exception_response(inference_exception) | |
except ValidationError as va1: | |
app_logger.error(f"validation error: {str(va1)}.") | |
raise ValidationError("Unprocessable Entity") | |
def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse: | |
body = infer_lisa_gradio(request_input=request_input) | |
return JSONResponse(status_code=200, content={"body": json.dumps(body)}) | |
def infer_samgis(request_input: ApiRequestBody) -> JSONResponse: | |
from samgis_lisa_on_zero.prediction_api import predictors | |
from samgis_lisa_on_zero.io_package.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt, get_source_name | |
app_logger.info("starting plain samgis inference request...") | |
try: | |
import time | |
time_start_run = time.time() | |
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input) | |
app_logger.info(f"body_request:{body_request}.") | |
try: | |
source_name = body_request["source_name"] | |
app_logger.info(f"source_name = {source_name}.") | |
output = predictors.samexporter_predict( | |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], | |
source=body_request["source"], source_name=source_name | |
) | |
duration_run = time.time() - time_start_run | |
app_logger.info(f"duration_run:{duration_run}.") | |
body = { | |
"duration_run": duration_run, | |
"output": output | |
} | |
return JSONResponse(status_code=200, content={"body": json.dumps(body)}) | |
except Exception as inference_exception: | |
handle_exception_response(inference_exception) | |
except ValidationError as va1: | |
app_logger.error(f"validation error: {str(va1)}.") | |
raise ValidationError("Unprocessable Entity") | |
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: | |
app_logger.error(f"exception errors: {exc.errors()}.") | |
app_logger.error(f"exception body: {exc.body}.") | |
headers = request.headers.items() | |
app_logger.error(f'request header: {dict(headers)}.') | |
params = request.query_params.items() | |
app_logger.error(f'request query params: {dict(params)}.') | |
return JSONResponse( | |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
content={"msg": "Error - Unprocessable Entity"} | |
) | |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: | |
app_logger.error(f"exception: {str(exc)}.") | |
headers = request.headers.items() | |
app_logger.error(f'request header: {dict(headers)}.') | |
params = request.query_params.items() | |
app_logger.error(f'request query params: {dict(params)}.') | |
return JSONResponse( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
content={"msg": "Error - Internal Server Error"} | |
) | |
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "") | |
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.") | |
if bool(write_tmp_on_disk): | |
try: | |
path_write_tmp_on_disk = pathlib.Path(write_tmp_on_disk) | |
try: | |
pathlib.Path.unlink(path_write_tmp_on_disk, missing_ok=True) | |
except (IsADirectoryError, PermissionError, OSError) as err: | |
app_logger.error(f"{err} while removing old write_tmp_on_disk:{write_tmp_on_disk}.") | |
app_logger.error(f"is file?{path_write_tmp_on_disk.is_file()}.") | |
app_logger.error(f"is symlink?{path_write_tmp_on_disk.is_symlink()}.") | |
app_logger.error(f"is folder?{path_write_tmp_on_disk.is_dir()}.") | |
os.makedirs(write_tmp_on_disk, exist_ok=True) | |
app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output") | |
except RuntimeError as runtime_error: | |
app_logger.error(f"{runtime_error} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...") | |
raise runtime_error | |
templates = Jinja2Templates(directory=WORKDIR / "static") | |
def list_files(request: Request): | |
files = os.listdir(write_tmp_on_disk) | |
files_paths = sorted([f"{request.url._url}/{f}" for f in files]) | |
print(files_paths) | |
return templates.TemplateResponse( | |
"list_files.html", {"request": request, "files": files_paths} | |
) | |
static_dist_folder = WORKDIR / "static" / "dist" | |
frontend_builder.build_frontend( | |
project_root_folder=frontend_builder.env_project_root_folder, | |
input_css_path=frontend_builder.env_input_css_path, | |
output_dist_folder=static_dist_folder | |
) | |
create_folders_and_variables_if_not_exists.folders_creation() | |
app_logger.info("build_frontend ok!") | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static") | |
# important: the index() function and the app.mount MUST be at the end | |
# samgis.html | |
app.mount(VITE_SAMGIS_URL, StaticFiles(directory=static_dist_folder, html=True), name="samgis") | |
async def samgis() -> FileResponse: | |
return FileResponse(path=static_dist_folder / "samgis.html", media_type="text/html") | |
# lisa.html | |
app.mount(VITE_LISA_URL, StaticFiles(directory=static_dist_folder, html=True), name="lisa") | |
async def lisa() -> FileResponse: | |
return FileResponse(path=static_dist_folder / "lisa.html", media_type="text/html") | |
# index.html (lisa.html copy) | |
app.mount(VITE_INDEX_URL, StaticFiles(directory=static_dist_folder, html=True), name="index") | |
async def index() -> FileResponse: | |
return FileResponse(path=static_dist_folder / "index.html", media_type="text/html") | |
app_helpers.app_logger.info(f"prepared inference_fn function:{inference_fn.__name__}, creating gradio interface...") | |
io = get_gradio_interface_geojson(infer_lisa_gradio) | |
app_helpers.app_logger.info( | |
f"created gradio interface, mounting gradio app on url {VITE_GRADIO_URL} within FastAPI...") | |
app = gr.mount_gradio_app(app, io, path=VITE_GRADIO_URL) | |
app_helpers.app_logger.info("mounted gradio app within fastapi") | |
if __name__ == '__main__': | |
try: | |
uvicorn.run(host="0.0.0.0", port=7860, app=app) | |
except Exception as ex: | |
import logging | |
logging.error(f"fastapi/gradio application {FASTAPI_TITLE}, exception:{ex}.") | |
print(f"fastapi/gradio application {FASTAPI_TITLE}, exception:{ex}.") | |
raise ex | |