Spaces:
Sleeping
Sleeping
| import copy | |
| import enum | |
| import pandas as pd | |
| from typing import List, Optional | |
| import requests | |
| import streamlit as st | |
| MODEL_DESCRIPTION = """ | |
| **keras_2_0**: Current production model\n | |
| **keras_image_embeddings_3_0**: same as `keras_300_epochs_3_0` but with image embedding as input\n | |
| **keras_300_epochs_3_0**: trained on 300 epochs with product name, ingredients, OCR-extracted ingredients and nutriments as input\n | |
| **keras_ingredient_ocr_3_0**: same as `keras_sota_3_0`, but trained on less epochs\n | |
| **keras_baseline_3_0**: model trained with product name, ingredients and nutriments as input\n | |
| **keras_original_3_0**: same inputs as **keras_2_0** (product name + ingredients), but retrained\n | |
| **keras_product_name_only_3_0**: model with only product name as input | |
| """ | |
| http_session = requests.Session() | |
| class NeuralCategoryClassifierModel(enum.Enum): | |
| keras_2_0 = "keras-2.0" | |
| keras_image_embeddings_3_0 = "keras-image-embeddings-3-0" | |
| keras_300_epochs_3_0 = "keras-300-epochs-3-0" | |
| keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0" | |
| keras_baseline_3_0 = "keras-baseline-3.0" | |
| keras_original_3_0 = "keras-original-3.0" | |
| keras_product_name_only_3_0 = "keras-product-name-only-3.0" | |
| LOCAL_DB = False | |
| if LOCAL_DB: | |
| ROBOTOFF_BASE_URL = "http://localhost:5500/api/v1" | |
| else: | |
| ROBOTOFF_BASE_URL = "https://robotoff.openfoodfacts.org/api/v1" | |
| PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category" | |
| def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None): | |
| data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name} | |
| if threshold is not None: | |
| data["threshold"] = threshold | |
| r = requests.post(PREDICTION_URL, json=data) | |
| r.raise_for_status() | |
| return r.json()["neural"] | |
| def display_predictions( | |
| barcode: str, | |
| model_names: List[str], | |
| threshold: Optional[float] = None, | |
| ): | |
| debug = None | |
| for model_name in model_names: | |
| response = get_predictions(barcode, model_name, threshold) | |
| response = copy.deepcopy(response) | |
| if model_name != NeuralCategoryClassifierModel.keras_2_0.name and "debug" in response: | |
| if debug is None: | |
| debug = response["debug"] | |
| response.pop("debug") | |
| st.markdown(f"**{model_name}**") | |
| st.write(pd.DataFrame(response["predictions"])) | |
| if debug is not None: | |
| st.markdown("**v3 debug information**") | |
| st.write(debug) | |
| st.sidebar.title("Category Prediction Demo") | |
| query_params = st.experimental_get_query_params() | |
| default_barcode = query_params["barcode"][0] if "barcode" in query_params else "" | |
| barcode = st.sidebar.text_input( | |
| "Product barcode", default_barcode | |
| ) | |
| threshold = st.sidebar.number_input("Threshold", format="%f", value=0.5) or None | |
| st.sidebar.write("---\n# Model description\n" + MODEL_DESCRIPTION) | |
| model_names = st.multiselect( | |
| "Name of the model", | |
| [x.name for x in NeuralCategoryClassifierModel], | |
| default=[x.name for x in NeuralCategoryClassifierModel], | |
| ) | |
| if barcode: | |
| barcode = barcode.strip() | |
| display_predictions( | |
| barcode=barcode, | |
| threshold=threshold, | |
| model_names=model_names, | |
| ) | |