File size: 2,668 Bytes
979f8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from turtle import title
import gradio as gr
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import PredictionErrorDisplay


def predict_diabetes(subsample, plot_type):
    X, y = load_diabetes(return_X_y=True)
    lr = LinearRegression()
    y_pred = cross_val_predict(lr, X, y, cv=10)

    fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
    if "Actual vs. Predicted" in plot_type:
        PredictionErrorDisplay.from_predictions(
            y,
            y_pred=y_pred,
            kind="actual_vs_predicted",
            subsample=subsample,
            ax=axs[0],
            random_state=0,
        )
        axs[0].set_title("Actual vs. Predicted values")
    if "Residuals vs. Predicted" in plot_type:
        PredictionErrorDisplay.from_predictions(
            y,
            y_pred=y_pred,
            kind="residual_vs_predicted",
            subsample=subsample,
            ax=axs[1],
            random_state=0,
        )
        axs[1].set_title("Residuals vs. Predicted Values")

    fig.suptitle("Plotting cross-validated predictions")
    plt.tight_layout()
    plt.close(fig)

    # Save the figure as an image
    image_path = "predictions.png"
    fig.savefig(image_path)
    return image_path


# Define the Gradio interface
inputs = [
    gr.inputs.Slider(minimum=1, maximum=100, step=1, default=100, label="Subsample"),
    gr.inputs.CheckboxGroup(["Actual vs. Predicted", "Residuals vs. Predicted"], label="Plot Types", default=["Actual vs. Predicted", "Residuals vs. Predicted"])
]
outputs = gr.outputs.Image(label="Cross-Validated Predictions", type="pil")

title = "Plotting Cross-Validated Predictions"
description="This app plots cross-validated predictions for a linear regression model trained on the diabetes dataset. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_predict.html"
examples = [
    [
        100,
        ["Actual vs. Predicted"],
        "Plotting cross-validated predictions with Actual vs. Predicted plot.",
    ],
    [
        50,
        ["Residuals vs. Predicted"],
        "Plotting cross-validated predictions with Residuals vs. Predicted plot.",
    ],
    [
        75,
        ["Actual vs. Predicted", "Residuals vs. Predicted"],
        "Plotting cross-validated predictions with both Actual vs. Predicted and Residuals vs. Predicted plots.",
    ],
]

gr.Interface(fn=predict_diabetes, title=title, description=description, examples=examples, inputs=inputs, outputs=outputs).launch()