import math import numpy as np import gradio as gr from datasets import load_dataset from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split LOGS_DATASET_URI = 'pgurazada1/machine-failure-mlops-demo-logs' # Load and cache training data dataset = fetch_openml(data_id=42890, as_frame=True, parser="auto") data_df = dataset.data target = 'Machine failure' numeric_features = [ 'Air temperature [K]', 'Process temperature [K]', 'Rotational speed [rpm]', 'Torque [Nm]', 'Tool wear [min]' ] categorical_features = ['Type'] X = data_df[numeric_features + categorical_features] y = data_df[target] Xtrain, Xtest, ytrain, ytest = train_test_split( X, y, test_size=0.2, random_state=42 ) # Access latest logs & return a sample def extract_log_sample(): prediction_logs = load_dataset(LOGS_DATASET_URI) prediction_logs_df = prediction_logs['train'].to_pandas() sample_df = prediction_logs_df.sample(frac=0.3, random_state=42) return sample_df def psi(actual_proportions, expected_proportions): psi_values = (actual_proportions - expected_proportions) * \ np.log(actual_proportions / expected_proportions) return sum(psi_values) # Model Drift def check_model_drift(p_pos_label_sample_logs): """ Check PSI. If PSI is more than 0.1, flag model drift. """ live_proportions = np.array([1-p_pos_label_sample_logs, p_pos_label_sample_logs]) training_proportions = ytrain.value_counts(normalize=True).values psi_value = psi(live_proportions, training_proportions) if psi_value > 0.1: return f"Model Drift Detected! Check Logs!(proportion of positive labels in training data = {training_proportions[1]})" else: return f"No Model Drift (proportion of positive labels in training data = {training_proportions[1]})" def check_data_drift(feature): """ Compare training data features and live features. If the deviation is more than 2 standard deviations, flag data drift. Numeric features and catagorical features are dealt with separately. """ sample_df = extract_log_sample() numeric_features = [ 'Air temperature [K]', 'Process temperature [K]', 'Rotational speed [rpm]', 'Torque [Nm]', 'Tool wear [min]' ] categorical_features = ['Type'] if feature in numeric_features: mean_feature_training_data = Xtrain[feature].mean() std_feature_training_data = Xtrain[feature].std() mean_feature_sample_logs = sample_df[feature].mean() mean_diff = abs(mean_feature_training_data - mean_feature_sample_logs) if mean_diff > 2 * std_feature_training_data: return "Data Drift Detected! Check Logs!" else: return "No Data Drift!" else: live_proportions = sample_df[feature].value_counts(normalize=True).values training_proportions = Xtrain[feature].value_counts(normalize=True).values psi_value = psi(live_proportions, training_proportions) if psi_value > 0.1: return "Data Drift Detected! Check Logs!" else: return "No Data Drift!" with gr.Blocks() as demo: gr.Markdown("# Model Drift Detection") gr.Markdown("*Ground-truth is not available, comparing live data with training data*") with gr.Row(): with gr.Column(): model_drift_input = gr.Number(label='Proportion of positive labels (1) in the live data sample') model_drift_check_btn = gr.Button(value="Check Model Drift") with gr.Column(): model_drift_check_output = gr.Label(label="Model Drift Status") model_drift_check_btn.click( check_model_drift, inputs=model_drift_input, outputs=model_drift_check_output, api_name="check-model-drift" ) examples = gr.Examples(examples=[0.0008, 0.035], inputs=[model_drift_input]) gr.Markdown("# Data Drift Detection") gr.Markdown("*Compare the distribution of feature in training data and live data*") with gr.Row(): with gr.Column(): data_drift_input = gr.Dropdown( choices=['Air temperature [K]', 'Process temperature [K]', 'Rotational speed [rpm]', 'Torque [Nm]', 'Tool wear [min]', 'Type' ], label='Feature' ) data_drift_check_btn = gr.Button(value="Check Data Drift") with gr.Column(): data_drift_check_output = gr.Label(label="Model Drift Status") data_drift_check_btn.click( check_data_drift, inputs=data_drift_input, outputs=data_drift_check_output, api_name="check-data-drift" ) demo.queue().launch()