from plot_functions import * import hydra import torch from model import LitUnsupervisedSegmenter from helper import inference_on_location_and_month, inference_on_location from plot_functions import segment_region from functools import partial import gradio as gr import logging import geopandas as gpd mapbox_access_token = "pk.eyJ1IjoiamVyZW15LWVraW1ldHJpY3MiLCJhIjoiY2xrNjBwNGU2MDRhMjNqbWw0YTJrbnpvNCJ9.poVyIzhJuJmD6ffrL9lm2w" geo_df = gpd.read_file(gpd.datasets.get_path('naturalearth_cities')) def get_geomap(long, lat ): fig = go.Figure(go.Scattermapbox( lat=geo_df.geometry.y, lon=geo_df.geometry.x, mode='markers', marker=go.scattermapbox.Marker( size=14 ), text=geo_df.name, )) fig.add_trace(go.Scattermapbox(lat=[lat], lon=[long], mode='markers', marker=go.scattermapbox.Marker( size=14 ), marker_color="green", text=['Actual position'])) fig.update_layout( showlegend=False, hovermode='closest', mapbox=dict( accesstoken=mapbox_access_token, center=go.layout.mapbox.Center( lat=lat, lon=long ), zoom=3 ) ) return fig if __name__ == "__main__": logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO) # Initialize hydra with configs #hydra.initialize(config_path="configs", job_name="corine") cfg = hydra.compose(config_name="my_train_config.yml") logging.info(f"config : {cfg}") # Load the model nbclasses = cfg.dir_dataset_n_classes model = LitUnsupervisedSegmenter(nbclasses, cfg) logging.info(f"Model Initialiazed") model_path = "biomap/checkpoint/model/model.pt" saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) logging.info(f"Model weights Loaded") model.load_state_dict(saved_state_dict) logging.info(f"Model Loaded") # css=".VIDEO video{height: 100%;width:50%;margin:auto};.VIDEO{height: 50%;};.svelte-1vnmhm4{height:auto}" with gr.Blocks() as demo: gr.Markdown("Estimate Biodiversity in the world.") with gr.Tab("Single Image"): with gr.Row(): input_map = gr.Plot().style() with gr.Column(): input_latitude = gr.Number(label="lattitude", value=2.98) input_longitude = gr.Number(label="longitude", value=48.81) input_date = gr.Textbox(label="start_date", value="2020-03-20") single_button = gr.Button("Predict") with gr.Row(): raw_image = gr.Image(label = "Localisation visualization") output_image = gr.Image(label = "Labeled visualisation") score_biodiv = gr.Number(label = "Biodiversity score") with gr.Tab("TimeLapse"): with gr.Row(): input_map_2 = gr.Plot().style() with gr.Row(): timelapse_input_latitude = gr.Number(value=2.98, label="Latitude") timelapse_input_longitude = gr.Number(value=48.81, label="Longitude") timelapse_start_date = gr.Textbox(value='2020-05-01', label="Start Date") timelapse_end_date = gr.Textbox(value='2020-06-30', label="End Date") segmentation = gr.CheckboxGroup(choices=['month', 'year', '2months'], value=['month'], label="Select Segmentation Level:") timelapse_button = gr.Button(value="Predict") map = gr.Plot().style() demo.load(get_geomap, [input_latitude, input_longitude], input_map) single_button.click(get_geomap, [input_latitude, input_longitude], input_map) single_button.click(partial(inference_on_location_and_month, model), inputs=[input_latitude, input_longitude, input_date], outputs=[raw_image, output_image,score_biodiv]) demo.load(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2) timelapse_button.click(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2) timelapse_button.click(segment_region, inputs=[timelapse_input_latitude, timelapse_input_longitude, timelapse_start_date, timelapse_end_date,segmentation], outputs=[map]) demo.launch()