sadhaklal's picture
added "Usage" section to README.md
8027b61 verified
metadata
license: apache-2.0
datasets:
  - scikit-learn/iris
metrics:
  - accuracy
library_name: pytorch
pipeline_tag: tabular-classification

logistic-regression-iris

A logistic regression model trained on the Iris dataset.

It takes two inputs: 'PetalLengthCm' and 'PetalWidthCm'. It predicts whether the species is 'Iris-setosa'.

It is a PyTorch adaptation of the scikit-learn model in Chapter 10 of Aurelien Geron's book 'Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow'.

Code: https://github.com/sambitmukherjee/handson-ml3-pytorch/blob/main/chapter10/logistic_regression_iris.ipynb

Experiment tracking: https://wandb.ai/sadhaklal/logistic-regression-iris

Usage

!pip install -q datasets

from datasets import load_dataset

iris = load_dataset("scikit-learn/iris")
iris.set_format("pandas")
iris_df = iris['train'][:]
X = iris_df[['PetalLengthCm', 'PetalWidthCm']]
y = (iris_df['Species'] == "Iris-setosa").astype(int)

class_names = ["Not Iris-setosa", "Iris-setosa"]

from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(X.values, y.values, test_size=0.3, stratify=y, random_state=42)
X_means, X_stds = X_train.mean(axis=0), X_train.std(axis=0)

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

device = torch.device("cpu")

class LinearModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        out = self.fc(x)
        return out

model = LinearModel.from_pretrained("sadhaklal/logistic-regression-iris")
model.to(device)

# Inference on new data:
import numpy as np

X_new = np.array([[2.0, 0.5], [3.0, 1.0]]) # Contains data on 2 new flowers.
X_new = ((X_new - X_means) / X_stds) # Normalize.
X_new = torch.from_numpy(X_new).float()

model.eval()
X_new = X_new.to(device)
with torch.no_grad():
    logits = model(X_new)
proba = torch.sigmoid(logits.squeeze())
preds = (proba > 0.5).long()

print(f"Predicted classes: {preds}")
print(f"Predicted probabilities of being Iris-setosa: {proba}")

Metric

As shown above, the validation set contains 30% of the examples (selected at random in a stratified fashion).

Accuracy on the validation set: 1.0