Dan Biagini commited on
Commit
0200926
·
1 Parent(s): ddb9a2a

add try it inference for hockey breeds

Browse files
Files changed (2) hide show
  1. requirements-cpu.txt +91 -0
  2. src/Hockey_Breeds.py +47 -5
requirements-cpu.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.4.1
2
+ annotated-types==0.7.0
3
+ attrs==24.2.0
4
+ blinker==1.8.2
5
+ blis==0.7.11
6
+ cachetools==5.5.0
7
+ catalogue==2.0.10
8
+ certifi==2024.8.30
9
+ charset-normalizer==3.3.2
10
+ click==8.1.7
11
+ cloudpathlib==0.19.0
12
+ confection==0.1.5
13
+ contourpy==1.3.0
14
+ cycler==0.12.1
15
+ cymem==2.0.8
16
+ fastai==2.7.17
17
+ fastcore==1.7.4
18
+ fastdownload==0.0.7
19
+ fastprogress==1.0.3
20
+ filelock==3.13.1
21
+ fonttools==4.53.1
22
+ fsspec==2024.2.0
23
+ gitdb==4.0.11
24
+ GitPython==3.1.43
25
+ huggingface-hub==0.24.6
26
+ idna==3.8
27
+ Jinja2==3.1.4
28
+ joblib==1.4.2
29
+ jsonschema==4.23.0
30
+ jsonschema-specifications==2023.12.1
31
+ kiwisolver==1.4.7
32
+ langcodes==3.4.0
33
+ language_data==1.2.0
34
+ marisa-trie==1.2.0
35
+ markdown-it-py==3.0.0
36
+ MarkupSafe==2.1.5
37
+ matplotlib==3.9.2
38
+ mdurl==0.1.2
39
+ mpmath==1.3.0
40
+ murmurhash==1.0.10
41
+ narwhals==1.6.0
42
+ networkx==3.2.1
43
+ numpy==1.26.4
44
+ packaging==24.1
45
+ pandas==2.2.2
46
+ pillow==10.4.0
47
+ preshed==3.0.9
48
+ protobuf==5.28.0
49
+ pyarrow==17.0.0
50
+ pydantic==2.9.0
51
+ pydantic_core==2.23.2
52
+ pydeck==0.9.1
53
+ Pygments==2.18.0
54
+ pyparsing==3.1.4
55
+ python-dateutil==2.9.0.post0
56
+ pytz==2024.1
57
+ PyYAML==6.0.2
58
+ referencing==0.35.1
59
+ requests==2.32.3
60
+ rich==13.8.0
61
+ rpds-py==0.20.0
62
+ scikit-learn==1.5.1
63
+ scipy==1.14.1
64
+ shellingham==1.5.4
65
+ six==1.16.0
66
+ smart-open==7.0.4
67
+ smmap==5.0.1
68
+ spacy==3.7.6
69
+ spacy-legacy==3.0.12
70
+ spacy-loggers==1.0.5
71
+ srsly==2.4.8
72
+ streamlit==1.38.0
73
+ streamlit-image-select==0.6.0
74
+ sympy==1.12
75
+ tenacity==8.5.0
76
+ thinc==8.2.5
77
+ threadpoolctl==3.5.0
78
+ toml==0.10.2
79
+ torch==2.4.1+cpu
80
+ torchaudio==2.4.1+cpu
81
+ torchvision==0.19.1+cpu
82
+ tornado==6.4.1
83
+ tqdm==4.66.5
84
+ typer==0.12.5
85
+ typing_extensions==4.12.2
86
+ tzdata==2024.1
87
+ urllib3==2.2.2
88
+ wasabi==1.1.3
89
+ watchdog==4.0.2
90
+ weasel==0.4.1
91
+ wrapt==1.16.0
src/Hockey_Breeds.py CHANGED
@@ -1,7 +1,35 @@
1
  import streamlit as st
2
  from streamlit_image_select import image_select
 
3
 
 
4
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  st.set_page_config(page_title='Hockey Breeds', layout="wide",
7
  page_icon=":frame_with_picture:")
@@ -22,13 +50,27 @@ desc = '''This "Hockey Breeds" model was built using 50 hockey related images fo
22
  The total training time for this was approximately 5 minutes running on a low-end GPU. It’s impressive how accurate this quick / small model can be!'''
23
 
24
  st.markdown(desc)
25
-
26
  st.subheader("Validation Results")
27
  st.markdown('Validation of the model\'s performance was done using 26 images not included in the training set. The model performed fairly well against the validation dataset, with only 1 misclassified image.')
28
- st.image("src/images/samples/confusion_matrix.png", caption="Confusion Matrix for Hockey Breeds ")
 
 
 
 
29
 
30
- st.subheader("Try it out")
31
 
32
- # unzip the sample images
33
 
34
- img = image_select(label="Select an image and hockey breeds will guess a label", images=["src/images/"])
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from streamlit_image_select import image_select
3
+ import zip_files
4
 
5
+ import random
6
  import logging
7
+ from huggingface_hub import from_pretrained_fastai
8
+
9
+ @st.cache_resource
10
+ def get_model():
11
+ repo_id = "danbiagini/hockey_breeds"
12
+ return from_pretrained_fastai(repo_id)
13
+
14
+ def classify_image(learn, img):
15
+ categories = ('Hockey Goalie', 'Hockey Player', "Hockey Referee")
16
+ pred,idx,prob = learn.predict(img)
17
+ return dict(zip(categories, map(float, prob)))
18
+
19
+ def reroll_sample_images():
20
+ # unzip the sample images
21
+ players = zip_files.extract_files_from_zip("src/images/samples/player-samples.zip")
22
+ goalies = zip_files.extract_files_from_zip("src/images/samples/goalie-samples.zip")
23
+ referees = zip_files.extract_files_from_zip("src/images/samples/referee-samples.zip")
24
+
25
+ #randomize a single file from players, goalies and referee for samples
26
+ st.session_state.sample = dict()
27
+ st.session_state.sample["player"] = players[list(players.keys())[random.randint(0, len(players) - 1)]]
28
+ st.session_state.sample["goalie"] = goalies[list(goalies.keys())[random.randint(0, len(goalies) - 1)]]
29
+ st.session_state.sample["referee"] = referees[list(referees.keys())[random.randint(0, len(referees) - 1)]]
30
+
31
+ if 'sample' not in st.session_state:
32
+ reroll_sample_images()
33
 
34
  st.set_page_config(page_title='Hockey Breeds', layout="wide",
35
  page_icon=":frame_with_picture:")
 
50
  The total training time for this was approximately 5 minutes running on a low-end GPU. It’s impressive how accurate this quick / small model can be!'''
51
 
52
  st.markdown(desc)
53
+ st.image("src/images/samples/sampl_batch.png")
54
  st.subheader("Validation Results")
55
  st.markdown('Validation of the model\'s performance was done using 26 images not included in the training set. The model performed fairly well against the validation dataset, with only 1 misclassified image.')
56
+ st.image("src/images/artifacts/confusion_matrix.png", caption="Confusion Matrix for Hockey Breeds ")
57
+
58
+ st.subheader("Try It Out")
59
+
60
+ img = image_select(label="Select an image and hockey breeds will guess a label", images=list(st.session_state.sample.values()))
61
 
62
+ st.button("Re-roll Samples", on_click=reroll_sample_images)
63
 
64
+ model = get_model()
65
 
66
+ if img:
67
+ res = classify_image(model, img)
68
+ # Sort the dictionary items by value in descending order
69
+ max = 0
70
+ max_label = ""
71
+ for k,v in res.items():
72
+ prob = round(v*100, 2)
73
+ if prob > max:
74
+ max = prob
75
+ max_label = k
76
+ st.metric(label=max_label, value=max)