|
import os
|
|
import json
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import TensorDataset, DataLoader, random_split
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from PIL import Image
|
|
import requests
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f'Using device: {device}')
|
|
|
|
def load_config(config_file='config.json'):
|
|
"""
|
|
Loads configuration parameters from a JSON file.
|
|
Args:
|
|
config_file (str): Path to the JSON config file.
|
|
Returns:
|
|
config (dict): Dictionary containing configuration parameters.
|
|
"""
|
|
with open(config_file, 'r') as f:
|
|
return json.load(f)
|
|
|
|
def download_quickdraw_data():
|
|
"""
|
|
Downloads 'cat.npy' and 'dog.npy' files from the Quick, Draw! dataset.
|
|
"""
|
|
os.makedirs('quickdraw_data', exist_ok=True)
|
|
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
|
|
|
|
categories = ['cat', 'dog']
|
|
for category in categories:
|
|
url = f"{base_url}{category}.npy"
|
|
save_path = os.path.join('quickdraw_data', f"{category}.npy")
|
|
|
|
if os.path.exists(save_path):
|
|
print(f"{category}.npy already exists, skipping download.")
|
|
continue
|
|
|
|
print(f"Downloading {category}.npy...")
|
|
response = requests.get(url, stream=True)
|
|
if response.status_code == 200:
|
|
with open(save_path, 'wb') as f:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
f.write(chunk)
|
|
print(f"Downloaded {category}.npy")
|
|
else:
|
|
print(f"Failed to download {category}.npy. Status code: {response.status_code}")
|
|
|
|
def load_and_preprocess_data(num_samples=5000):
|
|
"""
|
|
Loads and preprocesses the data for 'cat' and 'dog' categories.
|
|
Args:
|
|
num_samples (int): Number of samples to load for each category.
|
|
Returns:
|
|
train_loader, test_loader: DataLoaders for training and testing.
|
|
"""
|
|
|
|
cat_data = np.load('quickdraw_data/cat.npy')
|
|
dog_data = np.load('quickdraw_data/dog.npy')
|
|
|
|
|
|
cat_data = cat_data[:num_samples]
|
|
dog_data = dog_data[:num_samples]
|
|
|
|
|
|
cat_labels = np.zeros(len(cat_data), dtype=np.int64)
|
|
dog_labels = np.ones(len(dog_data), dtype=np.int64)
|
|
|
|
|
|
data = np.concatenate((cat_data, dog_data), axis=0)
|
|
labels = np.concatenate((cat_labels, dog_labels), axis=0)
|
|
|
|
|
|
data = data.astype('float32') / 255.0
|
|
|
|
|
|
data = data.reshape(-1, 1, 28, 28)
|
|
|
|
|
|
data_tensor = torch.tensor(data)
|
|
labels_tensor = torch.tensor(labels)
|
|
|
|
|
|
dataset = TensorDataset(data_tensor, labels_tensor)
|
|
|
|
|
|
train_size = int(0.8 * len(dataset))
|
|
test_size = len(dataset) - train_size
|
|
|
|
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
|
|
|
|
|
config = load_config()
|
|
batch_size = config['batch_size']
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size)
|
|
|
|
return train_loader, test_loader
|
|
|
|
class SimpleCNN(nn.Module):
|
|
"""
|
|
Defines a simple Convolutional Neural Network for binary classification.
|
|
"""
|
|
def __init__(self):
|
|
super(SimpleCNN, self).__init__()
|
|
|
|
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
|
|
|
self.fc1 = nn.Linear(64 * 7 * 7, 128)
|
|
self.fc2 = nn.Linear(128, 2)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = self.pool(x)
|
|
x = F.relu(self.conv2(x))
|
|
x = self.pool(x)
|
|
x = x.view(-1, 64 * 7 * 7)
|
|
x = F.relu(self.fc1(x))
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
def train_model(model, train_loader, num_epochs=5, learning_rate=0.001):
|
|
"""
|
|
Trains the model using the training DataLoader.
|
|
Args:
|
|
model: The neural network model to train.
|
|
train_loader: DataLoader for the training data.
|
|
num_epochs (int): Number of epochs to train.
|
|
learning_rate (float): Learning rate for the optimizer.
|
|
"""
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
|
|
model.train()
|
|
for epoch in range(num_epochs):
|
|
running_loss = 0.0
|
|
|
|
for images, labels in train_loader:
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
outputs = model(images)
|
|
loss = criterion(outputs, labels)
|
|
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * images.size(0)
|
|
|
|
epoch_loss = running_loss / len(train_loader.dataset)
|
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
|
|
|
|
def evaluate_model(model, test_loader):
|
|
"""
|
|
Evaluates the model on the test DataLoader.
|
|
Args:
|
|
model: The trained neural network model.
|
|
test_loader: DataLoader for the test data.
|
|
"""
|
|
model.eval()
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
for images, labels in test_loader:
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
|
|
outputs = model(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
accuracy = 100 * correct / total
|
|
print(f'Test Accuracy: {accuracy:.2f}%')
|
|
|
|
def save_model(model, filepath='cat_dog_classifier.pth'):
|
|
"""
|
|
Saves the trained model to a file.
|
|
Args:
|
|
model: The trained neural network model.
|
|
filepath (str): The path where the model will be saved.
|
|
"""
|
|
torch.save(model.state_dict(), filepath)
|
|
print(f'Model saved to {filepath}')
|
|
|
|
def load_model(model, filepath='cat_dog_classifier.pth'):
|
|
"""
|
|
Loads the model parameters from a file.
|
|
Args:
|
|
model: The neural network model to load parameters into.
|
|
filepath (str): The path to the saved model file.
|
|
"""
|
|
model.load_state_dict(torch.load(filepath, map_location=device))
|
|
model.to(device)
|
|
print(f'Model loaded from {filepath}')
|
|
|
|
def predict_image(model, image):
|
|
"""
|
|
Predicts the class of a single image.
|
|
Args:
|
|
model: The trained neural network model.
|
|
image: A PIL Image or NumPy array.
|
|
Returns:
|
|
prediction (str): The predicted class label ('cat' or 'dog').
|
|
"""
|
|
|
|
if isinstance(image, Image.Image):
|
|
image = image.resize((28, 28)).convert('L')
|
|
image = np.array(image).astype('float32') / 255.0
|
|
elif isinstance(image, np.ndarray):
|
|
if image.shape != (28, 28):
|
|
image = Image.fromarray(image).resize((28, 28)).convert('L')
|
|
image = np.array(image).astype('float32') / 255.0
|
|
else:
|
|
raise ValueError("Image must be a PIL Image or NumPy array.")
|
|
|
|
image = image.reshape(1, 1, 28, 28)
|
|
image_tensor = torch.tensor(image).to(device)
|
|
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
output = model(image_tensor)
|
|
_, predicted = torch.max(output.data, 1)
|
|
return 'cat' if predicted.item() == 0 else 'dog'
|
|
|
|
def visualize_predictions(model, test_loader, num_images=8):
|
|
"""
|
|
Visualizes sample predictions from the test set.
|
|
Args:
|
|
model: The trained neural network model.
|
|
test_loader: DataLoader for the test data.
|
|
num_images (int): Number of images to display.
|
|
"""
|
|
model.eval()
|
|
dataiter = iter(test_loader)
|
|
images, labels = next(dataiter)
|
|
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
|
|
|
|
outputs = model(images)
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
|
|
images = images.cpu().numpy()
|
|
predicted = predicted.cpu().numpy()
|
|
labels = labels.cpu().numpy()
|
|
|
|
|
|
fig = plt.figure(figsize=(10, 4))
|
|
for idx in range(num_images):
|
|
ax = fig.add_subplot(2, num_images // 2, idx+1)
|
|
img = images[idx][0]
|
|
ax.imshow(img, cmap='gray')
|
|
pred_label = 'cat' if predicted[idx] == 0 else 'dog'
|
|
true_label = 'cat' if labels[idx] == 0 else 'dog'
|
|
ax.set_title(f'Pred: {pred_label}\nTrue: {true_label}')
|
|
ax.axis('off')
|
|
plt.tight_layout()
|
|
plt.savefig('sample_predictions.png')
|
|
print('Sample predictions saved to sample_predictions.png')
|
|
|
|
def main():
|
|
|
|
config = load_config()
|
|
|
|
|
|
download_quickdraw_data()
|
|
|
|
|
|
train_loader, test_loader = load_and_preprocess_data(num_samples=config['num_samples'])
|
|
|
|
|
|
model = SimpleCNN().to(device)
|
|
|
|
|
|
train_model(model, train_loader, num_epochs=config['num_epochs'], learning_rate=config['learning_rate'])
|
|
|
|
|
|
evaluate_model(model, test_loader)
|
|
|
|
|
|
visualize_predictions(model, test_loader, num_images=8)
|
|
|
|
|
|
save_model(model, config['model_save_path'])
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|