import streamlit as st from pathlib import Path # ----------------------------------------------------------------------------- # main # ----------------------------------------------------------------------------- def main(): st.title("SatVision Few-Shot Comparison") selected_option = st.select_slider( "## Number of training samples", options=[10, 100, 500, 1000, 5000]) st.markdown('Move slider to select how many training ' + \ 'samples the models were trained on') images = load_images(selected_option, Path('./images/images')) labels = load_labels(selected_option, Path('./images/labels')) preds = load_predictions(selected_option, Path('./images/predictions')) zipped_st_images = zip(images, preds['svb'], preds['unet'], labels) grid = make_grid(4, 4) for i, (image_data, svb_data, unet_data, label_data) in \ enumerate(zipped_st_images): if i == 0: grid[0][0].markdown(f'## MOD09GA 3-2-1 Image Chip') grid[0][1].markdown(f'## SatVision-B Prediction') grid[0][2].markdown(f'## UNet (CNN) Prediction') grid[0][3].markdown(f'## MCD12Q1 LandCover Target') grid[i][0].image(image_data[0], image_data[1], use_column_width=True) grid[i][1].image(svb_data[0], svb_data[1], use_column_width=True) grid[i][2].image(unet_data[0], unet_data[1], use_column_width=True) grid[i][3].image(label_data[0], label_data[1], use_column_width=True) st.text("Additional Information:") st.text("This is a placeholder for additional information about the images.") # ----------------------------------------------------------------------------- # load_images # ----------------------------------------------------------------------------- def load_images(selected_option: str, image_dir: Path): """ Given a selected option and image dir, return streamlit image objects. """ image_paths = find_images(selected_option, image_dir) images = [(str(path), f"MOD09GA 3-2-1 H18v04 2019 Example {i}") for \ i, path in enumerate(image_paths, 1)] return images # ----------------------------------------------------------------------------- # find_images # ----------------------------------------------------------------------------- def find_images(selected_option: str, image_dir: Path): images_regex = f'ft_demo_{selected_option}_*_img.png' images_matching_regex = sorted(image_dir.glob(images_regex)) assert len(images_matching_regex) == 3, "Should be 3 images matching regex" assert '1071' in str(images_matching_regex[0]), 'Should be 1071' return images_matching_regex # ----------------------------------------------------------------------------- # load_labels # ----------------------------------------------------------------------------- def load_labels(selected_option, label_dir: Path): label_paths = find_labels(selected_option, label_dir) labels = [(str(path), f"MCD12Q1 LandCover Target Example {i}") for \ i, path in enumerate(label_paths, 1)] return labels # ----------------------------------------------------------------------------- # find_labels # ----------------------------------------------------------------------------- def find_labels(selected_option: str, label_dir: Path): labels_regex = f'ft_demo_{selected_option}_*_label.png' labels_matching_regex = sorted(label_dir.glob(labels_regex)) assert len(labels_matching_regex) == 3, \ "Should be 3 label images matching regex" assert '1071' in str(labels_matching_regex[0]), 'Should be 1071' return labels_matching_regex # ----------------------------------------------------------------------------- # load_predictions # ----------------------------------------------------------------------------- def load_predictions(selected_option: str, pred_dir: Path): svb_pred_paths = find_preds(selected_option, pred_dir, 'svb') unet_pred_paths = find_preds(selected_option, pred_dir, 'cnn') svb_preds = [(str(path), f"SatVision-B Prediction Example {i}") for \ i, path in enumerate(svb_pred_paths, 1)] unet_preds = [(str(path), f"Unet Prediction Example {i}") for \ i, path in enumerate(unet_pred_paths, 1)] prediction_dict = {'svb': svb_preds, 'unet': unet_preds} return prediction_dict # ----------------------------------------------------------------------------- # find_preds # ----------------------------------------------------------------------------- def find_preds(selected_option: int, pred_dir: Path, model: str): if model == 'cnn': pred_regex = f'ft_cnn_demo_{selected_option}_*_pred.png' else: pred_regex = f'ft_demo_{selected_option}_*_pred.png' model_specific_dir = pred_dir / str(selected_option) / model assert model_specific_dir.exists(), f'{model_specific_dir} does not exist' preds_matching_regex = sorted(model_specific_dir.glob(pred_regex)) assert len(preds_matching_regex) == 3, \ "Should be 3 prediction images matching regex" assert '1071' in str(preds_matching_regex[0]), 'Should be 1071' return preds_matching_regex # ----------------------------------------------------------------------------- # make_grid # ----------------------------------------------------------------------------- def make_grid(cols,rows): grid = [0]*cols for i in range(cols): with st.container(): grid[i] = st.columns(rows, gap='large') return grid # ----------------------------------------------------------------------------- # Main execution # ----------------------------------------------------------------------------- if __name__ == "__main__": main()