Spaces:
Sleeping
Sleeping
import copy | |
import enum | |
import os | |
from typing import Optional | |
import requests | |
import streamlit as st | |
http_session = requests.Session() | |
class NeuralCategoryClassifierModel(enum.Enum): | |
keras_2_0 = "keras-2.0" | |
keras_sota_3_0 = "keras-sota-3-0" | |
keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0" | |
keras_baseline_3_0 = "keras-baseline-3.0" | |
keras_original_3_0 = "keras-original-3.0" | |
keras_product_name_only_3_0 = "keras-product-name-only-3.0" | |
LOCAL_DB = False | |
if LOCAL_DB: | |
ROBOTOFF_BASE_URL = "http://localhost:5500/api/v1" | |
else: | |
ROBOTOFF_BASE_URL = "https://robotoff.openfoodfacts.org/api/v1" | |
PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category" | |
def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None): | |
data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name} | |
if threshold is not None: | |
data["threshold"] = threshold | |
r = requests.post(PREDICTION_URL, json=data) | |
r.raise_for_status() | |
return r.json()["neural"] | |
def display_predictions( | |
barcode: str, | |
model_names: list[str], | |
threshold: Optional[float] = None, | |
): | |
debug_showed = False | |
for model_name in model_names: | |
response = get_predictions(barcode, model_name, threshold) | |
response = copy.deepcopy(response) | |
if "debug" in response: | |
if not debug_showed: | |
debug_showed = True | |
st.write(response["debug"]) | |
response.pop("debug") | |
st.write(f"** {model_name} **") | |
st.write(response) | |
st.sidebar.title("Category Prediction Demo") | |
barcode = st.sidebar.text_input( | |
"Product barcode" | |
) | |
threshold = st.sidebar.number_input("Threshold", format="%f") or None | |
model_names = st.multiselect( | |
"Name of the model", | |
[x.name for x in NeuralCategoryClassifierModel], | |
default=NeuralCategoryClassifierModel.keras_sota_3_0.name, | |
) | |
if barcode: | |
barcode = barcode.strip() | |
display_predictions( | |
barcode=barcode, | |
threshold=threshold, | |
model_names=model_names, | |
) | |