import copy import enum import pandas as pd from typing import List, Optional import requests import streamlit as st MODEL_DESCRIPTION = """ **keras_2_0**: Current production model\n **keras_image_embeddings_3_0**: same as `keras_300_epochs_3_0` but with image embedding as input\n **keras_300_epochs_3_0**: trained on 300 epochs with product name, ingredients, OCR-extracted ingredients and nutriments as input\n **keras_ingredient_ocr_3_0**: same as `keras_sota_3_0`, but trained on less epochs\n **keras_baseline_3_0**: model trained with product name, ingredients and nutriments as input\n **keras_original_3_0**: same inputs as **keras_2_0** (product name + ingredients), but retrained\n **keras_product_name_only_3_0**: model with only product name as input """ http_session = requests.Session() @enum.unique class NeuralCategoryClassifierModel(enum.Enum): keras_2_0 = "keras-2.0" keras_image_embeddings_3_0 = "keras-image-embeddings-3-0" keras_300_epochs_3_0 = "keras-300-epochs-3-0" keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0" keras_baseline_3_0 = "keras-baseline-3.0" keras_original_3_0 = "keras-original-3.0" keras_product_name_only_3_0 = "keras-product-name-only-3.0" LOCAL_DB = False if LOCAL_DB: ROBOTOFF_BASE_URL = "http://localhost:5500/api/v1" else: ROBOTOFF_BASE_URL = "https://robotoff.openfoodfacts.org/api/v1" PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category" @st.cache def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None): data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name} if threshold is not None: data["threshold"] = threshold r = requests.post(PREDICTION_URL, json=data) r.raise_for_status() return r.json()["neural"] def display_predictions( barcode: str, model_names: List[str], threshold: Optional[float] = None, ): debug = None for model_name in model_names: response = get_predictions(barcode, model_name, threshold) response = copy.deepcopy(response) if model_name != NeuralCategoryClassifierModel.keras_2_0.name and "debug" in response: if debug is None: debug = response["debug"] response.pop("debug") st.markdown(f"**{model_name}**") st.write(pd.DataFrame(response["predictions"])) if debug is not None: st.markdown("**v3 debug information**") st.write(debug) st.sidebar.title("Category Prediction Demo") query_params = st.experimental_get_query_params() default_barcode = query_params["barcode"][0] if "barcode" in query_params else "" barcode = st.sidebar.text_input( "Product barcode", default_barcode ) threshold = st.sidebar.number_input("Threshold", format="%f", value=0.5) or None st.sidebar.write("---\n# Model description\n" + MODEL_DESCRIPTION) model_names = st.multiselect( "Name of the model", [x.name for x in NeuralCategoryClassifierModel], default=[x.name for x in NeuralCategoryClassifierModel], ) if barcode: barcode = barcode.strip() display_predictions( barcode=barcode, threshold=threshold, model_names=model_names, )