Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import urllib | |
import fastai.vision.all as fai_vision | |
import numpy as np | |
from pathlib import Path | |
import pathlib | |
from PIL import Image | |
import platform | |
import altair as alt | |
import pandas as pd | |
import frontmatter | |
def main(): | |
st.title('Fish Masker and Classifier') | |
with open('README.md') as readme_file: | |
readme = frontmatter.load(readme_file) | |
st.markdown(readme.content) | |
data_loader, segmenter = load_unet_model() | |
classification_model = load_classification_model() | |
st.markdown("## Instructions") | |
st.markdown("Upload an Amazonian fish photo for masking.") | |
uploaded_image = st.file_uploader("", IMAGE_TYPES) | |
if uploaded_image: | |
image_data = uploaded_image.read() | |
st.markdown('## Original image') | |
st.image(image_data, use_column_width=True) | |
original_pil = Image.open(uploaded_image) | |
original_pil.save('original.jpg') | |
single_file = [Path('original.jpg')] | |
single_pil = Image.open(single_file[0]) | |
input_dl = segmenter.dls.test_dl(single_file) | |
masks, _ = segmenter.get_preds(dl=input_dl) | |
masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0]) | |
st.markdown('## Masked image') | |
st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"') | |
st.image(masked_pil, use_column_width=True) | |
masked_pil.save('masked.jpg') | |
st.markdown('## Classification') | |
prediction = classification_model.predict('masked.jpg') | |
pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab) | |
st.altair_chart(pred_chart, use_container_width=True) | |
def mask_fish_pil(unmasked_fish, fastai_mask): | |
unmasked_np = np.array(unmasked_fish) | |
np_mask = fastai_mask.argmax(dim=0).numpy() | |
total_pixels = np_mask.size | |
fish_pixels = np.count_nonzero(np_mask) | |
percentage_fish = (fish_pixels / total_pixels) * 100 | |
np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8) | |
np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR)) | |
np_mask = np_mask.reshape(*np_mask.shape, 1) / 255 | |
masked_fish_np = (unmasked_np * np_mask).astype(np.uint8) | |
masked_fish_pil = Image.fromarray(masked_fish_np) | |
return masked_fish_pil, percentage_fish | |
def predictions_to_chart(prediction, classes): | |
pred_rows = [] | |
for i, conf in enumerate(list(prediction[2])): | |
pred_row = {'class': classes[i], | |
'probability': round(float(conf) * 100,2)} | |
pred_rows.append(pred_row) | |
pred_df = pd.DataFrame(pred_rows) | |
pred_df.head() | |
top_probs = pred_df.sort_values('probability', ascending=False).head(4) | |
chart = ( | |
alt.Chart(top_probs) | |
.mark_bar() | |
.encode( | |
x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))), | |
y=alt.Y("class:N", | |
sort=alt.EncodingSortField(field="probability", order="descending")) | |
) | |
) | |
return chart | |
def load_unet_model(): | |
data_loader = fai_vision.SegmentationDataLoaders.from_label_func( | |
path = Path("."), | |
bs = 1, | |
fnames = [Path('test_fish.jpg')], | |
label_func = lambda x: x, | |
codes = np.array(["Photo", "Masks"], dtype=str), | |
item_tfms = [fai_vision.Resize(256, method = 'squish'),], | |
batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)], | |
valid_pct = 0.2, num_workers = 0) | |
segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34) | |
segmenter.load('fish_mask_model') | |
return data_loader, segmenter | |
def load_classification_model(): | |
plt = platform.system() | |
if plt == 'Linux' or plt == 'Darwin': | |
pathlib.WindowsPath = pathlib.PosixPath | |
inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True) | |
return inf_model | |
IMAGE_TYPES = ["png", "jpg","jpeg"] | |
if __name__ == "__main__": | |
main() |