Raphaël Bournhonesque commited on
Commit
062bda3
·
1 Parent(s): a894b74

improve app

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import copy
2
  import enum
3
- import os
4
  from typing import List, Optional
5
 
6
  import requests
@@ -44,29 +44,35 @@ def display_predictions(
44
  model_names: List[str],
45
  threshold: Optional[float] = None,
46
  ):
47
- debug_showed = False
48
  for model_name in model_names:
49
  response = get_predictions(barcode, model_name, threshold)
50
  response = copy.deepcopy(response)
51
- if "debug" in response:
52
- if not debug_showed:
53
- debug_showed = True
54
- st.write(response["debug"])
55
  response.pop("debug")
56
- st.write(f"** {model_name} **")
57
- st.write(response)
 
 
 
 
58
 
59
 
60
 
61
  st.sidebar.title("Category Prediction Demo")
 
 
 
62
  barcode = st.sidebar.text_input(
63
- "Product barcode"
64
  )
65
  threshold = st.sidebar.number_input("Threshold", format="%f") or None
66
  model_names = st.multiselect(
67
  "Name of the model",
68
  [x.name for x in NeuralCategoryClassifierModel],
69
- default=NeuralCategoryClassifierModel.keras_sota_3_0.name,
70
  )
71
 
72
  if barcode:
 
1
  import copy
2
  import enum
3
+ import pandas as pd
4
  from typing import List, Optional
5
 
6
  import requests
 
44
  model_names: List[str],
45
  threshold: Optional[float] = None,
46
  ):
47
+ debug = None
48
  for model_name in model_names:
49
  response = get_predictions(barcode, model_name, threshold)
50
  response = copy.deepcopy(response)
51
+ if model_name != NeuralCategoryClassifierModel.keras_2_0.name and "debug" in response:
52
+ if debug is None:
53
+ debug = response["debug"]
 
54
  response.pop("debug")
55
+ st.markdown(f"**{model_name}**")
56
+ st.write(pd.DataFrame(response["predictions"]))
57
+
58
+ if debug is not None:
59
+ st.markdown("**v3 debug information**")
60
+ st.write(debug)
61
 
62
 
63
 
64
  st.sidebar.title("Category Prediction Demo")
65
+ query_params = st.experimental_get_query_params()
66
+
67
+ default_barcode = query_params["barcode"][0] if "barcode" in query_params else ""
68
  barcode = st.sidebar.text_input(
69
+ "Product barcode", default_barcode
70
  )
71
  threshold = st.sidebar.number_input("Threshold", format="%f") or None
72
  model_names = st.multiselect(
73
  "Name of the model",
74
  [x.name for x in NeuralCategoryClassifierModel],
75
+ default=[x.name for x in NeuralCategoryClassifierModel],
76
  )
77
 
78
  if barcode: