from PIL import Image import hydra import matplotlib as mpl from utils import prep_for_plot import torch.multiprocessing import torchvision.transforms as T # import matplotlib.pyplot as plt from model import LitUnsupervisedSegmenter colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey') class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background') cmap = mpl.colors.ListedColormap(colors) #from train_segmentation import LitUnsupervisedSegmenter, cmap from utils_gee import extract_img, transform_ee_img import plotly.graph_objects as go import as px import numpy as np from plotly.subplots import make_subplots import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey') class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background') scores_init = [2,3,4,3,1,4,0] # Import model configs hydra.initialize(config_path="configs", job_name="corine") cfg = hydra.compose(config_name="my_train_config.yml") nbclasses = cfg.dir_dataset_n_classes # Load Model model_path = "checkpoint/model/" saved_state_dict = torch.load(model_path,map_location=torch.device('cpu')) model = LitUnsupervisedSegmenter(nbclasses, cfg) model.load_state_dict(saved_state_dict) #normalize img preprocess = T.Compose([ T.ToPILImage(), T.Resize((320,320)), # T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # Function that look for img on EE and segment it # -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) : if how == 'month': img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28') elif how == 'year' : if year_end == None : img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04) else : img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04) img_test= transform_ee_img(img, max = 0.25) # Preprocess opened img x = preprocess(img_test) x = torch.unsqueeze(x, dim=0).cpu() # model=model.cpu() with torch.no_grad(): feats, code = linear_preds = model.linear_probe(x, code) linear_preds = linear_preds.argmax(1) outputs = { 'img': x[:model.cfg.n_images].detach().cpu(), 'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu() } return outputs # Function that look for all img on EE and extract all segments with the date as first output arg def segment_group(location, start_date, end_date, how = 'month') : outputs = [] st_month = int(start_date[5:7]) end_month = int(end_date[5:7]) st_year = int(start_date[0:4]) end_year = int(end_date[0:4]) for year in range(st_year, end_year+1) : if year != end_year : last = 12 else : last = end_month if year != st_year: start = 1 else : start = st_month if how == 'month' : for month in range(start, last + 1): month_str = f"{month:0>2d}" year_str = str(year) outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str))) elif how == 'year' : outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}"))) elif how == '2months' : for month in range(start, last + 1): month_str = f"{month:0>2d}" year_str = str(year) month_end = (month) % 12 +1 if month_end < month : year_end = year +1 else : year_end = year month_end= f"{month_end:0>2d}" year_end = str(year_end) outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end))) return outputs # Function that transforms an output to PIL images def transform_to_pil(outputs,alpha=0.3): # Transform img with torch img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0) img=T.ToPILImage()(img) # Transform label by saving it then open it # label = outputs['linear_preds'][0] # plt.imsave('label.png',label,cmap=cmap) # label ='label.png') cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)]) labels = np.array(outputs['linear_preds'][0])-1 label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8)) # Overlay labels with img wit alpha background = img.convert("RGBA") overlay = label.convert("RGBA") labeled_img = Image.blend(background, overlay, alpha) return img, label, labeled_img # Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation def values_from_output(output): imgs = transform_to_pil(output,alpha = 0.3) img = imgs[0] img = np.array(img.convert('RGB')) labeled_img = imgs[2] labeled_img = np.array(labeled_img.convert('RGB')) nb_values = [] for i in range(7): nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1)) score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init) return img, labeled_img, nb_values, score # Function that extract from outputs (from segment_group function) all dates/ all images def values_from_outputs(outputs) : months = [] imgs = [] imgs_label = [] nb_values = [] scores = [] for output in outputs: img, labeled_img, nb_value, score = values_from_output(output[1]) months.append(output[0]) imgs.append(img) imgs_label.append(labeled_img) nb_values.append(nb_value) scores.append(score) return months, imgs, imgs_label, nb_values, scores def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) : fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) # Scores scatters = [] temp = [] for score in scores : temp_score = [] temp_date = [] score = scores[i] temp.append(score) text_temp = ["" for i in temp] text_temp[-1] = str(round(score,2)) scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center")) # Scores fig = make_subplots( rows=1, cols=4, # specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]] # row_heights=[0.8, 0.2], column_widths = [0.6, 0.6,0.3, 0.3], subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores") ) fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) fig.add_trace(go.Pie(labels = class_names, values = nb_values[0], marker_colors = colors, name="Segment repartition", textposition='inside', texttemplate = "%{percent:.0%}", textfont_size=14 ), row=1, col=3) fig.add_trace(scatters[0], row=1, col=4) # fig.add_annotation(text='score:' + str(scores[0]), # showarrow=False, # row=2, col=2) number_frames = len(imgs) frames = [dict( name = k, data = [ fig2["frames"][k]["data"][0], fig3["frames"][k]["data"][0], go.Pie(labels = class_names, values = nb_values[k], marker_colors = colors, name="Segment repartition", textposition='inside', texttemplate = "%{percent:.0%}", textfont_size=14 ), scatters[k] ], traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in # that are updated by the above three go.Scatter instances ) for k in range(number_frames)] updatemenus = [dict(type='buttons', buttons=[dict(label='Play', method='animate', args=[[f'{k}' for k in range(number_frames)], dict(frame=dict(duration=500, redraw=False), transition=dict(duration=0), easing='linear', fromcurrent=True, mode='immediate' )])], direction= 'left', pad=dict(r= 10, t=85), showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top') ] sliders = [{'yanchor': 'top', 'xanchor': 'left', 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'}, 'transition': {'duration': 500.0, 'easing': 'linear'}, 'pad': {'b': 10, 't': 50}, 'len': 0.9, 'x': 0.1, 'y': 0, 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False}, 'transition': {'duration': 0, 'easing': 'linear'}}], 'label': months[k], 'method': 'animate'} for k in range(number_frames) ]}] fig.update(frames=frames) for i,fr in enumerate(fig["frames"]): fr.update( layout={ "xaxis": { "range": [0,imgs[0].shape[1]+i/100000] }, "yaxis": { "range": [imgs[0].shape[0]+i/100000,0] }, }) fr.update(layout_title_text= months[i]) fig.update(layout_title_text= 'tot') fig.update( layout={ "xaxis": { "range": [0,imgs[0].shape[1]+i/100000], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis": { "range": [imgs[0].shape[0]+i/100000,0], 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at y=0 'visible': False,}, "xaxis3": { "range": [0,len(scores)+1], 'autorange': False, # thin lines in the background 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at y=0 'visible': False }, "yaxis3": { "range": [0,1.5], 'autorange': False, 'showgrid': False, # thin lines in the background 'zeroline': False, # thick line at y=0 'visible': False # thin lines in the background } }, legend=dict( yanchor="bottom", y=0.99, xanchor="center", x=0.01 ) ) fig.update_layout(updatemenus=updatemenus, sliders=sliders) fig.update_layout(margin=dict(b=0, r=0)) # #in jupyter notebook return fig # Last function (global one) # how = 'month' or '2months' or 'year' def segment_region(location, start_date, end_date, how = 'month'): #extract the outputs for each image outputs = segment_group(location, start_date, end_date, how = how) #extract the intersting values from image months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs) #Create the figure fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) return fig #normalize img preprocess = T.Compose([ T.ToPILImage(), T.Resize((320,320)), # #in jupyter notebook return fig # Last function (global one) # how = 'month' or '2months' or 'year' def segment_region(latitude, longitude, start_date, end_date, how = 'month'): location = [float(latitude),float(longitude)] how = how[0] #extract the outputs for each image outputs = segment_group(location, start_date, end_date, how = how) #extract the intersting values from image months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs) #Create the figure fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) return fig