pgurazada1's picture
Update app.py
4088f05 verified
import math
import numpy as np
import gradio as gr
from datasets import load_dataset
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
LOGS_DATASET_URI = 'pgurazada1/machine-failure-mlops-demo-logs'
# Load and cache training data
dataset = fetch_openml(data_id=42890, as_frame=True, parser="auto")
data_df = dataset.data
target = 'Machine failure'
numeric_features = [
'Air temperature [K]',
'Process temperature [K]',
'Rotational speed [rpm]',
'Torque [Nm]',
'Tool wear [min]'
]
categorical_features = ['Type']
X = data_df[numeric_features + categorical_features]
y = data_df[target]
Xtrain, Xtest, ytrain, ytest = train_test_split(
X, y,
test_size=0.2,
random_state=42
)
# Access latest logs & return a sample
def extract_log_sample():
prediction_logs = load_dataset(LOGS_DATASET_URI)
prediction_logs_df = prediction_logs['train'].to_pandas()
sample_df = prediction_logs_df.sample(frac=0.3, random_state=42)
return sample_df
def psi(actual_proportions, expected_proportions):
psi_values = (actual_proportions - expected_proportions) * \
np.log(actual_proportions / expected_proportions)
return sum(psi_values)
# Model Drift
def check_model_drift(p_pos_label_sample_logs):
"""
Check PSI. If PSI is more than 0.1, flag model drift.
"""
live_proportions = np.array([1-p_pos_label_sample_logs, p_pos_label_sample_logs])
training_proportions = ytrain.value_counts(normalize=True).values
psi_value = psi(live_proportions, training_proportions)
if psi_value > 0.1:
return f"Model Drift Detected! Check Logs!(proportion of positive labels in training data = {training_proportions[1]})"
else:
return f"No Model Drift (proportion of positive labels in training data = {training_proportions[1]})"
def check_data_drift(feature):
"""
Compare training data features and live features. If the deviation is
more than 2 standard deviations, flag data drift.
Numeric features and catagorical features are dealt with separately.
"""
sample_df = extract_log_sample()
numeric_features = [
'Air temperature [K]',
'Process temperature [K]',
'Rotational speed [rpm]',
'Torque [Nm]',
'Tool wear [min]'
]
categorical_features = ['Type']
if feature in numeric_features:
mean_feature_training_data = Xtrain[feature].mean()
std_feature_training_data = Xtrain[feature].std()
mean_feature_sample_logs = sample_df[feature].mean()
mean_diff = abs(mean_feature_training_data - mean_feature_sample_logs)
if mean_diff > 2 * std_feature_training_data:
return "Data Drift Detected! Check Logs!"
else:
return "No Data Drift!"
else:
live_proportions = sample_df[feature].value_counts(normalize=True).values
training_proportions = Xtrain[feature].value_counts(normalize=True).values
psi_value = psi(live_proportions, training_proportions)
if psi_value > 0.1:
return "Data Drift Detected! Check Logs!"
else:
return "No Data Drift!"
with gr.Blocks() as demo:
gr.Markdown("# Model Drift Detection")
gr.Markdown("*Ground-truth is not available, comparing live data with training data*")
with gr.Row():
with gr.Column():
model_drift_input = gr.Number(label='Proportion of positive labels (1) in the live data sample')
model_drift_check_btn = gr.Button(value="Check Model Drift")
with gr.Column():
model_drift_check_output = gr.Label(label="Model Drift Status")
model_drift_check_btn.click(
check_model_drift,
inputs=model_drift_input,
outputs=model_drift_check_output,
api_name="check-model-drift"
)
examples = gr.Examples(examples=[0.0008, 0.035],
inputs=[model_drift_input])
gr.Markdown("# Data Drift Detection")
gr.Markdown("*Compare the distribution of feature in training data and live data*")
with gr.Row():
with gr.Column():
data_drift_input = gr.Dropdown(
choices=['Air temperature [K]', 'Process temperature [K]',
'Rotational speed [rpm]', 'Torque [Nm]',
'Tool wear [min]', 'Type'
],
label='Feature'
)
data_drift_check_btn = gr.Button(value="Check Data Drift")
with gr.Column():
data_drift_check_output = gr.Label(label="Model Drift Status")
data_drift_check_btn.click(
check_data_drift,
inputs=data_drift_input,
outputs=data_drift_check_output,
api_name="check-data-drift"
)
demo.queue().launch()