MyoQuant / pages /4_ATP_Staining_Analysis.py
lambda scientist
update
07cafd6
import streamlit as st
from streamlit.components.v1 import html
import matplotlib
try:
from imageio.v2 import imread
except:
from imageio import imread
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from myoquant.src.common_func import (
load_cellpose,
run_cellpose,
is_gpu_availiable,
df_from_cellpose_mask,
)
from myoquant.src.ATP_analysis import (
get_all_intensity,
estimate_threshold,
plot_density,
predict_all_cells,
paint_full_image,
)
labels_predict = {1: "fiber type 1", 2: "fiber type 2"}
use_GPU = is_gpu_availiable()
np.random.seed(42)
st.set_page_config(
page_title="MyoQuant ATP Analysis",
page_icon="🔬",
)
@st.cache_resource
def st_load_cellpose():
return load_cellpose()
@st.cache_data
def st_run_cellpose(image_atp, _model):
return run_cellpose(image_atp, _model)
@st.cache_data
def st_df_from_cellpose_mask(mask):
return df_from_cellpose_mask(mask)
@st.cache_data
def st_get_all_intensity(image_atp, df_cellpose):
return get_all_intensity(image_atp, df_cellpose)
@st.cache_data
def st_estimate_threshold(intensity_list):
return estimate_threshold(intensity_list)
@st.cache_data
def st_plot_density(all_cell_median_intensity, intensity_threshold):
return plot_density(all_cell_median_intensity, intensity_threshold)
@st.cache_data
def st_predict_all_cells(image_atp, cellpose_df, intensity_threshold):
return predict_all_cells(image_atp, cellpose_df, intensity_threshold)
@st.cache_data
def st_paint_full_image(image_atp, df_cellpose, class_predicted_all):
return paint_full_image(image_atp, df_cellpose, class_predicted_all)
model_cellpose = st_load_cellpose()
with st.sidebar:
st.write("Threshold Parameters")
intensity_threshold = st.slider("Intensity Threshold (0=auto)", 0, 255, 0, 5)
st.title("ATP Staining Analysis")
st.write(
"This demo will automatically quantify the number of type 1 muscle fibers vs the number of type 2 muscle fiber on ATP stained images."
)
st.write("Upload your ATP Staining image")
uploaded_file_atp = st.file_uploader("Choose a file")
if uploaded_file_atp is not None:
image_ndarray_atp = imread(uploaded_file_atp)
st.write("Raw Image")
image = st.image(uploaded_file_atp)
mask_cellpose = st_run_cellpose(image_ndarray_atp, model_cellpose)
st.header("Segmentation Results")
st.subheader("CellPose results")
fig, ax = plt.subplots(1, 1)
ax.imshow(mask_cellpose, cmap="viridis")
ax.axis("off")
st.pyplot(fig)
df_cellpose = st_df_from_cellpose_mask(mask_cellpose)
st.header("Cell Intensity Plot")
all_cell_median_intensity = st_get_all_intensity(image_ndarray_atp, df_cellpose)
figure_intensity = st_plot_density(all_cell_median_intensity, intensity_threshold)
st.pyplot(figure_intensity)
st.header("ATP Cell Classification Results")
if intensity_threshold == 0:
muscle_fiber_type_all, all_cell_median_intensity = st_predict_all_cells(
image_ndarray_atp, df_cellpose, intensity_threshold=None
)
else:
muscle_fiber_type_all, all_cell_median_intensity = st_predict_all_cells(
image_ndarray_atp, df_cellpose, intensity_threshold=intensity_threshold
)
df_cellpose["muscle_fiber_type"] = muscle_fiber_type_all
df_cellpose["median_intensity"] = all_cell_median_intensity
count_per_label = np.unique(muscle_fiber_type_all, return_counts=True)
st.dataframe(
df_cellpose.drop(
[
"centroid-0",
"centroid-1",
"bbox-0",
"bbox-1",
"bbox-2",
"bbox-3",
"image",
],
axis=1,
)
)
st.write("Total number of cells detected: ", len(muscle_fiber_type_all))
for index, elem in enumerate(count_per_label[0]):
st.write(
"Number of cells classified as ",
labels_predict[int(elem)],
": ",
count_per_label[1][int(index)],
" ",
100 * count_per_label[1][int(index)] / len(muscle_fiber_type_all),
"%",
)
st.header("Painted predicted image")
st.write(
"Green color indicates cells classified as control, red color indicates cells classified as sick"
)
paint_img = st_paint_full_image(
image_ndarray_atp, df_cellpose, muscle_fiber_type_all
)
fig3, ax3 = plt.subplots(1, 1)
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"", ["white", "green", "red"]
)
ax3.imshow(image_ndarray_atp)
ax3.imshow(paint_img, cmap=cmap, alpha=0.5)
ax3.axis("off")
st.pyplot(fig3)
html(
f"""
<script defer data-domain="lbgi.fr/myoquant" src="https://plausible.cmeyer.fr/js/script.js"></script>
"""
)