import streamlit as st from streamlit.components.v1 import html import matplotlib try: from imageio.v2 import imread except: from imageio import imread import matplotlib.pyplot as plt import pandas as pd import numpy as np from myoquant.src.common_func import ( load_cellpose, run_cellpose, is_gpu_availiable, df_from_cellpose_mask, ) from myoquant.src.ATP_analysis import ( get_all_intensity, estimate_threshold, plot_density, predict_all_cells, paint_full_image, ) labels_predict = {1: "fiber type 1", 2: "fiber type 2"} use_GPU = is_gpu_availiable() np.random.seed(42) st.set_page_config( page_title="MyoQuant ATP Analysis", page_icon="🔬", ) @st.cache_resource def st_load_cellpose(): return load_cellpose() @st.cache_data def st_run_cellpose(image_atp, _model): return run_cellpose(image_atp, _model) @st.cache_data def st_df_from_cellpose_mask(mask): return df_from_cellpose_mask(mask) @st.cache_data def st_get_all_intensity(image_atp, df_cellpose): return get_all_intensity(image_atp, df_cellpose) @st.cache_data def st_estimate_threshold(intensity_list): return estimate_threshold(intensity_list) @st.cache_data def st_plot_density(all_cell_median_intensity, intensity_threshold): return plot_density(all_cell_median_intensity, intensity_threshold) @st.cache_data def st_predict_all_cells(image_atp, cellpose_df, intensity_threshold): return predict_all_cells(image_atp, cellpose_df, intensity_threshold) @st.cache_data def st_paint_full_image(image_atp, df_cellpose, class_predicted_all): return paint_full_image(image_atp, df_cellpose, class_predicted_all) model_cellpose = st_load_cellpose() with st.sidebar: st.write("Threshold Parameters") intensity_threshold = st.slider("Intensity Threshold (0=auto)", 0, 255, 0, 5) st.title("ATP Staining Analysis") st.write( "This demo will automatically quantify the number of type 1 muscle fibers vs the number of type 2 muscle fiber on ATP stained images." ) st.write("Upload your ATP Staining image") uploaded_file_atp = st.file_uploader("Choose a file") if uploaded_file_atp is not None: image_ndarray_atp = imread(uploaded_file_atp) st.write("Raw Image") image = st.image(uploaded_file_atp) mask_cellpose = st_run_cellpose(image_ndarray_atp, model_cellpose) st.header("Segmentation Results") st.subheader("CellPose results") fig, ax = plt.subplots(1, 1) ax.imshow(mask_cellpose, cmap="viridis") ax.axis("off") st.pyplot(fig) df_cellpose = st_df_from_cellpose_mask(mask_cellpose) st.header("Cell Intensity Plot") all_cell_median_intensity = st_get_all_intensity(image_ndarray_atp, df_cellpose) figure_intensity = st_plot_density(all_cell_median_intensity, intensity_threshold) st.pyplot(figure_intensity) st.header("ATP Cell Classification Results") if intensity_threshold == 0: muscle_fiber_type_all, all_cell_median_intensity = st_predict_all_cells( image_ndarray_atp, df_cellpose, intensity_threshold=None ) else: muscle_fiber_type_all, all_cell_median_intensity = st_predict_all_cells( image_ndarray_atp, df_cellpose, intensity_threshold=intensity_threshold ) df_cellpose["muscle_fiber_type"] = muscle_fiber_type_all df_cellpose["median_intensity"] = all_cell_median_intensity count_per_label = np.unique(muscle_fiber_type_all, return_counts=True) st.dataframe( df_cellpose.drop( [ "centroid-0", "centroid-1", "bbox-0", "bbox-1", "bbox-2", "bbox-3", "image", ], axis=1, ) ) st.write("Total number of cells detected: ", len(muscle_fiber_type_all)) for index, elem in enumerate(count_per_label[0]): st.write( "Number of cells classified as ", labels_predict[int(elem)], ": ", count_per_label[1][int(index)], " ", 100 * count_per_label[1][int(index)] / len(muscle_fiber_type_all), "%", ) st.header("Painted predicted image") st.write( "Green color indicates cells classified as control, red color indicates cells classified as sick" ) paint_img = st_paint_full_image( image_ndarray_atp, df_cellpose, muscle_fiber_type_all ) fig3, ax3 = plt.subplots(1, 1) cmap = matplotlib.colors.LinearSegmentedColormap.from_list( "", ["white", "green", "red"] ) ax3.imshow(image_ndarray_atp) ax3.imshow(paint_img, cmap=cmap, alpha=0.5) ax3.axis("off") st.pyplot(fig3) html( f""" """ )