Spaces:
Sleeping
Sleeping
import copy | |
import enum | |
import pandas as pd | |
from typing import List, Optional | |
import requests | |
import streamlit as st | |
MODEL_DESCRIPTION = """ | |
**keras_2_0**: Current production model\n | |
**keras_image_embeddings_3_0**: same as `keras_300_epochs_3_0` but with image embedding as input\n | |
**keras_300_epochs_3_0**: trained on 300 epochs with product name, ingredients, OCR-extracted ingredients and nutriments as input\n | |
**keras_ingredient_ocr_3_0**: same as `keras_sota_3_0`, but trained on less epochs\n | |
**keras_baseline_3_0**: model trained with product name, ingredients and nutriments as input\n | |
**keras_original_3_0**: same inputs as **keras_2_0** (product name + ingredients), but retrained\n | |
**keras_product_name_only_3_0**: model with only product name as input | |
""" | |
http_session = requests.Session() | |
class NeuralCategoryClassifierModel(enum.Enum): | |
keras_2_0 = "keras-2.0" | |
keras_image_embeddings_3_0 = "keras-image-embeddings-3-0" | |
keras_300_epochs_3_0 = "keras-300-epochs-3-0" | |
keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0" | |
keras_baseline_3_0 = "keras-baseline-3.0" | |
keras_original_3_0 = "keras-original-3.0" | |
keras_product_name_only_3_0 = "keras-product-name-only-3.0" | |
LOCAL_DB = False | |
if LOCAL_DB: | |
ROBOTOFF_BASE_URL = "http://localhost:5500/api/v1" | |
else: | |
ROBOTOFF_BASE_URL = "https://robotoff.openfoodfacts.org/api/v1" | |
PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category" | |
def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None): | |
data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name} | |
if threshold is not None: | |
data["threshold"] = threshold | |
r = requests.post(PREDICTION_URL, json=data) | |
r.raise_for_status() | |
return r.json()["neural"] | |
def display_predictions( | |
barcode: str, | |
model_names: List[str], | |
threshold: Optional[float] = None, | |
): | |
debug = None | |
for model_name in model_names: | |
response = get_predictions(barcode, model_name, threshold) | |
response = copy.deepcopy(response) | |
if model_name != NeuralCategoryClassifierModel.keras_2_0.name and "debug" in response: | |
if debug is None: | |
debug = response["debug"] | |
response.pop("debug") | |
st.markdown(f"**{model_name}**") | |
st.write(pd.DataFrame(response["predictions"])) | |
if debug is not None: | |
st.markdown("**v3 debug information**") | |
st.write(debug) | |
st.sidebar.title("Category Prediction Demo") | |
query_params = st.experimental_get_query_params() | |
default_barcode = query_params["barcode"][0] if "barcode" in query_params else "" | |
barcode = st.sidebar.text_input( | |
"Product barcode", default_barcode | |
) | |
threshold = st.sidebar.number_input("Threshold", format="%f", value=0.5) or None | |
st.sidebar.write("---\n# Model description\n" + MODEL_DESCRIPTION) | |
model_names = st.multiselect( | |
"Name of the model", | |
[x.name for x in NeuralCategoryClassifierModel], | |
default=[x.name for x in NeuralCategoryClassifierModel], | |
) | |
if barcode: | |
barcode = barcode.strip() | |
display_predictions( | |
barcode=barcode, | |
threshold=threshold, | |
model_names=model_names, | |
) | |