Raphaël Bournhonesque
Add new model
a332564
raw
history blame
3.22 kB
import copy
import enum
import pandas as pd
from typing import List, Optional
import requests
import streamlit as st
MODEL_DESCRIPTION = """
**keras_2_0**: Current production model\n
**keras_image_embeddings_3_0**: same as `keras_300_epochs_3_0` but with image embedding as input\n
**keras_300_epochs_3_0**: trained on 300 epochs with product name, ingredients, OCR-extracted ingredients and nutriments as input\n
**keras_ingredient_ocr_3_0**: same as `keras_sota_3_0`, but trained on less epochs\n
**keras_baseline_3_0**: model trained with product name, ingredients and nutriments as input\n
**keras_original_3_0**: same inputs as **keras_2_0** (product name + ingredients), but retrained\n
**keras_product_name_only_3_0**: model with only product name as input
"""
http_session = requests.Session()
@enum.unique
class NeuralCategoryClassifierModel(enum.Enum):
keras_2_0 = "keras-2.0"
keras_image_embeddings_3_0 = "keras-image-embeddings-3-0"
keras_300_epochs_3_0 = "keras-300-epochs-3-0"
keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0"
keras_baseline_3_0 = "keras-baseline-3.0"
keras_original_3_0 = "keras-original-3.0"
keras_product_name_only_3_0 = "keras-product-name-only-3.0"
LOCAL_DB = False
if LOCAL_DB:
ROBOTOFF_BASE_URL = "http://localhost:5500/api/v1"
else:
ROBOTOFF_BASE_URL = "https://robotoff.openfoodfacts.org/api/v1"
PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category"
@st.cache
def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None):
data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name}
if threshold is not None:
data["threshold"] = threshold
r = requests.post(PREDICTION_URL, json=data)
r.raise_for_status()
return r.json()["neural"]
def display_predictions(
barcode: str,
model_names: List[str],
threshold: Optional[float] = None,
):
debug = None
for model_name in model_names:
response = get_predictions(barcode, model_name, threshold)
response = copy.deepcopy(response)
if model_name != NeuralCategoryClassifierModel.keras_2_0.name and "debug" in response:
if debug is None:
debug = response["debug"]
response.pop("debug")
st.markdown(f"**{model_name}**")
st.write(pd.DataFrame(response["predictions"]))
if debug is not None:
st.markdown("**v3 debug information**")
st.write(debug)
st.sidebar.title("Category Prediction Demo")
query_params = st.experimental_get_query_params()
default_barcode = query_params["barcode"][0] if "barcode" in query_params else ""
barcode = st.sidebar.text_input(
"Product barcode", default_barcode
)
threshold = st.sidebar.number_input("Threshold", format="%f", value=0.5) or None
st.sidebar.write("---\n# Model description\n" + MODEL_DESCRIPTION)
model_names = st.multiselect(
"Name of the model",
[x.name for x in NeuralCategoryClassifierModel],
default=[x.name for x in NeuralCategoryClassifierModel],
)
if barcode:
barcode = barcode.strip()
display_predictions(
barcode=barcode,
threshold=threshold,
model_names=model_names,
)