Spaces:
Runtime error
Runtime error
MikeTrizna
commited on
Commit
โข
4c907a8
1
Parent(s):
b201d17
Initial commit of app. Directly copied from miketrizna/amazonian_fish_classifier
Browse files- .streamlit/config.toml +2 -0
- README.md +6 -2
- app.py +116 -0
- models/fish_classification_model.pkl +3 -0
- models/fish_mask_model.pth +3 -0
- requirements.txt +6 -0
- test_fish.jpg +0 -0
.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[client]
|
2 |
+
showErrorDetails = false
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Amazonian Fish Classifier
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
colorTo: pink
|
6 |
sdk: streamlit
|
@@ -10,4 +10,8 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Amazonian Fish Classifier
|
3 |
+
emoji: ๐
|
4 |
colorFrom: green
|
5 |
colorTo: pink
|
6 |
sdk: streamlit
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
This is a demonstration app of the two machine learning models described in the paper:
|
14 |
+
|
15 |
+
> Robillard, A., Trizna, M. G., Ruiz-Tafur, K., Panduro, E. D., de Santana, C. D., White, A. E., Dikow, R. B., Deichmann, J. 2023. Application of a Deep Learning Image Classifier for Identification of Amazonian Fishes. *Ecology and Evolution* [https://doi.org/10.1002/ece3.9987](https://doi.org/10.1002/ece3.9987)
|
16 |
+
|
17 |
+
The models weights and image data are available on FigShare at [https://doi.org/10.25573/data.c.5761097.v1](https://doi.org/10.25573/data.c.5761097.v1)
|
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import fastai.vision.all as fai_vision
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
import pathlib
|
8 |
+
from PIL import Image
|
9 |
+
import platform
|
10 |
+
import altair as alt
|
11 |
+
import pandas as pd
|
12 |
+
import frontmatter
|
13 |
+
|
14 |
+
def main():
|
15 |
+
st.title('Fish Masker and Classifier')
|
16 |
+
|
17 |
+
with open('README.md') as readme_file:
|
18 |
+
readme = frontmatter.load(readme_file)
|
19 |
+
st.markdown(readme.content)
|
20 |
+
|
21 |
+
data_loader, segmenter = load_unet_model()
|
22 |
+
classification_model = load_classification_model()
|
23 |
+
|
24 |
+
st.markdown("## Instructions")
|
25 |
+
st.markdown("Upload an Amazonian fish photo for masking.")
|
26 |
+
uploaded_image = st.file_uploader("", IMAGE_TYPES)
|
27 |
+
if uploaded_image:
|
28 |
+
image_data = uploaded_image.read()
|
29 |
+
st.markdown('## Original image')
|
30 |
+
st.image(image_data, use_column_width=True)
|
31 |
+
|
32 |
+
original_pil = Image.open(uploaded_image)
|
33 |
+
|
34 |
+
original_pil.save('original.jpg')
|
35 |
+
|
36 |
+
single_file = [Path('original.jpg')]
|
37 |
+
single_pil = Image.open(single_file[0])
|
38 |
+
input_dl = segmenter.dls.test_dl(single_file)
|
39 |
+
masks, _ = segmenter.get_preds(dl=input_dl)
|
40 |
+
masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0])
|
41 |
+
|
42 |
+
st.markdown('## Masked image')
|
43 |
+
st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"')
|
44 |
+
st.image(masked_pil, use_column_width=True)
|
45 |
+
|
46 |
+
masked_pil.save('masked.jpg')
|
47 |
+
|
48 |
+
st.markdown('## Classification')
|
49 |
+
|
50 |
+
prediction = classification_model.predict('masked.jpg')
|
51 |
+
pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab)
|
52 |
+
st.altair_chart(pred_chart, use_container_width=True)
|
53 |
+
|
54 |
+
|
55 |
+
def mask_fish_pil(unmasked_fish, fastai_mask):
|
56 |
+
unmasked_np = np.array(unmasked_fish)
|
57 |
+
np_mask = fastai_mask.argmax(dim=0).numpy()
|
58 |
+
total_pixels = np_mask.size
|
59 |
+
fish_pixels = np.count_nonzero(np_mask)
|
60 |
+
percentage_fish = (fish_pixels / total_pixels) * 100
|
61 |
+
np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8)
|
62 |
+
np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR))
|
63 |
+
np_mask = np_mask.reshape(*np_mask.shape, 1) / 255
|
64 |
+
masked_fish_np = (unmasked_np * np_mask).astype(np.uint8)
|
65 |
+
masked_fish_pil = Image.fromarray(masked_fish_np)
|
66 |
+
return masked_fish_pil, percentage_fish
|
67 |
+
|
68 |
+
def predictions_to_chart(prediction, classes):
|
69 |
+
pred_rows = []
|
70 |
+
for i, conf in enumerate(list(prediction[2])):
|
71 |
+
pred_row = {'class': classes[i],
|
72 |
+
'probability': round(float(conf) * 100,2)}
|
73 |
+
pred_rows.append(pred_row)
|
74 |
+
pred_df = pd.DataFrame(pred_rows)
|
75 |
+
pred_df.head()
|
76 |
+
top_probs = pred_df.sort_values('probability', ascending=False).head(4)
|
77 |
+
chart = (
|
78 |
+
alt.Chart(top_probs)
|
79 |
+
.mark_bar()
|
80 |
+
.encode(
|
81 |
+
x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))),
|
82 |
+
y=alt.Y("class:N",
|
83 |
+
sort=alt.EncodingSortField(field="probability", order="descending"))
|
84 |
+
)
|
85 |
+
)
|
86 |
+
return chart
|
87 |
+
|
88 |
+
@st.cache(allow_output_mutation=True)
|
89 |
+
def load_unet_model():
|
90 |
+
data_loader = fai_vision.SegmentationDataLoaders.from_label_func(
|
91 |
+
path = Path("."),
|
92 |
+
bs = 1,
|
93 |
+
fnames = [Path('test_fish.jpg')],
|
94 |
+
label_func = lambda x: x,
|
95 |
+
codes = np.array(["Photo", "Masks"], dtype=str),
|
96 |
+
item_tfms = [fai_vision.Resize(256, method = 'squish'),],
|
97 |
+
batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)],
|
98 |
+
valid_pct = 0.2, num_workers = 0)
|
99 |
+
segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34)
|
100 |
+
segmenter.load('fish_mask_model')
|
101 |
+
return data_loader, segmenter
|
102 |
+
|
103 |
+
@st.cache(allow_output_mutation=True)
|
104 |
+
def load_classification_model():
|
105 |
+
plt = platform.system()
|
106 |
+
|
107 |
+
if plt == 'Linux' or plt == 'Darwin':
|
108 |
+
pathlib.WindowsPath = pathlib.PosixPath
|
109 |
+
inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True)
|
110 |
+
|
111 |
+
return inf_model
|
112 |
+
|
113 |
+
IMAGE_TYPES = ["png", "jpg","jpeg"]
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
main()
|
models/fish_classification_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ac16550590dd60da201ce13e2f1b057d5343ef490db8663c463f8bbefef610e
|
3 |
+
size 179319095
|
models/fish_mask_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29b8afc516eb9f19e99dc53e924839a7157ac241d13f0945aec4717574c7908a
|
3 |
+
size 494929527
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==0.89
|
2 |
+
fastai==2.2
|
3 |
+
protobuf==3.20
|
4 |
+
altair
|
5 |
+
pandas
|
6 |
+
frontmatter
|
test_fish.jpg
ADDED