import streamlit as st import os import urllib import fastai.vision.all as fai_vision import numpy as np from pathlib import Path import pathlib from PIL import Image import platform import altair as alt import pandas as pd import frontmatter def main(): st.title('Fish Masker and Classifier') with open('README.md') as readme_file: readme = frontmatter.load(readme_file) st.markdown(readme.content) data_loader, segmenter = load_unet_model() classification_model = load_classification_model() st.markdown("## Instructions") st.markdown("Upload an Amazonian fish photo for masking.") uploaded_image = st.file_uploader("", IMAGE_TYPES) if uploaded_image: image_data = uploaded_image.read() st.markdown('## Original image') st.image(image_data, use_column_width=True) original_pil = Image.open(uploaded_image) original_pil.save('original.jpg') single_file = [Path('original.jpg')] single_pil = Image.open(single_file[0]) input_dl = segmenter.dls.test_dl(single_file) masks, _ = segmenter.get_preds(dl=input_dl) masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0]) st.markdown('## Masked image') st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"') st.image(masked_pil, use_column_width=True) masked_pil.save('masked.jpg') st.markdown('## Classification') prediction = classification_model.predict('masked.jpg') pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab) st.altair_chart(pred_chart, use_container_width=True) def mask_fish_pil(unmasked_fish, fastai_mask): unmasked_np = np.array(unmasked_fish) np_mask = fastai_mask.argmax(dim=0).numpy() total_pixels = np_mask.size fish_pixels = np.count_nonzero(np_mask) percentage_fish = (fish_pixels / total_pixels) * 100 np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8) np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR)) np_mask = np_mask.reshape(*np_mask.shape, 1) / 255 masked_fish_np = (unmasked_np * np_mask).astype(np.uint8) masked_fish_pil = Image.fromarray(masked_fish_np) return masked_fish_pil, percentage_fish def predictions_to_chart(prediction, classes): pred_rows = [] for i, conf in enumerate(list(prediction[2])): pred_row = {'class': classes[i], 'probability': round(float(conf) * 100,2)} pred_rows.append(pred_row) pred_df = pd.DataFrame(pred_rows) pred_df.head() top_probs = pred_df.sort_values('probability', ascending=False).head(4) chart = ( alt.Chart(top_probs) .mark_bar() .encode( x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))), y=alt.Y("class:N", sort=alt.EncodingSortField(field="probability", order="descending")) ) ) return chart @st.cache(allow_output_mutation=True) def load_unet_model(): data_loader = fai_vision.SegmentationDataLoaders.from_label_func( path = Path("."), bs = 1, fnames = [Path('test_fish.jpg')], label_func = lambda x: x, codes = np.array(["Photo", "Masks"], dtype=str), item_tfms = [fai_vision.Resize(256, method = 'squish'),], batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)], valid_pct = 0.2, num_workers = 0) segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34) segmenter.load('fish_mask_model') return data_loader, segmenter @st.cache(allow_output_mutation=True) def load_classification_model(): plt = platform.system() if plt == 'Linux' or plt == 'Darwin': pathlib.WindowsPath = pathlib.PosixPath inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True) return inf_model IMAGE_TYPES = ["png", "jpg","jpeg"] if __name__ == "__main__": main()