Clement Vachet
Improve code based on pylint and black suggestions
9ecca49
raw
history blame
4.49 kB
"""
Gradio web application
"""
import os
import json
import requests
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from classification.classifier import Classifier
AWS_API = None
# Initialize API URLs from env file or global settings
def retrieve_api():
"""Initialize API URLs from env file or global settings"""
env_path = find_dotenv("config_api.env")
if env_path:
load_dotenv(dotenv_path=env_path)
print("config_api.env file loaded successfully.")
else:
print("config_api.env file not found.")
# Use of AWS endpoint or local container by default
global AWS_API
AWS_API = os.getenv("AWS_API", default="http://localhost:8000")
def initialize_classifier():
"""Initialize ML classifier"""
cls = Classifier()
return cls
def predict_class_local(sepl, sepw, petl, petw):
"""ML prediction using direct source code - local"""
data = list(map(float, [sepl, sepw, petl, petw]))
cls = initialize_classifier()
results = cls.load_and_test(data)
return results
def predict_class_aws(sepl, sepw, petl, petw):
"""ML prediction using AWS API endpoint"""
if AWS_API == "http://localhost:8080":
api_endpoint = AWS_API + "/2015-03-31/functions/function/invocations"
else:
api_endpoint = AWS_API + "/test/classify"
data = list(map(float, [sepl, sepw, petl, petw]))
json_object = {"features": [data]}
response = requests.post(api_endpoint, json=json_object, timeout=60)
if response.status_code == 200:
# Process the response
response_json = response.json()
results_dict = json.loads(response_json["body"])
else:
results_dict = {"Error": response.status_code}
gr.Error(f"\t API Error: {response.status_code}")
return results_dict
def predict(sepl, sepw, petl, petw, execution_type):
"""ML prediction - local or via API endpoint"""
print("ML prediction type: ", execution_type)
results = None
if execution_type == "Local":
results = predict_class_local(sepl, sepw, petl, petw)
elif execution_type == "AWS API":
results = predict_class_aws(sepl, sepw, petl, petw)
prediction = results["predictions"][0]
confidence = max(results["probabilities"][0])
return f"Prediction: {prediction} \t - \t Confidence: {confidence:.3f}"
# Define the Gradio interface
def user_interface():
"""Gradio application"""
description = """
Aims: Categorization of different species of iris flowers (Setosa, Versicolor, and Virginica)
based on measurements of physical characteristics (sepals and petals).
Notes: This web application uses two types of machine learning predictions:
- local prediction (direct source code)
- cloud prediction via an AWS API (i.e. use of ECR, Lambda function and API Gateway)
"""
with gr.Blocks() as demo:
gr.Markdown("# IRIS classification task - use of AWS Lambda")
gr.Markdown(description)
with gr.Row():
with gr.Column():
with gr.Group():
gr_sepl = gr.Slider(
minimum=4.0, maximum=8.0, step=0.1, label="Sepal Length (in cm)"
)
gr_sepw = gr.Slider(
minimum=2.0, maximum=5.0, step=0.1, label="Sepal Width (in cm)"
)
gr_petl = gr.Slider(
minimum=1.0, maximum=7.0, step=0.1, label="Petal Length (in cm)"
)
gr_petw = gr.Slider(
minimum=0.1, maximum=2.8, step=0.1, label="Petal Width (in cm)"
)
with gr.Column():
with gr.Row():
gr_execution_type = gr.Radio(
["Local", "AWS API"], value="Local", label="Prediction type"
)
with gr.Row():
gr_output = gr.Textbox(label="Prediction output")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_button = gr.ClearButton()
submit_btn.click(
fn=predict,
inputs=[gr_sepl, gr_sepw, gr_petl, gr_petw, gr_execution_type],
outputs=[gr_output],
)
clear_button.click(lambda: None, inputs=None, outputs=[gr_output], queue=False)
demo.queue().launch(debug=True)
if __name__ == "__main__":
retrieve_api()
user_interface()