import os import shutil import time import uuid from datetime import datetime from decimal import Decimal import gradio as gr import matplotlib.pyplot as plt from settings import DEMO plt.switch_backend("agg") # fix for "RuntimeError: main thread is not in main loop" import numpy as np import pandas as pd from PIL import Image from model import GranaAnalyser ga = GranaAnalyser( "weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt", "weights/AS_square_v16.ckpt", "weights/period_measurer_weights-1.298_real_full-fa12970.ckpt", ) def calc_ratio(pixels, nano): """ Calculates ratio of pixels to nanometers and returns as str to populate ratio_input :param pixels: :param nano: :return: """ if not (pixels and nano): pass else: res = pixels / nano return res # https://jakevdp.github.io/PythonDataScienceHandbook/05.13-kernel-density-estimation.html def KDE(dataset, h): # the Kernel function def K(x): return np.exp(-(x ** 2) / 2) / np.sqrt(2 * np.pi) n_samples = dataset.size x_range = dataset # x-value range for plotting KDEs total_sum = 0 # iterate over datapoints for i, xi in enumerate(dataset): total_sum += K((x_range - xi) / h) y_range = total_sum / (h * n_samples) return y_range def prepare_files_for_download( dir_name, grana_data, aggregated_data, detection_visualizations_dict, images_grana_dict, ): """ Save and zip files for download :param dir_name: :param grana_data: DataFrame containing all grana measurements :param aggregated_data: dict containing aggregated measurements :return: """ dir_to_zip = f"{dir_name}/to_zip" # raw data grana_data_csv_path = f"{dir_to_zip}/grana_raw_data.csv" grana_data.to_csv(grana_data_csv_path, index=False) # aggregated measurements aggregated_csv_path = f"{dir_to_zip}/grana_aggregated_data.csv" aggregated_data.to_csv(aggregated_csv_path) # annotated pictures masked_images_dir = f"{dir_to_zip}/annotated_images" os.makedirs(masked_images_dir) for img_name, img in detection_visualizations_dict.items(): filename_split = img_name.split(".") extension = filename_split[-1] filename = ".".join(filename_split[:-1]) filename = f"{filename}_annotated.{extension}" img.save(f"{masked_images_dir}/{filename}") # single_grana images grana_images_dir = f"{dir_to_zip}/single_grana_images" os.makedirs(grana_images_dir) org_images_dict = pd.Series( grana_data["source image"].values, index=grana_data["granum ID"] ).to_dict() for img_name, img in images_grana_dict.items(): org_filename = org_images_dict[img_name] org_filename_split = org_filename.split(".") org_filename_no_ext = ".".join(org_filename_split[:-1]) img_name_ext = f"{org_filename_no_ext}_granum_{str(img_name)}.png" img.save(f"{grana_images_dir}/{img_name_ext}") # zip all files date_str = datetime.today().strftime("%Y-%m-%d") zip_name = f"GRANA_results_{date_str}" zip_path = f"{dir_name}/{zip_name}" shutil.make_archive(zip_path, "zip", dir_to_zip) # delete to_zip dir zip_dir_path = os.path.join(os.getcwd(), dir_to_zip) shutil.rmtree(zip_dir_path) download_file_path = f"{zip_path}.zip" return download_file_path def show_info_on_submit(s): return ( gr.Button(interactive=False), gr.Button(interactive=False), gr.Row(visible=True), gr.Row(visible=False), ) def load_css(): with open("styles.css", "r") as f: css_content = f.read() return css_content primary_hue = gr.themes.Color( c50="#e1f8ee", c100="#b7efd5", c200="#8de6bd", c300="#63dda5", c400="#39d48d", c500="#27b373", c600="#1e8958", c700="#155f3d", c800="#0c3522", c900="#030b07", c950="#000", ) theme = gr.themes.Default( primary_hue=primary_hue, font=[gr.themes.GoogleFont("Ubuntu"), "ui-sans-serif", "system-ui", "sans-serif"], ) def draw_violin_plot(y, ylabel, title): # only generate plot for 3 or more values if y.count() < 3: return None # Colors RED_DARK = "#850e00" DARK_GREEN = "#0c3522" BRIGHT_GREEN = "#8de6bd" # Create jittered version of "x" (which is only 1) x_jittered = [] kde = KDE(y, (y.max() - y.min()) / y.size / 2) kde = kde / kde.max() * 0.2 for y_val in kde: x_jittered.append(1 + np.random.uniform(-y_val, y_val, 1)) fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.scatter(x=x_jittered, y=y, s=20, alpha=0.4, c=DARK_GREEN) violins = ax.violinplot( y, widths=0.45, bw_method="silverman", showmeans=False, showmedians=False, showextrema=False, ) # change violin color for pc in violins["bodies"]: pc.set_facecolor(BRIGHT_GREEN) # add a boxplot to ax # but make the whiskers length equal to 1 SD, i.e. in the proportion of the IQ range, but this length should start from the mean but be visible from the box boundary lower = np.mean(y) - 1 * np.std(y) upper = np.mean(y) + 1 * np.std(y) medianprops = dict(linewidth=1, color="black", solid_capstyle="butt") boxplot_stats = [ { "med": np.median(y), "q1": np.percentile(y, 25), "q3": np.percentile(y, 75), "whislo": lower, "whishi": upper, } ] ax.bxp( boxplot_stats, # data for the boxplot showfliers=False, # do not show the outliers beyond the caps. showcaps=True, # show the caps medianprops=medianprops, ) # Add mean value point ax.scatter(1, y.mean(), s=30, color=RED_DARK, zorder=3) ax.set_xticks([]) ax.set_ylabel(ylabel) ax.set_title(title) fig.tight_layout() return fig def transform_aggregated_results_table(results_dict): MEASUREMENT_HEADER = "measurement [unit]" VALUE_HEADER = "value +-SD" def get_value_str(value, std): if np.isnan(value) or np.isnan(std): return "-" value_str = str(Decimal(str(value)).quantize(Decimal("0.01"))) std_str = str(Decimal(str(std)).quantize(Decimal("0.01"))) return f"{value_str} +-{std_str}" def append_to_dict(new_key, old_val_key, old_sd_key): aggregated_dict[MEASUREMENT_HEADER].append(new_key) value_str = get_value_str(results_dict[old_val_key], results_dict[old_sd_key]) aggregated_dict[VALUE_HEADER].append(value_str) aggregated_dict = {MEASUREMENT_HEADER: [], VALUE_HEADER: []} # area append_to_dict("area [nm^2]", "area nm^2", "area nm^2 std") # perimeter append_to_dict("perimeter [nm]", "perimeter nm", "perimeter nm std") # diameter append_to_dict("diameter [nm]", "diameter nm", "diameter nm std") # height append_to_dict("height [nm]", "height nm", "height nm std") # number of layers append_to_dict("number of thylakoids", "Number of layers", "Number of layers std") # SRD append_to_dict("SRD [nm]", "period nm", "period nm std") # GSI append_to_dict("GSI", "GSI", "GSI std") # N grana aggregated_dict[MEASUREMENT_HEADER].append("number of grana") aggregated_dict[VALUE_HEADER].append(str(int(results_dict["N grana"]))) return aggregated_dict def rename_columns_in_results_table(results_table): column_names = { "Granum ID": "granum ID", "File name": "source image", "area nm^2": "area [nm^2]", "perimeter nm": "perimeter [nm]", "diameter nm": "diameter [nm]", "height nm": "height [nm]", "Number of layers": "number of thylakoids", "period nm": "SRD [nm]", "period SD nm": "SRD SD [nm]", } results_table = results_table.rename(columns=column_names) return results_table with gr.Blocks(css=load_css(), theme=theme) as demo: svg = """ """ gr.HTML( f'
GRANA
' ) with gr.Row(elem_classes="input-row"): # input with gr.Column(): gr.HTML( "

1. Choose images to upload. All the images need to be of the same scale and experimental variant.

" ) img_input = gr.File(file_count="multiple") gr.HTML("

2. Set the scale of the images for the measurements.

") with gr.Row(): with gr.Column(): gr.HTML("Either provide pixel per nanometer ratio...") ratio_input = gr.Number( label="pixel per nm", precision=3, step=0.001 ) with gr.Column(): gr.HTML("...or length of the scale bar in pixels and nanometers.") pixels_input = gr.Number(label="Length in pixels") nano_input = gr.Number(label="Length in nanometers") pixels_input.change( calc_ratio, inputs=[pixels_input, nano_input], outputs=ratio_input, ) nano_input.change( calc_ratio, inputs=[pixels_input, nano_input], outputs=ratio_input, ) with gr.Row(): clear_btn = gr.ClearButton(img_input, "Clear") submit_btn = gr.Button("Submit", variant="primary") with gr.Row(visible=False) as loading_row: with gr.Column(): gr.HTML( "
Images are being processed. This may take a while...
" ) with gr.Row(visible=False) as output_row: with gr.Column(): gr.HTML( '
Results
' "

Full results are a zip file containing:

" "

" "" '' '' "

Note that GRANA only stores the result files for 1 hour.

", elem_classes="input-row", ) with gr.Row(elem_classes="input-row"): download_file_out = gr.DownloadButton( label="Download results", variant="primary", elem_classes="margin-bottom", ) with gr.Row(): gr.HTML( '

Annotated images

' "Gallery of uploaded images with masks of recognized grana structures. " "Each granum mask is " "labeled with its number. Note that only fully visible grana in the image are masked." ) with gr.Row(elem_classes="margin-bottom"): gallery_out = gr.Gallery( columns=4, rows=2, object_fit="contain", label="Detection visualizations", show_download_button=False, ) with gr.Row(elem_classes="input-row"): gr.HTML( '

Aggregated results for all uploaded images

' ) with gr.Row(elem_classes=["input-row", "margin-bottom"]): table_out = gr.Dataframe(label="Aggregated data") with gr.Row(): gr.HTML( '

Violin graphs

' "These graphs present aggregated results for selected structural parameters. " "The graph for each parameter is only generated if three or more values are available. " "Each graph " "displays individual data points, a box plot indicating the first and third quartiles, whiskers " "marking the standard deviation (SD), the median value (horizontal line on the box plot), " "the mean value (red dot), and a density plot where the width represents the frequency." ) with gr.Row(): area_plot_out = gr.Plot(label="Area") perimeter_plot_out = gr.Plot(label="Perimeter") gsi_plot_out = gr.Plot(label="GSI") with gr.Row(elem_classes="margin-bottom"): diameter_plot_out = gr.Plot(label="Diameter") height_plot_out = gr.Plot(label="Height") srd_plot_out = gr.Plot(label="SRD") with gr.Row(): gr.HTML( '

Recognized and rotated grana structures

' ) with gr.Row(elem_classes="margin-bottom"): gallery_single_grana_out = gr.Gallery( columns=4, rows=2, object_fit="contain", label="Single grana images", show_download_button=False, ) with gr.Row(): gr.HTML( '

Full results

' "Note that structural parameters other than area and perimeter are only calculated for the grana " "whose direction and/or SRD could be estimated." ) with gr.Row(): table_full_out = gr.Dataframe(label="Full measurements data") submit_btn.click( show_info_on_submit, inputs=[submit_btn], outputs=[submit_btn, clear_btn, loading_row, output_row], ) def enable_submit(): return ( gr.Button(interactive=True), gr.Button(interactive=True), gr.Row(visible=False), ) def gradio_analize_image(images, scale): """ Model accepts following parameters: :param images: list of images to be processed, in either tiff or png format :param scale: float, nm to pixel ratio Model returns the following objects: - detection_visualizations: list of images with masks to be displayed as gallery and served to download as zip of images - grana_data: dataframe with measurements for each image to be served to download as a csv file - images_grana: list of images with single grana to be served to download as zip of images - aggregated_data: dataframe with aggregated measurements for all images to be displayed as table and served to download as csv """ # validate that at least one image has been uploaded if images is None or len(images) == 0: raise gr.Error("Please upload at least one image") # on demo instance, we limit the number of images to 5 if DEMO: if len(images) > 5: raise gr.Error("In demo version it is possible to analyze up to 5 images.") # validate that scale has been provided correctly if scale is None or scale == 0: raise gr.Error("Please provide scale. Use dot as decimal separator") # validate that all images are png or tiff for image in images: if not image.name.lower().endswith((".png", ".tif", ".jpg", ".jpeg")): raise gr.Error("Only png, tiff, jpg ang jpeg images are supported") # clean up previous results # find all directories in current working directory that start with "results_" # that were created more than 1 hour ago and delete them with all contents for directory_name in os.listdir(): if directory_name.startswith("results_"): dir_path = os.path.join(os.getcwd(), directory_name) if os.path.isdir(dir_path): if time.time() - os.path.getctime(dir_path) > 60 * 60: shutil.rmtree(dir_path) # create a directory for results results_dir_name = "results_{uuid}".format(uuid=uuid.uuid4().hex) os.makedirs(results_dir_name) zip_dir_name = f"{results_dir_name}/to_zip" os.makedirs(zip_dir_name) # model takes a dict of images, so we need to convert input to list of PIL.PngImagePlugin.PngImageFile or # PIL.TiffImagePlugin.TiffImageFile objects images_dict = { image.name.split("/")[-1]: Image.open(image.name) for i, image in enumerate(images) } # model works here ( detection_visualizations_dict, grana_data, images_grana_dict, aggregated_data, ) = ga.predict(images_dict, scale) detection_visualizations = list(detection_visualizations_dict.values()) images_grana = list(images_grana_dict.values()) # rearrange aggregated data to be displayed as table aggregated_dict = transform_aggregated_results_table(aggregated_data) aggregated_df_transposed = pd.DataFrame.from_dict(aggregated_dict) # rename columns in full results grana_data = rename_columns_in_results_table(grana_data) # save files returned by model to disk so they can be retrieved for downloading download_file_path = prepare_files_for_download( results_dir_name, grana_data, aggregated_df_transposed, detection_visualizations_dict, images_grana_dict, ) # generate plot area_fig = draw_violin_plot( grana_data["area [nm^2]"].dropna(), "Granum area [nm^2]", "Grana areas from all uploaded images", ) perimeter_fig = draw_violin_plot( grana_data["perimeter [nm]"].dropna(), "Granum perimeter [nm]", "Grana perimeters from all uploaded images", ) gsi_fig = draw_violin_plot( grana_data["GSI"].dropna(), "GSI", "GSI from all uploaded images", ) diameter_fig = draw_violin_plot( grana_data["diameter [nm]"].dropna(), "Granum diameter [nm]", "Grana diameters from all uploaded images", ) height_fig = draw_violin_plot( grana_data["height [nm]"].dropna(), "Granum height [nm]", "Grana heights from all uploaded images", ) srd_fig = draw_violin_plot( grana_data["SRD [nm]"].dropna(), "SRD [nm]", "SRD from all uploaded images" ) return [ gr.Row(visible=True), gr.Row(visible=True), download_file_path, detection_visualizations, aggregated_df_transposed, area_fig, perimeter_fig, gsi_fig, diameter_fig, height_fig, srd_fig, images_grana, grana_data, ] submit_btn.click( fn=gradio_analize_image, inputs=[ img_input, ratio_input, ], outputs=[ loading_row, output_row, # file_download_checkboxes, download_file_out, gallery_out, table_out, area_plot_out, perimeter_plot_out, gsi_plot_out, diameter_plot_out, height_plot_out, srd_plot_out, gallery_single_grana_out, table_full_out, ], ).then(fn=enable_submit, inputs=[], outputs=[submit_btn, clear_btn, loading_row]) demo.launch( share=False, debug=True, server_name="0.0.0.0", allowed_paths=["images/logo.svg"] )