Raphaël Bournhonesque commited on
Commit
4123d5a
·
1 Parent(s): a332564

update demo

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +12 -45
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 👀
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: streamlit
7
- sdk_version: 1.17.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: streamlit
7
+ sdk_version: 1.25.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,35 +1,13 @@
1
  import copy
2
- import enum
3
  import pandas as pd
4
- from typing import List, Optional
5
 
6
  import requests
7
  import streamlit as st
8
 
9
 
10
- MODEL_DESCRIPTION = """
11
- **keras_2_0**: Current production model\n
12
- **keras_image_embeddings_3_0**: same as `keras_300_epochs_3_0` but with image embedding as input\n
13
- **keras_300_epochs_3_0**: trained on 300 epochs with product name, ingredients, OCR-extracted ingredients and nutriments as input\n
14
- **keras_ingredient_ocr_3_0**: same as `keras_sota_3_0`, but trained on less epochs\n
15
- **keras_baseline_3_0**: model trained with product name, ingredients and nutriments as input\n
16
- **keras_original_3_0**: same inputs as **keras_2_0** (product name + ingredients), but retrained\n
17
- **keras_product_name_only_3_0**: model with only product name as input
18
- """
19
-
20
  http_session = requests.Session()
21
 
22
- @enum.unique
23
- class NeuralCategoryClassifierModel(enum.Enum):
24
- keras_2_0 = "keras-2.0"
25
- keras_image_embeddings_3_0 = "keras-image-embeddings-3-0"
26
- keras_300_epochs_3_0 = "keras-300-epochs-3-0"
27
- keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0"
28
- keras_baseline_3_0 = "keras-baseline-3.0"
29
- keras_original_3_0 = "keras-original-3.0"
30
- keras_product_name_only_3_0 = "keras-product-name-only-3.0"
31
-
32
-
33
  LOCAL_DB = False
34
 
35
  if LOCAL_DB:
@@ -40,9 +18,9 @@ else:
40
  PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category"
41
 
42
 
43
- @st.cache
44
- def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None):
45
- data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name}
46
  if threshold is not None:
47
  data["threshold"] = threshold
48
 
@@ -52,22 +30,19 @@ def get_predictions(barcode: str, model_name: str, threshold: Optional[float] =
52
 
53
  def display_predictions(
54
  barcode: str,
55
- model_names: List[str],
56
  threshold: Optional[float] = None,
57
  ):
58
  debug = None
59
- for model_name in model_names:
60
- response = get_predictions(barcode, model_name, threshold)
61
- response = copy.deepcopy(response)
62
- if model_name != NeuralCategoryClassifierModel.keras_2_0.name and "debug" in response:
63
- if debug is None:
64
- debug = response["debug"]
65
- response.pop("debug")
66
- st.markdown(f"**{model_name}**")
67
- st.write(pd.DataFrame(response["predictions"]))
68
 
69
  if debug is not None:
70
- st.markdown("**v3 debug information**")
71
  st.write(debug)
72
 
73
 
@@ -81,17 +56,9 @@ barcode = st.sidebar.text_input(
81
  )
82
  threshold = st.sidebar.number_input("Threshold", format="%f", value=0.5) or None
83
 
84
- st.sidebar.write("---\n# Model description\n" + MODEL_DESCRIPTION)
85
- model_names = st.multiselect(
86
- "Name of the model",
87
- [x.name for x in NeuralCategoryClassifierModel],
88
- default=[x.name for x in NeuralCategoryClassifierModel],
89
- )
90
-
91
  if barcode:
92
  barcode = barcode.strip()
93
  display_predictions(
94
  barcode=barcode,
95
  threshold=threshold,
96
- model_names=model_names,
97
  )
 
1
  import copy
 
2
  import pandas as pd
3
+ from typing import Optional
4
 
5
  import requests
6
  import streamlit as st
7
 
8
 
 
 
 
 
 
 
 
 
 
 
9
  http_session = requests.Session()
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  LOCAL_DB = False
12
 
13
  if LOCAL_DB:
 
18
  PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category"
19
 
20
 
21
+ @st.cache_data
22
+ def get_predictions(barcode: str, threshold: Optional[float] = None):
23
+ data = {"barcode": barcode, "predictors": ["neural"]}
24
  if threshold is not None:
25
  data["threshold"] = threshold
26
 
 
30
 
31
  def display_predictions(
32
  barcode: str,
 
33
  threshold: Optional[float] = None,
34
  ):
35
  debug = None
36
+ response = get_predictions(barcode, threshold)
37
+ response = copy.deepcopy(response)
38
+ if "debug" in response:
39
+ if debug is None:
40
+ debug = response["debug"]
41
+ response.pop("debug")
42
+ st.write(pd.DataFrame(response["predictions"]))
 
 
43
 
44
  if debug is not None:
45
+ st.markdown("**Debug information**")
46
  st.write(debug)
47
 
48
 
 
56
  )
57
  threshold = st.sidebar.number_input("Threshold", format="%f", value=0.5) or None
58
 
 
 
 
 
 
 
 
59
  if barcode:
60
  barcode = barcode.strip()
61
  display_predictions(
62
  barcode=barcode,
63
  threshold=threshold,
 
64
  )